我有一個數據集(71094火車圖像和17000測試),我需要訓練一個CNN。在預處理過程中,我嘗試使用numpy創建一個矩陣,結果是荒謬的大(火車數據爲71094 * 100 * 100 * 3)[所有圖像都是RGB 100乘100]。因此,我收到了一個內存錯誤。我該如何解決這個問題。請幫忙。 這是我的代碼..預處理CNN的numpy圖像數據集:內存錯誤
import numpy as np
import cv2
from matplotlib import pyplot as plt
data_dir = './fashion-data/images/'
train_data = './fashion-data/train.txt'
test_data = './fashion-data/test.txt'
f = open(train_data, 'r').read()
ims = f.split('\n')
print len(ims)
train = np.zeros((71094, 100, 100, 3)) #this line causes the error..
for ix in range(train.shape[0]):
i = cv2.imread(data_dir + ims[ix] + '.jpg')
label = ims[ix].split('/')[0]
train[ix, :, :, :] = cv2.resize(i, (100, 100))
print train[0]
train_labels = np.zeros((71094, 1))
for ix in range(train_labels.shape[0]):
l = ims[ix].split('/')[0]
train_labels[ix] = int(l)
print train_labels[0]
np.save('./data/train', train)
np.save('./data/train_labels', train_labels)