我已經創建了一個PR here並且它可以幫助你處理簡單的案件
讓我簡單介紹一下我的實現,所以你可以,如果你需要編寫自己的版本。主要的部分是_time_step
功能的修改:
def _time_step(time, output_ta_t, state, *args):
的參數保持不變,除了額外的*args
傳入但是爲什麼args
?這是因爲我想支持張量流傳統的習慣行爲。您可以只通過簡單地忽略args
參數返回最終狀態:
if states_ta is not None:
# If you want to return all states, set `args` to be `states_ta`
loop_vars = (time, output_ta, state, states_ta)
else:
# If you want the final state only, ignore `args`
loop_vars = (time, output_ta, state)
如何使用它?
if args:
args = tuple(
ta.write(time, out) for ta, out in zip(args[0], [new_state])
)
其實這只是一個以下(原件)代碼修改:
output_ta_t = tuple(
ta.write(time, out) for ta, out in zip(output_ta_t, output)
)
現在args
應該包含所有你想要的狀態。
之後所有的作品上面做,你可以用以下代碼拿起狀態(或最終狀態):
_, output_final_ta, *state_info = control_flow_ops.while_loop(...
和
if states_ta is not None:
final_state, states_final_ta = state_info
else:
final_state, states_final_ta = state_info[0], None
雖然我還沒有測試它它應該在'簡單'條件下工作(here's我的測試用例)
我創建了一個PR [這裏](https://github.com/tensorflow/tensorflow/pull/9995),它可能會幫助你處理簡單的案例 – Carefree0910