2017-07-26 70 views
8

我試圖在TensorForestEstimator模型中使用表示7個特徵和7個標籤的數值浮點數據。也就是說,featureslabels的形狀都是(484876, 7)。我適當地設置了num_classes=7num_features=7ForestHParams。該數據的格式如下:TensorFlow在裝配時崩潰TensorForestEstimator

f1  f2  f3 f4  f5 f6 f7 l1  l2  l3  l4  l5  l6  l7 
39000.0 120.0 65.0 1000.0 25.0 0.69 3.94 39000.0 39959.0 42099.0 46153.0 49969.0 54127.0 55911.0 
32000.0 185.0 65.0 1000.0 75.0 0.46 2.19 32000.0 37813.0 43074.0 48528.0 54273.0 60885.0 63810.0 
30000.0 185.0 65.0 1000.0 25.0 0.41 1.80 30000.0 32481.0 35409.0 39145.0 42750.0 46678.0 48595.0 

當調用fit()的Python以下消息崩潰:

Python的意外退出,同時使用_pywrap_tensorflow_internal.so插件。

這裏是輸出使tf.logging.set_verbosity('INFO')時:

INFO:tensorflow:training graph for tree: 0 
INFO:tensorflow:training graph for tree: 1 
... 
INFO:tensorflow:training graph for tree: 9998 
INFO:tensorflow:training graph for tree: 9999 
INFO:tensorflow:Create CheckpointSaverHook. 
2017-07-26 10:25:30.908894: F tensorflow/contrib/tensor_forest/kernels/count_extremely_random_stats_op.cc:404] 
Check failed: column < num_classes_ (39001 vs. 8) 

Process finished with exit code 134 (interrupted by signal 6: SIGABRT) 

我不知道這是什麼錯誤意味着,它並沒有真正意義,因爲num_classes=7,而不是8和特徵的形狀和標籤是(484876, 7),我不知道39001是從哪裏來的。

下面是代碼重現:

import numpy as np 
import pandas as pd 
import os 

def get_training_data(): 
    training_file = "data.txt" 
    data = pd.read_csv(training_file, sep='\t') 

    X = np.array(data.drop('Result', axis=1), dtype=np.float32) 

    y = [] 
    for e in data.ResultStr: 
     y.append(list(np.array(str(e).replace('[', '').replace(']', '').split(',')))) 

    y = np.array(y, dtype=np.float32) 

    features = tf.constant(X) 
    labels = tf.constant(y) 

    return features, labels 

hyperparameters = ForestHParams(
    num_trees=100, 
    max_nodes=10000, 
    bagging_fraction=1.0, 
    num_splits_to_consider=0, 
    feature_bagging_fraction=1.0, 
    max_fertile_nodes=0, 
    split_after_samples=250, 
    min_split_samples=5, 
    valid_leaf_threshold=1, 
    dominate_method='bootstrap', 
    dominate_fraction=0.99, 
    # All parameters above are default 
    num_classes=7, 
    num_features=7 
) 

estimator = TensorForestEstimator(
    params=hyperparameters, 
    # All parameters below are default 
    device_assigner=None, 
    model_dir=None, 
    graph_builder_class=RandomForestGraphs, 
    config=None, 
    weights_name=None, 
    keys_name=None, 
    feature_engineering_fn=None, 
    early_stopping_rounds=100, 
    num_trainers=1, 
    trainer_id=0, 
    report_feature_importances=False, 
    local_eval=False 
) 

estimator.fit(
    input_fn=lambda: get_training_data(), 
    max_steps=100, 
    monitors=[ 
     TensorForestLossHook(
      early_stopping_rounds=30 
     ) 
    ] 
) 

它還如果我SKCompat包裹它不工作,會出現同樣的錯誤。這次事故的原因是什麼?

+0

你能提供樣本輸入數據嗎? – denfromufa

+0

我編輯了問題以提供輸入數據。 – jshapy8

回答

4

regression=True需要在ForestHParams中指定,因爲TensorForestEstimator默認情況下假定它被用來解決只能輸出一個值的分類問題。

在初始化估計器時會創建一個隱式的num_outputs變量,如果未指定regression,則會將其設置爲1。如果指定了regression,則num_outputs = num_classes和檢查點將正常保存。