2016-10-15 22 views
1
image_size = 28 
num_labels = 10 

def reformat(dataset, labels): 
    dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32) 
    # Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...] 
    labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32) 
    return dataset, labels 
train_dataset, train_labels = reformat(train_dataset, train_labels) 
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels) 
test_dataset, test_labels = reformat(test_dataset, test_labels) 
print('Training set', train_dataset.shape, train_labels.shape) 
print('Validation set', valid_dataset.shape, valid_labels.shape) 
print('Test set', test_dataset.shape, test_labels.shape) 

這條線是什麼意思?做numpy覆蓋==運算符,因爲我不明白遵循python代碼

labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32) 

代碼是從https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/2_fullyconnected.ipynb

回答

3

在numpy的,所述==操作者比較兩個numpy的陣列時(如在該行註釋被完成),因此是的,它在這個意義上說是重載意味着什麼不同。它將兩個numpy數組按元素進行比較,並返回一個與兩個輸入相同大小的布爾numpy數組。其他比較也是如此,例如>=,<等。

例如,

import numpy as np 
print(np.array([5,8,2]) == np.array([5,3,2])) 
# [True False True] 
print((np.array([5,8,2]) == np.array([5,3,2])).astype(np.float32)) 
# [1. 0. 1.] 
1

對於numpy的陣列==操作者是逐元素操作返回boolean陣列。如註釋中所述,astype函數將布爾值True轉換爲1.0False0.0

0

https://docs.python.org/3/reference/expressions.html#value-comparisons描述了值比較,如==。雖然默認比較是identityx is y,但它首先檢查兩個參數是否實施__eq__方法。數字,列表和字典實現他們自己的版本。 numpy也是如此。

關於numpy__eq__的獨特之處在於,它在可能的情況下逐元素地比較,並返回相同大小的布爾數組。

In [426]: [1,2,3]==[1,2,3] 
Out[426]: True 
In [427]: z1=np.array([1,2,3]); z2=np.array([1,2,3]) 
In [428]: z1==z2 
Out[428]: array([ True, True, True], dtype=bool) 
In [432]: z1=np.array([1,2,3]); z2=np.array([1,2,4]) 
In [433]: z1==z2 
Out[433]: array([ True, True, False], dtype=bool) 
In [434]: (z1==z2).astype(float)  # change bool to float 
Out[434]: array([ 1., 1., 0.]) 

一個常見的SO問題是'爲什麼我得到這個ValueError?'

In [435]: if z1==z2: print('yes') 
... 
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 

這是因爲比較產生的數組有多個True/False值。

花車比較也是一個常見問題。檢查出iscloseallclose它出現問題。