0
我正在使用TF.Learn estimator來進行預測。數據到fit method作爲一個返回特徵映射一個輸入函數傳遞 - 一個python字典映射功能名稱的張量從盤出列他們:如何過濾存儲在功能圖中的數據點?
def input_fn():
feature_columns = get_feature_columns()
features = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns=feature_columns)
feature_map = tf.contrib.learn.io.read_batch_features(
file_pattern=data_dir,
batch_size=BATCH_SIZE,
features=features)
target = feature_map.pop("target")
return feature_map, target
我想過濾基於一些數據謂詞P
,以便估計器將批量批次BATCH_SIZE中的點仍然批量化,但只有那些滿足P
的點才能批量化。我怎樣才能輕鬆實現?
(問題是類似於:How to filter tensor from queue based on some predicate in tensorflow?,但你過濾只有一個張量)
不幸的是,[read_batch_features]的輸出(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.read_batch_features.md# tfcontriblearnread_batch_featuresfile_pattern-batch_size-features-reader-randomize_inputtrue-num_epochsnone-queue_capacity10000-feature_queue_capacity100-reader_num_threads1-parser_num_threads1 -namenone-read_batch_features)是字典,而不是隊列。 我試着創建一個隊列並排隊'feature_map.values()',但得到了類型錯誤(float32 vs SparseTensor)。 – sygi
我也嘗試過評估feature_map並將其排入隊列,但仍然存在SparseTensors的問題 - 它們的內容是ndarray類型的,但我無法將其指定爲隊列的dtype。 – sygi
要排隊sparsetensor,需要序列化它。 –