2017-08-31 54 views
1

我有一個DataFrame a和系列b。我想找到ab的每一列的條件關聯,條件是b的值。具體來說,我使用pd.cutb分成5組。但是,我沒有使用標準分位數,而是使用高於或低於平均值的b的標準偏差。Groupby + DataFrame和系列之間的相關性

np.random.seed(123) 

a = (pd.DataFrame(np.random.randn(1000,3)) 
    .add_prefix('col')) 
b = pd.Series(np.random.randn(1000)) 

mu, sigma = b.mean(), b.std() 
breakpoints = mu + np.array([-2., -1., 1., 2.]) * sigma 
breakpoints = np.append(np.insert(breakpoints, 0, -np.inf), np.inf) 
# There are now 6 breakpoints to create 5 groupings: 
# array([  -inf, -1.91260048, -0.9230609 , 1.05601827, 2.04555785, 
#    inf]) 

labels = ['[-inf,-2]', '(-2,-1]', '(-1,1]', '(1,2]', '(2,inf]'] 
groups = pd.cut(b, bins=breakpoints, labels=labels) 

在這裏一切都很好。我在最後一行掛了,使用.corrwith.groupby,它拋出一個ValueError

a.groupby(groups).corrwith(b.groupby(groups)) 

任何想法? a.corrwith(b)的結果是一個系列,所以我認爲這裏的結果應該是一個以組/桶爲列的DataFrame。例如,一列是:

print(a[b < breakpoints[1]].corrwith(b[b < breakpoints[1]])) 
# Correlation conditional on that `b` is [-inf, -2 stdev] 
col0 0.43708 
col1 -0.08440 
col2 -0.02923 
dtype: float64 

回答

0

一個解決方案,它的功能,但並不漂亮:

full = a.join(b.to_frame(name='_drop')) 
corrs = (full.groupby(groups) 
     .corr() 
     .loc[(slice(None), a.columns), '_drop'] 
     .unstack() 
     .T) 

print(corrs) 
     [-inf,-2] (-2,-1] (-1,1] (1,2] (2,inf] 
col0 0.43708 0.06716 0.02437 0.01695 0.05384 
col1 -0.08440 0.04208 0.05529 -0.07146 0.14766 
col2 -0.02923 -0.19672 0.01519 -0.02290 -0.17101