未验证 提交 41307c73 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #7284 from pkuyym/fix-7211

Stop gradient when pool_type=='max'
...@@ -816,6 +816,11 @@ def sequence_pool(input, pool_type, **kwargs): ...@@ -816,6 +816,11 @@ def sequence_pool(input, pool_type, **kwargs):
"MaxIndex": max_index}, "MaxIndex": max_index},
attrs={"pooltype": pool_type.upper()}) attrs={"pooltype": pool_type.upper()})
# when pool_type is max, variable max_index is initialized,
# so we stop the gradient explicitly here
if pool_type == 'max':
max_index.stop_gradient = True
return pool_out return pool_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册