未验证 提交 a2b0357d 编写于 作者: T tangwei12 提交者: GitHub

pre padding in dygraph (#30179)

Change-Id: Ia5279b0cbb6a5b3970aff66e9510e0d85efa70ce
上级 85545bbc
...@@ -26,12 +26,10 @@ paddle.disable_static() ...@@ -26,12 +26,10 @@ paddle.disable_static()
class EmbeddingDygraph(unittest.TestCase): class EmbeddingDygraph(unittest.TestCase):
def test_1(self): def test_1(self):
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
paddle.disable_static(paddle.CPUPlace()) paddle.disable_static(paddle.CPUPlace())
x = paddle.to_tensor(x_data, stop_gradient=False) x = paddle.to_tensor(x_data, stop_gradient=False)
y = paddle.to_tensor(y_data, stop_gradient=False)
embedding = paddle.nn.Embedding(10, 3, sparse=True) embedding = paddle.nn.Embedding(10, 3, sparse=True, padding_idx=9)
w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32)
embedding.weight.set_value(w0) embedding.weight.set_value(w0)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import paddle import paddle
from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import Flatten #DEFINE_ALIAS
from ...fluid.dygraph import layers from ...fluid.dygraph import layers
from ...fluid.framework import in_dygraph_mode
from .. import functional as F from .. import functional as F
from ...fluid.framework import _dygraph_tracer from ...fluid.framework import _dygraph_tracer
...@@ -1277,6 +1278,9 @@ class Embedding(layers.Layer): ...@@ -1277,6 +1278,9 @@ class Embedding(layers.Layer):
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
if in_dygraph_mode() and padding_idx != -1:
self.weight[padding_idx] = 0.0
def forward(self, x): def forward(self, x):
return F.embedding( return F.embedding(
x, x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册