math-modeling-competition-C.../heatmap.py

27 lines
685 B
Python
Raw Normal View History

2022-07-06 06:40:41 +00:00
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
plt.rc('axes', unicode_minus=False)
sns.set_theme(style="white")
sns.set(font="WenQuanYi Zen Hei")
def create_heatmap(dataset):
# 计算相关矩阵
corr = dataset.corr()
# 为上三角矩阵生成蒙版
mask = np.triu(np.ones_like(corr, dtype=bool))
# 设置图大小
f, ax = plt.subplots(figsize=(11, 9))
# 生成颜色图表
cmap = sns.diverging_palette(230, 20, as_cmap=True)
# 使用蒙版和正确的纵横比绘图
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
square=True, linewidths=.5, cbar_kws={"shrink": .5})
plt.show()