2016-10-02 69 views
1

我正在嘗試使用tensorflow LSTMs進行時間序列預測。我使用的lstm-for-epf.py修改後的版本從repo時間序列的tensorflow lstm模型

import numpy as np 
import pandas as pd 
from matplotlib import pyplot as plt 

from tensorflow.contrib import learn 
from sklearn.metrics import mean_squared_error, mean_absolute_error 
from lstm_predictor import generate_data, load_csvdata, lstm_model 


LOG_DIR = './ops_logs' 
TIMESTEPS = 10 
RNN_LAYERS = [{'steps': TIMESTEPS}] 
DENSE_LAYERS = [10, 10] 
TRAINING_STEPS = 100000 
BATCH_SIZE = 100 
PRINT_STEPS = TRAINING_STEPS/100 

dateparse = lambda dates: pd.datetime.strptime(dates, '%d/%m/%Y %H:%M') 
rawdata = pd.read_csv("RealMarketPriceDataPT.csv", 
        parse_dates={'timeline': ['date', '(UTC)']}, 
        index_col='timeline', date_parser=dateparse) 


X, y = load_csvdata(rawdata, TIMESTEPS, seperate=False) 


regressor = learn.TensorFlowEstimator(model_fn=lstm_model(TIMESTEPS, RNN_LAYERS, DENSE_LAYERS), 
             n_classes=0, 
             verbose=1, 
             steps=TRAINING_STEPS, 
             optimizer='Adagrad', 
             learning_rate=0.03, 
             batch_size=BATCH_SIZE) 




validation_monitor = learn.monitors.ValidationMonitor(X['val'], y['val'], 
                 every_n_steps=PRINT_STEPS, 
                 early_stopping_rounds=1000, 
                 batch_size=BATCH_SIZE) 

regressor.fit(X['train'], y['train'], monitors=[validation_monitor], logdir=LOG_DIR) 


predicted = regressor.predict(X['test']) 
mse = mean_absolute_error(y['test'], predicted) 
print ("Error: %f" % mse) 

# plot_predicted, = plt.plot(predicted, label='predicted') 
# plot_test, = plt.plot(y['test'], label='test') 
# plt.legend(handles=[plot_predicted, plot_test]) 

這是給錯誤。

Traceback (most recent call last): 
    File "lstm-for-epf.py", line 43, in <module> 
    regressor.fit(X['train'], y['train'], monitors=[validation_monitor], logdir=LOG_DIR) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/base.py", line 166, in fit 
    monitors=monitors) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 578, in _train_model 
    max_steps=max_steps) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/graph_actions.py", line 280, in _supervised_train 
    None) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/supervised_session.py", line 270, in run 
    run_metadata=run_metadata) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/recoverable_session.py", line 54, in run 
    run_metadata=run_metadata) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/coordinated_session.py", line 70, in run 
    self._coord.join(self._coordinated_threads_to_join) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/python/training/coordinator.py", line 357, in join 
    six.reraise(*self._exc_info_to_raise) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/six.py", line 686, in reraise 
    raise value 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/coordinated_session.py", line 66, in run 
    return self._sess.run(*args, **kwargs) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/monitored_session.py", line 107, in run 
    induce_stop = monitor.step_end(monitors_step, monitor_outputs) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/monitors.py", line 396, in step_end 
    return self.every_n_step_end(step, output) 
    File "/home/tensorflow/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/monitors.py", line 687, in every_n_step_end 
    steps=self.eval_steps, metrics=self.metrics, name=self.name) 
TypeError: evaluate() got an unexpected keyword argument 'batch_size' 

回答