2017-09-06 61 views
-1

我想在Numpy中實現Tensorflow或PyTorch的分散和聚集操作。我一直在撓頭。任何指針都非常感謝!如何在numpy中分散和收集操作?

+0

我懷疑有問題的代碼是開源... –

+1

貌似這些方法的Python前端到C++方法。如果你需要'numpy'專家的幫助,你需要解釋他們的工作。換句話說,用'numpy'術語(舉例)說明你想要做什麼。沒有'pytorch'的經驗,我無法輕易理解文檔。 – hpaulj

+0

@MadPhysicist是的代碼是開源的。你可以在這裏查看。這是一個非常酷的項目:http://openmined.org/ –

回答

0

scatter方法竟然比我預想的要好得多。我沒有在NumPy中找到任何現成的功能。我在這裏分享它是爲了任何需要用NumPy實現它的人的興趣。 (PS self是方法的目的地或輸出。)

def scatter_numpy(self, dim, index, src): 
    """ 
    Writes all values from the Tensor src into self at the indices specified in the index Tensor. 

    :param dim: The axis along which to index 
    :param index: The indices of elements to scatter 
    :param src: The source element(s) to scatter 
    :return: self 
    """ 
    if index.dtype != np.dtype('int_'): 
     raise TypeError("The values of index must be integers") 
    if self.ndim != index.ndim: 
     raise ValueError("Index should have the same number of dimensions as output") 
    if dim >= self.ndim or dim < -self.ndim: 
     raise IndexError("dim is out of range") 
    if dim < 0: 
     # Not sure why scatter should accept dim < 0, but that is the behavior in PyTorch's scatter 
     dim = self.ndim + dim 
    idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] 
    self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:] 
    if idx_xsection_shape != self_xsection_shape: 
     raise ValueError("Except for dimension " + str(dim) + 
         ", all dimensions of index and output should be the same size") 
    if (index >= self.shape[dim]).any() or (index < 0).any(): 
     raise IndexError("The values of index must be between 0 and (self.shape[dim] -1)") 

    def make_slice(arr, dim, i): 
     slc = [slice(None)] * arr.ndim 
     slc[dim] = i 
     return slc 

    # We use index and dim parameters to create idx 
    # idx is in a form that can be used as a NumPy advanced index for scattering of src param. in self 
    idx = [[*np.indices(idx_xsection_shape).reshape(index.ndim - 1, -1), 
      index[make_slice(index, dim, i)].reshape(1, -1)[0]] for i in range(index.shape[dim])] 
    idx = list(np.concatenate(idx, axis=1)) 
    idx.insert(dim, idx.pop()) 

    if not np.isscalar(src): 
     if index.shape[dim] > src.shape[dim]: 
      raise IndexError("Dimension " + str(dim) + "of index can not be bigger than that of src ") 
     src_xsection_shape = src.shape[:dim] + src.shape[dim + 1:] 
     if idx_xsection_shape != src_xsection_shape: 
      raise ValueError("Except for dimension " + 
          str(dim) + ", all dimensions of index and src should be the same size") 
     # src_idx is a NumPy advanced index for indexing of elements in the src 
     src_idx = list(idx) 
     src_idx.pop(dim) 
     src_idx.insert(dim, np.repeat(np.arange(index.shape[dim]), np.prod(idx_xsection_shape))) 
     self[idx] = src[src_idx] 

    else: 
     self[idx] = src 

    return self 

有可能是gather一個簡單的解決方案,但是這是我的解決:
(這裏self是,值採集的ndarray從。)

def gather_numpy(self, dim, index): 
    """ 
    Gathers values along an axis specified by dim. 
    For a 3-D tensor the output is specified by: 
     out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 
     out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 
     out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 

    :param dim: The axis along which to index 
    :param index: A tensor of indices of elements to gather 
    :return: tensor of gathered values 
    """ 
    idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] 
    self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:] 
    if idx_xsection_shape != self_xsection_shape: 
     raise ValueError("Except for dimension " + str(dim) + 
         ", all dimensions of index and self should be the same size") 
    if index.dtype != np.dtype('int_'): 
     raise TypeError("The values of index must be integers") 
    data_swaped = np.swapaxes(self, 0, dim) 
    index_swaped = np.swapaxes(index, 0, dim) 
    gathered = np.choose(index_swaped, data_swaped) 
    return np.swapaxes(gathered, 0, dim) 
0

富勒refindices是numpy的數組:

散點圖更新:

ref[indices] = updates   # tf.scatter_update(ref, indices, updates) 
ref[:, indices] = updates  # tf.scatter_update(ref, indices, updates, axis=1) 
ref[..., indices, :] = updates # tf.scatter_update(ref, indices, updates, axis=-2) 
ref[..., indices] = updates  # tf.scatter_update(ref, indices, updates, axis=-1) 

收集:

ref[indices]   # tf.gather(ref, indices) 
ref[:, indices]  # tf.gather(ref, indices, axis=1) 
ref[..., indices, :] # tf.gather(ref, indices, axis=-2) 
ref[..., indices]  # tf.gather(ref, indices, axis=-1) 

更多見numpy docs on indexing

+0

在你的解決方案中,你如何定義你想要分散src的維度? –

+0

已更新的答案。 – DomJack

0

對於散射,而不是像@DomJack所建議的那樣使用片分配,通常最好使用np.add.at;因爲與切片分配不同,這在存在重複索引的情況下具有明確定義的行爲。

+0

你的定義是什麼意思?我的理解是在PyTorch和Tensorflow中,重複的索引導致重寫值。在TF的情況下,他們特別警告更新的順序不是確定性的。我查看了np.add.at,它似乎對「scatter_add」操作有效(不是?),但這不是我想要的行爲。 –