2017-09-26 150 views
-1

當對np.broadcast_to()的簡單調用失敗時,將兩個數組廣播到一起的最佳方式是什麼?Python/NumPy中的多維廣播 - 或者與numpy.squeeze()相反

考慮下面的例子:

import numpy as np 

arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6)) 
arr2 = np.arange(3 * 5).reshape((3, 5)) 

arr1 + arr2 
# ValueError: operands could not be broadcast together with shapes (2,3,4,5,6) (3,5) 

arr2_ = np.broadcast_to(arr2, arr1.shape) 
# ValueError: operands could not be broadcast together with remapped shapes 

arr2_ = arr2.reshape((1, 3, 1, 5, 1)) 
arr1 + arr2 
# now this works because the singletons trigger the automatic broadcast 

這僅工作,如果我手動選擇其自動廣播是去工作的形狀。 自動做這件事最有效的方法是什麼? 除了在巧妙構造的可播放形狀上重塑形狀嗎?

請注意與np.squeeze()的關係:這將通過刪除單例執行逆操作。所以我需要的是某種np.squeeze()的逆。 官方documentation(截至NumPy的1.13.0暗示的np.squeeze()逆是np.expand_dim(),但這是幾乎沒有靈活,我需要它,實際上np.expand_dim()大致相當於np.reshape(array, shape + (1,))array[:, None]

這個問題也涉及到例如通過sum接受keepdims關鍵字:

import numpy as np 

arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6)) 

# not using `keepdims` 
arr2 = np.sum(arr1, (0, 2, 4)) 
arr2.shape 
# : (3, 5) 

arr1 + arr2 
# ValueError: operands could not be broadcast together with shapes (2,3,4,5,6) (3,5) 


# now using `keepdims` 
arr2 = np.sum(arr1, (0, 2, 4), keepdims=True) 
arr2.shape 
# : (1, 3, 1, 5, 1) 

arr1 + arr2 
# now this works because it has the correct shape 

編輯:很顯然,在情況np.newaxiskeepdims機制是合適的選擇,則不需要unsqueeze()函數。

然而,有些用例中這些都不是一個選項。

例如,考慮在指定的任意維數上的numpy.average()中實施的加權平均的情況。 現在參數weights必須與輸入具有相同的形狀。 但是,weights沒有必要指定非縮減維度上的權重,因爲它們只是重複,而NumPy的廣播機制會適當地處理它們。

所以,如果我們想有這樣的功能,我們就需要像(其中一些一致性檢查只是爲了簡化省略),以代碼的東西:

def weighted_average(arr, weights=None, axis=None): 
    if weights is not None and weights.shape != arr.shape: 
     weights = unsqueeze(weights, ...) 
     weights = np.zeros_like(arr) + weights 
    result = np.sum(arr * weights, axis=axis) 
    result /= np.sum(weights, axis=axis) 
    return result 

或等價:

def weighted_average(arr, weights=None, axis=None): 
    if weights is not None and weights.shape != arr.shape: 
     weights = unsqueeze(weights, ...) 
     weights = np.zeros_like(arr) + weights 
    return np.average(arr, weights, axis) 

在兩者中的任何一箇中,都不可能用weights[:, np.newaxis]類似的語句來代替unsqueeze(),因爲我們事先不知道需要新軸的位置,也不能使用的keepdims功能,因爲代碼將在arr * weights處失敗。

如果np.expand_dims()支持其參數axis的整數int,但是NumPy 1.13.0不支持,則可以相對較好地處理這種情況。

+2

這是非常不安全的('np.squeeze'也非常不安全,但這更糟糕)。如果任何輸入碰巧具有相同長度的多個維度,則無法告訴如何對齊這兩個數組的形狀。使用'np.newaxis'將新軸插入到正確的位置會更好,或者使用像'keepdims'這樣的機制來確保形狀首先對齊。 – user2357112

+0

您的擔心很好,但我清楚地看到一些例子都不合適。我添加了一個例子,既不能使用np.squeeze也不能使用keepdims。 – norok2

+0

這仍然是'np.newaxis'的情況。您可能必須手動構建索引元組,或者可以使用'reshape'並手動構造形狀元組,但絕對不是'unsqueeze'的工作。 – user2357112

回答

1

我的方法是通過定義以下unsqueezing()函數來處理可以自動完成並在輸入可能不明確時發出警告的情況(例如,當源形狀的某些源元素可能與目標形狀):

def unsqueezing(
     source_shape, 
     target_shape): 
    """ 
    Generate a broadcasting-compatible shape. 

    The resulting shape contains *singletons* (i.e. `1`) for non-matching dims. 
    Assumes all elements of the source shape are contained in the target shape 
    (excepts for singletons) in the correct order. 

    Warning! The generated shape may not be unique if some of the elements 
    from the source shape are present multiple timesin the target shape. 

    Args: 
     source_shape (Sequence): The source shape. 
     target_shape (Sequence): The target shape. 

    Returns: 
     shape (tuple): The broadcast-safe shape. 

    Raises: 
     ValueError: if elements of `source_shape` are not in `target_shape`. 

    Examples: 
     For non-repeating elements, `unsqueezing()` is always well-defined: 

     >>> unsqueezing((2, 3), (2, 3, 4)) 
     (2, 3, 1) 
     >>> unsqueezing((3, 4), (2, 3, 4)) 
     (1, 3, 4) 
     >>> unsqueezing((3, 5), (2, 3, 4, 5, 6)) 
     (1, 3, 1, 5, 1) 
     >>> unsqueezing((1, 3, 5, 1), (2, 3, 4, 5, 6)) 
     (1, 3, 1, 5, 1) 

     If there is nothing to unsqueeze, the `source_shape` is returned: 

     >>> unsqueezing((1, 3, 1, 5, 1), (2, 3, 4, 5, 6)) 
     (1, 3, 1, 5, 1) 
     >>> unsqueezing((2, 3), (2, 3)) 
     (2, 3) 

     If some elements in `source_shape` are repeating in `target_shape`, 
     a user warning will be issued: 

     >>> unsqueezing((2, 2), (2, 2, 2, 2, 2)) 
     (2, 2, 1, 1, 1) 
     >>> unsqueezing((2, 2), (2, 3, 2, 2, 2)) 
     (2, 1, 2, 1, 1) 

     If some elements of `source_shape` are not presente in `target_shape`, 
     an error is raised. 

     >>> unsqueezing((2, 3), (2, 2, 2, 2, 2)) 
     Traceback (most recent call last): 
      ... 
     ValueError: Target shape must contain all source shape elements\ 
(in correct order). (2, 3) -> (2, 2, 2, 2, 2) 
     >>> unsqueezing((5, 3), (2, 3, 4, 5, 6)) 
     Traceback (most recent call last): 
      ... 
     ValueError: Target shape must contain all source shape elements\ 
(in correct order). (5, 3) -> (2, 3, 4, 5, 6) 

    """ 
    shape = [] 
    j = 0 
    for i, dim in enumerate(target_shape): 
     if j < len(source_shape): 
      shape.append(dim if dim == source_shape[j] else 1) 
      if i + 1 < len(target_shape) and dim == source_shape[j] \ 
        and dim != 1 and dim in target_shape[i + 1:]: 
       text = ('Multiple positions (e.g. {} and {})' 
         ' for source shape element {}.'.format(
        i, target_shape[i + 1:].index(dim) + (i + 1), dim)) 
       warnings.warn(text) 
      if dim == source_shape[j] or source_shape[j] == 1: 
       j += 1 
     else: 
      shape.append(1) 
    if j < len(source_shape): 
     raise ValueError(
      'Target shape must contain all source shape elements' 
      ' (in correct order). {} -> {}'.format(source_shape, target_shape)) 
    return tuple(shape) 

這可以被用來定義unsqueeze()作爲np.squeeze()更靈活的逆相比np.expand_dims()僅可以一次將一行單:

def unsqueeze(
     arr, 
     axis=None, 
     shape=None, 
     reverse=False): 
    """ 
    Add singletons to the shape of an array to broadcast-match a given shape. 

    In some sense, this function implements the inverse of `numpy.squeeze()`. 


    Args: 
     arr (np.ndarray): The input array. 
     axis (int|Iterable|None): Axis or axes in which to operate. 
      If None, a valid set axis is generated from `shape` when this is 
      defined and the shape can be matched by `unsqueezing()`. 
      If int or Iterable, specified how singletons are added. 
      This depends on the value of `reverse`. 
      If `shape` is not None, the `axis` and `shape` parameters must be 
      consistent. 
      Values must be in the range [-(ndim+1), ndim+1] 
      At least one of `axis` and `shape` must be specified. 
     shape (int|Iterable|None): The target shape. 
      If None, no safety checks are performed. 
      If int, this is interpreted as the number of dimensions of the 
      output array. 
      If Iterable, the result must be broadcastable to an array with the 
      specified shape. 
      If `axis` is not None, the `axis` and `shape` parameters must be 
      consistent. 
      At least one of `axis` and `shape` must be specified. 
     reverse (bool): Interpret `axis` parameter as its complementary. 
      If True, the dims of the input array are placed at the positions 
      indicated by `axis`, and singletons are placed everywherelse and 
      the `axis` length must be equal to the number of dimensions of the 
      input array; the `shape` parameter cannot be `None`. 
      If False, the singletons are added at the position(s) specified by 
      `axis`. 
      If `axis` is None, `reverse` has no effect. 

    Returns: 
     arr (np.ndarray): The reshaped array. 

    Raises: 
     ValueError: if the `arr` shape cannot be reshaped correctly. 

    Examples: 
     Let's define some input array `arr`: 

     >>> arr = np.arange(2 * 3 * 4).reshape((2, 3, 4)) 
     >>> arr.shape 
     (2, 3, 4) 

     A call to `unsqueeze()` can be reversed by `np.squeeze()`: 

     >>> arr_ = unsqueeze(arr, (0, 2, 4)) 
     >>> arr_.shape 
     (1, 2, 1, 3, 1, 4) 
     >>> arr = np.squeeze(arr_, (0, 2, 4)) 
     >>> arr.shape 
     (2, 3, 4) 

     The order of the axes does not matter: 

     >>> arr_ = unsqueeze(arr, (0, 4, 2)) 
     >>> arr_.shape 
     (1, 2, 1, 3, 1, 4) 

     If `shape` is an int, `axis` must be consistent with it: 

     >>> arr_ = unsqueeze(arr, (0, 2, 4), 6) 
     >>> arr_.shape 
     (1, 2, 1, 3, 1, 4) 
     >>> arr_ = unsqueeze(arr, (0, 2, 4), 7) 
     Traceback (most recent call last): 
      ... 
     ValueError: Incompatible `[0, 2, 4]` axis and `7` shape for array of\ 
shape (2, 3, 4) 

     It is possible to reverse the meaning to `axis` to add singletons 
     everywhere except where specified (but requires `shape` to be defined 
     and the length of `axis` must match the array dims): 

     >>> arr_ = unsqueeze(arr, (0, 2, 4), 10, True) 
     >>> arr_.shape 
     (2, 1, 3, 1, 4, 1, 1, 1, 1, 1) 
     >>> arr_ = unsqueeze(arr, (0, 2, 4), reverse=True) 
     Traceback (most recent call last): 
      ... 
     ValueError: When `reverse` is True, `shape` cannot be None. 
     >>> arr_ = unsqueeze(arr, (0, 2), 10, True) 
     Traceback (most recent call last): 
      ... 
     ValueError: When `reverse` is True, the length of axis (2) must match\ 
the num of dims of array (3). 

     Axes values must be valid: 

     >>> arr_ = unsqueeze(arr, 0) 
     >>> arr_.shape 
     (1, 2, 3, 4) 
     >>> arr_ = unsqueeze(arr, 3) 
     >>> arr_.shape 
     (2, 3, 4, 1) 
     >>> arr_ = unsqueeze(arr, -1) 
     >>> arr_.shape 
     (2, 3, 4, 1) 
     >>> arr_ = unsqueeze(arr, -4) 
     >>> arr_.shape 
     (1, 2, 3, 4) 
     >>> arr_ = unsqueeze(arr, 10) 
     Traceback (most recent call last): 
      ... 
     ValueError: Axis (10,) out of range. 

     If `shape` is specified, `axis` can be omitted (USE WITH CARE!) or its 
     value is used for addiotional safety checks: 

     >>> arr_ = unsqueeze(arr, shape=(2, 3, 4, 5, 6)) 
     >>> arr_.shape 
     (2, 3, 4, 1, 1) 
     >>> arr_ = unsqueeze(
     ...  arr, (3, 6, 8), (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6), True) 
     >>> arr_.shape 
     (1, 1, 1, 2, 1, 1, 3, 1, 4, 1, 1) 
     >>> arr_ = unsqueeze(
     ...  arr, (3, 7, 8), (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6), True) 
     Traceback (most recent call last): 
      ... 
     ValueError: New shape [1, 1, 1, 2, 1, 1, 1, 3, 4, 1, 1] cannot be\ 
broadcasted to shape (2, 5, 3, 2, 7, 2, 3, 2, 4, 5, 6) 
     >>> arr = unsqueeze(arr, shape=(2, 5, 3, 7, 2, 4, 5, 6)) 
     >>> arr.shape 
     (2, 1, 3, 1, 1, 4, 1, 1) 
     >>> arr = np.squeeze(arr) 
     >>> arr.shape 
     (2, 3, 4) 
     >>> arr = unsqueeze(arr, shape=(5, 3, 7, 2, 4, 5, 6)) 
     Traceback (most recent call last): 
      ... 
     ValueError: Target shape must contain all source shape elements\ 
(in correct order). (2, 3, 4) -> (5, 3, 7, 2, 4, 5, 6) 

     The behavior is consistent with other NumPy functions and the 
     `keepdims` mechanism: 

     >>> axis = (0, 2, 4) 
     >>> arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6)) 
     >>> arr2 = np.sum(arr1, axis, keepdims=True) 
     >>> arr2.shape 
     (1, 3, 1, 5, 1) 
     >>> arr3 = np.sum(arr1, axis) 
     >>> arr3.shape 
     (3, 5) 
     >>> arr3 = unsqueeze(arr3, axis) 
     >>> arr3.shape 
     (1, 3, 1, 5, 1) 
     >>> np.all(arr2 == arr3) 
     True 
    """ 
    # calculate `new_shape` 
    if axis is None and shape is None: 
     raise ValueError(
      'At least one of `axis` and `shape` parameters must be specified.') 
    elif axis is None and shape is not None: 
     new_shape = unsqueezing(arr.shape, shape) 
    elif axis is not None: 
     if isinstance(axis, int): 
      axis = (axis,) 
     # calculate the dim of the result 
     if shape is not None: 
      if isinstance(shape, int): 
       ndim = shape 
      else: # shape is a sequence 
       ndim = len(shape) 
     elif not reverse: 
      ndim = len(axis) + arr.ndim 
     else: 
      raise ValueError('When `reverse` is True, `shape` cannot be None.') 
     # check that axis is properly constructed 
     if any([ax < -ndim - 1 or ax > ndim + 1 for ax in axis]): 
      raise ValueError('Axis {} out of range.'.format(axis)) 
     # normalize axis using `ndim` 
     axis = sorted([ax % ndim for ax in axis]) 
     # manage reverse mode 
     if reverse: 
      if len(axis) == arr.ndim: 
       axis = [i for i in range(ndim) if i not in axis] 
      else: 
       raise ValueError(
        'When `reverse` is True, the length of axis ({})' 
        ' must match the num of dims of array ({}).'.format(
         len(axis), arr.ndim)) 
     elif len(axis) + arr.ndim != ndim: 
      raise ValueError(
       'Incompatible `{}` axis and `{}` shape' 
       ' for array of shape {}'.format(axis, shape, arr.shape)) 
     # generate the new shape from axis, ndim and shape 
     new_shape = [] 
     i, j = 0, 0 
     for l in range(ndim): 
      if i < len(axis) and l == axis[i] or j >= arr.ndim: 
       new_shape.append(1) 
       i += 1 
      else: 
       new_shape.append(arr.shape[j]) 
       j += 1 

    # check that `new_shape` is consistent with `shape` 
    if shape is not None: 
     if isinstance(shape, int): 
      if len(new_shape) != ndim: 
       raise ValueError(
        'Length of new shape {} does not match ' 
        'expected length ({}).'.format(len(new_shape), ndim)) 
     else: 
      if not all([new_dim == 1 or new_dim == dim 
         for new_dim, dim in zip(new_shape, shape)]): 
       raise ValueError(
        'New shape {} cannot be broadcasted to shape {}'.format(
         new_shape, shape)) 

    return arr.reshape(new_shape) 

利用這些,我們可以寫:

import numpy as np 

arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6)) 
arr2 = np.arange(3 * 5).reshape((3, 5)) 

arr3 = unsqueeze(arr2, (0, 2, 4)) 

arr1 + arr3 
# now this works because it has the correct shape 

arr3 = unsqueeze(arr2, shape=arr1.shape) 
arr1 + arr3 
# this also works because the shape can be expanded unambiguously 

所以動態廣播現在可以發生,這是與keepdims的行爲是一致的:

import numpy as np 

axis = (0, 2, 4) 
arr1 = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6)) 
arr2 = np.sum(arr1, axis, keepdims=True) 
arr3 = np.sum(arr1, axis) 
arr3 = unsqueeze(arr3, axis) 
np.all(arr2 == arr3) 
# : True 

實際上,這個擴展np.expand_dims()處理更復雜場景。

對此代碼的改進顯然更值得歡迎。

+0

這是由你來證明,邏輯是正確和明確的,至少對於你自己的用例。 – hpaulj

+0

我擴展了'unsqueeze()'來處理一個'axis'參數來模擬'np.expand_dims()'當它的一個整數序列,並且當'axis'沒有被指定時保持* automagic *擴展,另外警告用戶* automagic *結果不是唯一定義的。 – norok2