未验证 提交 4763e6bc 编写于 作者: T tangwei12 提交者: GitHub

pre padding in dygraph (#30163)

Change-Id: Ia5279b0cbb6a5b3970aff66e9510e0d85efa70ce
上级 198fbdfb
......@@ -26,12 +26,10 @@ paddle.disable_static()
class EmbeddingDygraph(unittest.TestCase):
def test_1(self):
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
paddle.disable_static(paddle.CPUPlace())
x = paddle.to_tensor(x_data, stop_gradient=False)
y = paddle.to_tensor(y_data, stop_gradient=False)
embedding = paddle.nn.Embedding(10, 3, sparse=True)
embedding = paddle.nn.Embedding(10, 3, sparse=True, padding_idx=9)
w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32)
embedding.weight.set_value(w0)
......
......@@ -16,6 +16,7 @@
import paddle
from ...fluid.dygraph import Flatten #DEFINE_ALIAS
from ...fluid.dygraph import layers
from ...fluid.framework import in_dygraph_mode
from .. import functional as F
from ...fluid.framework import _dygraph_tracer
......@@ -1352,6 +1353,9 @@ class Embedding(layers.Layer):
dtype=self._dtype,
is_bias=False)
if in_dygraph_mode() and padding_idx != -1:
self.weight[padding_idx] = 0.0
def forward(self, x):
return F.embedding(
x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册