向动画散点图添加图例
adding a legend to animated scatterplot
我正在制作一个带有散点图的动画,显示多个组随时间变化的数据。
当我想添加图例时,我能得到的最好的只能显示一组。
示例数据集:
import pandas as pd
df = pd.DataFrame([
[1, 'a', 0.39, 0.73],
[1, 'b', 0.87, 0.94],
[1, 'c', 0.87, 0.23],
[2, 'a', 0.17, 0.37],
[2, 'b', 0.03, 0.12],
[2, 'c', 0.86, 0.22],
[3, 'a', 0.01, 0.15],
[3, 'b', 0.03, 0.1],
[3, 'c', 0.29, 0.19],
columns=['period', 'group', 'x', 'y']
)
我这样制作动画:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fig, ax = plt.subplots()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
colors = {
'a': 'r',
'b': 'b',
'c': 'g'
}
scat = ax.scatter([], [],
c=df['group'].map(colors),
)
def init():
scat.set_offsets([])
return scat,
def update(period):
scat.set_offsets(df[df['period'] == period][['x', 'y']])
scat.set_label(df[df['period'] == period]['group'])
ax.legend([scat], df['group'].unique().tolist(), loc=1)
ax.set_title(period)
return scat,
ani = animation.FuncAnimation(fig, update, init_func=init,
frames=[1,2,3,4,5],
interval=500,
repeat=True)
plt.show()
传说中只出现了a组。
如果我只输入 ax.legend(loc=1)
,它会显示如下内容:
6 a
7 b
8 c
Name: group, dtype:object
每一帧的数字都在变化。
我已经检查了这些答案:
:让我到现在的位置。
:我在 legend.remove()
上得到 UnboundLocalError: local variable 'legend' referenced before assignment
:只显示a组。
我找到了解决方案。
我需要为每组创建一个散点图。然后我更新 update()
方法中的每个散点图。
这是我的最终代码:
fig, ax = plt.subplots()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
colors = {
'a': 'r',
'b': 'b',
'c': 'g'
}
scats = []
groups = df.groupby('group')
for name, grp in groups:
scat = ax.scatter([], [],
color=colors[name],
label=name)
scats.append(scat)
ax.legend(loc=4)
def init():
for scat in scats:
scat.set_offsets([])
return scats,
def update(period):
for scat, (name, data) in zip(scats, groups):
sample = data[data['period'] == period][['x', 'y']]
scat.set_offsets(sample)
return scats,
ani = animation.FuncAnimation(fig, update, init_func=init
frames=[1, 2, 3, 4, 5],
interval=500,
repeat=True)
plt.show()
我正在制作一个带有散点图的动画,显示多个组随时间变化的数据。
当我想添加图例时,我能得到的最好的只能显示一组。
示例数据集:
import pandas as pd
df = pd.DataFrame([
[1, 'a', 0.39, 0.73],
[1, 'b', 0.87, 0.94],
[1, 'c', 0.87, 0.23],
[2, 'a', 0.17, 0.37],
[2, 'b', 0.03, 0.12],
[2, 'c', 0.86, 0.22],
[3, 'a', 0.01, 0.15],
[3, 'b', 0.03, 0.1],
[3, 'c', 0.29, 0.19],
columns=['period', 'group', 'x', 'y']
)
我这样制作动画:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fig, ax = plt.subplots()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
colors = {
'a': 'r',
'b': 'b',
'c': 'g'
}
scat = ax.scatter([], [],
c=df['group'].map(colors),
)
def init():
scat.set_offsets([])
return scat,
def update(period):
scat.set_offsets(df[df['period'] == period][['x', 'y']])
scat.set_label(df[df['period'] == period]['group'])
ax.legend([scat], df['group'].unique().tolist(), loc=1)
ax.set_title(period)
return scat,
ani = animation.FuncAnimation(fig, update, init_func=init,
frames=[1,2,3,4,5],
interval=500,
repeat=True)
plt.show()
传说中只出现了a组。
如果我只输入 ax.legend(loc=1)
,它会显示如下内容:
6 a
7 b
8 c
Name: group, dtype:object
每一帧的数字都在变化。
我已经检查了这些答案:
legend.remove()
上得到 UnboundLocalError: local variable 'legend' referenced before assignment
我找到了解决方案。
我需要为每组创建一个散点图。然后我更新 update()
方法中的每个散点图。
这是我的最终代码:
fig, ax = plt.subplots()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
colors = {
'a': 'r',
'b': 'b',
'c': 'g'
}
scats = []
groups = df.groupby('group')
for name, grp in groups:
scat = ax.scatter([], [],
color=colors[name],
label=name)
scats.append(scat)
ax.legend(loc=4)
def init():
for scat in scats:
scat.set_offsets([])
return scats,
def update(period):
for scat, (name, data) in zip(scats, groups):
sample = data[data['period'] == period][['x', 'y']]
scat.set_offsets(sample)
return scats,
ani = animation.FuncAnimation(fig, update, init_func=init
frames=[1, 2, 3, 4, 5],
interval=500,
repeat=True)
plt.show()