Pandas Seaborn 热图错误

Pandas Seaborn Heatmap Error

我有一个 DataFrame 在未堆叠时看起来像这样。

Start Date  2016-07-11  2016-07-12  2016-07-13
Period
0             1.000000    1.000000         1.0
1             0.684211    0.738095         NaN
2             0.592105         NaN         NaN

我试图在 Seaborn 中将其绘制为热图,但它给了我意想不到的结果。

这是我的代码:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
df = pd.DataFrame(np.array(data), columns=['Start Date', 'Period', 'Users'])
df = df.fillna(0)
df = df.set_index(['Start Date', 'Period'])
sizes = df['Users'].groupby(level=0).first()
df = df['Users'].unstack(0).divide(sizes, axis=1)
plt.title("Test")
sns.heatmap(df.T, mask=df.T.isnull(), annot=True, fmt='.0%')
plt.tight_layout()
plt.savefig(table._v_name + "fig.png")

我想要它,这样文本就不会重叠,并且旁边没有 6 个热图例。另外,如果可能,我该如何修正日期,使其只显示 %Y-%m-%d?

虽然无法获得准确的可重现数据,但请考虑使用以下已发布的代码段数据。此示例运行 pivot_table() 以实现使用 StartDates 跨列发布的结构。总的来说,由于 unstack() 处理,你的热图可能会输出多个颜色条和重叠的图形,你似乎在按用户划分(查看 seaborn.FacetGrid 进行拆分)。所以下面通过热图按原样运行 df。此外,apply() 根据指定需要重新格式化日期时间。

from io import StringIO
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

data = '''Period,StartDate,Value
       0,2016-07-11,1.000000
       0,2016-07-12,1.000000
       0,2016-07-13,1.0
       1,2016-07-11,0.684211
       1,2016-07-12,0.738095
       1,2016-07-13
       2,2016-07-11,0.592105
       2,2016-07-12
       2,2016-07-13'''

df = pd.read_csv(StringIO(data))
df['StartDate'] = pd.to_datetime(df['StartDate'])
df['StartDate'] = df['StartDate'].apply(lambda x: x.strftime('%Y-%m-%d'))

pvtdf = df.pivot_table(values='Value', index=['Period'],
                       columns='StartDate', aggfunc=sum)
print(pvtdf)
# StartDate  2016-07-11  2016-07-12  2016-07-13
# Period                                       
# 0            1.000000    1.000000         1.0
# 1            0.684211    0.738095         NaN
# 2            0.592105         NaN         NaN

sns.set()    
plt.title("Test")
ax = sns.heatmap(pvtdf.T, mask=pvtdf.T.isnull(), annot=True, fmt='.0%')
plt.tight_layout()
plt.show()