有人可以在TensorFlow中解釋我的gradient_override_map
函數嗎? 我無法準確理解它的用法。Tensorflow的gradient_override_map函數
我看到的代碼用法:
with G.gradient_override_map({"Floor": "Identity"}):
return tf.reduce_mean(SomeVals) * SomeOtherVal
正是這裏發生了什麼?什麼是Identity
?
有人可以在TensorFlow中解釋我的gradient_override_map
函數嗎? 我無法準確理解它的用法。Tensorflow的gradient_override_map函數
我看到的代碼用法:
with G.gradient_override_map({"Floor": "Identity"}):
return tf.reduce_mean(SomeVals) * SomeOtherVal
正是這裏發生了什麼?什麼是Identity
?
盡我所知,gradient_override_map允許你說「在這種情況下,任何時候你會使用X的漸變,而不是使用Y的漸變」。這意味着您仍然需要Y的梯度作爲您要使用的漸變。
這是我見過的漂浮在尋找如何工作的一個例子:
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
舉:https://stackoverflow.com/a/43948872/1102705
RegisterGradient()
允許你註冊你定義一個新的運算的梯度,從而允許你有一個你想要的梯度的運算符,然後你可以在梯度覆蓋映射中使用該運算符。它有點笨重 - 你正在定義一個沒有前鋒的傳球。
我不明白的是名稱=「身份」是否真的有必要。
兩個「樓」和「身份」是操作的類型的字符串,前者對應於tf.floor而後者tf.identity。 所以我猜你的代碼的功能是用圖G中的tf.floor運算代替tf.identity的BPD計算機制的後向傳播梯度(BPG)計算機制,同時通過前向輸出tf.reduce_mean。看起來有點奇怪,因爲在我找到的gradient_override_map
的所有應用程序中,op_type_map的關鍵字始終與用於在上下文中生成輸出的操作的類型字符串相同。通過這個我的意思是我更熟悉與tf.floor(SomeVals)
返回的場景,而不是tf.reduce_mean(SomeVals)
。
gradient_override_map({op_A_type: op_B_type})
做的是用op_B替換op_A的BPG計算機制,同時保留op_A_type的正向傳播計算機制。 lahwran的答案中顯示了gradient_override_map的常見應用。
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
通過
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
裝飾,tf.RegisterGradient("CustomGrad")
登記由_const_mul_grad(unused_op, grad)
用於定製的運算式定義的梯度函數 - 「CustomGrad」,
而
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
保證所有操作的輸出(圖g),字符串類型爲「Id」實體「(tf.identity)與它們相同,而BPG計算機制tf。身份用字符串類型「CustomGrad」替換爲BPG計算操作機制。
P.S.
運算的類型字符串對應於OpDef.name
字段定義該操作的原。爲了找到一個運算的OpDef.name
,請參照明星的回答下this question
這是沒有必要的,因爲在tf.identity的ARG「名稱」是可選操作申報tf.identity的名稱。