未验证 提交 663eca45 编写于 作者: G Guo Sheng 提交者: GitHub

Fix dygraph dropout seed. test=develop (#24177)

上级 9b851ba2
...@@ -983,12 +983,11 @@ def dropout(x, ...@@ -983,12 +983,11 @@ def dropout(x,
if (seed is None or if (seed is None or
seed == 0) and default_main_program().random_seed != 0: seed == 0) and default_main_program().random_seed != 0:
seed = default_main_program().random_seed seed = default_main_program().random_seed
seed = seed if seed is not None else 0
_is_test = not _dygraph_tracer()._train_mode _is_test = not _dygraph_tracer()._train_mode
out, mask = core.ops.dropout(x, 'dropout_prob', dropout_prob, 'is_test', out, mask = core.ops.dropout(
_is_test, 'fix_seed', seed is not None, x, 'dropout_prob', dropout_prob, 'is_test', _is_test, 'fix_seed',
'seed', seed, 'dropout_implementation', seed is not None, 'seed', seed if seed is not None else 0,
dropout_implementation) 'dropout_implementation', dropout_implementation)
return out return out
helper = LayerHelper('dropout', **locals()) helper = LayerHelper('dropout', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册