2017-10-19 143 views
0

可以優雅地做到這一點嗎?寫入和讀取tfrecord文件中的SparseTensor

現在我唯一能想到的就是將SparseTensor的索引(tf.int64),值(tf.float32)和形狀(tf.int64)保存在3個獨立的功能中(前兩個是VarLenFeature最後一個是FixedLenFeature)。這看起來很麻煩。

任何意見是讚賞!

更新1

下面我的回答是不適合用於構建計算圖(B/C稀疏張量中的內容經由sess.run()中,如果調用花費了大量的時間將被提取)

mrry's answer的啓發,我想也許我們可以得到由tf.serialize_sparse生成的字節,以便我們以後可以使用tf.deserialize_many_sparse恢復SparseTensor。但是tf.serialize_sparse沒有在純python中實現(它調用外部函數SerializeSparse),這意味着我們仍然需要使用sess.run()來獲取字節。我怎樣才能得到一個純粹的Python版本SerializeSparse?謝謝。

回答

1

由於Tensorflow當前僅支持tfrecord中的3種類型:Float,Int64和Bytes,並且SparseTensor通常具有多於一種類型,所以我的解決方案是將SparseTensor轉換爲帶有Pickle的Bytes。

這裏是一個示例代碼:

import tensorflow as tf 
import pickle 
import numpy as np 
from scipy.sparse import csr_matrix 

#---------------------------------# 
# Write to a tfrecord file 

# create two sparse matrices (simulate the values from .eval() of SparseTensor) 
a = csr_matrix(np.arange(12).reshape((4,3))) 
b = csr_matrix(np.random.rand(20).reshape((5,4))) 

# convert them to pickle bytes 
p_a = pickle.dumps(a) 
p_b = pickle.dumps(b) 

# put the bytes in context_list and feature_list 
## save p_a in context_lists 
context_lists = tf.train.Features(feature={ 
    'context_a': tf.train.Feature(bytes_list=tf.train.BytesList(value=[p_a])) 
    }) 
## save p_b as a one element sequence in feature_lists 
p_b_features = [tf.train.Feature(bytes_list=tf.train.BytesList(value=[p_b]))] 
feature_lists = tf.train.FeatureLists(feature_list={ 
    'features_b': tf.train.FeatureList(feature=p_b_features) 
    }) 

# create the SequenceExample 
SeqEx = tf.train.SequenceExample(
    context = context_lists, 
    feature_lists = feature_lists 
    ) 
SeqEx_serialized = SeqEx.SerializeToString() 

# write to a tfrecord file 
tf_FWN = 'test_pickle1.tfrecord' 
tf_writer1 = tf.python_io.TFRecordWriter(tf_FWN) 
tf_writer1.write(SeqEx_serialized) 
tf_writer1.close() 

#---------------------------------# 
# Read from the tfrecord file 

# first, define the parse function 
def _parse_SE_test_pickle1(in_example_proto): 
    context_features = { 
     'context_a': tf.FixedLenFeature([], dtype=tf.string) 
     } 
    sequence_features = { 
     'features_b': tf.FixedLenSequenceFeature([1], dtype=tf.string) 
     } 
    context, sequence = tf.parse_single_sequence_example(
     in_example_proto, 
     context_features=context_features, 
     sequence_features=sequence_features 
    ) 
    p_a_tf = context['context_a'] 
    p_b_tf = sequence['features_b'] 

    return tf.tuple([p_a_tf, p_b_tf]) 

# use the Dataset API to read 
dataset = tf.data.TFRecordDataset(tf_FWN) 
dataset = dataset.map(_parse_SE_test_pickle1) 
dataset = dataset.batch(1) 
iterator = dataset.make_initializable_iterator() 
next_element = iterator.get_next() 

sess = tf.InteractiveSession() 
sess.run(tf.global_variables_initializer()) 
sess.run(iterator.initializer) 

[p_a_bat, p_b_bat] = sess.run(next_element) 

# 1st index refers to batch, 2nd and 3rd indices refers to the sequence position (only for b) 
rec_a = pickle.loads(p_a_bat[0]) 
rec_b = pickle.loads(p_b_bat[0][0][0]) 

# check whether the recovered the same as the original ones. 
assert((rec_a - a).nnz == 0) 
assert((rec_b - b).nnz == 0) 

# print the contents 
print("\n------ a -------") 
print(a.todense()) 
print("\n------ rec_a -------") 
print(rec_a.todense()) 
print("\n------ b -------") 
print(b.todense()) 
print("\n------ rec_b -------") 
print(rec_b.todense()) 

這裏是我的了:

------ a ------- 
[[ 0 1 2] 
[ 3 4 5] 
[ 6 7 8] 
[ 9 10 11]] 

------ rec_a ------- 
[[ 0 1 2] 
[ 3 4 5] 
[ 6 7 8] 
[ 9 10 11]] 

------ b ------- 
[[ 0.88612402 0.51438017 0.20077887 0.20969243] 
[ 0.41762425 0.47394715 0.35596051 0.96074408] 
[ 0.35491739 0.0761953 0.86217511 0.45796474] 
[ 0.81253723 0.57032448 0.94959189 0.10139615] 
[ 0.92177499 0.83519464 0.96679833 0.41397829]] 

------ rec_b ------- 
[[ 0.88612402 0.51438017 0.20077887 0.20969243] 
[ 0.41762425 0.47394715 0.35596051 0.96074408] 
[ 0.35491739 0.0761953 0.86217511 0.45796474] 
[ 0.81253723 0.57032448 0.94959189 0.10139615] 
[ 0.92177499 0.83519464 0.96679833 0.41397829]]