2015-11-20 69 views
6

我想基礎上,這個問題的答案創建一個可滾動的multiplot:使用ax.plot()正確更新 Creating a scrollable multiplot with python's pylab如何更新滾動,matplotlib藝術家的multiplot

線創建的,但我不能弄清楚如何更新使用xvlines()fill_between()創建的藝術家。

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import matplotlib.gridspec as gridspec 
from matplotlib.widgets import Slider 

#create dataframes 
dfs={} 
for x in range(100): 
    col1=np.random.normal(10,0.5,30) 
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30] 
    col3=np.random.randint(4,size=30) 
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3}) 

#create figure,axis,subplot 
fig = plt.figure() 
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1) 
ax = plt.subplot(gs[0]) 
ax.set_ylim([0,12]) 

#slider 
frame=0 
axframe = plt.axes([0.13, 0.02, 0.75, 0.03]) 
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d') 

#plots 
ln1,=ax.plot(dfs[0].index,dfs[0]['col1']) 
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black') 

#artists 
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5) 
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5) 
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5) 
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black') 

#update plots 
def update(val): 
    frame = np.floor(sframe.val) 
    ln1.set_ydata(dfs[frame]['col1']) 
    ln2.set_ydata(dfs[frame]['col2']) 
    ax.set_title('Frame ' + str(int(frame))) 
    plt.draw() 

#connect callback to slider 
sframe.on_changed(update) 
plt.show() 

這是它看起來像此刻 enter image description here

我不能採用同樣的方法爲plot(),因爲下面產生的錯誤信息:

ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5) 
TypeError: 'PolyCollection' object is not iterable 

這是什麼意思看起來像在每一幀 enter image description here

回答

1

fill_between返回一個PolyCollection,它需要一個創建時的頂點列表(或多個列表)。不幸的是,我還沒有找到一種方法來檢索用於創建給定PolyCollection的頂點,但在您的情況下,很容易直接創建PolyCollection(從而避免使用fill_between),然後在更改幀時更新頂點。

下面一個版本的代碼已經做了你所追求的:

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import matplotlib.gridspec as gridspec 
from matplotlib.widgets import Slider 

from matplotlib.collections import PolyCollection 

#create dataframes 
dfs={} 
for x in range(100): 
    col1=np.random.normal(10,0.5,30) 
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30] 
    col3=np.random.randint(4,size=30) 
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3}) 

#create figure,axis,subplot 
fig = plt.figure() 
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1) 
ax = plt.subplot(gs[0]) 
ax.set_ylim([0,12]) 

#slider 
frame=0 
axframe = plt.axes([0.13, 0.02, 0.75, 0.03]) 
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d') 

#plots 
ln1,=ax.plot(dfs[0].index,dfs[0]['col1']) 
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black') 

##additional code to update the PolyCollections 
val_r = 5 
val_b = 8 
val_g = 7 

def update_collection(collection, value, frame = 0): 
    xs = np.array(dfs[frame].index) 
    ys = np.array(dfs[frame]['col2']) 

    ##we need to catch the case where no points with y == value exist: 
    try: 
     minx = np.min(xs[ys == value]) 
     maxx = np.max(xs[ys == value]) 
     miny = value-0.5 
     maxy = value+0.5 
     verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]]) 
    except ValueError: 
     verts = np.zeros((0,2)) 
    finally: 
     collection.set_verts([verts]) 

#artists 

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5) 
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5) 
ax.add_collection(reds) 
update_collection(reds,val_r) 

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5) 
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5) 
ax.add_collection(blues) 
update_collection(blues, val_b) 

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5) 
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5) 
ax.add_collection(greens) 
update_collection(greens, val_g) 

ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black') 

#update plots 
def update(val): 
    frame = np.floor(sframe.val) 
    ln1.set_ydata(dfs[frame]['col1']) 
    ln2.set_ydata(dfs[frame]['col2']) 
    ax.set_title('Frame ' + str(int(frame))) 

    ##updating the PolyCollections: 
    update_collection(reds,val_r, frame) 
    update_collection(blues,val_b, frame) 
    update_collection(greens,val_g, frame) 

    plt.draw() 

#connect callback to slider 
sframe.on_changed(update) 
plt.show() 

每三個PolyCollectionsredsblues,並且greens)只有四個頂點(矩形的邊緣),其根據給定的數據確定(在update_collections中完成)。結果是這樣的:

example result of given code

測試在Python 3.5

相關問題