2017-08-23 53 views
1

我想設計這樣的損失函數:Tensorflow:tf.argmax和切片

sum((y[argmax(y_)] - y_[argmax(y_)])²) 

我沒有找到一個方法來做到y[argmax(y_)]。我試過y[k],y[:,k]y[None,k]這些工作都沒有。這是我的代碼:

Na = 3 
    x = tf.placeholder(tf.float32, [None, 2]) 
    W = tf.Variable(tf.zeros([2, Na])) 
    b = tf.Variable(tf.zeros([Na])) 
    y = tf.nn.relu(tf.matmul(x, W) + b) 
    y_ = tf.placeholder(tf.float32, [None, 3]) 
    k = tf.argmax(y_, 1) 
    diff = y[k] - y_[k] 
    loss = tf.reduce_sum(tf.square(diff)) 

和錯誤:

File "/home/ncarrara/phd/code/cython/robotnavigation/ftq/cftq19.py", line 156, in <module> 
    diff = y[k] - y_[k] 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 499, in _SliceHelper 
    name=name) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 663, in strided_slice 
    shrink_axis_mask=shrink_axis_mask) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3515, in strided_slice 
    shrink_axis_mask=shrink_axis_mask, name=name) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op 
    op_def=op_def) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op 
    set_shapes_for_outputs(ret) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs 
    shapes = shape_func(op) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring 
    return call_cpp_shape_fn(op, require_shape_fn=True) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn 
    debug_python_shape_fn, require_shape_fn) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl 
    raise ValueError(err.message) 
ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [?,3], [1,?], [1,?], [1]. 

回答

0

這是可以做到使用tf.gather_nd

import tensorflow as tf 

Na = 3 
x = tf.placeholder(tf.float32, [None, 2]) 
W = tf.Variable(tf.zeros([2, Na])) 
b = tf.Variable(tf.zeros([Na])) 
y = tf.nn.relu(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 3]) 
k = tf.argmax(y_, 1) 
# Make index tensor with row and column indices 
num_examples = tf.cast(tf.shape(x)[0], dtype=k.dtype) 
idx = tf.stack([tf.range(num_examples), k], axis=-1) 
diff = tf.gather_nd(y, idx) - tf.gather_nd(y_, idx) 
loss = tf.reduce_sum(tf.square(diff)) 

說明:

在這種情況下,這個想法tf.gather_nd是製作一個矩陣(一個二維張量),其中每一行包含索引輸出中的行和列。例如,如果我有一個矩陣a含有:含有

| 1 2 | 
| 0 1 | 
| 2 2 | 
| 1 0 | 

接着的tf.gather_nd(a, i)結果將是載體(一維張量)::

| 1 2 3 | 
| 4 5 6 | 
| 7 8 9 | 

和含有基質i

| 6 | 
| 2 | 
| 9 | 
| 4 | 

在這種情況下,列索引由tf.argmaxk中給出;它會告訴你每一行,哪一列是最高值的列。現在你只需要將行索引與每一個這些。 k中的第一個元素是行0中最大值列的索引,第1行中的下一個元素是索引,依此類推。 num_examples只是xtf.range(num_examples)中的行數,然後給出從0到x減去1的行數(即所有的行索引序列)中的行向量。現在您只需要將ktf.stack所做的一樣,結果idx就是tf.gather_nd的參數。

+0

看起來不錯,但現在我不確定是否足以驗證您的答案,無論如何,謝謝! –

+0

@nicolascarrara我已經添加了一些解釋。 – jdehesa

+0

非常感謝,現在很清楚! –