未验证 提交 a2b0357d 编写于 作者: T tangwei12 提交者: GitHub

pre padding in dygraph (#30179)

Change-Id: Ia5279b0cbb6a5b3970aff66e9510e0d85efa70ce
上级 85545bbc
......@@ -26,12 +26,10 @@ paddle.disable_static()
class EmbeddingDygraph(unittest.TestCase):
def test_1(self):
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
paddle.disable_static(paddle.CPUPlace())
x = paddle.to_tensor(x_data, stop_gradient=False)
y = paddle.to_tensor(y_data, stop_gradient=False)
embedding = paddle.nn.Embedding(10, 3, sparse=True)
embedding = paddle.nn.Embedding(10, 3, sparse=True, padding_idx=9)
w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32)
embedding.weight.set_value(w0)
......
......@@ -16,6 +16,7 @@
import paddle
from ...fluid.dygraph import Flatten #DEFINE_ALIAS
from ...fluid.dygraph import layers
from ...fluid.framework import in_dygraph_mode
from .. import functional as F
from ...fluid.framework import _dygraph_tracer
......@@ -1277,6 +1278,9 @@ class Embedding(layers.Layer):
dtype=self._dtype,
is_bias=False)
if in_dygraph_mode() and padding_idx != -1:
self.weight[padding_idx] = 0.0
def forward(self, x):
return F.embedding(
x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册