From 5310ceabf55a30f520fb1ad00848f47545e047eb Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Fri, 27 Aug 2021 11:27:16 +0800 Subject: [PATCH] add elementwise max grad op for npu (#34862) * add elementwise max grad op for npu * add elementwise max grad op for npu * add elementwise max grad op for npu * add elementwise max grad op for npu * add elementwise max grad op for npu --- .../elementwise/elementwise_max_op_npu.cc | 215 ++++++++++++++- .../npu/test_elementwise_max_op_npu.py | 258 +++++++++++++++--- 2 files changed, 420 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_max_op_npu.cc index a616d0bc9d..2929fb040c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op_npu.cc @@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include - #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_npu.h" #include "paddle/fluid/operators/npu_op_runner.h" namespace paddle { @@ -27,21 +25,202 @@ template class ElementwiseMaxNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - - auto place = ctx.GetPlace(); - - out->mutable_data(place); + out->mutable_data(ctx.GetPlace()); + int axis = ctx.Attr("axis"); + + bool direct_compute = false; + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + if (x_dims.size() >= y_dims.size()) { + direct_compute = + y_dims == framework::slice_ddim(x_dims, axis, x_dims.size()); + } else { + direct_compute = + x_dims == framework::slice_ddim(y_dims, axis, y_dims.size()); + } auto stream = ctx.template device_context() .stream(); - const auto& runner = NpuOpRunner("Maximum", {*x, *y}, {*out}, {}); - runner.Run(stream); + if (direct_compute) { + const auto& runner = NpuOpRunner("Maximum", {*x, *y}, {*out}, {}); + runner.Run(stream); + } else { + Tensor transformed_x, transformed_y; + NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, + &transformed_y); + const auto& runner = + NpuOpRunner("Maximum", {transformed_x, transformed_y}, {*out}, {}); + runner.Run(stream); + } + } +}; + +template +class ElementwiseMaxGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + + // The ascend elementwise_max_grad op only supports broadcast + // when axis is -1, and requires all the inputs must have the + // same shape when axis is not -1. For convenience, we should + // broadcast the original input x and y to transformed_x and + // transformed_x firstly, then use tmp tensor to get the op + // output, last reduce the tmp tensor shape to match the + // paddle output. + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + Tensor transformed_x, transformed_y; + NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, + &transformed_y); + + auto dout_dims = dout->dims(); + auto stream = dev_ctx.stream(); + framework::NPUAttributeMap attr_input = {{"grad_x", true}, + {"grad_y", true}}; + // Reshape info vector. + std::vector reduce_axes; + + if (dx && dy) { + dx->mutable_data(ctx.GetPlace()); + dy->mutable_data(ctx.GetPlace()); + Tensor tmp_dx; + tmp_dx.mutable_data(dout_dims, ctx.GetPlace()); + Tensor tmp_dy; + tmp_dy.mutable_data(dout_dims, ctx.GetPlace()); + + const auto& runner = + NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y}, + {tmp_dx, tmp_dy}, attr_input); + runner.Run(stream); + + if (x_dims != dout_dims) { + reduce_axes.clear(); + int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + x_dims.size()) || + (dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dx, ctx.GetPlace(), dev_ctx, dx); + } + + if (y_dims != dout_dims) { + reduce_axes.clear(); + int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + y_dims.size()) || + (dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dy, ctx.GetPlace(), dev_ctx, dy); + } + + } else if (dx) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(dout_dims, ctx.GetPlace()); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + + dx->mutable_data(ctx.GetPlace()); + Tensor tmp_dx; + tmp_dx.mutable_data(dout_dims, ctx.GetPlace()); + + const auto& runner = + NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y}, + {tmp_dx, zero_tensor}, attr_input); + runner.Run(stream); + + if (x_dims != dout_dims) { + reduce_axes.clear(); + + int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + x_dims.size()) || + (dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dx, ctx.GetPlace(), dev_ctx, dx); + } + + } else if (dy) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(dout_dims, ctx.GetPlace()); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + + dy->mutable_data(ctx.GetPlace()); + Tensor tmp_dy; + tmp_dy.mutable_data(dout_dims, ctx.GetPlace()); + + const auto& runner = + NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y}, + {zero_tensor, tmp_dy}, attr_input); + runner.Run(stream); + + if (y_dims != dout_dims) { + reduce_axes.clear(); + + int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + y_dims.size()) || + (dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dy, ctx.GetPlace(), dev_ctx, dy); + } + } else { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support all outputs to be empty.")); + } } }; @@ -49,9 +228,19 @@ class ElementwiseMaxNPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( elementwise_max, - ops::ElementwiseMaxNPUKernel, - ops::ElementwiseMaxNPUKernel); + ops::ElementwiseMaxNPUKernel, + ops::ElementwiseMaxNPUKernel, + ops::ElementwiseMaxNPUKernel, + ops::ElementwiseMaxNPUKernel, + ops::ElementwiseMaxNPUKernel); + +REGISTER_OP_NPU_KERNEL( + elementwise_max_grad, + ops::ElementwiseMaxGradNPUKernel, + ops::ElementwiseMaxGradNPUKernel, + ops::ElementwiseMaxGradNPUKernel, + ops::ElementwiseMaxGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_max_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_max_op_npu.py index 6c325b0202..461e15352e 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_elementwise_max_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_max_op_npu.py @@ -26,70 +26,248 @@ paddle.enable_static() SEED = 2021 -class TestElementwiseMax(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "elementwise_max" - self.place = paddle.NPUPlace(0) +def ComputeGrad(x, y, out, axis): + grad = 1 / out.size + shape_x = x.shape + shape_y = y.shape + shape_out = out.shape + reduce_axes_x = [] + reduce_axes_y = [] + + if shape_x != shape_out: + if len(shape_x.shape) < len(shape_out.shape): + src_axis = axis + else: + src_axis = 0 - self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - out = np.maximum(x, y) + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_x)) or ( + shape_out[ax] > 1 and shape_x[ax - src_axis] == 1): + reduce_axes_x.append(ax) - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) - } - self.attrs = {} - self.outputs = {'Out': out} + if shape_y != shape_out: + if len(shape_y) < len(shape_out): + src_axis = axis + else: + src_axis = 0 - def set_npu(self): - self.__class__.use_npu = True + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_y)) or ( + shape_out[ax] > 1 and shape_y[ax - src_axis] == 1): + reduce_axes_y.append(ax) - def init_dtype(self): - self.dtype = np.float32 + if len(reduce_axes_x) > 0: + for i in reduce_axes_x: + x = np.expand_dims(x, axis=i) - def test_check_output(self): - self.check_output_with_place(self.place) + if len(reduce_axes_y) > 0: + for i in reduce_axes_y: + y = np.expand_dims(y, axis=i) + + mask = np.sign(np.subtract(x, y)) + dx = np.maximum(mask, 0) * grad + dy = np.abs(np.minimum(mask, 0) * grad) + + if len(reduce_axes_x) > 0: + for i, element in enumerate(reduce_axes_x): + dx = np.add.reduce(dx, element - i) + + if len(reduce_axes_y) > 0: + for i, element in enumerate(reduce_axes_y): + dy = np.add.reduce(dy, element - i) - # TODO(ascendrc): Max grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') - # + return dx, dy -class TestElementwiseMaxFp16(OpTest): +class TestElementwiseMaxOp(OpTest): def setUp(self): self.set_npu() self.op_type = "elementwise_max" self.place = paddle.NPUPlace(0) self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - out = np.maximum(x, y) + self.init_input_output() + self.init_axis() self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.attrs = {} - self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': self.out} def set_npu(self): self.__class__.use_npu = True - self.__class__.no_need_check_grad = True def init_dtype(self): - self.dtype = np.float16 + self.dtype = np.float32 + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + sgn = np.random.choice([-1, 1], [13, 17]).astype(self.dtype) + self.y = self.x + sgn * np.random.uniform(0.1, 1, + [13, 17]).astype(self.dtype) + self.out = np.maximum(self.x, self.y) + + def init_axis(self): + self.axis = -1 def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Y'], 'Out', no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['X'], 'Out', no_grad_set=set("Y")) + + +class TestElementwiseMaxOp_int32(TestElementwiseMaxOp): + def init_dtype(self): + self.dtype = np.int32 + + # CTest does not support check grad for int32. + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMaxOp_scalar(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype(self.dtype) + self.y = np.array([0.5]).astype(self.dtype) + self.out = np.maximum(self.x, self.y) + + +class TestElementwiseMaxOp_vector(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.random((100, )).astype(self.dtype) + sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x + sgn * np.random.uniform(0.1, 1, + (100, )).astype(self.dtype) + self.out = np.maximum(self.x, self.y) + + +class TestElementwiseMaxOp_broadcast_0(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (100, 5, 2)).astype(self.dtype) + sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x[:, 0, 0] + sgn * \ + np.random.uniform(1, 2, (100, )).astype(self.dtype) + self.out = np.maximum(self.x, self.y.reshape(100, 1, 1)) + + def init_axis(self): + self.axis = 0 + + +class TestElementwiseMaxOp_broadcast_1(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (2, 100, 3)).astype(self.dtype) + sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x[0, :, 0] + sgn * \ + np.random.uniform(1, 2, (100, )).astype(self.dtype) + self.out = np.maximum(self.x, self.y.reshape(1, 100, 1)) + + def init_axis(self): + self.axis = 1 + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwiseMaxOp_broadcast_2(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(self.dtype) + sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x[0, 0, :] + sgn * \ + np.random.uniform(1, 2, (100, )).astype(self.dtype) + self.out = np.maximum(self.x, self.y.reshape(1, 1, 100)) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + if self.dtype == np.float16: + return + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + if self.dtype == np.float16: + return + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwiseMaxOp_broadcast_3(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (2, 50, 2, 1)).astype(self.dtype) + sgn = np.random.choice([-1, 1], (50, 2)).astype(self.dtype) + self.y = self.x[0, :, :, 0] + sgn * \ + np.random.uniform(1, 2, (50, 2)).astype(self.dtype) + self.out = np.maximum(self.x, self.y.reshape(1, 50, 2, 1)) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseMaxOp_broadcast_4(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype) + sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(self.dtype) + self.y = self.x + sgn * \ + np.random.uniform(1, 2, (2, 3, 1, 5)).astype(self.dtype) + self.out = np.maximum(self.x, self.y) + + +class TestElementwiseMaxOp_broadcast_5(TestElementwiseMaxOp): + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype) + sgn = np.random.choice([-1, 1], (2, 3, 1, 1)).astype(self.dtype) + self.y = self.x + sgn * \ + np.random.uniform(1, 2, (2, 3, 1, 1)).astype(self.dtype) + self.out = np.maximum(self.x, self.y) class TestElementwiseMaxNet(unittest.TestCase): -- GitLab