2017-08-01 33 views
0

比方說c = a + b,但是abndarray s,其形狀不一定相同。也就是說,它們可以是遵循general broadcasting rules的任何兩個陣列。更多pythonic方式來計算numpy廣播添加的派生?

我有一些輸出的衍生dl/dc,我想計算dl/da。如果ab的形狀相同,則dl/da = dl/db = dl/dc。但是,我可能會有一些像這樣的補充a.shape == (3,)b.shape == (2,3),所以c[i][j] = a[j] + b[i][j]。這意味着dl/da[j] = sum_i c[i][j]。通常,dl/da是在a中廣播的所有座標軸上dl/dc的總和。

爲了計算的一般ab鏈式法則衍生品,我寫了下面的功能,但我覺得它不是很Python的,而且也許可以更有效的進行:

def addition_derivatives(x, y, d): 
    flip = False 
    if x.ndim < y.ndim: # x should have higher ndim 
     flip = True 
     x, y = y, x 

    S = x.shape # shape of array with higher ndim 
    s = y.shape # shape of array with lower ndim 

    # figure out which axes will be broadcast in which arrays 
    n = len(S) 
    # impute missing ones in the shape of the smaller array as per: 
    # https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules 
    s = tuple(1 if i < len(S) - len(s) else s[i - (len(S) - len(s))] for i in range(n)) 
    axis_x = [] 
    axis_y = [] 
    for i in range(n): 
     assert s[i] == S[i] or s[i] == 1 or S[i] == 1 
     if S[i] == 1 and s[i] != 1: 
      axis_x.append(i) 
     if s[i] == 1 and S[i] != 1: 
      axis_y.append(i) 
    axis_x, axis_y = map(tuple, (axis_x, axis_y)) 

    # compute the derivatives 
    dx = np.sum(d, axis=axis_x).reshape(x.shape) 
    dy = np.sum(d, axis=axis_y).reshape(y.shape) 
    if flip: 
     dx, dy = dy, dx 

    return dx, dy 

回答

0

其實我結束了找到一種使用np.broadcast_arraysnp.strides進行破解的方法。我不確定這會在所有情況下都能正常工作,但它迄今爲止已工作,因爲np.strides對於所有尺寸爲1的軸都返回0.

def addition_derivatives(x, y, d): 
    bx, by = np.broadcast_arrays(x, y) 
    ax = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx == 0 and dy != 0) 
    ay = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx != 0 and dy == 0) 
    dx = np.sum(d, ax).reshape(x.shape) 
    dy = np.sum(d, ay).reshape(y.shape) 
    return dx, dy