提交 79c6a01c 编写于 作者: G guosheng

Fix dropout in dygraph Transformer.

上级 1f5d2987
......@@ -88,9 +88,9 @@ class PrePostProcessLayer(Layer):
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))))
elif cmd == "d": # add dropout
if dropout_rate:
self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False))
x, dropout_prob=dropout_rate, is_test=False)
if dropout_rate else x)
def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册