From 4763e6bc4e59b78ac52d02e3b4f4b6fe80a2a91e Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 7 Jan 2021 10:00:14 +0800 Subject: [PATCH] pre padding in dygraph (#30163) Change-Id: Ia5279b0cbb6a5b3970aff66e9510e0d85efa70ce --- .../tests/unittests/test_nn_functional_embedding_dygraph.py | 4 +--- python/paddle/nn/layer/common.py | 4 ++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py index 43a0d481b28..acff7daadeb 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py +++ b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py @@ -26,12 +26,10 @@ paddle.disable_static() class EmbeddingDygraph(unittest.TestCase): def test_1(self): x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) - y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32) paddle.disable_static(paddle.CPUPlace()) x = paddle.to_tensor(x_data, stop_gradient=False) - y = paddle.to_tensor(y_data, stop_gradient=False) - embedding = paddle.nn.Embedding(10, 3, sparse=True) + embedding = paddle.nn.Embedding(10, 3, sparse=True, padding_idx=9) w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) embedding.weight.set_value(w0) diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 25e6d5b320f..05d619bd729 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -16,6 +16,7 @@ import paddle from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers +from ...fluid.framework import in_dygraph_mode from .. import functional as F from ...fluid.framework import _dygraph_tracer @@ -1352,6 +1353,9 @@ class Embedding(layers.Layer): dtype=self._dtype, is_bias=False) + if in_dygraph_mode() and padding_idx != -1: + self.weight[padding_idx] = 0.0 + def forward(self, x): return F.embedding( x, -- GitLab