From 57d4288ad4c45ca83e25f900f3aacd90626d3202 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Tue, 30 Mar 2021 21:01:20 +0800 Subject: [PATCH] [dynamic setitem] Fix bug of dynamic setitem: Decerease axes to do right broadcast (#31960) --- paddle/fluid/pybind/imperative.cc | 14 ++++++---- .../tests/unittests/test_set_value_op.py | 28 ++++++++++++++++--- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 58ef1778630..eed3b3b7691 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -611,15 +611,17 @@ void BindImperative(py::module *m_ptr) { // TODO(liym27): Try not to call TensorToPyArray because it always // copys data to cpu place, which reduces performance. if (parse_index && value_is_tensor) { - std::vector axes, starts, ends, steps, decrease_axis, + std::vector axes, starts, ends, steps, decrease_axes, infer_flags; ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, - &steps, &decrease_axis, &infer_flags); + &steps, &decrease_axes, &infer_flags); - framework::AttributeMap attrs = {{"axes", axes}, - {"starts", starts}, - {"ends", ends}, - {"steps", steps}}; + framework::AttributeMap attrs = { + {"axes", axes}, + {"starts", starts}, + {"ends", ends}, + {"steps", steps}, + {"decrease_axes", decrease_axes}}; imperative::NameVarBaseMap ins = {{"Input", {self}}}; imperative::NameVarBaseMap outs = {{"Out", {self}}}; diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 808d77d4761..0885891cdbe 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -48,18 +48,37 @@ class TestSetValueBase(unittest.TestCase): class TestSetValueApi(TestSetValueBase): - def test_api(self): + def _run_static(self): + paddle.enable_static() with paddle.static.program_guard(self.program): x = paddle.ones(shape=self.shape, dtype=self.dtype) self._call_setitem(x) exe = paddle.static.Executor(paddle.CPUPlace()) out = exe.run(self.program, fetch_list=[x]) + paddle.disable_static() + return out + + def _run_dynamic(self): + paddle.disable_static() + x = paddle.ones(shape=self.shape, dtype=self.dtype) + self._call_setitem(x) + out = x.numpy() + paddle.enable_static() + return out + + def test_api(self): + static_out = self._run_static() + dynamic_out = self._run_dynamic() self._get_answer() + + error_msg = "\nIn {} mode: \nExpected res = \n{}, \n\nbut received : \n{}" self.assertTrue( - (self.data == out).all(), - msg="\nExpected res = \n{}, \n\nbut received : \n{}".format( - self.data, out)) + (self.data == static_out).all(), + msg=error_msg.format("static", self.data, static_out)) + self.assertTrue( + (self.data == dynamic_out).all(), + msg=error_msg.format("dynamic", self.data, dynamic_out)) # 1. Test different type of item: int, Python slice, Paddle Tensor @@ -748,6 +767,7 @@ class TestError(TestSetValueBase): exe.run(program) def test_error(self): + paddle.enable_static() with paddle.static.program_guard(self.program): self._value_type_error() self._dtype_error() -- GitLab