From aa8877174be3e5f1d09bb6e51825e229b1908f72 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 5 May 2023 10:30:48 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90prim=E3=80=91modify=20assign=20api=20s?= =?UTF-8?q?etOutput=20in=20by=5Fpass=20(#53417)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * modify concat_grad add sum comp rule * cast and by_pass modify * only modify by_pass * modify by_pass --- .../composite_backward_api.h | 1 + .../api/manual_prim/utils/static_utils.cc | 10 +-- .../fluid/tests/unittests/test_softmax_op.py | 83 ++++++++++--------- 3 files changed, 49 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 78c4ee61f3d..8ad54079ca0 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { set_output(res, x_grad); } } + template void gather_grad(const Tensor& x, const Tensor& index, diff --git a/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc index f6c4f6b2e8b..d76a8ad5523 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc @@ -53,20 +53,20 @@ void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { } template <> -void by_pass(const paddle::Tensor& x, paddle::Tensor* out) { - Tensor new_out = - empty({}, phi::DataType::FLOAT32, paddle::Place()); +void by_pass(const paddle::Tensor& x, paddle::Tensor* real_out) { framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); framework::OpDesc* op = block->AppendOp(); op->SetType("assign"); op->SetInput("X", {std::static_pointer_cast(x.impl())->Name()}); + auto out = empty({}, x.dtype(), paddle::Place()); op->SetOutput( - "Out", {std::static_pointer_cast(out->impl())->Name()}); + "Out", {std::static_pointer_cast(out.impl())->Name()}); op->CheckAttrs(); op->InferVarType(block); op->InferShape(*block); - set_output(new_out, out); + + set_output(out, real_out); } } // namespace prim diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 61494bc16f9..9dc1f8d5408 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -462,16 +462,17 @@ class TestSoftmaxAPI(unittest.TestCase): self.softmax = F.softmax def test_static_check(self): - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data('X', self.x_np.shape, 'float32') - out1 = self.softmax(x) - m = paddle.nn.Softmax() - out2 = m(x) - exe = paddle.static.Executor(self.place) - res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) - out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) - for r in res: - np.testing.assert_allclose(out_ref, r, rtol=1e-05) + with paddle.fluid.framework._static_guard(): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.x_np.shape, 'float32') + out1 = self.softmax(x) + m = paddle.nn.Softmax() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) + for r in res: + np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_check(self): paddle.disable_static(self.place) @@ -505,19 +506,20 @@ class TestSoftmaxAPI(unittest.TestCase): paddle.enable_static() def test_error(self): - with paddle.static.program_guard(paddle.static.Program()): - # The input type must be Variable. - self.assertRaises(TypeError, self.softmax, 1) - # The input dtype must be float16, float32, float64. - x_int32 = paddle.static.data( - name='x_int32', shape=[2, 3], dtype='int32' - ) - self.assertRaises(TypeError, self.softmax, x_int32) - # support the input dtype is float16 - x_fp16 = paddle.static.data( - name='x_fp16', shape=[2, 3], dtype='float16' - ) - self.softmax(x_fp16) + with paddle.fluid.framework._static_guard(): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, self.softmax, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.static.data( + name='x_int32', shape=[2, 3], dtype='int32' + ) + self.assertRaises(TypeError, self.softmax, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.static.data( + name='x_fp16', shape=[2, 3], dtype='float16' + ) + self.softmax(x_fp16) class TestSoftmaxAPI_ZeroDim(unittest.TestCase): @@ -538,23 +540,24 @@ class TestSoftmaxAPI_ZeroDim(unittest.TestCase): paddle.enable_static() def test_static(self): - main_prog = fluid.Program() - with fluid.program_guard(main_prog, fluid.Program()): - x = paddle.rand([]) - x.stop_gradient = False - out = paddle.nn.functional.softmax(x) - fluid.backward.append_backward(out) - - # Test compile shape - self.assertEqual(x.shape, ()) - self.assertEqual(out.shape, ()) - - exe = fluid.Executor() - result = exe.run(main_prog, fetch_list=[x, out]) - - # Test runtime shape - self.assertEqual(result[0].shape, ()) - self.assertEqual(result[1].shape, ()) + with paddle.fluid.framework._static_guard(): + main_prog = fluid.Program() + with fluid.program_guard(main_prog, fluid.Program()): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.nn.functional.softmax(x) + fluid.backward.append_backward(out) + + # Test compile shape + self.assertEqual(x.shape, ()) + self.assertEqual(out.shape, ()) + + exe = fluid.Executor() + result = exe.run(main_prog, fetch_list=[x, out]) + + # Test runtime shape + self.assertEqual(result[0].shape, ()) + self.assertEqual(result[1].shape, ()) class TestSoftmaxInplaceAPI(TestSoftmaxAPI): -- GitLab