From 4a4849736f91def1e01b75e78e7e54ddaa6b13ef Mon Sep 17 00:00:00 2001 From: heyanru <81976792+heyanru01@users.noreply.github.com> Date: Mon, 13 Mar 2023 17:11:05 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90prim=E3=80=91Maximum=20grad=20(#51006)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refresh * compat * register * testop * fix * fix * fox * cast * cast * fix * type * fix * out * cast * fix * fix * fix * broad * broad * broad * fix * fix * fix * fix * fix * broad * broad * numel * fix * fix * fix * fix * cinn * fix * fix * fix * fix --- .../elementwise/elementwise_max_op.cc | 35 +++++++++- paddle/fluid/prim/api/api.yaml | 2 + .../composite_backward_api.h | 46 +++++++++++++ paddle/fluid/prim/utils/static/desc_tensor.h | 2 + paddle/phi/api/yaml/legacy_backward.yaml | 1 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/test_elementwise_max_op.py | 67 ++++++++++++++++--- 7 files changed, 145 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc index 1911b5c2de6..0da6e495dc5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -15,6 +15,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { namespace framework { @@ -68,6 +71,35 @@ class ElementwiseFMaxOpMaker : public ElementwiseOpMaker { } }; +class ElementwiseMaxCompositeGradOpMaker + : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor y = this->GetSingleForwardInput("Y"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::Tensor dx = this->GetSingleInputGrad("X"); + auto* dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(dx); + paddle::Tensor dy = this->GetSingleInputGrad("Y"); + auto* dy_ptr = this->GetOutputPtr(&dy); + std::string dy_name = this->GetOutputName(dy); + int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument( + "We only support axis = -1 in composite maximum_grad but we got: ", + axis)); + VLOG(6) << "Runing maximum_grad composite func"; + prim::maximum_grad(x, y, out_grad, axis, dx_ptr, dy_ptr); + this->RecoverOutputName(dx, dx_name); + this->RecoverOutputName(dy, dy_name); + } +}; + template class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker { public: @@ -112,7 +144,8 @@ REGISTER_OPERATOR(elementwise_max, ops::ElementwiseMaxOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseMaxGradOpMaker, - ops::ElementwiseMaxGradOpMaker); + ops::ElementwiseMaxGradOpMaker, + ops::ElementwiseMaxCompositeGradOpMaker); REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad); diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 8d56ab4d462..529d024b8b8 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -31,3 +31,5 @@ - pad - cumsum - put_along_axis +- greater_than +- less_equal diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 2fe33c92ff8..248afe1e0d8 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -898,5 +898,51 @@ void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { } } +template +void maximum_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* x_grad, + Tensor* y_grad) { + if (x_grad) { + auto x_tmp = cast(greater_than(x, y), out_grad.dtype()); + auto dx_res = out_grad * x_tmp; + if (y.dims() != x.dims()) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + set_output(dx_res, x_grad); + } else { + auto dx_reduce_res = + dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + set_output(dx_tmp, x_grad); + } + } else { + set_output(dx_res, x_grad); + } + } + + if (y_grad) { + auto y_tmp = cast(less_equal(x, y), out_grad.dtype()); + auto dy_res = out_grad * y_tmp; + if (x.dims() != y.dims()) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + set_output(dy_res, y_grad); + } else { + auto dy_reduce_res = + dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + set_output(dy_tmp, y_grad); + } + } else { + set_output(dy_res, y_grad); + } + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/utils/static/desc_tensor.h b/paddle/fluid/prim/utils/static/desc_tensor.h index 6ac6caf4b02..7d8c939fec1 100644 --- a/paddle/fluid/prim/utils/static/desc_tensor.h +++ b/paddle/fluid/prim/utils/static/desc_tensor.h @@ -39,6 +39,8 @@ class DescTensor : public phi::ExtendedTensor, return dims_; } + int64_t numel() const override { return product(dims()); } + DataType dtype() const override { return paddle::framework::TransToPhiDataType(desc_ptr_->GetDataType()); } diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index d882665f156..da66d7486e3 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -752,6 +752,7 @@ param: [x, y] kernel : func : maximum_grad + composite : maximum_grad(x, y, out_grad, axis, x_grad, y_grad) - backward_op : mean_all_grad forward : mean_all(Tensor x) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6e17016571b..a355b2545dd 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1213,6 +1213,7 @@ set(TEST_CINN_OPS test_elementwise_mul_op test_gather_nd_op test_elementwise_pow_op + test_elementwise_max_op test_transpose_op test_reshape_op test_mean_op diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py index f6d17740687..a2220f20465 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py @@ -25,6 +25,8 @@ class TestElementwiseOp(OpTest): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" + self.enable_cinn = False # If x and y have the same value, the max() is not differentiable. # So we generate test data by the following method # to avoid them being too close to each other. @@ -42,25 +44,58 @@ class TestElementwiseOp(OpTest): def test_check_grad_normal(self): if hasattr(self, 'attrs'): - self.check_grad(['X', 'Y'], 'Out', check_eager=False) + if self.attrs['axis'] == -1: + self.check_grad( + ['X', 'Y'], 'Out', check_eager=False, check_prim=True + ) + else: + self.check_grad(['X', 'Y'], 'Out', check_eager=False) else: - self.check_grad(['X', 'Y'], 'Out', check_eager=True) + self.check_grad( + ['X', 'Y'], 'Out', check_eager=True, check_prim=True + ) def test_check_grad_ingore_x(self): - self.check_grad( - ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X") - ) + if hasattr(self, 'attrs') and self.attrs['axis'] != -1: + self.check_grad( + ['Y'], + 'Out', + max_relative_error=0.005, + no_grad_set=set("X"), + ) + else: + self.check_grad( + ['Y'], + 'Out', + max_relative_error=0.005, + no_grad_set=set("X"), + check_prim=True, + ) def test_check_grad_ingore_y(self): - self.check_grad( - ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y') - ) + if hasattr(self, 'attrs') and self.attrs['axis'] != -1: + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.005, + no_grad_set=set('Y'), + ) + else: + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.005, + no_grad_set=set('Y'), + check_prim=True, + ) class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" + self.enable_cinn = False x = np.random.uniform(0.1, 1, []).astype("float64") y = np.random.uniform(0.1, 1, []).astype("float64") self.inputs = {'X': x, 'Y': y} @@ -71,6 +106,8 @@ class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" + self.enable_cinn = False x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") y = np.random.uniform(0.1, 1, []).astype("float64") self.inputs = {'X': x, 'Y': y} @@ -81,6 +118,8 @@ class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" + self.enable_cinn = False x = np.random.uniform(0.1, 1, []).astype("float64") y = np.random.uniform(0.1, 1, [13, 17]).astype("float64") self.inputs = {'X': x, 'Y': y} @@ -99,6 +138,8 @@ class TestElementwiseBF16Op(OpTest): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" + self.enable_cinn = False self.dtype = np.uint16 # If x and y have the same value, the max() is not differentiable. # So we generate test data by the following method @@ -120,6 +161,7 @@ class TestElementwiseBF16Op(OpTest): def test_check_grad_normal(self): if hasattr(self, 'attrs'): + # check_prim=False, bfloat16 is not supported in `less_equal` self.check_grad(['X', 'Y'], 'Out', check_eager=False) else: self.check_grad(['X', 'Y'], 'Out', check_eager=True) @@ -138,6 +180,8 @@ class TestElementwiseMaxOp_scalar(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" + self.enable_cinn = False x = np.random.random_integers(-5, 5, [2, 3, 20]).astype("float64") y = np.array([0.5]).astype("float64") self.inputs = {'X': x, 'Y': y} @@ -148,6 +192,8 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" + self.enable_cinn = False x = np.random.random((100,)).astype("float64") sgn = np.random.choice([-1, 1], (100,)).astype("float64") y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype("float64") @@ -159,6 +205,7 @@ class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" x = np.random.uniform(0.5, 1, (100, 5, 2)).astype(np.float64) sgn = np.random.choice([-1, 1], (100,)).astype(np.float64) y = x[:, 0, 0] + sgn * np.random.uniform(1, 2, (100,)).astype( @@ -178,6 +225,7 @@ class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" x = np.random.uniform(0.5, 1, (2, 100, 3)).astype(np.float64) sgn = np.random.choice([-1, 1], (100,)).astype(np.float64) y = x[0, :, 0] + sgn * np.random.uniform(1, 2, (100,)).astype( @@ -197,6 +245,7 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" x = np.random.uniform(0.5, 1, (1, 3, 100)).astype(np.float64) sgn = np.random.choice([-1, 1], (100,)).astype(np.float64) y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype( @@ -215,6 +264,7 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" x = np.random.uniform(0.5, 1, (2, 50, 2, 1)).astype(np.float64) sgn = np.random.choice([-1, 1], (50, 2)).astype(np.float64) y = x[0, :, :, 0] + sgn * np.random.uniform(1, 2, (50, 2)).astype( @@ -234,6 +284,7 @@ class TestElementwiseMaxOp_broadcast_4(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_max" self.python_api = paddle.maximum + self.prim_op_type = "prim" x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float64) sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float64) y = x + sgn * np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float64) -- GitLab