2017-03-22 150 views
1

我有許多樹的xgboost.dump文本文件。 我想查找所有路徑以獲取每個路徑的值。 這是一棵樹。找到來自xgboost.dump的二叉樹的所有路徑

tree[0]: 
0:[a<0.966398] yes=1,no=2,missing=1 
    1:[b<0.323071] yes=3,no=4,missing=3 
     3:[c<0.461248] yes=7,no=8,missing=7 
      7:leaf=0.00972768 
      8:leaf=-0.0179376 
     4:[a<0.379082] yes=9,no=10,missing=9 
      9:leaf=0.0146003 
      10:leaf=0.0454369 
    2:[b<0.322352] yes=5,no=6,missing=5 
     5:[c<0.674868] yes=11,no=12,missing=11 
      11:leaf=0.0497964 
      12:leaf=0.00953781 
     6:[f<0.598267] yes=13,no=14,missing=13 
      13:leaf=0.0504545 
      14:leaf=0.0867654 

我想所有的路徑轉換成

path1, a<0.966398, b<0.323071, c<0.461248, leaf = 0.00097268 
path2, a<0.966398, b<0.323071, c>0.461248, leaf = -0.0179376 
path3, a<0.966398, b>0.323071, a<0.379082, leaf = 0.0146003 
path4, a<0.966398, b>0.323071, a>0.379082, leaf = 0.0454369 
path5, a>0.966398, b<0.322352, c<0.674868, leaf = 0.0497964 
path6, a>0.966398, b<0.322352, c>0.674868, leaf = 0.00953781 
path7, a>0.966398, b>0.322352, f<0.598267, leaf = 0.0504545 
path8, a>0.966398, b>0.322352, f>0.598267, leaf = 0.0864654 

我已經嘗試列出像

array([[ 0, 1, 3, 7], 
     [ 0, 1, 3, 8], 
     [ 0, 1, 4, 9], 
     [ 0, 1, 4, 10], 
     [ 0, 2, 5, 11], 
     [ 0, 2, 5, 12], 
     [ 0, 2, 6, 13], 
     [ 0, 2, 6, 14]]) 

所有可能的路徑,但一旦MAX_DEPTH較高這樣會導致錯誤,一些分支將停止增長,路徑將錯誤。 所以我需要解析文本文件中的yes,no來生成真實的,正確的路徑。 有什麼建議嗎? 謝謝!

回答

0

下面是我用R實現來解決這個問題的方法。其他語言的用戶可以遵循邏輯和實物複製。

首先,我從由xgb.model.dt.tree()生成的模型轉儲文件開始。

然後,我寫了一個函數來解析從任意節點到轉儲模型的單個樹內的最終父節點的有效路徑。

後來,我使用purrr :: by_row()將該函數應用於模型轉儲的所有終端節點「葉」記錄,並將結果轉換用於目的。

該函數有兩個參數,一個用於正在測試的樹,另一個用於終端節點的標識。它遵循以下一般步驟:

  1. 與每個樹爲單位的目標(終端)節點開始,查找具有目標節點爲有效孩子的C中(行「是」,「否」 ,「失蹤」)決定分裂。
  2. 將此有效父節點ID連接到一個向量中,該向量將用於跟蹤從目標節點到最終父節點的路徑的每個步驟。該向量在函數完成時返回。
  3. 接下來,爲鏈上的每個節點重複「誰是我的父母」步驟,直到路徑碰到最終父母爲止(此節點ID始終以「-0」結尾),同時更新每個新步驟的路徑向量連鎖,鏈條。
  4. 一旦函數命中終端節點,返回()該路徑。

在我的情況下,我將這個函數應用到模型轉儲中的所有「Leaf」節點上,使用purrr :: by_row()while .collat​​ing =「rows」來表示通道作爲輸出中的附加行。

這也很可能不是最快的方式。

xgb.booster模型中nrounds或max_depth的增加將導致此過程的運行時間增加。您可以使用樹的子集(xgb.model.dt.tree()的參數n_first_tree = N)來開發您的方法,以便您可以估算解析整個最終模型中的終端節點路徑所需的時間。在我的情況下,在max_depth = 5時有500棵樹的模型可能需要30分鐘以上。