2013-10-20 58 views
3

嘗試繪製連接3D子圖上的點到另一個3D子圖的直線。在2D中,使用ConnectionPatch很容易。我試圖模仿從here Arrow3D類沒有運氣。matplotlib用於3D子圖的連接補丁

我很高興在這一點上即使只是一個解決辦法。作爲一個例子,在下面的代碼生成的圖中,我想連接兩個綠點。

def cylinder(r, n): 
    ''' 
    Returns the unit cylinder that corresponds to the curve r. 
    INPUTS: r - a vector of radii 
      n - number of coordinates to return for each element in r 

    OUTPUTS: x,y,z - coordinates of points 
    ''' 

    # ensure that r is a column vector 
    r = np.atleast_2d(r) 
    r_rows, r_cols = r.shape 

    if r_cols > r_rows: 
     r = r.T 

    # find points along x and y axes 
    points = np.linspace(0, 2*np.pi, n+1) 
    x = np.cos(points)*r 
    y = np.sin(points)*r 

    # find points along z axis 
    rpoints = np.atleast_2d(np.linspace(0, 1, len(r))) 
    z = np.ones((1, n+1))*rpoints.T 

    return x, y, z 


#--------------------------------------- 
# 3D example 
#--------------------------------------- 
fig = plt.figure() 

# top figure 
ax = fig.add_subplot(2,1,1, projection='3d') 
x,y,z = cylinder(np.linspace(2,1,num=10), 40) 
for i in range(len(z)): 
    ax.plot(x[i], y[i], z[i], 'c') 
ax.plot([2], [0], [0],'go') 

# bottom figure 
ax2 = fig.add_subplot(2,1,2, projection='3d') 
x,y,z = cylinder(np.linspace(0,1,num=10), 40) 
for i in range(len(z)): 
    ax2.plot(x[i], y[i], z[i], 'r') 
ax2.plot([1], [0], [1],'go') 

plt.show() 

回答

4

我正試圖在今晚解決一個非常類似的問題!一些代碼可能是不必要的,但它會給你的主要的想法......我希望

啓示:http://hackmap.blogspot.com.au/2008/06/pylab-matplotlib-imagemap.html 並在過去的兩年時間之外許多不同的來源......

#! /usr/bin/env python 

import numpy as np 
import matplotlib.pyplot as plt 
from mpl_toolkits.mplot3d import Axes3D 
from mpl_toolkits.mplot3d import proj3d 
import matplotlib 

N = 50 
x = np.random.rand(N) 
y = np.random.rand(N) 
z = np.random.rand(N) 

# point's to join 
p1 = 10 
p2 = 20 

fig = plt.figure() 

# a background axis to draw lines on 
ax0 = plt.axes([0.,0.,1.,1.]) 
ax0.set_xlim(0,1) 
ax0.set_ylim(0,1) 

# use these to know how to transform the screen coords 
dpi = ax0.figure.get_dpi() 
height = ax0.figure.get_figheight() * dpi 
width = ax0.figure.get_figwidth() * dpi 

# first scatter plot 
ax1 = plt.axes([0.05,0.05,0.9,0.425], projection='3d') 
ax1.scatter(x, y, z) 

# one point of interest 
ax1.scatter(x[p1], y[p1], z[p1], s=100.) 
x1, y1, _ = proj3d.proj_transform(x[p1], y[p1], z[p1], ax1.get_proj()) 
[x1,y1] = ax1.transData.transform((x1, y1)) # convert 2d space to screen space 
# put them in screen space relative to ax0 
x1 = x1/width 
y1 = y1/height 

# second scatter plot (same data) 
ax2 = plt.axes([0.05,0.475,0.9,0.425], projection='3d') 
ax2.scatter(x, y, z) 

# another point of interest 
ax2.scatter(x[p2], y[p2], z[p2], s=100.) 
x2, y2, _ = proj3d.proj_transform(x[p2], y[p2], z[p2], ax2.get_proj()) 
[x2,y2] = ax2.transData.transform((x2, y2)) # convert 2d space to screen space 
x2 = x2/width 
y2 = y2/height 


# set all these guys to invisible (needed?, smartest way?) 
for item in [fig, ax1, ax2]: 
    item.patch.set_visible(False) 

# draw a line between the transformed points 
# again, needed? I know it works... 

transFigure = fig.transFigure.inverted() 

coord1 = transFigure.transform(ax0.transData.transform([x1,y1])) 
coord2 = transFigure.transform(ax0.transData.transform([x2,y2])) 

line = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]), 
           transform=fig.transFigure) 
fig.lines = line, 

plt.show() 

success http://i42.tinypic.com/5bpxyo.jpg

+0

真棒!我把我的代碼放在下面。它會清理你所擁有的幾行,但大部分是相同的。謝謝! – benten

0

我的最終代碼,只需要有一個可行的例子:

#! /usr/bin/env python 

import numpy as np 
import matplotlib.pyplot as plt 
import mpl_toolkits.mplot3d.axes3d as p3 
from mpl_toolkits.mplot3d import Axes3D 
from mpl_toolkits.mplot3d import proj3d 
import matplotlib 



def cylinder(r, n): 
    ''' 
    Returns the unit cylinder that corresponds to the curve r. 
    INPUTS: r - a vector of radii 
      n - number of coordinates to return for each element in r 

    OUTPUTS: x,y,z - coordinates of points 
    ''' 

    # ensure that r is a column vector 
    r = np.atleast_2d(r) 
    r_rows, r_cols = r.shape 

    if r_cols > r_rows: 
     r = r.T 

    # find points along x and y axes 
    points = np.linspace(0, 2*np.pi, n+1) 
    x = np.cos(points)*r 
    y = np.sin(points)*r 

    # find points along z axis 
    rpoints = np.atleast_2d(np.linspace(0, 1, len(r))) 
    z = np.ones((1, n+1))*rpoints.T 

    return x, y, z 



#--------------------------------------- 
# 3D example 
#--------------------------------------- 
fig = plt.figure() 

# a background axis to draw lines on 
ax0 = plt.axes([0.,0.,1.,1.]) 
ax0.set_xlim(0,1) 
ax0.set_ylim(0,1) 

# use these to know how to transform the screen coords 
dpi = ax0.figure.get_dpi() 
height = ax0.figure.get_figheight() * dpi 
width = ax0.figure.get_figwidth() * dpi 


# top figure 
ax1 = fig.add_subplot(2,1,1, projection='3d') 
x,y,z = cylinder(np.linspace(2,1,num=10), 40) 
for i in range(len(z)): 
    ax1.plot(x[i], y[i], z[i], 'c') 


# bottom figure 
ax2 = fig.add_subplot(2,1,2, projection='3d') 
x,y,z = cylinder(np.linspace(0,1,num=10), 40) 
for i in range(len(z)): 
    ax2.plot(x[i], y[i], z[i], 'r') 


# first point of interest 
p1 = ([2],[0],[0]) 
ax1.plot(p1[0], p1[1], p1[2],'go') 
x1, y1, _ = proj3d.proj_transform(p1[0], p1[1], p1[2], ax1.get_proj()) 
[x1,y1] = ax1.transData.transform((x1[0], y1[0])) # convert 2d space to screen space 
# put them in screen space relative to ax0 
x1 = x1/width 
y1 = y1/height 

# another point of interest 
p2 = ([1], [0], [1]) 
ax2.plot(p2[0], p2[1], p2[2],'go') 
x2, y2, _ = proj3d.proj_transform(p2[0], p2[1], p2[2], ax2.get_proj()) 
[x2,y2] = ax2.transData.transform((x2[0], y2[0])) # convert 2d space to screen space 
x2 = x2/width 
y2 = y2/height 

# plot line between subplots 
transFigure = fig.transFigure.inverted() 
coord1 = transFigure.transform(ax0.transData.transform([x1,y1])) 
coord2 = transFigure.transform(ax0.transData.transform([x2,y2])) 
fig.lines = ax0.plot((coord1[0],coord2[0]),(coord1[1],coord2[1]), transform=fig.transFigure, linestyle='dashed') 

plt.show()