提交 79c6a01c 编写于 作者: G guosheng

Fix dropout in dygraph Transformer.

上级 1f5d2987
...@@ -88,9 +88,9 @@ class PrePostProcessLayer(Layer): ...@@ -88,9 +88,9 @@ class PrePostProcessLayer(Layer):
bias_attr=fluid.ParamAttr( bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.))))) initializer=fluid.initializer.Constant(0.)))))
elif cmd == "d": # add dropout elif cmd == "d": # add dropout
if dropout_rate: self.functors.append(lambda x: layers.dropout(
self.functors.append(lambda x: layers.dropout( x, dropout_prob=dropout_rate, is_test=False)
x, dropout_prob=dropout_rate, is_test=False)) if dropout_rate else x)
def forward(self, x, residual=None): def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd): for i, cmd in enumerate(self.process_cmd):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册