回答這個問題的標題:
的HDF5文件應該有兩個數據集根,名爲「數據」和「標籤」,分別。形狀是(data amount
,dimension
)。我只使用一維數據,所以我不確定channel
,width
和height
的順序是什麼。也許沒關係。 dtype
應該是浮動或雙。
樣品代碼與h5py
創建列車組是:
import h5py, os
import numpy as np
f = h5py.File('train.h5', 'w')
# 1200 data, each is a 128-dim vector
f.create_dataset('data', (1200, 128), dtype='f8')
# Data's labels, each is a 4-dim vector
f.create_dataset('label', (1200, 4), dtype='f4')
# Fill in something with fixed pattern
# Regularize values to between 0 and 1, or SigmoidCrossEntropyLoss will not work
for i in range(1200):
a = np.empty(128)
if i % 4 == 0:
for j in range(128):
a[j] = j/128.0;
l = [1,0,0,0]
elif i % 4 == 1:
for j in range(128):
a[j] = (128 - j)/128.0;
l = [1,0,1,0]
elif i % 4 == 2:
for j in range(128):
a[j] = (j % 6)/128.0;
l = [0,1,1,0]
elif i % 4 == 3:
for j in range(128):
a[j] = (j % 4) * 4/128.0;
l = [1,0,1,1]
f['data'][i] = a
f['label'][i] = l
f.close()
另外,不需要精度層,簡單地移除它是好的。下一個問題是損失層。由於SoftmaxWithLoss
只有一個輸出(具有最大值的維度的索引),因此它不能用於多標籤問題。感謝Adian和Shai,我發現SigmoidCrossEntropyLoss
在這種情況下是很好的。
下面是完整的代碼,從數據創建,培訓網絡,並得到測試結果:
main.py (modified from caffe lanet example)
import os, sys
PROJECT_HOME = '.../project/'
CAFFE_HOME = '.../caffe/'
os.chdir(PROJECT_HOME)
sys.path.insert(0, CAFFE_HOME + 'caffe/python')
import caffe, h5py
from pylab import *
from caffe import layers as L
def net(hdf5, batch_size):
n = caffe.NetSpec()
n.data, n.label = L.HDF5Data(batch_size=batch_size, source=hdf5, ntop=2)
n.ip1 = L.InnerProduct(n.data, num_output=50, weight_filler=dict(type='xavier'))
n.relu1 = L.ReLU(n.ip1, in_place=True)
n.ip2 = L.InnerProduct(n.relu1, num_output=50, weight_filler=dict(type='xavier'))
n.relu2 = L.ReLU(n.ip2, in_place=True)
n.ip3 = L.InnerProduct(n.relu2, num_output=4, weight_filler=dict(type='xavier'))
n.loss = L.SigmoidCrossEntropyLoss(n.ip3, n.label)
return n.to_proto()
with open(PROJECT_HOME + 'auto_train.prototxt', 'w') as f:
f.write(str(net(PROJECT_HOME + 'train.h5list', 50)))
with open(PROJECT_HOME + 'auto_test.prototxt', 'w') as f:
f.write(str(net(PROJECT_HOME + 'test.h5list', 20)))
caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.SGDSolver(PROJECT_HOME + 'auto_solver.prototxt')
solver.net.forward()
solver.test_nets[0].forward()
solver.step(1)
niter = 200
test_interval = 10
train_loss = zeros(niter)
test_acc = zeros(int(np.ceil(niter * 1.0/test_interval)))
print len(test_acc)
output = zeros((niter, 8, 4))
# The main solver loop
for it in range(niter):
solver.step(1) # SGD by Caffe
train_loss[it] = solver.net.blobs['loss'].data
solver.test_nets[0].forward(start='data')
output[it] = solver.test_nets[0].blobs['ip3'].data[:8]
if it % test_interval == 0:
print 'Iteration', it, 'testing...'
correct = 0
data = solver.test_nets[0].blobs['ip3'].data
label = solver.test_nets[0].blobs['label'].data
for test_it in range(100):
solver.test_nets[0].forward()
# Positive values map to label 1, while negative values map to label 0
for i in range(len(data)):
for j in range(len(data[i])):
if data[i][j] > 0 and label[i][j] == 1:
correct += 1
elif data[i][j] %lt;= 0 and label[i][j] == 0:
correct += 1
test_acc[int(it/test_interval)] = correct * 1.0/(len(data) * len(data[0]) * 100)
# Train and test done, outputing convege graph
_, ax1 = subplots()
ax2 = ax1.twinx()
ax1.plot(arange(niter), train_loss)
ax2.plot(test_interval * arange(len(test_acc)), test_acc, 'r')
ax1.set_xlabel('iteration')
ax1.set_ylabel('train loss')
ax2.set_ylabel('test accuracy')
_.savefig('converge.png')
# Check the result of last batch
print solver.test_nets[0].blobs['ip3'].data
print solver.test_nets[0].blobs['label'].data
h5list文件只包含在每一行的H5文件的路徑:
train.h5list
/home/foo/bar/project/train.h5
test.h5list
/home/foo/bar/project/test.h5
和求解:
auto_solver.prototxt
train_net: "auto_train.prototxt"
test_net: "auto_test.prototxt"
test_iter: 10
test_interval: 20
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
lr_policy: "inv"
gamma: 0.0001
power: 0.75
display: 100
max_iter: 10000
snapshot: 5000
snapshot_prefix: "sed"
solver_mode: GPU
收斂圖表: ![Converge graph](https://i.stack.imgur.com/i0pHF.png)
最後一批結果:
[[ 35.91593933 -37.46276474 -6.2579031 -6.30313492]
[ 42.69248581 -43.00864792 13.19664764 -3.35134125]
[ -1.36403108 1.38531208 2.77786589 -0.34310576]
[ 2.91686511 -2.88944006 4.34043217 0.32656598]
...
[ 35.91593933 -37.46276474 -6.2579031 -6.30313492]
[ 42.69248581 -43.00864792 13.19664764 -3.35134125]
[ -1.36403108 1.38531208 2.77786589 -0.34310576]
[ 2.91686511 -2.88944006 4.34043217 0.32656598]]
[[ 1. 0. 0. 0.]
[ 1. 0. 1. 0.]
[ 0. 1. 1. 0.]
[ 1. 0. 1. 1.]
...
[ 1. 0. 0. 0.]
[ 1. 0. 1. 0.]
[ 0. 1. 1. 0.]
[ 1. 0. 1. 1.]]
我覺得這個代碼仍然有很多事情,以改善。任何建議表示讚賞。
不應該''數據類型'f4'以及? – Shai
更改爲f4不會更改錯誤。 –
可能是一個有價值的資源:http://stackoverflow.com/questions/33112941/multiple-category-classification-in-caffe –