未验证 提交 e65bac28 编写于 作者: Y Yuang Liu 提交者: GitHub

Update for scatter support fake 2d index (#47946)

上级 a00aebe1
...@@ -987,13 +987,22 @@ void ScatterInferMeta(const MetaTensor& x, ...@@ -987,13 +987,22 @@ void ScatterInferMeta(const MetaTensor& x,
const auto& updates_dims = updates.dims(); const auto& updates_dims = updates.dims();
const auto& ref_dims = x.dims(); const auto& ref_dims = x.dims();
const auto& index_dims = index.dims(); const auto& index_dims = index.dims();
PADDLE_ENFORCE_EQ(
index_dims.size(), if (index_dims.size() == 2) {
1, PADDLE_ENFORCE_EQ(index_dims[1],
phi::errors::InvalidArgument( 1,
"The size of Input(Ids)'s shape should be equal to 1, but " phi::errors::InvalidArgument(
"received the rank of Input(Ids) is %d.", "The last dim of the index should be 1 when the "
index_dims.size())); "index is a 2D tensor, but we get %d.",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
phi::errors::InvalidArgument("The index should be a 1D tensor when the "
"index is not a 2D tensor, but we get %d.",
index_dims.size()));
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ref_dims.size(), ref_dims.size(),
updates_dims.size(), updates_dims.size(),
......
...@@ -191,6 +191,25 @@ class TestScatterOp5(OpTest): ...@@ -191,6 +191,25 @@ class TestScatterOp5(OpTest):
) )
class TestScatterOp6(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
ref_np = np.ones((3, 50)).astype("float32")
index_np = np.array([[1], [2]]).astype("int32")
updates_np = np.random.random((2, 50)).astype("float32")
output_np = np.copy(ref_np)
output_np[np.array([1, 2]).astype("int32")] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(check_eager=False)
def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out", check_eager=False)
class TestScatterAPI(unittest.TestCase): class TestScatterAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.places = [fluid.CPUPlace()] self.places = [fluid.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册