未验证 提交 370b50f6 编写于 作者: S seemingwang 提交者: GitHub

【Zero-Dim】Support Zero dim for embedding and one-hot (#49562)

* zero-tensor

* remove unused

* zero_dim_xpu

* relocate

* add value test

* fix syntax
上级 1e8976e8
......@@ -2257,8 +2257,8 @@ void OneHotRawInferMeta(const MetaTensor& x,
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
0,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 0."));
auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth.to<int>());
auto out_dims = phi::make_ddim(out_dims_vec);
......@@ -2273,8 +2273,8 @@ void OneHotInferMeta(const MetaTensor& x,
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
x_dims.size(),
1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
0,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 0."));
int depth = depth_t.to<int>();
auto out_dims_vec = phi::vectorize(x_dims);
......
......@@ -1421,6 +1421,24 @@ class TestNoBackwardAPI(unittest.TestCase):
out = paddle.zeros(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_embedding(self):
ids = paddle.full(shape=[], fill_value=1, dtype='int64')
w0 = paddle.arange(3, 9).reshape((3, 2)).astype(paddle.float32)
w = paddle.to_tensor(w0, stop_gradient=False)
emb = paddle.nn.functional.embedding(
x=ids, weight=w, sparse=True, name="embedding"
)
self.assertEqual(emb.shape, [2])
res = [5.0, 6.0]
for i in range(len(res)):
self.assertEqual(emb.numpy()[i], res[i])
def test_one_hot_label(self):
label = paddle.full(shape=[], fill_value=2, dtype='int64')
one_hot_label = paddle.nn.functional.one_hot(label, num_classes=4)
self.assertEqual(one_hot_label.shape, [4])
self.assertEqual(one_hot_label.numpy()[2], 1)
class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -1590,6 +1608,39 @@ class TestNoBackwardAPIStatic(unittest.TestCase):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_embedding(self):
ids = paddle.full(shape=[], fill_value=1, dtype='int64')
w0 = paddle.arange(3, 9).reshape((3, 2)).astype(paddle.float32)
w = paddle.to_tensor(w0, stop_gradient=False)
emb = paddle.nn.functional.embedding(
x=ids, weight=w, sparse=True, name="embedding"
)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[emb])
self.assertEqual(res[0].shape, (2,))
result = [5.0, 6.0]
for i in range(len(res)):
self.assertEqual(res[0][i], result[i])
def test_static_embedding(self):
ids = paddle.full(shape=[], fill_value=1, dtype='int64')
emb = paddle.static.nn.embedding(ids, (20, 3))
prog = paddle.static.default_main_program()
self.exe.run(paddle.fluid.default_startup_program())
res = self.exe.run(prog, fetch_list=[emb])
self.assertEqual(res[0].shape, (3,))
def test_one_hot_label(self):
label = paddle.full(shape=[], fill_value=2, dtype='int64')
one_hot_label = paddle.nn.functional.one_hot(label, num_classes=4)
prog = paddle.static.default_main_program()
self.exe.run(paddle.fluid.default_startup_program())
res = self.exe.run(prog, fetch_list=[one_hot_label])
self.assertEqual(res[0].shape, (4,))
self.assertEqual(res[0][2], 1)
if __name__ == "__main__":
unittest.main()
......@@ -833,6 +833,24 @@ class TestNoBackwardAPI(unittest.TestCase):
out = paddle.zeros(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_embedding(self):
ids = paddle.full(shape=[], fill_value=1, dtype='int64')
w0 = paddle.arange(3, 9).reshape((3, 2)).astype(paddle.float32)
w = paddle.to_tensor(w0, stop_gradient=False)
emb = paddle.nn.functional.embedding(
x=ids, weight=w, sparse=True, name="embedding"
)
self.assertEqual(emb.shape, [2])
res = [5.0, 6.0]
for i in range(len(res)):
self.assertEqual(emb.numpy()[i], res[i])
def test_one_hot_label(self):
label = paddle.full(shape=[], fill_value=2, dtype='int64')
one_hot_label = paddle.nn.functional.one_hot(label, num_classes=4)
self.assertEqual(one_hot_label.shape, [4])
self.assertEqual(one_hot_label.numpy()[2], 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册