我正在實現CNN體系結構(FCN-8s模型,具有預訓練VGG16模型),用於在我自己的數據上進行語義分割(2個類,因此,每個像素的二進制分類)Tensorflow - 傳輸學習實現(語義分割)
我怎麼打算去這就是:
- 負載配重塊的預訓練模型
- 添加/刪除其他高層轉換爲FCN
- 凍結的下層預先訓練的模型(不要你在訓練階段PDATE)
- 列車上的特定數據集
網絡假設這是正確的,我怎麼去凍結我tensorflow模型較低層? (我正在尋找具體的實現細節)我看了一下關於TensorFlow教程的Inception再訓練,但我還不確定。
這是工作流程我心目中:
通過現有的預訓練模式運行我的數據,並提取特徵輸出,沒有訓練它。 (如何?)
將這些功能輸出饋送到包含更高層的另一個網絡 - 並開始訓練它。
任何建議將是有幫助的!
否則,如果我錯了,我應該怎麼想這個?
UPDATE:
我拿起chasep255的建議之下,並試圖用tf.stop_gradient,以便在我的模型‘凍結’的較低層。很顯然,我的實施有些問題。可能的選擇/建議?
該模型基於FCN(用於語義分割)論文構建。我從模型體系結構(即我的特徵)中提取logits
,我最初直接將其輸入loss
函數,以使用softmax分類器將其最小化。 (每像素分類)deconv_1
是我logits張量,形狀的[batch, h, w, num_classes] = [1, 750, 750, 2]
實現:
logits = vgg_fcn.deconv_1
stopper = tf.stop_gradient(logits, 'stop_gradients')
loss = train_func.loss(stopper, labels_placeholder, 2)
with tf.name_scope('Optimizer'):
train_op = train_func.training(loss, FLAGS.learning_rate)
with tf.name_scope('Accuracy'):
eval_correct = train_func.accuracy_eval(logits, labels_placeholder)
accuracy_summary = tf.scalar_summary('Accuracy', eval_correct)
我然後如下運行這些圖形操作:
_, acc, loss_value = sess.run([train_op,eval_correct, loss], feed_dict=feed_dict)
當我從而運行該訓練週期,有沒有優化損失值,絕對是因爲我已經介紹了tf.stop_gradient
Op。
有關詳細信息,下面我損失函數:
def loss(logits, labels, num_classes):
logits = tf.reshape(logits, [-1, num_classes])
#epsilon = tf.constant(value=1e-4)
#logits = logits + epsilon
labels = tf.to_int64(tf.reshape(labels, [-1]))
print ('shape of logits: %s' % str(logits.get_shape()))
print ('shape of labels: %s' % str(labels.get_shape()))
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='Cross_Entropy')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='xentropy_mean')
tf.add_to_collection('losses', cross_entropy_mean)
loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
return loss
嗨murushiv,關於您的代碼logits = vgg_fcn.deconv_1,您是否使用MarvinTeichmann的fcn實現?我讀他的代碼,並沒有找到deconv_1。你想分享更多信息嗎? – user288609
@ user288609它是一個稍微修改過的版本。 deconv_1與其中一種比較方法相同。 (或模塊?)但是我意識到這裏有一個錯誤,因爲我應該在高分之前進行攻擊,使用它作爲logits並訓練上採樣層,如果這樣做合理的話。 – mshiv
嗨murushiv,在他們的fcn實現中,在「self.upscore2」之前有「pred」層,你的意思是直接在損失函數中使用logits(pred)嗎?順便說一下,你說有錯誤。你能詳細解釋一下嗎?我試圖瞭解他們的實施。感謝您的幫助。準確地說是 – user288609