2012-05-21 23 views
7

我有點尷尬地承認這一點,但我似乎很難被一個簡單的編程問題困住。我正在構建一個決策樹實現,並一直使用遞歸來獲取標記樣本的列表,遞歸地將列表分成兩半,然後將它變成一棵樹。使用while循環+堆棧編碼遞歸樹創建

不幸的是,對於深度樹,我遇到了堆棧溢出錯誤(ha!),所以我的第一個想法是使用continuation將其變爲尾遞歸。不幸的是,Scala不支持這種TCO,所以唯一的解決方案是使用蹦牀。蹦牀看起來效率不高,我希望有一些簡單的基於堆棧的命令式解決方案來解決這個問題,但是我很難找到它。

遞歸版本看起來有點像(簡體):

private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = { 
    if (shouldStop(samples)) { 
    DTLeaf(makeProportions(samples)) 
    } else { 
    val featureIdx = getSplittingFeature(samples, usedFeatures) 
    val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _)) 
    DTBranch(
     trainTree(statsWithFeature, usedFeatures + featureIdx), 
     trainTree(statsWithoutFeature, usedFeatures + featureIdx), 
     featureIdx) 
    } 
} 

所以基本上我根據一些數據的功能遞歸細分名單一分爲二,並通過所使用的特徵列表,所以傳遞我不重複 - 這些都是在「getSplittingFeature」函數中處理的,所以我們可以忽略它。代碼非常簡單!不過,我很難找出一個基於堆棧的解決方案,它不僅僅使用閉包並且有效地成爲蹦牀。我知道我們至少必須在堆棧中保留很少的參數「框架」,但我想避免關閉調用。

我得到的是,我應該明確地寫出了callstack和程序計數器在遞歸解決方案中隱含地處理了什麼,但是如果沒有延續,我會遇到麻煩。在這一點上,它甚至不是效率,我只是好奇。所以,請不要提醒我,過早優化是萬惡之源,而基於蹦牀的解決方案可能會工作得很好。我知道它可能會 - 它本身就是一個難題。

任何人都可以告訴我什麼規範的while-loop-and-stack-based解決方案是這樣的嗎?

更新:基於Thipor Kong出色的解決方案,我編寫了一個基於while-loops/stacks/hashtable的算法實現,該算法應該是遞歸版本的直接轉換。這正是我正在尋找的:

最終更新:我已經使用了順序整數索引,以及將所有內容都放回到數組中,而不是用於性能的映射,增加了maxDepth支持,並最終獲得了相同的解決方案性能遞歸版本(不知道內存使用情況,但我猜以下):

private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = { 
    // Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit 
    type DenseIntMap[T] = ArrayBuffer[T] 
    def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = { 
    if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) } 
    ab.update(idx, item) 
    } 
    var currentChildId = 0 // get childIdx or create one if it's not there already 
    def child(childMap: DenseIntMap[Int], heapIdx: Int) = 
    if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx) 
    else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId } 
    // go down 
    val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx 
    val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx 
    val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx 
    val nodes = new DenseIntMap[DTree]() // heapIdx -> node 
    while (!todo.isEmpty) { 
    val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop() 
    if (shouldStop(samples) || maxDepth == 0) { 
     updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples))) 
    } else { 
     val featureIdx = getSplittingFeature(samples, usedFeatures) 
     val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _)) 
     todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx))) 
     todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx))) 
     branches.push((heapIdx, featureIdx)) 
    } 
    } 
    // go up 
    while (!branches.isEmpty) { 
    val (heapIdx, featureIdx) = branches.pop() 
    updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx)) 
    } 
    nodes(0) 
} 
+0

不是卸載到基於堆棧的實現(堆棧在堆上)在概念上與蹦牀相同嗎? – ron

+0

排序,但蹦牀意味着你正在保持堆棧充滿關閉,我希望有一個解決方案,只使用堆棧的數據。也許標記爲StepOne(a,b,c),StepTwo(a,b,c)或多個堆棧或類似的數據,但不涉及函數調用。 – lvilnis

+0

對我的代碼進行了另一項更改。節點id的名稱空間更經濟,您可以插入自己的節點id類型(或BigInt,如果您喜歡)。 –

回答

3

二叉樹只是存儲在陣列中,作爲Wikipedia描述:對於節點i,左子進入2*i+1和正確的孩子在2*i+2。當「下」時,你會收集一些待辦事項,但仍然需要分解才能到達一片葉子。一旦你只有葉子,再往上走(右陣列中左)建立決策節點:

更新:一個清理版本,還支持存儲INT分支的特徵(類型參數B),這是更多的功能/完全純,並支持與ron建議的地圖稀疏樹。

Update2-3:經濟地使用節點標識符的名稱空間和對ID類型的抽象以允許大樹。從Stream中獲取節點ID。

sealed trait DTree[A, B] 
case class DTLeaf[A, B](a: A, b: B) extends DTree[A, B] 
case class DTBranch[A, B](left: DTree[A, B], right: DTree[A, B], b: B) extends DTree[A, B] 

def mktree[A, B, Id](a: A, b: B, split: (A, B) => Option[(A, A, B)], ids: Stream[Id]) = { 
    @tailrec 
    def goDown(todo: Seq[(A, B, Id)], branches: Seq[(Id, B, Id, Id)], leafs: Map[Id, DTree[A, B]], ids: Stream[Id]): (Seq[(Id, B, Id, Id)], Map[Id, DTree[A, B]]) = 
    todo match { 
     case Nil => (branches, leafs) 
     case (a, b, id) :: rest => 
     split(a, b) match { 
      case None => 
      goDown(rest, branches, leafs + (id -> DTLeaf(a, b)), ids) 
      case Some((left, right, b2)) => 
      val leftId #:: rightId #:: idRest = ids 
      goDown((right, b2, rightId) +: (left, b2, leftId) +: rest, (id, b2, leftId, rightId) +: branches, leafs, idRest) 
     } 
    } 

    @tailrec 
    def goUp[A, B](branches: Seq[(Id, B, Id, Id)], nodes: Map[Id, DTree[A, B]]): Map[Id, DTree[A, B]] = 
    branches match { 
     case Nil => nodes 
     case (id, b, leftId, rightId) :: rest => 
     goUp(rest, nodes + (id -> DTBranch(nodes(leftId), nodes(rightId), b))) 
    } 

    val rootId #:: restIds = ids 
    val (branches, leafs) = goDown(Seq((a, b, rootId)), Seq(), Map(), restIds) 
    goUp(branches, leafs)(rootId) 
} 

// try it out 

def split(xs: Seq[Int], b: Int) = 
    if (xs.size > 1) { 
    val (left, right) = xs.splitAt(xs.size/2) 
    Some((left, right, b + 1)) 
    } else { 
    None 
    } 

val tree = mktree(0 to 1000, 0, split _, Stream.from(0)) 
println(tree) 
+0

每個DTBranch需要「featureIndex」這個事實呢?這使得它變得更加棘手,因爲將所有的葉子轉換爲分支,我們需要他們的featureIndex,然後將這些分支組合在一起,我們需要他們的featureIndexes,等等。我認爲這是正確的想法,但我會玩弄它。 – lvilnis

+0

當下降時(而不是無),您可以將featureIndices放入堆中,以便在再次向上時可用於創建DTBranch。 –

+0

太棒了!我會嘗試一下,並在一小時內將你的答案標記爲答案。 – lvilnis