2015-01-16 76 views
0

有許多用於拼合嵌套列表的配方。我將複製一個解決方案在這裏僅供參考:拼合並解開numpy數組的嵌套列表

def flatten(x): 
    result = [] 
    for el in x: 
     if hasattr(el, "__iter__") and not isinstance(el, basestring): 
     result.extend(flatten(el)) 
     else: 
     result.append(el) 
    return result 

我感興趣的是反操作,它將列表重建爲其原始格式。例如:

L = [[array([[ 24, -134],[ -67, -207]])], 
    [array([[ 204, -45],[ 99, -118]])], 
    [array([[ 43, -154],[-122, 168]]), array([[ 33, -110],[ 147, -26],[ -49, -122]])]] 

# flattened version 

L_flat = [24, -134, -67, -207, 204, -45, 99, -118, 43, -154, -122, 168, 33, -110, 147, -26, -49, -122] 

是否有一種有效的扁平化方法,可以節省指標並重建爲原始格式?

請注意,該列表可以是任意深度,並且可能不具有規則形狀,並且將包含不同維度的數組。

當然,flattening函數也應該改變,以存儲列表的結構和numpy數組的形狀。

+1

你應該從扁平版本中知道它最初的樣子?你在扁平化過程中失去了信息。 – jonrsharpe

+0

當然,應該更改展平功能以存儲列表的結構。 – memecs

+0

某種程度上,你已經回答了你自己的問題;你需要修改'flatten'來提供關於列表結構和其中數組形狀的保留信息。例如,它可以在平展的「L」旁邊返回'[[(2,2)],[(2,2)],[(2,2),(3,2)]]''。然後,您將不得不相應地切分「L_flat」並對每個切片的陣列進行「重塑」。 – jonrsharpe

回答

1

您正在構建一個悖論:您想要展平該對象,但不想展平該對象,並在對象的某個位置保留其結構信息。

所以Python的方式做,這是扁平化的對象,但寫一個類,將有一個__iter__,讓您順序(即以平坦的方式。)經過底層對象的元素。這將與轉換爲單位事物一樣快(如果每個元素只應用一次),並且不會複製或更改原始非平坦容器。

0

這是我想出來的,結果比迭代嵌套列表和單獨加載要快30倍。

def flatten(nl): 
    l1 = [len(s) for s in itertools.chain.from_iterable(nl)] 
    l2 = [len(s) for s in nl] 

    nl = list(itertools.chain.from_iterable(
     itertools.chain.from_iterable(nl))) 

    return nl,l1,l2 

def reconstruct(nl,l1,l2): 
    return np.split(np.split(nl,np.cumsum(l1)),np.cumsum(l2))[:-1] 

L_flat,l1,l2 = flatten(L) 
L_reconstructed = reconstruct(L_flat,l1,l2) 

更好的解決方案可以對任意數量的嵌套級別進行迭代工作。

+0

爲什麼選擇倒票? – memecs

2

我一直在尋找一個解決方案,以扁平化和numpy的陣列unflatten嵌套列表,但只發現這個沒有答案的問題,所以我想出了這個:

def _flatten(values): 
    if isinstance(values, np.ndarray): 
     yield values.flatten() 
    else: 
     for value in values: 
      yield from _flatten(value) 

def flatten(values): 
    # flatten nested lists of np.ndarray to np.ndarray 
    return np.concatenate(list(_flatten(values))) 

def _unflatten(flat_values, prototype, offset): 
    if isinstance(prototype, np.ndarray): 
     shape = prototype.shape 
     new_offset = offset + np.product(shape) 
     value = flat_values[offset:new_offset].reshape(shape) 
     return value, new_offset 
    else: 
     result = [] 
     for value in prototype: 
      value, offset = _unflatten(flat_values, value, offset) 
      result.append(value) 
     return result, offset 

def unflatten(flat_values, prototype): 
    # unflatten np.ndarray to nested lists with structure of prototype 
    result, offset = _unflatten(flat_values, prototype, 0) 
    assert(offset == len(flat_values)) 
    return result 

例子:

a = [ 
    np.random.rand(1), 
    [ 
     np.random.rand(2, 1), 
     np.random.rand(1, 2, 1), 
    ], 
    [[]], 
] 

b = flatten(a) 

# 'c' will have values of 'b' and structure of 'a' 
c = unflatten(b, a) 

輸出:

a: 
[array([ 0.26453544]), [array([[ 0.88273824], 
     [ 0.63458643]]), array([[[ 0.84252894], 
     [ 0.91414218]]])], [[]]] 
b: 
[ 0.26453544 0.88273824 0.63458643 0.84252894 0.91414218] 
c: 
[array([ 0.26453544]), [array([[ 0.88273824], 
     [ 0.63458643]]), array([[[ 0.84252894], 
     [ 0.91414218]]])], [[]]] 

許可:WTFPL