在 Python 中并排绘制线性图和对数图。类似于 R 中的 mfrow=c(2,1)

Plot linear plot and log plot next to each other in Python. Similar to mfrow=c(2,1) in R

我试图在 python 中绘制两个相邻的图,一个是实验的线性结果,一个是对数变换。目标是将图彼此相邻放置,类似于 par(mfrow=c(1,2)) in R .

if __name__ == '__main__':
  c_1 = run_experiment(1.0, 2.0, 3.0, 0.1, 100000) # run_experiment(m1, m2, m3, eps, N):
  c_05 = run_experiment(1.0, 2.0, 3.0, 0.05, 100000)
  c_01 = run_experiment(1.0, 2.0, 3.0, 0.01, 100000)

  # log scale plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.xscale('log')
  plt.title(label="Log Multi-Arm Bandit")
  plt.show()


  # linear plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.show()

我尝试了很多方法,但似乎总是出错。有人可以用代码实现这个。我在 Python 方面相对较新,并且大多具有 R 方面的经验,但任何帮助对我来说都意义重大。下面我将提供一些我尝试进行的更改的代码。

  # log scale plot
  fig, axes = plt.subplots(122)
  ax1, ax2 = axes[0], axes[1]
  ax1.plot(c_1, label='eps = 0.10')
  ax1.plot(c_05,label='eps = 0.05')
  ax1.plot(c_01,label='eps = 0.01')
  ax1.legend()
  ax1.xscale('log')
  #plt.title(label="Log Multi-Arm Bandit")
  #plt.show()

  # linear plot
  ax2.plot(c_1, label='eps = 0.10')
  ax2.plot(c_05,label='eps = 0.05')
  ax2.plot(c_01, label='eps = 0.01')
  ax2.legend()
  plt.show()

但是我收到一个错误。

我用于 运行 实验的完整代码从头到尾在下面。这是强化学习的多臂强盗问题。

# Premable
from __future__ import print_function, division
from builtins import range
import numpy as np
import matplotlib.pyplot as plt


class Bandit:
  def __init__(self, m):  # m is the true mean
    self.m = m
    self.mean = 0
    self.N = 0

  def pull(self): # simulated pulling bandits arm
    return np.random.randn() + self.m

  def update(self, x):
    self.N += 1
    # look at the derivation above of the mean
    self.mean = (1 - 1.0/self.N)*self.mean + 1.0/self.N*x  


def run_experiment(m1, m2, m3, eps, N):
  bandits = [Bandit(m1), Bandit(m2), Bandit(m3)]

  data = np.empty(N)

  for i in range(N): # Implement epsilon greedy shown above
    # epsilon greedy
    p = np.random.random()
    if p < eps:
      j = np.random.choice(3) # Explore
    else:
      j = np.argmax([b.mean for b in bandits]) # Exploit
    x = bandits[j].pull()  # Pull and update
    bandits[j].update(x)

    # Results for the plot
    data[i] = x  # Store the results in an array called data of size N
    # Calculate cumulative average
  cumulative_average = np.cumsum(data) / (np.arange(N) + 1)


  # plot moving average ctr
  plt.plot(cumulative_average) # plot cumulative average
  # Plot bars with each of the means so we can see where are 
  # cumulative averages relative to means
  plt.plot(np.ones(N)*m1) 
  plt.title('Slot Machine ')
  plt.plot(np.ones(N)*m2)
  plt.title('Slot Machine ')
  plt.plot(np.ones(N)*m3)
  plt.title('Slot Machine ')
  # We do this on a log scale so that you can see the 
  # fluctuations in earlier rounds more clearly
  plt.xscale('log') 
  plt.show()

  for b in bandits:
    print(b.mean)

  return cumulative_average

if __name__ == '__main__':
  c_1 = run_experiment(1.0, 2.0, 3.0, 0.1, 100000) # run_experiment(m1, m2, m3, eps, N):
  c_05 = run_experiment(1.0, 2.0, 3.0, 0.05, 100000)
  c_01 = run_experiment(1.0, 2.0, 3.0, 0.01, 100000)

  # log scale plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.xscale('log')
  plt.title(label="Log Multi-Arm Bandit")
  plt.show()


  # linear plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.show()

作为@JohanC pointed out in the comments, you are confusing the syntax of plt.subplots() and plt.subplot()。行

fig, axes = plt.subplots(122)

在单列中创建 122 个子图。这应该是

fig, axes = plt.subplots(1, 2)