讓我們說我有一個numpy的陣列a
與形狀(10, 10, 4, 5, 3, 3)
,和指數的兩個列表,b
和c
,形狀(1000, 6)
和(1000, 5)
的分別代表指數和部分指數陣列。我想使用索引來訪問數組,分別產生形狀爲(1000,)
和(1000, 3)
的數組。索引多維numpy的數組索引
我知道一些方法來做到這一點,但它們都很笨重,而且非pythonic,例如將索引轉換爲元組或索引每個軸分開。
a = np.random.random((10, 10, 4, 5, 3, 3))
b = np.random.randint(3, size=(1000, 6))
c = np.random.randint(3, size=(1000, 5))
# method one
tuple_index_b = [tuple(row) for row in b]
tuple_index_c = [tuple(row) for row in c]
output_b = np.array([a[row] for row in tuple_index_b])
output_c = np.array([a[row] for row in tuple_index_c])
# method two
output_b = a[b[:, 0], b[:, 1], b[:, 2], b[:, 3], b[:, 4], b[:, 5]]
output_c = a[c[:, 0], c[:, 1], c[:, 2], c[:, 3], c[:, 4]]
顯然,這兩種方法都不是很優雅,或者很容易擴展到更高維度。第一個也很慢,有兩個列表解析,第二個需要你分別寫出每個軸。直觀的語法a[b]
由於某種原因返回形狀(1000, 6, 10, 4, 5, 3, 3)
的數組,可能與廣播有關。
那麼,有沒有辦法在Numpy中做到這一點,不涉及這麼多的手工勞動/時間?
編輯:不是重複的,因爲這個問題涉及多維索引列表,而不僅僅是單個索引,並且已經產生了一些有用的新方法。