27 lines
685 B
Python
27 lines
685 B
Python
|
import numpy as np
|
||
|
import seaborn as sns
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
plt.rc('axes', unicode_minus=False)
|
||
|
sns.set_theme(style="white")
|
||
|
sns.set(font="WenQuanYi Zen Hei")
|
||
|
|
||
|
|
||
|
def create_heatmap(dataset):
|
||
|
# 计算相关矩阵
|
||
|
corr = dataset.corr()
|
||
|
|
||
|
# 为上三角矩阵生成蒙版
|
||
|
mask = np.triu(np.ones_like(corr, dtype=bool))
|
||
|
|
||
|
# 设置图大小
|
||
|
f, ax = plt.subplots(figsize=(11, 9))
|
||
|
|
||
|
# 生成颜色图表
|
||
|
cmap = sns.diverging_palette(230, 20, as_cmap=True)
|
||
|
|
||
|
# 使用蒙版和正确的纵横比绘图
|
||
|
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
|
||
|
square=True, linewidths=.5, cbar_kws={"shrink": .5})
|
||
|
|
||
|
plt.show()
|