diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc index 21e16a5cd14c65c40123ab62a381419e7b34d915..0d5320d56346589ab26fa48a48fad32d21f5597e 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc @@ -14,6 +14,9 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" @@ -25,6 +28,32 @@ class ReduceMaxOpMaker : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_max"; } }; +namespace paddle { +namespace operators { + +class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + public: + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + void Apply() override { + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor out = this->GetSingleForwardOutput("Out"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + std::vector axis = this->Attr>("dim"); + bool keep_dim = this->Attr("keep_dim"); + bool reduce_all = this->Attr("reduce_all"); + paddle::Tensor x_grad_t = this->GetSingleInputGrad("X"); + paddle::Tensor* x_grad = this->GetOutputPtr(&x_grad_t); + std::string x_grad_name = this->GetOutputName(x_grad_t); + VLOG(6) << "Runing max_grad composite func"; + prim::max_grad( + x, out, out_grad, axis, keep_dim, reduce_all, x_grad); + this->RecoverOutputName(x_grad_t, x_grad_name); + } +}; + +} // namespace operators +} // namespace paddle + DECLARE_INFER_SHAPE_FUNCTOR( reduce_max, ReduceMaxInferShapeFunctor, @@ -36,5 +65,6 @@ REGISTER_OPERATOR( ReduceMaxOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker, + ops::ReduceMaxCompositeGradOpMaker, ReduceMaxInferShapeFunctor); REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp) diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 3336acd2926a7b25df7f38ffdbb9fe94aa883553..c5eadec1e079d599ab245c8a7fdae6f9217f03f3 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -38,6 +38,7 @@ - sqrt - cumsum - put_along_axis +- equal - greater_than - less_equal - sin diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index ddddc7c0b75b2dd77caeeca35eb046c2b751cbe3..57203557fb5d118c2203241dfaf9907f39f6d07d 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1044,6 +1044,57 @@ void gather_nd_grad(const Tensor& x, } } +template +void max_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + std::vector x_dim = phi::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + if (x_dim_size == 0 || x_dim_size == 1 || keepdim) { + auto out_grad_tmp = out_grad.expand(IntArray(x_dim)); + auto out_tmp = out.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } else { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 1; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + auto out_grad_ = unsqueeze(out_grad, axis_); + auto out_ = unsqueeze(out, axis_); + auto out_grad_tmp = out_grad_.expand(IntArray(x_dim)); + auto out_tmp = out_.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } + set_output(x_grad_tmp, x_grad); +} + template void assign_grad(const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 38a6ca3155de85706f2ea13f28652b73db957e2c..d669f77929554aaab2469f5e3f4a24435388f2f3 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -727,6 +727,7 @@ param: [x] kernel : func : max_grad + composite : max_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad) - backward_op : max_pool2d_with_index_grad forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) -> Tensor(out), Tensor(mask) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 3798d8c818da8c14766a2284f902d6996577e33c..718ebbde816ca7346f89a6b9cb8b1c248842ad1b 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -240,7 +240,9 @@ class TestMaxOp(OpTest): def setUp(self): self.op_type = "reduce_max" + self.prim_op_type = "prim" self.python_api = paddle.max + self.public_python_api = paddle.max self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.attrs = {'dim': [-1]} self.outputs = { @@ -250,6 +252,16 @@ class TestMaxOp(OpTest): def test_check_output(self): self.check_output(check_eager=True) + def test_check_grad(self): + # only composite op support gradient check of reduce_max + self.check_grad( + ['X'], + 'Out', + check_eager=True, + check_prim=True, + only_check_prim=True, + ) + def test_raise_error(self): if core.is_compiled_with_cuda(): self.inputs = {'X': np.random.random((5, 6, 10)).astype("float16")} @@ -268,7 +280,10 @@ class TestMaxOp_ZeroDim(OpTest): def setUp(self): self.op_type = "reduce_max" + self.prim_op_type = "prim" self.python_api = paddle.max + self.public_python_api = paddle.max + self.enable_cinn = False self.inputs = {'X': np.random.random([]).astype("float64")} self.attrs = {'dim': []} self.outputs = { @@ -278,6 +293,46 @@ class TestMaxOp_ZeroDim(OpTest): def test_check_output(self): self.check_output(check_eager=True) + def test_check_grad(self): + # only composite op support gradient check of reduce_max + self.check_grad( + ['X'], + 'Out', + check_eager=True, + check_prim=True, + only_check_prim=True, + ) + + +class TestMaxOp_FP32(OpTest): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.prim_op_type = "prim" + self.python_api = paddle.max + self.public_python_api = paddle.max + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': [-1], 'keep_dim': True} + self.outputs = { + 'Out': self.inputs['X'].max( + axis=tuple(self.attrs['dim']), keepdims=True + ) + } + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + # only composite op support gradient check of reduce_max + self.check_grad( + ['X'], + 'Out', + check_eager=True, + check_prim=True, + only_check_prim=True, + ) + @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function," @@ -829,7 +884,9 @@ class TestReduceMaxOpMultiAxises(OpTest): def setUp(self): self.op_type = "reduce_max" + self.prim_op_type = "prim" self.python_api = paddle.max + self.public_python_api = paddle.max self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.attrs = {'dim': [-2, -1]} self.outputs = { @@ -839,6 +896,16 @@ class TestReduceMaxOpMultiAxises(OpTest): def test_check_output(self): self.check_output(check_eager=True) + def test_check_grad(self): + # only composite op support gradient check of reduce_max + self.check_grad( + ['X'], + 'Out', + check_eager=True, + check_prim=True, + only_check_prim=True, + ) + @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function,"