2015-12-15 84 views
3

根據下面的代碼生成熱圖,如何獲得列「D」(總列) 以顯示爲熱圖右側的列沒有顏色,只對齊每個單元的總值?我也試圖將標籤移到頂部。我不介意左邊的標籤是水平的,因爲這不會出現在我的實際數據中。如何在Seaborn熱圖旁邊添加列

import matplotlib.pyplot as plt 
import seaborn as sns 
import pandas as pd 
%matplotlib inline 
df = pd.DataFrame(
     {'A' : ['A', 'A', 'B', 'B','C', 'C', 'D', 'D'], 
     'B' : ['A', 'B', 'A', 'B','A', 'B', 'A', 'B'], 
     'C' : [2, 4, 5, 2, 0, 3, 9, 1], 
     'D' : [6, 6, 7, 7, 3, 3, 10, 10]}) 

df=df.pivot('A','B','C') 
fig, ax = plt.subplots(1, 1, figsize =(4,6)) 

sns.heatmap(df, annot=True, linewidths=0, cbar=False) 
plt.show() 

這裏是理想的結果:

Desired Result

提前感謝!

回答

3

我認爲最清潔的方法(儘管可能不是最短的方法)將繪製Total作爲其中一列,然後訪問熱圖的各個面的顏色並將其中一些顏色改爲白色。

負責熱圖上顏色的元素是matplotlib.collections.QuadMesh。它包含用於熱圖的每個方面的所有facecolors,從左到右,從下到上。

你可以修改一些顏色,並在你plt.show()之前將它們傳回QuadMesh

seaborn有一個小問題,seaborn更改了某些註釋的文本顏色,使它們在黑暗背景下可見,並且當您更改爲白色時它們變得不可見。所以現在我把所有文字的顏色設置爲黑色,你需要弄清楚什麼是最適合你的情節。

最後,把x軸蜱和頂部,使用標籤:

ax.xaxis.tick_top() 
ax.xaxis.set_label_position('top') 

代碼的最終版本:

import matplotlib.pyplot as plt 
from matplotlib.collections import QuadMesh 
from matplotlib.text import Text 

import seaborn as sns 
import pandas as pd 
import numpy as np 
%matplotlib inline 

df = pd.DataFrame(
     {'A' : ['A', 'A', 'B', 'B','C', 'C', 'D', 'D'], 
     'B' : ['A', 'B', 'A', 'B','A', 'B', 'A', 'B'], 
     'C' : [2, 4, 5, 2, 0, 3, 9, 1], 
     'D' : [6, 6, 7, 7, 3, 3, 10, 10]}) 

df=df.pivot('A','B','C') 

# create "Total" column 
df['Total'] = df['A'] + df['B'] 

fig, ax = plt.subplots(1, 1, figsize =(4,6)) 

sns.heatmap(df, annot=True, linewidths=0, cbar=False) 

# find your QuadMesh object and get array of colors 
quadmesh = ax.findobj(QuadMesh)[0] 
facecolors = quadmesh.get_facecolors() 

# make colors of the last column white 
facecolors[np.arange(2,12,3)] = np.array([1,1,1,1]) 

# set modified colors 
quadmesh.set_facecolors = facecolors 

# set color of all text to black 
for i in ax.findobj(Text): 
    i.set_color('black') 

# move x ticks and label to the top 
ax.xaxis.tick_top() 
ax.xaxis.set_label_position('top') 

plt.show() 

final figure

附:我在Python 2.7中,可能需要一些語法調整,儘管我想不出任何。