seaborn 中具有类别值的热图
Heatmap with category values in seaborn
我有以下数据框。
ID Cat V1 V2 V3
1 A 1 1 1
2 B 1 1 1
3 A 1 1 0
4 C 0 0 0
我想创建一个图(类似于热图)来显示是否观察到 V1 到 V3 (1) 或未观察到 (0)。
此外,每个字段都应根据行的类别进行着色。
例如Cat
为A,则为红色;如果Cat
是B,它应该是绿色的;如果Cat
是C,它应该是蓝色的。
因此,在这种情况下,热图第一行中的所有方块应为红色,第二行中的所有方块应为绿色。
我想在 python 中使用 seaborn 或 matplotlib 来创建绘图。
不过不知道剧情是什么类型
可能存在较少涉及的方式。以下方法遍历类别;对于每个类别,热图中都填充了相应的颜色。
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
import pandas as pd
import numpy as np
from io import StringIO
data_str = '''ID Cat V1 V2 V3
1 A 1 1 1
2 B 1 1 1
3 A 1 1 0
4 C 0 0 0'''
df = pd.read_csv(StringIO(data_str), delim_whitespace=True)
df = df.set_index('ID')
fig, ax = plt.subplots(figsize=(6, 6))
categories = ['A', 'B', 'C']
colors = ['crimson', 'lime', 'dodgerblue']
for cat, color in zip(categories, colors):
df_cat = df[['V1', 'V2', 'V3']][df['Cat'] == cat].reindex(df.index, fill_value=0)
df_cat[['V1', 'V2', 'V3']] = df_cat[['V1', 'V2', 'V3']].replace({0: np.nan})
if not np.all(df_cat.isna()):
sns.heatmap(data=df_cat, cmap=ListedColormap([color]), cbar=False, lw=1, alpha=1, ax=ax)
plt.show()
我有以下数据框。
ID Cat V1 V2 V3
1 A 1 1 1
2 B 1 1 1
3 A 1 1 0
4 C 0 0 0
我想创建一个图(类似于热图)来显示是否观察到 V1 到 V3 (1) 或未观察到 (0)。
此外,每个字段都应根据行的类别进行着色。
例如Cat
为A,则为红色;如果Cat
是B,它应该是绿色的;如果Cat
是C,它应该是蓝色的。
因此,在这种情况下,热图第一行中的所有方块应为红色,第二行中的所有方块应为绿色。
我想在 python 中使用 seaborn 或 matplotlib 来创建绘图。 不过不知道剧情是什么类型
可能存在较少涉及的方式。以下方法遍历类别;对于每个类别,热图中都填充了相应的颜色。
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
import pandas as pd
import numpy as np
from io import StringIO
data_str = '''ID Cat V1 V2 V3
1 A 1 1 1
2 B 1 1 1
3 A 1 1 0
4 C 0 0 0'''
df = pd.read_csv(StringIO(data_str), delim_whitespace=True)
df = df.set_index('ID')
fig, ax = plt.subplots(figsize=(6, 6))
categories = ['A', 'B', 'C']
colors = ['crimson', 'lime', 'dodgerblue']
for cat, color in zip(categories, colors):
df_cat = df[['V1', 'V2', 'V3']][df['Cat'] == cat].reindex(df.index, fill_value=0)
df_cat[['V1', 'V2', 'V3']] = df_cat[['V1', 'V2', 'V3']].replace({0: np.nan})
if not np.all(df_cat.isna()):
sns.heatmap(data=df_cat, cmap=ListedColormap([color]), cbar=False, lw=1, alpha=1, ax=ax)
plt.show()