matplotlib 热图 - 绘制奇数指数
matplotlib heatmap - plot odd indices
我想创建一个热图来绘制具有不同内核大小的卷积神经网络的分类精度。在我的 CNN 特定实现中,我只使用奇数大小的过滤器,但绘制的热图将绘图元素放置在奇数和偶数位置。忽略偶数位置并仅在奇数位置绘制元素的正确方法是什么?
我首先定义我感兴趣的内核宽度和高度:
search_height = range(1, 5, 2) # [1,3]
search_width = range(1, 5, 2) # [1,3]
预分配一个数组来保存精度值。我认为这可能是问题的一部分,因为它在偶数索引中存储值?
grid_accuracy = np.empty((len(search_height), len(search_width))) # 2x2 array
然后我得到不同内核大小的网络精度并将它们存储在数组中:
for i, h in enumerate(search_height):
for j, w in enumerate(search_width):
cur_test_acc = main(batch_size=200, num_epochs=100, k_height=h, k_width=w)
grid_accuracy[i, j] = cur_test_acc
最后我用存储的精度值绘制了一个热图:
plt.imshow(grid_accuracy, cmap = plt.cm.hot, interpolation='none')
plt.grid(True)
plt.xlabel('Kernel Width')
plt.ylabel('Kernel Height')
plt.xticks(search_width)
plt.yticks(search_height)
问题是我最终得到的情节如下所示:
目前好像是使用grid_accuracy
的水平和垂直索引作为元素的位置。
我真正想要的只是一个 2x2 的网格,其中每个单元格的值是相应内核的精度 width/height。轴刻度应该是我手动定义的高度和宽度(理想情况下,水平轴刻度在图上方):
您需要设置刻度 标签 而不仅仅是它们的位置:
import numpy as np
from matplotlib import pyplot as plt
plt.imshow(np.random.randn(2, 2), interpolation='none')
plt.xticks([0, 1], [1, 3])
plt.yticks([0, 1], [1, 3])
您可以使用 plt.tick_params
获取您要查找的格式:
plt.tick_params(length=0, labelbottom=False, labeltop=True)
我想创建一个热图来绘制具有不同内核大小的卷积神经网络的分类精度。在我的 CNN 特定实现中,我只使用奇数大小的过滤器,但绘制的热图将绘图元素放置在奇数和偶数位置。忽略偶数位置并仅在奇数位置绘制元素的正确方法是什么?
我首先定义我感兴趣的内核宽度和高度:
search_height = range(1, 5, 2) # [1,3]
search_width = range(1, 5, 2) # [1,3]
预分配一个数组来保存精度值。我认为这可能是问题的一部分,因为它在偶数索引中存储值?
grid_accuracy = np.empty((len(search_height), len(search_width))) # 2x2 array
然后我得到不同内核大小的网络精度并将它们存储在数组中:
for i, h in enumerate(search_height):
for j, w in enumerate(search_width):
cur_test_acc = main(batch_size=200, num_epochs=100, k_height=h, k_width=w)
grid_accuracy[i, j] = cur_test_acc
最后我用存储的精度值绘制了一个热图:
plt.imshow(grid_accuracy, cmap = plt.cm.hot, interpolation='none')
plt.grid(True)
plt.xlabel('Kernel Width')
plt.ylabel('Kernel Height')
plt.xticks(search_width)
plt.yticks(search_height)
问题是我最终得到的情节如下所示:
目前好像是使用grid_accuracy
的水平和垂直索引作为元素的位置。
我真正想要的只是一个 2x2 的网格,其中每个单元格的值是相应内核的精度 width/height。轴刻度应该是我手动定义的高度和宽度(理想情况下,水平轴刻度在图上方):
您需要设置刻度 标签 而不仅仅是它们的位置:
import numpy as np
from matplotlib import pyplot as plt
plt.imshow(np.random.randn(2, 2), interpolation='none')
plt.xticks([0, 1], [1, 3])
plt.yticks([0, 1], [1, 3])
您可以使用 plt.tick_params
获取您要查找的格式:
plt.tick_params(length=0, labelbottom=False, labeltop=True)