From d9de3ef672595156106572baf97b4f7a1b52b883 Mon Sep 17 00:00:00 2001 From: wangzhen38 <41941775+wangzhen38@users.noreply.github.com> Date: Thu, 9 Mar 2023 09:19:10 +0800 Subject: [PATCH] [prim] add elementwise_pow backward (#51230) * [cinn] add elementwise_pow backward * [cinn] update unnitest * [cinn] update by comments * [cinn] for ci * [cinn] for ci * [cinn] for ci * [cinn] for ci * [cinn] for ci --- .../elementwise/elementwise_pow_op.cc | 37 +++++++++++- .../composite_backward_api.h | 50 ++++++++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/test_elementwise_pow_op.py | 57 ++++++++++++++++++- 5 files changed, 143 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc index fcfee9b4fca..0273743c95a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc @@ -12,6 +12,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { namespace framework { @@ -41,6 +44,36 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); } }; + +class ElementwisePowCompositeGradOpMaker + : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor y = this->GetSingleForwardInput("Y"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::Tensor dx = this->GetSingleInputGrad("X"); + auto dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(dx); + paddle::Tensor dy = this->GetSingleInputGrad("Y"); + auto dy_ptr = this->GetOutputPtr(&dy); + std::string dy_name = this->GetOutputName(dy); + int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument( + "We only support axis = -1 in composite pow but we got: ", axis)); + VLOG(6) << "Runing pow_grad composite func"; + prim::elementwise_pow_grad( + x, y, out_grad, axis, dx_ptr, dy_ptr); + this->RecoverOutputName(dx, dx_name); + this->RecoverOutputName(dy, dy_name); + } +}; + class ElementwisePowOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Pow"; } @@ -65,7 +98,9 @@ REGISTER_OPERATOR(elementwise_pow, ops::ElementwiseOpInferVarType, ops::ElementwisePowOpGradMaker, ops::ElementwisePowOpGradMaker); -REGISTER_OPERATOR(elementwise_pow_grad, ops::ElementwiseOpGrad); +REGISTER_OPERATOR(elementwise_pow_grad, + ops::ElementwiseOpGrad, + ops::ElementwisePowCompositeGradOpMaker); REGISTER_OP_VERSION(elementwise_pow) .AddCheckpoint( diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index ffbeff8cade..3fa35c034a0 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -278,6 +278,56 @@ void divide_grad(const Tensor& x, } // indicate we will compute dx } +template +void elementwise_pow_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* dx, + Tensor* dy) { + if (dy) { + // dy = lnx * x^y + auto lnx = log(x); + auto x_pow_y = elementwise_pow(x, y); + auto dy_res = lnx * x_pow_y; + if (x.dims() != y.dims()) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + set_output(dy_res, dy); + } else { + auto dy_reduce_res = + dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + set_output(dy_tmp, dy); + } + } else { + set_output(dy_res, dy); + } + } // indicate we will compute dy + if (dx) { + // dx = y * x^(y-1) + auto tmp_z = y - 1.0; + auto x_pow_z = elementwise_pow(x, tmp_z); + auto dx_res = y * x_pow_z; + if (y.dims() != x.dims()) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + set_output(dx_res, dx); + } else { + auto dx_reduce_res = + dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + set_output(dx_tmp, dx); + } + + } else { + set_output(dx_res, dx); + } + } // indicate we will compute dx +} + template void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 22106521f29..9129bc803c0 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -423,6 +423,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param: [x, y] + composite : elementwise_pow_grad(x, y, out_grad, axis) kernel : func : elementwise_pow_grad diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 5229884fb24..2f079ea88ac 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1218,6 +1218,7 @@ set(TEST_CINN_OPS test_elementwise_div_op test_elementwise_mul_op test_gather_nd_op + test_elementwise_pow_op test_transpose_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py index 12a42b780d5..d0e97679ab2 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py @@ -31,6 +31,7 @@ class TestElementwisePowOp(OpTest): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.prim_op_type = "prim" self.inputs = { 'X': np.random.uniform(1, 2, [20, 5]).astype("float64"), 'Y': np.random.uniform(1, 2, [20, 5]).astype("float64"), @@ -45,15 +46,22 @@ class TestElementwisePowOp(OpTest): def test_check_grad_normal(self): if hasattr(self, 'attrs'): - self.check_grad(['X', 'Y'], 'Out', check_eager=False) + self.check_grad( + ['X', 'Y'], 'Out', check_eager=False, check_prim=True + ) else: - self.check_grad(['X', 'Y'], 'Out', check_eager=True) + self.check_grad( + ['X', 'Y'], 'Out', check_eager=True, check_prim=True + ) class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.enable_cinn = False + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(1, 2, []).astype("float64"), 'Y': np.random.uniform(1, 2, []).astype("float64"), @@ -65,6 +73,9 @@ class TestElementwisePowOp_ZeroDim2(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.enable_cinn = False + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(1, 2, [20, 5]).astype("float64"), 'Y': np.random.uniform(1, 2, []).astype("float64"), @@ -76,6 +87,9 @@ class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.enable_cinn = False + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(1, 2, []).astype("float64"), 'Y': np.random.uniform(1, 2, [20, 5]).astype("float64"), @@ -87,6 +101,8 @@ class TestElementwisePowOp_big_shape_1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(1, 2, [10, 10]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [10, 10]).astype("float64"), @@ -98,6 +114,8 @@ class TestElementwisePowOp_big_shape_2(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(1, 2, [10, 10]).astype("float64"), 'Y': np.random.uniform(0.2, 2, [10, 10]).astype("float64"), @@ -112,6 +130,8 @@ class TestElementwisePowOp_scalar(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(0.1, 1, [3, 3, 4]).astype(np.float64), 'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64), @@ -123,6 +143,9 @@ class TestElementwisePowOp_tensor(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(0.1, 1, [100]).astype("float64"), 'Y': np.random.uniform(1, 3, [100]).astype("float64"), @@ -134,6 +157,8 @@ class TestElementwisePowOp_broadcast_0(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 1, 100]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64"), @@ -145,6 +170,7 @@ class TestElementwisePowOp_broadcast_1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 100, 1]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64"), @@ -154,11 +180,18 @@ class TestElementwisePowOp_broadcast_1(TestElementwisePowOp): 'Out': np.power(self.inputs['X'], self.inputs['Y'].reshape(100, 1)) } + def test_check_grad_normal(self): + if hasattr(self, 'attrs'): + self.check_grad(['X', 'Y'], 'Out', check_eager=False) + else: + self.check_grad(['X', 'Y'], 'Out', check_eager=True) + class TestElementwisePowOp_broadcast_2(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.inputs = { 'X': np.random.uniform(0.1, 1, [100, 3, 1]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64"), @@ -170,11 +203,18 @@ class TestElementwisePowOp_broadcast_2(TestElementwisePowOp): ) } + def test_check_grad_normal(self): + if hasattr(self, 'attrs'): + self.check_grad(['X', 'Y'], 'Out', check_eager=False) + else: + self.check_grad(['X', 'Y'], 'Out', check_eager=True) + class TestElementwisePowOp_broadcast_3(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 20, 5, 1]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [20, 5]).astype("float64"), @@ -186,11 +226,19 @@ class TestElementwisePowOp_broadcast_3(TestElementwisePowOp): ) } + def test_check_grad_normal(self): + if hasattr(self, 'attrs'): + self.check_grad(['X', 'Y'], 'Out', check_eager=False) + else: + self.check_grad(['X', 'Y'], 'Out', check_eager=True) + class TestElementwisePowOp_broadcast_4(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(0.1, 1, [2, 10, 3, 5]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [2, 10, 1, 5]).astype("float64"), @@ -202,6 +250,7 @@ class TestElementwisePowOpInt(OpTest): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.inputs = {'X': np.asarray([1, 3, 6]), 'Y': np.asarray([1, 1, 1])} self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} @@ -217,6 +266,7 @@ class TestElementwisePowGradOpInt(unittest.TestCase): self.x = np.asarray([1, 3, 6]) self.y = np.asarray([1, 1, 1]) self.res = self.x**self.y + # dout = 1 self.grad_res = np.asarray([1, 1, 1]) # dx = dout * y * pow(x, y-1) @@ -250,6 +300,8 @@ class TestElementwisePowOpFP16(OpTest): def setUp(self): self.op_type = "elementwise_pow" self.python_api = paddle.pow + self.prim_op_type = "prim" + self.inputs = { 'X': np.random.uniform(1, 2, [20, 5]).astype("float16"), 'Y': np.random.uniform(1, 2, [20, 5]).astype("float16"), @@ -270,6 +322,7 @@ class TestElementwisePowOpFP16(OpTest): self.inputs['X'], self.inputs['Y'], 1 / self.inputs['X'].size ), check_eager=True, + check_prim=True, ) -- GitLab