0
我想看看我在我的網絡中使用的圖片都是OK的,所以我使用下面的代碼保存的一羣人:torchvision MNIST加載程序無法正常工作,或者我做錯了什麼?
train_set = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=download)
for it, (img, target) in enumerate(train_loader):
X = Variable(img)
tar = Variable(target)
X = X.view(batch_size, -1)
cur_img_batch = X.data.numpy()
cur_tar_batch = tar.data.numpy()
for i in range(batch_size):
cur_img = cur_img_batch[i]
im = Image.fromarray(cur_img.reshape((28, 28)).astype('uint8') * 255)
if cur_tar_batch[i] == 8:
im.save(test_img_dir + 'iter_' + str(it) + '_sample_' + str(i) + '.png')
這不是最乾淨的代碼,但它只是節省了一堆所有標記爲「8」的圖像。打開它們後,我發現其中大部分看起來像this,儘管它們中的一小部分完全是fine。
我做錯了什麼?
此行'cur_img.reshape((28,28))。astype('uint8')* 255'您是否將數據轉換爲整數後再乘以255? –
當然!這是它 - 非常感謝:) –
正確的行應該是:im = Image.fromarray((cur_img.reshape((28,28))* 255).astype('uint8')) –