事實證明,這是棘手近鄰因爲TF沒有Numpy切片普遍性(github issue #206),並且gather
僅適用於第一維。但這裏有一個方法,通過使用gather-> transpose-> gather->提取對角線來解決它
def identity_matrix(n):
"""Returns nxn identity matrix."""
# note, if n is a constant node, this assert node won't be executed,
# this error will be caught during shape analysis
assert_op = tf.Assert(tf.greater(n, 0), ["Matrix size must be positive"])
with tf.control_dependencies([assert_op]):
ones = tf.fill(n, 1)
diag = tf.diag(ones)
return diag
def extract_diagonal(tensor):
"""Extract diagonal of a square matrix."""
shape = tf.shape(tensor)
n = shape[0]
assert_op = tf.Assert(tf.equal(shape[0], shape[1]), ["Can't get diagonal of "
"a non-square matrix"])
with tf.control_dependencies([assert_op]):
return tf.reduce_sum(tf.mul(tensor, identity_matrix(n)), [0])
# create sample matrix
size=4
I0=np.zeros((size,size), dtype=np.int32)
for i in range(size):
for j in range(size):
I0[i, j] = 10*i+j
I = tf.placeholder(dtype=np.int32, shape=(size,size))
C = tf.placeholder(np.int32, shape=[None, 2])
C0 = np.array([[0, 1], [1, 2], [2, 3]])
row_indices = C[:, 0]
col_indices = C[:, 1]
# since gather only supports dim0, have to transpose
I1 = tf.gather(I, row_indices)
I2 = tf.gather(tf.transpose(I1), col_indices)
I3 = extract_diagonal(tf.transpose(I2))
sess = create_session()
print sess.run([I3], feed_dict={I:I0, C:C0})
所以用這樣的矩陣開始:
array([[ 0, 1, 2, 3],
[10, 11, 12, 13],
[20, 21, 22, 23],
[30, 31, 32, 33]], dtype=int32)
此代碼提取一個對角線上方主要
[array([ 1, 12, 23], dtype=int32)]
有一些魔術[]運營商發生越來越變成Squeeze
和Slice
有效率的方式使用tf.gather做近鄰插值...不知道關於線性插值 –
有'tf.image.resize_bilinear'。這不是你想要的嗎? – Albert
@Albert號這隻會做我想要的,如果我想在網格上的所有點上採樣源圖像。但是C的行可以是源圖像中的任何座標。 – CliffordVienna