diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc index e0763d769f047a963ea8e4905a9e79e1b583703a..85b247781a40da1220f422252db7ef0ce6446b31 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/operators/elementwise/elementwise_npu.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" #include "paddle/fluid/operators/npu_op_runner.h" @@ -27,21 +28,198 @@ template class ElementwisePowNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); auto place = ctx.GetPlace(); + int axis = ctx.Attr("axis"); out->mutable_data(place); - auto stream = - ctx.template device_context() - .stream(); + bool direct_compute = false; + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = + (axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis); + if (x_dims.size() >= y_dims.size()) { + direct_compute = + y_dims == framework::slice_ddim(x_dims, axis, x_dims.size()); + } else { + direct_compute = + x_dims == framework::slice_ddim(y_dims, axis, y_dims.size()); + } + + auto stream = dev_ctx.stream(); + + if (direct_compute) { + const auto& runner = NpuOpRunner("Pow", {*x, *y}, {*out}, {}); + runner.Run(stream); + } else { + Tensor transformed_x, transformed_y; + NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, + &transformed_y); + const auto& runner = + NpuOpRunner("Pow", {transformed_x, transformed_y}, {*out}, {}); + runner.Run(stream); + } + } +}; - const auto& runner = NpuOpRunner("Pow", {*x, *y}, {*out}, {}); - runner.Run(stream); +template +class ElementwisePowGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + auto place = ctx.GetPlace(); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = + (axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis); + Tensor transformed_x, transformed_y; + NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, + &transformed_y); + + auto dout_dims = dout->dims(); + auto stream = dev_ctx.stream(); + // Reshape info vector. + std::vector reduce_axes; + if (dx) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(dout_dims, place); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + + dx->mutable_data(place); + Tensor tmp_dx; + tmp_dx.mutable_data(dout_dims, place); + + // dx = dout * y * pow(x, y - 1); + Tensor PowGrad_dx_temp1(dout->type()); + PowGrad_dx_temp1.mutable_data(dout->dims(), place); + const auto& runner_PowGrad_dx_temp1 = + NpuOpRunner("Mul", {*dout, transformed_y}, {PowGrad_dx_temp1}, {}); + runner_PowGrad_dx_temp1.Run(stream); + + Tensor one_dx(transformed_y.type()); + one_dx.mutable_data(transformed_y.dims(), place); + const auto& runner_one_dx = + NpuOpRunner("OnesLike", {transformed_y}, {one_dx}, {}); + runner_one_dx.Run(stream); + + Tensor sub_dx(transformed_y.type()); + sub_dx.mutable_data(transformed_y.dims(), place); + const auto& runner_sub_dx = + NpuOpRunner("Sub", {transformed_y, one_dx}, {sub_dx}, {}); + runner_sub_dx.Run(stream); + + Tensor PowGrad_dx_temp2(transformed_x.type()); + PowGrad_dx_temp2.mutable_data(transformed_x.dims(), place); + const auto& runner_PowGrad_dx_temp2 = + NpuOpRunner("Pow", {transformed_x, sub_dx}, {PowGrad_dx_temp2}, {}); + runner_PowGrad_dx_temp2.Run(stream); + + const auto& runner_dx = NpuOpRunner( + "Mul", {PowGrad_dx_temp1, PowGrad_dx_temp2}, {tmp_dx}, {}); + runner_dx.Run(stream); + + if (x_dims != dout_dims) { + reduce_axes.clear(); + + int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + x_dims.size()) || + (dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dx, place, dev_ctx, dx); + } + } + if (dy) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(dout_dims, place); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + + dy->mutable_data(place); + Tensor tmp_dy; + tmp_dy.mutable_data(dout_dims, place); + + // dy = dout * log(x) * pow(x, y) + Tensor PowGrad_dy_temp1(transformed_x.type()); + PowGrad_dy_temp1.mutable_data(transformed_x.dims(), place); + const auto& runner_PowGrad_dy_temp1 = NpuOpRunner( + "Pow", {transformed_x, transformed_y}, {PowGrad_dy_temp1}, {}); + runner_PowGrad_dy_temp1.Run(stream); + + Tensor one_dy(transformed_x.type()); + one_dy.mutable_data(transformed_x.dims(), place); + const auto& runner_one_dy = + NpuOpRunner("OnesLike", {transformed_x}, {one_dy}, {}); + runner_one_dy.Run(stream); + + Tensor sub_dy(transformed_x.type()); + sub_dy.mutable_data(transformed_x.dims(), place); + const auto& runner_sub_dy = + NpuOpRunner("Sub", {transformed_x, one_dy}, {sub_dy}, {}); + runner_sub_dy.Run(stream); + + Tensor log_dy(transformed_x.type()); + log_dy.mutable_data(transformed_x.dims(), place); + const auto& runner_log_dy = NpuOpRunner("Log1p", {sub_dy}, {log_dy}, {}); + runner_log_dy.Run(stream); + + Tensor PowGrad_dy_temp2(transformed_x.type()); + PowGrad_dy_temp2.mutable_data(transformed_x.dims(), place); + const auto& runner_PowGrad_dy_temp2 = NpuOpRunner( + "Mul", {log_dy, PowGrad_dy_temp1}, {PowGrad_dy_temp2}, {}); + runner_PowGrad_dy_temp2.Run(stream); + + const auto& runner_dy = + NpuOpRunner("Mul", {*dout, PowGrad_dy_temp2}, {tmp_dy}, {}); + runner_dy.Run(stream); + + if (y_dims != dout_dims) { + reduce_axes.clear(); + + int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + y_dims.size()) || + (dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dy, place, dev_ctx, dy); + } + } + if (!dx && !dy) { + PADDLE_THROW(platform::errors::Unavailable( + "Not support all outputs to be empty.")); + } } }; @@ -49,9 +227,18 @@ class ElementwisePowNPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( elementwise_pow, - ops::ElementwisePowNPUKernel, - ops::ElementwisePowNPUKernel); + ops::ElementwisePowNPUKernel, + ops::ElementwisePowNPUKernel, + ops::ElementwisePowNPUKernel, + ops::ElementwisePowNPUKernel); + +REGISTER_OP_NPU_KERNEL( + elementwise_pow_grad, + ops::ElementwisePowGradNPUKernel, + ops::ElementwisePowGradNPUKernel, + ops::ElementwisePowGradNPUKernel, + ops::ElementwisePowGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py index dea1828a6d75fca3e6a871207e8a746305169a6c..ce645f317d054c264a730c150df42bccbfabbeee 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py @@ -13,19 +13,71 @@ # limitations under the License. from __future__ import print_function +import paddle.fluid as fluid +import paddle +from op_test import OpTest import numpy as np import unittest import sys sys.path.append("..") -from op_test import OpTest -import paddle -import paddle.fluid as fluid paddle.enable_static() SEED = 2021 +def ComputeGrad(x, y, out, axis): + grad = 1 / out.size + shape_x = x.shape + shape_y = y.shape + shape_out = out.shape + reduce_axes_x = [] + reduce_axes_y = [] + + if shape_x != shape_out: + if len(shape_x) < len(shape_out): + src_axis = axis + else: + src_axis = 0 + + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_x)) or ( + shape_out[ax] > 1 and shape_x[ax - src_axis] == 1): + reduce_axes_x.append(ax) + + if shape_y != shape_out: + if len(shape_y) < len(shape_out): + src_axis = axis + else: + src_axis = 0 + + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_y)) or ( + shape_out[ax] > 1 and shape_y[ax - src_axis] == 1): + reduce_axes_y.append(ax) + + if len(reduce_axes_x) > 0: + for i in reduce_axes_x: + x = np.expand_dims(x, axis=i) + + if len(reduce_axes_y) > 0: + for i in reduce_axes_y: + y = np.expand_dims(y, axis=i) + + dx = y * np.power(x, y - 1) * grad + dy = np.log(x) * np.power(x, y) * grad + + if len(reduce_axes_x) > 0: + for i, element in enumerate(reduce_axes_x): + dx = np.add.reduce(dx, element - i) + + if len(reduce_axes_y) > 0: + for i, element in enumerate(reduce_axes_y): + dy = np.add.reduce(dy, element - i) + + return dx, dy + + class TestElementwisePow(OpTest): def setUp(self): self.set_npu() @@ -33,17 +85,15 @@ class TestElementwisePow(OpTest): self.place = paddle.NPUPlace(0) self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - out = np.power(x, y) + self.init_input_output() + self.init_axis() self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.attrs = {} - self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': self.out} def set_npu(self): self.__class__.use_npu = True @@ -54,44 +104,177 @@ class TestElementwisePow(OpTest): def test_check_output(self): self.check_output_with_place(self.place) - # TODO(ascendrc): Pow grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') - # + def init_axis(self): + self.axis = -1 + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowFp16(TestElementwisePow): + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) -class TestElementwisePowFp16(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "elementwise_pow" - self.place = paddle.NPUPlace(0) + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True - self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - out = np.power(x, y) + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) - } - self.attrs = {} - self.outputs = {'Out': out} + +class TestElementwisePowDouble(TestElementwisePow): + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) def set_npu(self): self.__class__.use_npu = True self.__class__.no_need_check_grad = True def init_dtype(self): - self.dtype = np.float16 + self.dtype = np.float64 def test_check_output(self): self.check_output_with_place(self.place, atol=1e-5) +class TestElementwisePowOp_broadcast_0(TestElementwisePow): + def init_axis(self): + self.axis = 1 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [1, 11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowOp_broadcast_1(TestElementwisePow): + def init_axis(self): + self.axis = 1 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [2, 100, 1]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [100]).astype(self.dtype) + self.out = np.power(self.x, self.y.reshape(1, 100, 1)) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowOp_broadcast_2(TestElementwisePow): + def init_axis(self): + self.axis = 0 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(0.1, 1, [100, 3, 1]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype) + self.out = np.power(self.x, self.y.reshape(100, 1, 1)) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + class TestElementwisePowNet(unittest.TestCase): def _test(self, run_npu=True): main_prog = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 15ba331e9de5a384be1a0527b35b49d3afa6d92d..29374a979650404341e39a341415aee64f657288 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -41,6 +41,7 @@ NEED_TO_FIX_OP_LIST = [ 'elementwise_min', 'elementwise_mul', 'elementwise_sub', + 'elementwise_pow', 'filter_by_instag', 'fused_elemwise_activation', 'fused_emb_seq_pool',