如何在 Seaborn/Python 中的 clustermap 中将标签添加到侧面颜色条
How to add labels to the side color bar in clustermap in Seaborn/Python
我已经编写了一个 Python 脚本来绘制 clustermap。
import sys
import importlib
import matplotlib.pyplot as plt
# import PRCC function
import PRCC as prcc
import QSP_analysis as qa
#%%
import numpy as np
from pyDOE2 import lhs
# Reading data
num_samples = 20
num_param = 15
num_readout = 9
header, data = qa.read_csv('3.csv')
param_names = header[1:num_param+1]
read_names = header[num_param+1:]
lhd = data[:,1:num_param+1].astype(float)
readout = data[:,num_param+1:].astype(float)
Rho, Pval, Sig, Pval_correct = prcc.partial_corr(lhd, readout, 1e-14,Type = 'Spearman', MTC='Bonferroni')
sig_txt = np.zeros((num_param, num_readout), dtype='U8')
sig_txt[Pval_correct<5e-2] = '*'
sig_txt[Pval_correct<1e-6] = '**'
sig_txt[Pval_correct<1e-9] = '***'
param_group = ["beige"]*10 + ["khaki"]*(num_param-10)
readout_group = ["#B4B4FF"]*2+ ["mediumslateblue"]*(num_readout-2)
importlib.reload(qa)
cm = qa.cluster_map(np.transpose(Pval), read_names,param_names,
(10,6), cmap="bwr",
annot=np.transpose(sig_txt),
row_colors = readout_group,
col_colors = param_group,
col_cluster=False, row_cluster=False,
show_dendrogram = [False, False])
cm.savefig('heat.png',format='png', dpi=600,bbox_inches='tight')
代码中的函数如下:
def read_csv(filename, header_line = 1, dtype = str):
with open(filename) as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='|')
# header
header = ''
for i in range(header_line):
header = next(reader)
# data
data = np.asarray(list(reader), dtype = dtype)
return header, data
def cluster_map(data, row_label, col_label, fig_size, annot = None,
show_dendrogram = [True, True], **kwarg):
df = pd.DataFrame(data=data, index = row_label, columns = col_label)
g = sns.clustermap(df, annot = annot, fmt = '',
vmin=-1, vmax=1, cbar_kws={"ticks":[-1, -.5, 0, .5, 1]}, **kwarg)
#row_order = g.dendrogram_row.reordered_ind
#col_order = g.dendrogram_col.reordered_ind
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=-55, ha = 'left')
g.ax_row_dendrogram.set_visible(show_dendrogram[0])
g.ax_col_dendrogram.set_visible(show_dendrogram[1])
g.fig.set_size_inches(*fig_size)
return g
PRCC 函数是调用MATLAB 并计算PRCC 值的脚本。主脚本读取一个包含 20 行 24 列且 headers 不同的 csv 文件。代码的输出是一个基于某些列(垂直:read_names 和 horizontal:param_names)的 clustermap。
我添加了一个颜色条来对水平和垂直轴上的变量进行分类。输出图如下。如何为这些添加标签 color-bar:水平标签(ABM" 和 "QSP")和垂直标签("端点" 和 "预处理")?
Cluster-map from the script
以下示例代码假设您使用 row_cluster=False, col_cluster=False)
调用 sns.clustermap
,因此行和列保持其原始顺序(如果重新排序,原始组将分开)。
itertools
中的 groupby
可用于 颜色列表。它们的累积和表示颜色之间的边界。平均这些位置适合放置标签。
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from itertools import groupby
def cluster_map(data, row_label, col_label, fig_size, annot=None, row_color_labels=None, col_color_labels=None,
show_dendrogram=[True, True], **kwarg):
df = pd.DataFrame(data=data, index=row_label, columns=col_label)
g = sns.clustermap(df, annot=annot, fmt='',
vmin=-1, vmax=1, cbar_kws={"ticks": [-1, -.5, 0, .5, 1]}, **kwarg)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), va='center')
if row_color_labels is not None:
row_colors = kwarg['row_colors']
borders = np.cumsum([0] + [sum(1 for i in g) for k, g in groupby(row_colors)])
for b0, b1, label in zip(borders[:-1], borders[1:], row_color_labels):
g.ax_row_colors.text(-0.06, (b0 + b1) / 2, label, color='black', ha='right', va='center', rotation=90,
transform=g.ax_row_colors.get_yaxis_transform())
if col_color_labels is not None:
col_colors = kwarg['col_colors']
borders = np.cumsum([0] + [sum(1 for i in g) for k, g in groupby(col_colors)])
for b0, b1, label in zip(borders[:-1], borders[1:], col_color_labels):
g.ax_col_colors.text((b0 + b1) / 2, 1.06, label, color='black', ha='center', va='bottom',
transform=g.ax_col_colors.get_xaxis_transform())
cluster_map(np.random.uniform(-1, 1, size=(7, 12)),
fig_size=(12, 12),
col_label=[*'ABCDEFGHIJKL'],
row_label=['Alkaid', 'Mizar', 'Alioth', 'Megrez', 'Phecda', 'Merak', 'Dubhe'],
col_colors=["beige"] * 10 + ["khaki"] * (12 - 10),
row_colors=["#B4B4FF"] * 2 + ["mediumslateblue"] * (7 - 2),
col_color_labels=["ABM", "QSP"],
row_color_labels=["endpoint", "pretreatment"],
row_cluster=False,
col_cluster=False)
我已经编写了一个 Python 脚本来绘制 clustermap。
import sys
import importlib
import matplotlib.pyplot as plt
# import PRCC function
import PRCC as prcc
import QSP_analysis as qa
#%%
import numpy as np
from pyDOE2 import lhs
# Reading data
num_samples = 20
num_param = 15
num_readout = 9
header, data = qa.read_csv('3.csv')
param_names = header[1:num_param+1]
read_names = header[num_param+1:]
lhd = data[:,1:num_param+1].astype(float)
readout = data[:,num_param+1:].astype(float)
Rho, Pval, Sig, Pval_correct = prcc.partial_corr(lhd, readout, 1e-14,Type = 'Spearman', MTC='Bonferroni')
sig_txt = np.zeros((num_param, num_readout), dtype='U8')
sig_txt[Pval_correct<5e-2] = '*'
sig_txt[Pval_correct<1e-6] = '**'
sig_txt[Pval_correct<1e-9] = '***'
param_group = ["beige"]*10 + ["khaki"]*(num_param-10)
readout_group = ["#B4B4FF"]*2+ ["mediumslateblue"]*(num_readout-2)
importlib.reload(qa)
cm = qa.cluster_map(np.transpose(Pval), read_names,param_names,
(10,6), cmap="bwr",
annot=np.transpose(sig_txt),
row_colors = readout_group,
col_colors = param_group,
col_cluster=False, row_cluster=False,
show_dendrogram = [False, False])
cm.savefig('heat.png',format='png', dpi=600,bbox_inches='tight')
代码中的函数如下:
def read_csv(filename, header_line = 1, dtype = str):
with open(filename) as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='|')
# header
header = ''
for i in range(header_line):
header = next(reader)
# data
data = np.asarray(list(reader), dtype = dtype)
return header, data
def cluster_map(data, row_label, col_label, fig_size, annot = None,
show_dendrogram = [True, True], **kwarg):
df = pd.DataFrame(data=data, index = row_label, columns = col_label)
g = sns.clustermap(df, annot = annot, fmt = '',
vmin=-1, vmax=1, cbar_kws={"ticks":[-1, -.5, 0, .5, 1]}, **kwarg)
#row_order = g.dendrogram_row.reordered_ind
#col_order = g.dendrogram_col.reordered_ind
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=-55, ha = 'left')
g.ax_row_dendrogram.set_visible(show_dendrogram[0])
g.ax_col_dendrogram.set_visible(show_dendrogram[1])
g.fig.set_size_inches(*fig_size)
return g
PRCC 函数是调用MATLAB 并计算PRCC 值的脚本。主脚本读取一个包含 20 行 24 列且 headers 不同的 csv 文件。代码的输出是一个基于某些列(垂直:read_names 和 horizontal:param_names)的 clustermap。
我添加了一个颜色条来对水平和垂直轴上的变量进行分类。输出图如下。如何为这些添加标签 color-bar:水平标签(ABM" 和 "QSP")和垂直标签("端点" 和 "预处理")?
Cluster-map from the script
以下示例代码假设您使用 row_cluster=False, col_cluster=False)
调用 sns.clustermap
,因此行和列保持其原始顺序(如果重新排序,原始组将分开)。
itertools
中的 groupby
可用于
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from itertools import groupby
def cluster_map(data, row_label, col_label, fig_size, annot=None, row_color_labels=None, col_color_labels=None,
show_dendrogram=[True, True], **kwarg):
df = pd.DataFrame(data=data, index=row_label, columns=col_label)
g = sns.clustermap(df, annot=annot, fmt='',
vmin=-1, vmax=1, cbar_kws={"ticks": [-1, -.5, 0, .5, 1]}, **kwarg)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), va='center')
if row_color_labels is not None:
row_colors = kwarg['row_colors']
borders = np.cumsum([0] + [sum(1 for i in g) for k, g in groupby(row_colors)])
for b0, b1, label in zip(borders[:-1], borders[1:], row_color_labels):
g.ax_row_colors.text(-0.06, (b0 + b1) / 2, label, color='black', ha='right', va='center', rotation=90,
transform=g.ax_row_colors.get_yaxis_transform())
if col_color_labels is not None:
col_colors = kwarg['col_colors']
borders = np.cumsum([0] + [sum(1 for i in g) for k, g in groupby(col_colors)])
for b0, b1, label in zip(borders[:-1], borders[1:], col_color_labels):
g.ax_col_colors.text((b0 + b1) / 2, 1.06, label, color='black', ha='center', va='bottom',
transform=g.ax_col_colors.get_xaxis_transform())
cluster_map(np.random.uniform(-1, 1, size=(7, 12)),
fig_size=(12, 12),
col_label=[*'ABCDEFGHIJKL'],
row_label=['Alkaid', 'Mizar', 'Alioth', 'Megrez', 'Phecda', 'Merak', 'Dubhe'],
col_colors=["beige"] * 10 + ["khaki"] * (12 - 10),
row_colors=["#B4B4FF"] * 2 + ["mediumslateblue"] * (7 - 2),
col_color_labels=["ABM", "QSP"],
row_color_labels=["endpoint", "pretreatment"],
row_cluster=False,
col_cluster=False)