2017-10-19 80 views
0

我修改了&深層教程的代碼,使用tf.contrib.learn.read_batch_examples從文件中讀取大量輸入。爲了加快訓練過程,我設置了read_batch_size並且出現錯誤ValueError:所有形狀必須完全定義:[TensorShape([]),TensorShape([Dimension(None)])] 我的代碼段:在tf.contrib.learn.read_batch_examples中設置read_batch_size時出現錯誤。默認是好的

def input_fn_pre(batch_size, filename): 
    examples_op = tf.contrib.learn.read_batch_examples(
    filename, 
    batch_size=5000, 
    reader=tf.TextLineReader, 
    num_epochs=5, 
    num_threads=5, 
    read_batch_size=2500, 
    parse_fn=lambda x: tf.decode_csv(x, [tf.constant(['0'], dtype=tf.string)] * len(COLUMNS) * 2500, use_quote_delim=False))         
    examples_dict = {} 

    for i, col in enumerate(COLUMNS): 
    examples_dict[col] = examples_op[:, i] 
    feature_cols = {k: tf.string_to_number(examples_dict[k], out_type=tf.float32) for k in CONTINUOUS_COLUMNS} 
    feature_cols.update({k: dense_to_sparse(examples_dict[k]) for k in CATEGORICAL_COLUMNS}) 
    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32) 
    return feature_cols, label 

而使用默認參數設置就可以了:

def input_fn_pre(batch_size, filename): 
    examples_op = tf.contrib.learn.read_batch_examples(
    filename, 
    batch_size=5000, 
    reader=tf.TextLineReader, 
    num_epochs=5, 
    num_threads=5, 
    parse_fn=lambda x: tf.decode_csv(x, [tf.constant(['0'], dtype=tf.string)] * len(COLUMNS), use_quote_delim=False))         
    examples_dict = {} 

    for i, col in enumerate(COLUMNS): 
    examples_dict[col] = examples_op[:, i] 
    feature_cols = {k: tf.string_to_number(examples_dict[k], out_type=tf.float32) for k in CONTINUOUS_COLUMNS} 
    feature_cols.update({k: dense_to_sparse(examples_dict[k]) for k in CATEGORICAL_COLUMNS}) 
    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32) 
    return feature_cols, label 

沒有在tensorflow DOC足夠的解釋。

回答

0

我沒有看到你的兩個代碼片段之間的任何區別。你能更新你的代碼嗎?

+0

對不起,我更新了我的代碼。 –

+0

通過閱讀代碼,錯誤並不明顯。你能發佈你最小的完整(可運行)代碼嗎? – Mingxing

相關問題