2012-12-06 79 views
8

默認情況下,酸洗numpy視圖數組會失去視圖關係,即使數組基底也被酸洗。我的情況是,我有一些醃製的複雜容器對象。在某些情況下,一些包含的數據是其他一些數據。保存每個視圖的獨立數組不僅會造成空間損失,而且重新加載的數據也會丟失視圖關係。酸洗時保留numpy視圖

一個簡單的例子是(但在我的情況下,容器比一本字典更復雜):

import numpy as np 
import cPickle 

tmp = np.zeros(2) 
d1 = dict(a=tmp,b=tmp[:]) # d1 to be saved: b is a view on a 

pickled = cPickle.dumps(d1) 
d2 = cPickle.loads(pickled) # d2 reloaded copy of d1 container 

print 'd1 before:', d1 
d1['b'][:] = 1 
print 'd1 after: ', d1 

print 'd2 before:', d2 
d2['b'][:] = 1 
print 'd2 after: ', d2 

這將打印:

d1 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d1 after: {'a': array([ 1., 1.]), 'b': array([ 1., 1.])} 
d2 before: {'a': array([ 0., 0.]), 'b': array([ 0., 0.])} 
d2 after: {'a': array([ 0., 0.]), 'b': array([ 1., 1.])} # not a view anymore 

我的問題:

( 1)有沒有辦法保存它? (2)(更好),有沒有辦法做到這一點,只有在基本的醃

對於(1)我認爲有可能是通過改變,等...的__setstate____reduce_ex_的一些方法查看數組。但是我現在還沒有充滿信心。對於(2)我不知道。

回答

7

這不是在NumPy中正確完成的,因爲pickle基本數組並不總是有意義的,而且pickle沒有公開檢查另一個對象是否也作爲API的一部分被pickle的能力。

但是這種檢查可以在NumPy數組的自定義容器中完成。例如:

import numpy as np 
import pickle 

def byte_offset(array, source): 
    return array.__array_interface__['data'][0] - np.byte_bounds(source)[0] 

class SharedPickleList(object): 
    def __init__(self, arrays): 
     self.arrays = list(arrays) 

    def __getstate__(self): 
     unique_ids = {id(array) for array in self.arrays} 
     source_arrays = {} 
     view_tuples = {} 
     for array in self.arrays: 
      if array.base is None or id(array.base) not in unique_ids: 
       # only use views if the base is also being pickled 
       source_arrays[id(array)] = array 
      else: 
       view_tuples[id(array)] = (array.shape, 
              array.dtype, 
              id(array.base), 
              byte_offset(array, array.base), 
              array.strides) 
     order = [id(array) for array in self.arrays] 
     return (source_arrays, view_tuples, order) 

    def __setstate__(self, state): 
     source_arrays, view_tuples, order = state 
     view_arrays = {} 
     for k, view_state in view_tuples.items(): 
      (shape, dtype, source_id, offset, strides) = view_state 
      buffer = source_arrays[source_id].data 
      array = np.ndarray(shape, dtype, buffer, offset, strides) 
      view_arrays[k] = array 
     self.arrays = [source_arrays[i] 
         if i in source_arrays 
         else view_arrays[i] 
         for i in order] 

# unit tests 
def check_roundtrip(arrays): 
    unpickled_arrays = pickle.loads(pickle.dumps(
     SharedPickleList(arrays))).arrays 
    assert all(a.shape == b.shape and (a == b).all() 
       for a, b in zip(arrays, unpickled_arrays)) 

indexers = [0, None, slice(None), slice(2), slice(None, -1), 
      slice(None, None, -1), slice(None, 6, 2)] 

source0 = np.random.randint(100, size=10) 
arrays0 = [np.asarray(source0[k1]) for k1 in indexers] 
check_roundtrip([source0] + arrays0) 

source1 = np.random.randint(100, size=(8, 10)) 
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers] 
check_roundtrip([source1] + arrays1) 

這導致顯著節省空間:

source = np.random.rand(1000) 
arrays = [source] + [source[n:] for n in range(99)] 
print(len(pickle.dumps(arrays, protocol=-1))) 
# 766372 
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1))) 
# 11833