2014-02-09 42 views
47

我想在pyplot中使用Pandas DataFrame對象製作一個簡單的散點圖,但希望繪製兩個變量的有效方式,但需要第三列指定的符號(鍵)。我嘗試過使用df.groupby的各種方式,但沒有成功。下面是一個示例df腳本。這根據'key1'標記顏色,但是id喜歡看'key1'類別的圖例。我關門了嗎?謝謝。Pandas中的散點圖/ Pyplot:如何按類別繪圖

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) 
df['key1'] = (4,4,4,6,6,6,8,8,8,8) 
fig1 = plt.figure(1) 
ax1 = fig1.add_subplot(111) 
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8) 
plt.show() 

回答

64

您可以使用scatter此,但這需要您的數值爲key1,並且您將不會看到圖例,正如您注意到的那樣。

對於像這樣的離散類別,最好使用plot。例如:

import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd 
np.random.seed(1974) 

# Generate Data 
num = 20 
x, y = np.random.random((2, num)) 
labels = np.random.choice(['a', 'b', 'c'], num) 
df = pd.DataFrame(dict(x=x, y=y, label=labels)) 

groups = df.groupby('label') 

# Plot 
fig, ax = plt.subplots() 
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling 
for name, group in groups: 
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name) 
ax.legend() 

plt.show() 

enter image description here

如果你想的東西看起來像默認pandas風格,那麼就更新rcParams與大熊貓樣式表,並使用它的顏色生成。 (我還扭捏傳說略):

import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd 
np.random.seed(1974) 

# Generate Data 
num = 20 
x, y = np.random.random((2, num)) 
labels = np.random.choice(['a', 'b', 'c'], num) 
df = pd.DataFrame(dict(x=x, y=y, label=labels)) 

groups = df.groupby('label') 

# Plot 
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet) 
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random') 

fig, ax = plt.subplots() 
ax.set_color_cycle(colors) 
ax.margins(0.05) 
for name, group in groups: 
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name) 
ax.legend(numpoints=1, loc='upper left') 

plt.show() 

enter image description here

+0

是的,看起來它會爲我工作。非常感謝。 – user2989613

+0

爲什麼在上面的RGB例子中,符號在圖例中顯示兩次?如何只顯示一次? –

+1

@SteveSchulist - 使用'ax.legend(numpoints = 1)'只顯示一個標記。有兩個,就像'Line2D'一樣,經常有一條線連接兩個標記。 –

15

隨着plt.scatter,我只能想到一個:使用代理藝術家:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) 
df['key1'] = (4,4,4,6,6,6,8,8,8,8) 
fig1 = plt.figure(1) 
ax1 = fig1.add_subplot(111) 
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8) 

ccm=x.get_cmap() 
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)] 
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1) 

,其結果是:

enter image description here

+0

另一個好建議。我也會試試這個。謝謝。 – user2989613

20

這就是簡單的用Seabornpip install seaborn)作爲oneliner

sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5) 做:

import seaborn as sns 
import pandas as pd 
import numpy as np 
np.random.seed(1974) 

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3), 
    index=pd.date_range('2010-01-01', freq='M', periods=10), 
    columns=('one', 'two', 'three')) 
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8) 

sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5) 

enter image description here

下面是引用數據框:

enter image description here

既然你有你的數據的三個變量列,你可能要繪製所有成對尺寸有:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1", size=5) 

enter image description here

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/是另一種選擇。

2

您也可以嘗試Altairggpot這些重點是聲明性可視化。

import numpy as np 
import pandas as pd 
np.random.seed(1974) 

# Generate Data 
num = 20 
x, y = np.random.random((2, num)) 
labels = np.random.choice(['a', 'b', 'c'], num) 
df = pd.DataFrame(dict(x=x, y=y, label=labels)) 

牽牛星代碼

from altair import Chart 
c = Chart(df) 
c.mark_circle().encode(x='x', y='y', color='label') 

enter image description here

ggplot代碼

from ggplot import * 
ggplot(aes(x='x', y='y', color='label'), data=df) +\ 
geom_point(size=50) +\ 
theme_bw() 

enter image description here

4

您可以使用df.plot.scatter,並傳遞一個AR射線到c =參數定義每個點的顏色:

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) 
df['key1'] = (4,4,4,6,6,6,8,8,8,8) 
colors = np.where(df["key1"]==4,'r','-') 
colors[df["key1"]==6] = 'g' 
colors[df["key1"]==8] = 'b' 
print(colors) 
df.plot.scatter(x="one",y="two",c=colors) 
plt.show() 

enter image description here

0

這是相當哈克,但你可以使用one1Float64Index一氣呵成做的一切:

df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True) 

enter image description here

請注意,截至0.20.3,sorting the index is necessary,圖例是a bit wonky