未验证 提交 57d4288a 编写于 作者: L liym27 提交者: GitHub

[dynamic setitem] Fix bug of dynamic setitem: Decerease axes to do right broadcast (#31960)

上级 0fa6c8a3
...@@ -611,15 +611,17 @@ void BindImperative(py::module *m_ptr) { ...@@ -611,15 +611,17 @@ void BindImperative(py::module *m_ptr) {
// TODO(liym27): Try not to call TensorToPyArray because it always // TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance. // copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) { if (parse_index && value_is_tensor) {
std::vector<int> axes, starts, ends, steps, decrease_axis, std::vector<int> axes, starts, ends, steps, decrease_axes,
infer_flags; infer_flags;
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
&steps, &decrease_axis, &infer_flags); &steps, &decrease_axes, &infer_flags);
framework::AttributeMap attrs = {{"axes", axes}, framework::AttributeMap attrs = {
{"starts", starts}, {"axes", axes},
{"ends", ends}, {"starts", starts},
{"steps", steps}}; {"ends", ends},
{"steps", steps},
{"decrease_axes", decrease_axes}};
imperative::NameVarBaseMap ins = {{"Input", {self}}}; imperative::NameVarBaseMap ins = {{"Input", {self}}};
imperative::NameVarBaseMap outs = {{"Out", {self}}}; imperative::NameVarBaseMap outs = {{"Out", {self}}};
......
...@@ -48,18 +48,37 @@ class TestSetValueBase(unittest.TestCase): ...@@ -48,18 +48,37 @@ class TestSetValueBase(unittest.TestCase):
class TestSetValueApi(TestSetValueBase): class TestSetValueApi(TestSetValueBase):
def test_api(self): def _run_static(self):
paddle.enable_static()
with paddle.static.program_guard(self.program): with paddle.static.program_guard(self.program):
x = paddle.ones(shape=self.shape, dtype=self.dtype) x = paddle.ones(shape=self.shape, dtype=self.dtype)
self._call_setitem(x) self._call_setitem(x)
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(self.program, fetch_list=[x]) out = exe.run(self.program, fetch_list=[x])
paddle.disable_static()
return out
def _run_dynamic(self):
paddle.disable_static()
x = paddle.ones(shape=self.shape, dtype=self.dtype)
self._call_setitem(x)
out = x.numpy()
paddle.enable_static()
return out
def test_api(self):
static_out = self._run_static()
dynamic_out = self._run_dynamic()
self._get_answer() self._get_answer()
error_msg = "\nIn {} mode: \nExpected res = \n{}, \n\nbut received : \n{}"
self.assertTrue( self.assertTrue(
(self.data == out).all(), (self.data == static_out).all(),
msg="\nExpected res = \n{}, \n\nbut received : \n{}".format( msg=error_msg.format("static", self.data, static_out))
self.data, out)) self.assertTrue(
(self.data == dynamic_out).all(),
msg=error_msg.format("dynamic", self.data, dynamic_out))
# 1. Test different type of item: int, Python slice, Paddle Tensor # 1. Test different type of item: int, Python slice, Paddle Tensor
...@@ -748,6 +767,7 @@ class TestError(TestSetValueBase): ...@@ -748,6 +767,7 @@ class TestError(TestSetValueBase):
exe.run(program) exe.run(program)
def test_error(self): def test_error(self):
paddle.enable_static()
with paddle.static.program_guard(self.program): with paddle.static.program_guard(self.program):
self._value_type_error() self._value_type_error()
self._dtype_error() self._dtype_error()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册