我試圖在TensorForestEstimator
模型中使用表示7個特徵和7個標籤的數值浮點數據。也就是說,features
和labels
的形狀都是(484876, 7)
。我適當地設置了num_classes=7
和num_features=7
的ForestHParams
。該數據的格式如下:TensorFlow在裝配時崩潰TensorForestEstimator
f1 f2 f3 f4 f5 f6 f7 l1 l2 l3 l4 l5 l6 l7
39000.0 120.0 65.0 1000.0 25.0 0.69 3.94 39000.0 39959.0 42099.0 46153.0 49969.0 54127.0 55911.0
32000.0 185.0 65.0 1000.0 75.0 0.46 2.19 32000.0 37813.0 43074.0 48528.0 54273.0 60885.0 63810.0
30000.0 185.0 65.0 1000.0 25.0 0.41 1.80 30000.0 32481.0 35409.0 39145.0 42750.0 46678.0 48595.0
當調用fit()
的Python以下消息崩潰:
Python的意外退出,同時使用_pywrap_tensorflow_internal.so插件。
這裏是輸出使tf.logging.set_verbosity('INFO')
時:
INFO:tensorflow:training graph for tree: 0
INFO:tensorflow:training graph for tree: 1
...
INFO:tensorflow:training graph for tree: 9998
INFO:tensorflow:training graph for tree: 9999
INFO:tensorflow:Create CheckpointSaverHook.
2017-07-26 10:25:30.908894: F tensorflow/contrib/tensor_forest/kernels/count_extremely_random_stats_op.cc:404]
Check failed: column < num_classes_ (39001 vs. 8)
Process finished with exit code 134 (interrupted by signal 6: SIGABRT)
我不知道這是什麼錯誤意味着,它並沒有真正意義,因爲num_classes=7
,而不是8和特徵的形狀和標籤是(484876, 7)
,我不知道39001是從哪裏來的。
下面是代碼重現:
import numpy as np
import pandas as pd
import os
def get_training_data():
training_file = "data.txt"
data = pd.read_csv(training_file, sep='\t')
X = np.array(data.drop('Result', axis=1), dtype=np.float32)
y = []
for e in data.ResultStr:
y.append(list(np.array(str(e).replace('[', '').replace(']', '').split(','))))
y = np.array(y, dtype=np.float32)
features = tf.constant(X)
labels = tf.constant(y)
return features, labels
hyperparameters = ForestHParams(
num_trees=100,
max_nodes=10000,
bagging_fraction=1.0,
num_splits_to_consider=0,
feature_bagging_fraction=1.0,
max_fertile_nodes=0,
split_after_samples=250,
min_split_samples=5,
valid_leaf_threshold=1,
dominate_method='bootstrap',
dominate_fraction=0.99,
# All parameters above are default
num_classes=7,
num_features=7
)
estimator = TensorForestEstimator(
params=hyperparameters,
# All parameters below are default
device_assigner=None,
model_dir=None,
graph_builder_class=RandomForestGraphs,
config=None,
weights_name=None,
keys_name=None,
feature_engineering_fn=None,
early_stopping_rounds=100,
num_trainers=1,
trainer_id=0,
report_feature_importances=False,
local_eval=False
)
estimator.fit(
input_fn=lambda: get_training_data(),
max_steps=100,
monitors=[
TensorForestLossHook(
early_stopping_rounds=30
)
]
)
它還如果我SKCompat
包裹它不工作,會出現同樣的錯誤。這次事故的原因是什麼?
你能提供樣本輸入數據嗎? – denfromufa
我編輯了問題以提供輸入數據。 – jshapy8