From a2b0357d79cfb2bbea832ef26f6821491b6c7484 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 7 Jan 2021 14:29:45 +0800 Subject: [PATCH] pre padding in dygraph (#30179) Change-Id: Ia5279b0cbb6a5b3970aff66e9510e0d85efa70ce --- .../tests/unittests/test_nn_functional_embedding_dygraph.py | 4 +--- python/paddle/nn/layer/common.py | 4 ++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py index 43a0d481b2..acff7daade 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py +++ b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py @@ -26,12 +26,10 @@ paddle.disable_static() class EmbeddingDygraph(unittest.TestCase): def test_1(self): x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) - y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32) paddle.disable_static(paddle.CPUPlace()) x = paddle.to_tensor(x_data, stop_gradient=False) - y = paddle.to_tensor(y_data, stop_gradient=False) - embedding = paddle.nn.Embedding(10, 3, sparse=True) + embedding = paddle.nn.Embedding(10, 3, sparse=True, padding_idx=9) w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) embedding.weight.set_value(w0) diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 5ae6e3ed77..9675524f93 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -16,6 +16,7 @@ import paddle from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers +from ...fluid.framework import in_dygraph_mode from .. import functional as F from ...fluid.framework import _dygraph_tracer @@ -1277,6 +1278,9 @@ class Embedding(layers.Layer): dtype=self._dtype, is_bias=False) + if in_dygraph_mode() and padding_idx != -1: + self.weight[padding_idx] = 0.0 + def forward(self, x): return F.embedding( x, -- GitLab