diff --git a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py index 43a0d481b28fdc47dec52fe9763dd920fd5a76a2..acff7daadeb33114646398a0a78250535e73f3aa 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py +++ b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py @@ -26,12 +26,10 @@ paddle.disable_static() class EmbeddingDygraph(unittest.TestCase): def test_1(self): x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) - y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32) paddle.disable_static(paddle.CPUPlace()) x = paddle.to_tensor(x_data, stop_gradient=False) - y = paddle.to_tensor(y_data, stop_gradient=False) - embedding = paddle.nn.Embedding(10, 3, sparse=True) + embedding = paddle.nn.Embedding(10, 3, sparse=True, padding_idx=9) w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) embedding.weight.set_value(w0) diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 5ae6e3ed770c978ea079596f9669e1bdd81d5b7b..9675524f938e544d0e6dbcc94bf8e7a979b070ae 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -16,6 +16,7 @@ import paddle from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers +from ...fluid.framework import in_dygraph_mode from .. import functional as F from ...fluid.framework import _dygraph_tracer @@ -1277,6 +1278,9 @@ class Embedding(layers.Layer): dtype=self._dtype, is_bias=False) + if in_dygraph_mode() and padding_idx != -1: + self.weight[padding_idx] = 0.0 + def forward(self, x): return F.embedding( x,