2017-07-27 89 views
2

我已經運行腳本5個多小時了。我有258個CSV文件,我想要轉換爲TF記錄。我寫了下面的腳本,正如我已經說過,我一直在運行它5個多小時已經:將CSV文件轉換爲TF記錄

import argparse 
import os 
import sys 
import standardize_data 
import tensorflow as tf 

FLAGS = None 
PATH = '/home/darth/GitHub Projects/gru_svm/dataset/train' 

def _int64_feature(value): 
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 

def _float_feature(value): 
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 

def convert_to(dataset, name): 
    """Converts a dataset to tfrecords""" 

    filename_queue = tf.train.string_input_producer(dataset) 

    # TF reader 
    reader = tf.TextLineReader() 

    # default values, in case of empty columns 
    record_defaults = [[0.0] for x in range(24)] 

    key, value = reader.read(filename_queue) 

    duration, service, src_bytes, dest_bytes, count, same_srv_rate, \ 
    serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count, \ 
    dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate, \ 
    flag, ids_detection, malware_detection, ashula_detection, label, src_ip_add, \ 
    src_port_num, dst_ip_add, dst_port_num, start_time, protocol = \ 
    tf.decode_csv(value, record_defaults=record_defaults) 

    features = tf.stack([duration, service, src_bytes, dest_bytes, count, same_srv_rate, 
         serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count, 
         dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate, 
         flag, ids_detection, malware_detection, ashula_detection, src_ip_add, 
         src_port_num, dst_ip_add, dst_port_num, start_time, protocol]) 

    filename = os.path.join(FLAGS.directory, name + '.tfrecords') 
    print('Writing {}'.format(filename)) 
    writer = tf.python_io.TFRecordWriter(filename) 
    with tf.Session() as sess: 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     try: 
      while not coord.should_stop(): 
       example, l = sess.run([features, label]) 
       print('Writing {dataset} : {example}, {label}'.format(dataset=sess.run(key), 
         example=example, label=l)) 
       example_to_write = tf.train.Example(features=tf.train.Features(feature={ 
        'duration' : _float_feature(example[0]), 
        'service' : _int64_feature(int(example[1])), 
        'src_bytes' : _float_feature(example[2]), 
        'dest_bytes' : _float_feature(example[3]), 
        'count' : _float_feature(example[4]), 
        'same_srv_rate' : _float_feature(example[5]), 
        'serror_rate' : _float_feature(example[6]), 
        'srv_serror_rate' : _float_feature(example[7]), 
        'dst_host_count' : _float_feature(example[8]), 
        'dst_host_srv_count' : _float_feature(example[9]), 
        'dst_host_same_src_port_rate' : _float_feature(example[10]), 
        'dst_host_serror_rate' : _float_feature(example[11]), 
        'dst_host_srv_serror_rate' : _float_feature(example[12]), 
        'flag' : _int64_feature(int(example[13])), 
        'ids_detection' : _int64_feature(int(example[14])), 
        'malware_detection' : _int64_feature(int(example[15])), 
        'ashula_detection' : _int64_feature(int(example[16])), 
        'label' : _int64_feature(int(l)), 
        'src_ip_add' : _float_feature(example[17]), 
        'src_port_num' : _float_feature(example[18]), 
        'dst_ip_add' : _float_feature(example[19]), 
        'dst_port_num' : _float_feature(example[20]), 
        'start_time' : _float_feature(example[21]), 
        'protocol' : _int64_feature(int(example[22])), 
        })) 
       writer.write(example_to_write.SerializeToString()) 
      writer.close() 
     except tf.errors.OutOfRangeError: 
      print('Done converting -- EOF reached.') 
     finally: 
      coord.request_stop() 

     coord.join(threads) 

def main(unused_argv): 
    files = standardize_data.list_files(path=PATH) 

    convert_to(dataset=files, name='train') 

這已經讓我想,也許是陷入無限循環?我想要做的是讀取每個CSV文件(258個CSV文件)中的所有行,並將這些行寫入TF Record(功能和標籤,當然是這樣)。然後,當沒有更多行可用時停止循環,或者CSV文件已經用盡。

standardize_data.list_files(path)是我在不同模塊中編寫的函數。我剛剛重新使用它的腳本。它所做的是返回在PATH中找到的所有文件的列表。請注意,我的PATH中的文件僅包含CSV文件。

回答

1

設置num_epochs=1string_input_producer。另一個注意事項:將這些csv轉換爲tfrecords可能不會提供您在tfrecords中尋找的任何優勢,但對於此類數據(具有大量單一功能/標籤),開銷非常高。你可能想要試驗這個部分。

+0

換句話說,你是不是建議把它們轉換成TF記錄? –

+0

做這個實驗:只轉換一個文件,然後檢查各自的大小。您的數據對於'tfrecords'表示效率不高。每個功能都與標籤一起保存,所以我認爲它的尺寸會比保存爲csv的尺寸大。 –

+0

示例CSV文件是10.1 MB,其等效的tfrecord是9.6 MB –