在 python 中使用 seaborn.PairGrid 生成相关矩阵时,对角线上直方图的标题

Titles for histograms on diagonal when using seaborn.PairGrid in python for generating a correlation matrix

在花一些时间确定 PairGrid 函数的工作原理后,我就快完成了。 下面是生成我想要的图的代码,但 histfunc 中缺少一个小细节。我想要的是绘制在对角线上的直方图的标题。如何将数据框列名传递给 histfunc?任何想法表示赞赏。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 4}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        # scatterplot with spline of deg=5 in red
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        #  histogram
        plt.hist(x,bins=30,color = "black", ec="white")    
        """
        vvvvvvvvvvvvvvvvvvvv
        here something like 
        plt.title(label) 
        is missing but the **kws only contain label as string not as 
        parameter contaning the column name
        ^^^^^^^^^^^^^^^^^^^
        """

    def corrfunc(x, y, dc=False, **kws):  
        # different sizes, text anc color in relation to r/d values         
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')


    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_upper(scatterfunc)
    g.map_diag(histfunc)
    g.map_lower(corrfunc)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)

它生成的是

感谢@ImportanceOfBeingErnest 在这里发表评论,为那些可能觉得有用的人更新了脚本。我还将散点图切换为 "lower",以便轴标签可见。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 16}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        """ scatterplot with spline of deg=5 in red"""
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        """ histogram"""
        plt.hist(x,bins=30,color = "black", ec="white")    

    def corrfunc(x, y, dc=False, **kws):  
        """different sizes, text anc color in relation to r/d values
           the dc parameter determines wheter distance correlation or 
           linear regression should be applied"""
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')

    def make_diag_titles(g,titles):
        for (i,row) in enumerate(g.axes):
            g.axes[i][i].title.set_text(titles[i])
        return g
    ###
    # here the plot is put together
    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_lower(scatterfunc)
    g.map_diag(histfunc)
    g.map_upper(corrfunc)
    g = make_diag_titles(g, data.columns)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)