2017-06-13 62 views
0

我期待使用1.2中可用的new Dataset API,但在應用簡單的map轉換時會遇到問題,該轉換會在index table中查找單詞。新的dataset.map轉換和查找表:不兼容的字符串類型

考慮一個簡單的例子:

import tensorflow as tf 

mapping_strings = tf.constant(["emerson", "lake", "palmer"]) 
table = tf.contrib.lookup.index_table_from_tensor(
    mapping=mapping_strings, num_oov_buckets=1) 

dataset = tf.contrib.data.Dataset.from_tensor_slices(
    tf.constant(["emerson", "lake"])) 

# Here is the map operation that generates an error. 
dataset = dataset.map(lambda x: table.lookup(x)) 

iterator = dataset.make_one_shot_iterator() 
next_element = iterator.get_next() 

with tf.Session() as sess: 
    sess.run(tf.tables_initializer()) 
    sess.run(next_element) 

隨着1.2.0-rc2,它會生成以下錯誤:

TypeError: In op 'string_to_index_Lookup/hash_table_Lookup', input types ([tf.string, tf.string, tf.int64]) are not compatible with expected types ([tf.string_ref, tf.string, tf.int64]) 

查找表需要一個tf.string_ref這個規定並沒有得到滿足。

由於我是TensorFlow的新手,我不認爲這是一個錯誤,而是一個糟糕的用法。我的錯誤是什麼?

謝謝!

編輯2017年6月15日:隨着nightly版本,但是,它拋出另一個錯誤:

ValueError: Cannot capture a stateful node (name:string_to_index/hash_table, type:HashTableV2) by value. 

回答

3

您可能需要使用Dataset.make_initializable_iterator(),而不是Dataset.make_one_shot_iterator()因爲哈希表狀態。

下面的代碼爲我工作:

import tensorflow as tf 

mapping_strings = tf.constant(["emerson", "lake", "palmer"]) 
table = tf.contrib.lookup.index_table_from_tensor(
    mapping=mapping_strings, num_oov_buckets=1) 

dataset = tf.contrib.data.Dataset.from_tensor_slices(
    tf.constant(["emerson", "lake"])) 

# Here is the map operation that generates an error. 
dataset = dataset.map(lambda x: table.lookup(x)) 

iterator = dataset.make_initializable_iterator() 
init_op = iterator.initializer 

with tf.Session() as sess: 
    sess.run(tf.tables_initializer()) 
    sess.run(init_op) 
+0

我相信當你升級到TensorFlow的最新(夜間)版本的工作原理長。在1.2.0的最新版本候選版本中存在一個bug,其中一些廢棄的內核(使用舊式'tf.string_ref'而不是'tf.resource')繼續在'tf.contrib.lookup.index_table_from_tensor )',現在已經修復了這些問題,以便在主分支中使用新版本。 – mrry

+0

謝謝你們兩位。所以解決方法是升級TensorFlow並使用迭代器初始值設定項。我更新了我的問題,以反映用TensorFlow版本得到的錯誤,而沒有提到錯誤@mrry,我接受了Satoshi的答案。 – guillaumekln

+0

tf.string_ref錯誤仍然存​​在於1.2.0 final中。 – guillaumekln

相關問題