2016-09-06 58 views
0

我正在使用TF.Learn estimator來進行預測。數據到fit method作爲一個返回特徵映射一個輸入函數傳遞 - 一個python字典映射功能名稱的張量從盤出列他們:如何過濾存儲在功能圖中的數據點?

def input_fn(): 
    feature_columns = get_feature_columns() 
    features = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns=feature_columns) 
    feature_map = tf.contrib.learn.io.read_batch_features(
     file_pattern=data_dir, 
     batch_size=BATCH_SIZE, 
     features=features) 
    target = feature_map.pop("target") 
    return feature_map, target 

我想過濾基於一些數據謂詞P,以便估計器將批量批次BATCH_SIZE中的點仍然批量化,但只有那些滿足P的點才能批量化。我怎樣才能輕鬆實現?

(問題是類似於:How to filter tensor from queue based on some predicate in tensorflow?,但你過濾只有一個張量)

回答

1

使用過濾隊列和具有從read_batch_features的結果取出一個單個元素的queuerunner和有條件地排入它的濾波基於你的謂詞的隊列應該工作。

+0

不幸的是,[read_batch_features]的輸出(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.read_batch_features.md# tfcontriblearnread_batch_featuresfile_pattern-batch_size-features-reader-randomize_inputtrue-num_epochsnone-queue_capacity10000-feature_queue_capacity100-reader_num_threads1-parser_num_threads1 -namenone-read_batch_features)是字典,而不是隊列。 我試着創建一個隊列並排隊'feature_map.values()',但得到了類型錯誤(float32 vs SparseTensor)。 – sygi

+0

我也嘗試過評估feature_map並將其排入隊列,但仍然存在SparseTensors的問題 - 它們的內容是ndarray類型的,但我無法將其指定爲隊列的dtype。 – sygi

+0

要排隊sparsetensor,需要序列化它。 –

相關問題