提交 872b1c88 编写于 作者: Y yangyaming

Stop gradient when pool_type=='max'

上级 25af35d8
...@@ -759,6 +759,11 @@ def sequence_pool(input, pool_type, **kwargs): ...@@ -759,6 +759,11 @@ def sequence_pool(input, pool_type, **kwargs):
"MaxIndex": max_index}, "MaxIndex": max_index},
attrs={"pooltype": pool_type.upper()}) attrs={"pooltype": pool_type.upper()})
# when pool_type is max, variable max_index is initialized,
# so we stop the gradient explicitly here
if pool_type == 'max':
max_index.stop_gradient = True
return pool_out return pool_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册