2016-11-28 76 views
0

我有一個函數diceTensorflow - 應用功能在1D張量

def dice(yPred,yTruth,thresh): 
    smooth = tf.constant(1.0) 
    threshold = tf.constant(thresh) 
    yPredThresh = tf.to_float(tf.greater_equal(yPred,threshold)) 
    mul = tf.mul(yPredThresh,yTruth) 
    intersection = 2*tf.reduce_sum(mul) + smooth 
    union = tf.reduce_sum(yPredThresh) + tf.reduce_sum(yTruth) + smooth 
    dice = intersection/union 
    return dice, yPredThresh 

其中工程。一個例子是這裏

with tf.Session() as sess: 

    thresh = 0.5 
    print("Dice example") 
    yPred = tf.constant([0.1,0.9,0.7,0.3,0.1,0.1,0.9,0.9,0.1],shape=[3,3]) 
    yTruth = tf.constant([0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0],shape=[3,3]) 
    diceScore, yPredThresh= dice(yPred=yPred,yTruth=yTruth,thresh= thresh) 

    diceScore_ , yPredThresh_ , yPred_, yTruth_ = sess.run([diceScore,yPredThresh,yPred, yTruth]) 
    print("\nScore = {0}".format(diceScore_)) 

>>> Score = 0.899999976158

我希望能夠遍歷骰子的第三arguement給出脫粒。我不知道這樣做的最佳方式,以便我可以從圖中提取它。因爲我不能環路成,顯然分割一個TF張量大致如下的東西線...

def diceROC(yPred,yTruth,thresholds=np.linspace(0.1,0.9,20)): 
    thresholds = thresholds.astype(np.float32) 
    nThreshs = thresholds.size 
    diceScores = tf.zeros(shape=nThreshs) 

    for i in xrange(nThreshs): 
     score,_ = dice(yPred,yTruth,thresholds[i]) 
     diceScores[i] = score 
    return diceScores 

評估diceScoreROC產生錯誤'Tensor' object does not support item assignment

回答

1

而不是循環,我會鼓勵你使用tensorflow的廣播能力。如果重新定義dice到:

def dice(yPred,yTruth,thresh): 
    smooth = tf.constant(1.0) 
    yPredThresh = tf.to_float(tf.greater_equal(yPred,thresh)) 
    mul = tf.mul(yPredThresh,yTruth) 
    intersection = 2*tf.reduce_sum(mul, [0, 1]) + smooth 
    union = tf.reduce_sum(yPredThresh, [0, 1]) + tf.reduce_sum(yTruth, [0, 1]) + smooth 
    dice = intersection/union 
    return dice, yPredThresh 

您可以通過3維yPredyTruth(假設張量將只是沿着最後一維重複)和1維thresh

with tf.Session() as sess: 

    thresh = [0.1,0.9,20, 0.5] 
    print("Dice example") 
    yPred = tf.constant([0.1,0.9,0.7,0.3,0.1,0.1,0.9,0.9,0.1],shape=[3,3,1]) 
    ypred_tiled = tf.tile(yPred, [1,1,4]) 
    yTruth = tf.constant([0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0],shape=[3,3,1]) 
    ytruth_tiled = tf.tile(yTruth, [1,1,4]) 
    diceScore, yPredThresh= dice(yPred=ypred_tiled,yTruth=ytruth_tiled,thresh= thresh) 

    diceScore_ = sess.run(diceScore) 
    print("\nScore = {0}".format(diceScore_)) 

你會得到:

Score = [ 0.73333335 0.77777779 0.16666667 0.89999998] 
+0

真棒,很好的答案謝謝 – mattdns