From a0069278e96a5fe8965c176afe807badfd1c3036 Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 31 Mar 2023 17:02:31 +0800 Subject: [PATCH] [Prim] Add prod backward composite rule (#51238) * first commit * add registry * add unit test * fix format * add unit test * fix bug * replace unsuqeeze to reshape * fix * fix unit test * update test * update test * fix unit test * fix * fix --- .../operators/reduce_ops/reduce_prod_op.cc | 40 ++++++++++++++ .../composite_backward_api.h | 52 +++++++++++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + .../fluid/tests/unittests/test_reduce_op.py | 41 ++++++++++----- 4 files changed, 122 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc index 1ba1a1aa628..0a9aebbebac 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" @@ -27,6 +30,42 @@ class OpBase; } // namespace imperative } // namespace paddle +namespace paddle { +namespace operators { +class ReduceProdCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + public: + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + void Apply() override { + // get inputs + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor out = this->GetSingleForwardOutput("Out"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + + // get attr + std::vector axis = this->Attr>("dim"); + bool keep_dim = this->Attr("keep_dim"); + bool reduce_all = this->Attr("reduce_all"); + + // get output + paddle::Tensor x_grad_t = this->GetSingleInputGrad("X"); + + // get output ptr + auto x_grad = this->GetOutputPtr(&x_grad_t); + + // get output orginal name + std::string x_grad_name = this->GetOutputName(x_grad_t); + VLOG(6) << "Runing prod_grad composite func"; + // call composite backward func + prim::prod_grad( + x, out, out_grad, axis, keep_dim, reduce_all, x_grad); + // recover output name + this->RecoverOutputName(x_grad_t, x_grad_name); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; class ReduceProdOpMaker : public ops::ReduceBaseOpMaker { @@ -46,5 +85,6 @@ REGISTER_OPERATOR( ReduceProdOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker, + ops::ReduceProdCompositeGradOpMaker, ReduceProdInferShapeFunctor); REGISTER_OPERATOR(reduce_prod_grad, ops::ReduceGradOp); diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 8371dfae2f5..98d5ca4845b 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1069,6 +1069,58 @@ void gather_nd_grad(const Tensor& x, } } +template +void prod_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + const IntArray& axis, + bool keep_dim, + bool reduce_all, + Tensor* x_grad) { + if (x_grad) { + std::vector x_dim = phi::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + auto out_tmp = Tensor(); + if (x_dim_size == 1) { + x_grad_tmp = out_grad.expand(IntArray(x_dim)); + out_tmp = out.expand(IntArray(x_dim)); + } else { + if (!keep_dim) { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 1; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + auto out_grad_ = unsqueeze(out_grad, axis_); + x_grad_tmp = out_grad_.expand(IntArray(x_dim)); + auto out_ = unsqueeze(out, axis_); + out_tmp = out_.expand(IntArray(x_dim)); + } else { + x_grad_tmp = out_grad.expand(IntArray(x_dim)); + out_tmp = out.expand(IntArray(x_dim)); + } + } + auto x_grad_res = x_grad_tmp * out_tmp * (1 / x); + set_output(x_grad_res, x_grad); + } +} + template void max_grad(const Tensor& x, const Tensor& out, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index cfe7930c8c3..7bf61e29931 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -875,6 +875,7 @@ param : [x] kernel : func : prod_grad + composite: prod_grad(x, out, out_grad, dims, keep_dim, reduce_all, x_grad) - backward_op : psroi_pool_grad forward : psroi_pool (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, int output_channels, float spatial_scale) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 0312671bfc4..72d1f4f2d69 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -407,6 +407,9 @@ class TestProdOp(OpTest): def setUp(self): self.op_type = "reduce_prod" self.python_api = raw_reduce_prod + self.public_python_api = raw_reduce_prod + self.prim_op_type = "prim" + self.init_data_type() self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)} self.outputs = {'Out': self.inputs['X'].prod(axis=0)} @@ -420,17 +423,27 @@ class TestProdOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_prim=True) + + +class TestProdOpFp64(TestProdOp): + def init_data_type(self): + self.data_type = "float64" class TestProdOp_ZeroDim(OpTest): def setUp(self): - self.python_api = paddle.prod + self.python_api = raw_reduce_prod + self.public_python_api = raw_reduce_prod self.op_type = "reduce_prod" + self.prim_op_type = "prim" self.inputs = {'X': np.random.random([]).astype("float64")} self.outputs = {'Out': self.inputs['X'].prod()} self.attrs = {'dim': [], 'reduce_all': True} + # 0-D tensor doesn't support in cinn + self.enable_cinn = False + def test_check_output(self): self.check_output() @@ -442,6 +455,8 @@ class TestProd6DOp(OpTest): def setUp(self): self.op_type = "reduce_prod" self.python_api = raw_reduce_prod + self.public_python_api = raw_reduce_prod + self.prim_op_type = "prim" self.init_data_type() self.inputs = { 'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type) @@ -460,13 +475,14 @@ class TestProd6DOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_prim=True) class TestProd8DOp(OpTest): def setUp(self): self.op_type = "reduce_prod" self.python_api = raw_reduce_prod + self.public_python_api = raw_reduce_prod self.init_data_type() self.inputs = { 'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype( @@ -1178,15 +1194,16 @@ class TestReduceWithDtype2(TestReduceWithDtype): class TestReduceSumOpError(unittest.TestCase): def test_errors(self): - with program_guard(Program(), Program()): - # The input type of reduce_sum_op must be Variable. - x1 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace() - ) - self.assertRaises(TypeError, paddle.sum, x1) - # The input dtype of reduce_sum_op must be float32 or float64 or int32 or int64. - x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="uint8") - self.assertRaises(TypeError, paddle.sum, x2) + with paddle.fluid.framework._static_guard(): + with program_guard(Program(), Program()): + # The input type of reduce_sum_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace() + ) + self.assertRaises(TypeError, paddle.sum, x1) + # The input dtype of reduce_sum_op must be float32 or float64 or int32 or int64. + x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="uint8") + self.assertRaises(TypeError, paddle.sum, x2) class API_TestSumOp(unittest.TestCase): -- GitLab