2017-10-15 62 views
0

我想看看我在我的網絡中使用的圖片都是OK的,所以我使用下面的代碼保存的一羣人:torchvision MNIST加載程序無法正常工作,或者我做錯了什麼?

train_set = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=download) 

for it, (img, target) in enumerate(train_loader): 
    X = Variable(img) 
    tar = Variable(target) 
    X = X.view(batch_size, -1) 
    cur_img_batch = X.data.numpy() 
    cur_tar_batch = tar.data.numpy() 
    for i in range(batch_size): 
     cur_img = cur_img_batch[i] 
     im = Image.fromarray(cur_img.reshape((28, 28)).astype('uint8') * 255) 
     if cur_tar_batch[i] == 8: 
      im.save(test_img_dir + 'iter_' + str(it) + '_sample_' + str(i) + '.png') 

這不是最乾淨的代碼,但它只是節省了一堆所有標記爲「8」的圖像。打開它們後,我發現其中大部分看起來像this,儘管它們中的一小部分完全是fine

我做錯了什麼?

+0

此行'cur_img.reshape((28,28))。astype('uint8')* 255'您是否將數據轉換爲整數後再乘以255? –

+0

當然!這是它 - 非常感謝:) –

+0

正確的行應該是:im = Image.fromarray((cur_img.reshape((28,28))* 255).astype('uint8')) –

回答

0

從評論:

的問題是在此行cur_img.reshape((28, 28)).astype('uint8') * 255,由255相乘,從而導致圖像與0或255

更新的代碼之前的歸一化圖像轉換爲整數:

​​3210
相關問題