From 79c6a01ca3a3b56523151508bcb3e652bb6a0125 Mon Sep 17 00:00:00 2001 From: guosheng Date: Sun, 9 Feb 2020 15:16:19 +0800 Subject: [PATCH] Fix dropout in dygraph Transformer. --- dygraph/transformer/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dygraph/transformer/model.py b/dygraph/transformer/model.py index b4ae428e..0187438d 100644 --- a/dygraph/transformer/model.py +++ b/dygraph/transformer/model.py @@ -88,9 +88,9 @@ class PrePostProcessLayer(Layer): bias_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(0.))))) elif cmd == "d": # add dropout - if dropout_rate: - self.functors.append(lambda x: layers.dropout( - x, dropout_prob=dropout_rate, is_test=False)) + self.functors.append(lambda x: layers.dropout( + x, dropout_prob=dropout_rate, is_test=False) + if dropout_rate else x) def forward(self, x, residual=None): for i, cmd in enumerate(self.process_cmd): -- GitLab