From 7bddf2e88fe1ee64cf695b4198cc398504cf90b5 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Wed, 29 Sep 2021 14:42:51 +0800 Subject: [PATCH] [NPU] mod for model bert (#36165) * merge conflict of paddle_gtest_main.cc * modify FLAGS_npu_precision_mode and default not to call aclSetCompileopt --- .../elementwise/elementwise_sub_op_npu.cc | 4 +- .../fluid/operators/fill_any_like_op_npu.cc | 12 +- paddle/fluid/operators/npu_op_runner.cc | 8 + paddle/fluid/operators/slice_op_npu.cc | 27 ++- paddle/fluid/platform/flags.cc | 7 + .../npu/test_elementwise_sub_op_npu.py | 5 + .../npu/test_fill_any_like_op_npu.py | 6 + .../tests/unittests/npu/test_slice_op_npu.py | 226 ++++++++++++++++++ 8 files changed, 290 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc index 94e78defbb..48b98dafc7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc @@ -166,9 +166,11 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel, +REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel, + ops::ElementwiseSubNPUKernel, ops::ElementwiseSubNPUKernel); REGISTER_OP_NPU_KERNEL(elementwise_sub_grad, + ops::ElementwiseSubGradNPUKernel, ops::ElementwiseSubGradNPUKernel, ops::ElementwiseSubGradNPUKernel); diff --git a/paddle/fluid/operators/fill_any_like_op_npu.cc b/paddle/fluid/operators/fill_any_like_op_npu.cc index d5204f5cac..566b265bfd 100644 --- a/paddle/fluid/operators/fill_any_like_op_npu.cc +++ b/paddle/fluid/operators/fill_any_like_op_npu.cc @@ -63,9 +63,12 @@ class FillAnyLikeNPUKernel : public framework::OpKernel { .stream(); auto shape = out->dims(); - const auto& runner = NpuOpRunner("FillD", {tensor_tmp}, {*out}, - {{"dims", framework::vectorize(shape)}}); - runner.Run(stream); + NpuOpRunner runner; + runner.SetType("Fill") + .AddInput(framework::vectorize(shape)) + .AddInput(tensor_tmp) + .AddOutput(*out) + .Run(stream); } }; @@ -75,5 +78,8 @@ class FillAnyLikeNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL(fill_any_like, ops::FillAnyLikeNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::FillAnyLikeNPUKernel, +#endif ops::FillAnyLikeNPUKernel, ops::FillAnyLikeNPUKernel); diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc index bb6549c111..d10e94962d 100644 --- a/paddle/fluid/operators/npu_op_runner.cc +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -26,6 +26,8 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" +DECLARE_string(npu_precision_mode); + namespace paddle { namespace operators { @@ -404,6 +406,12 @@ void NpuOpRunner::Run(aclrtStream stream) const { VLOG(4) << "attr: " << attr_; VLOG(4) << "stream: " << stream; + if (!FLAGS_npu_precision_mode.empty()) { + PADDLE_ENFORCE_NPU_SUCCESS( + aclSetCompileopt(ACL_PRECISION_MODE, FLAGS_npu_precision_mode.c_str())); + VLOG(4) << "set ACL_PRECISION_MODE: " << FLAGS_npu_precision_mode; + } + aclError ret = aclopCompileAndExecute( op_type_.c_str(), input_descs_.size(), input_descs_.data(), input_buffers_.data(), output_descs_.size(), output_descs_.data(), diff --git a/paddle/fluid/operators/slice_op_npu.cc b/paddle/fluid/operators/slice_op_npu.cc index 1084eadc55..f8bf46da4a 100644 --- a/paddle/fluid/operators/slice_op_npu.cc +++ b/paddle/fluid/operators/slice_op_npu.cc @@ -181,12 +181,37 @@ class SliceGradNPUKernel : public framework::OpKernel { paddings[i][1] = static_cast(in_dims[i] - size[i] - offsets[i]); } + Tensor tmp_dout; + tmp_dout.ShareDataWith(*dout); + auto out_dims = dout->dims(); + auto decrease_axis = ctx.Attr>("decrease_axis"); + auto decrease_size = decrease_axis.size(); + if (decrease_size > 0) { + if (decrease_size == static_cast(in_dims.size())) { + out_dims = framework::make_ddim(std::vector(decrease_size, 1)); + } else { + std::vector origin_out_shape(out_dims.size() + decrease_size, -1); + for (size_t i = 0; i < decrease_size; ++i) { + origin_out_shape[decrease_axis[i]] = 1; + } + int index = 0; + for (size_t i = 0; i < origin_out_shape.size(); ++i) { + if (origin_out_shape[i] == -1) { + origin_out_shape[i] = out_dims[index]; + ++index; + } + } + out_dims = framework::make_ddim(origin_out_shape); + } + tmp_dout.Resize(out_dims); + } + dinput->mutable_data(ctx.GetPlace()); auto stream = ctx.template device_context() .stream(); const auto& runner = - NpuOpRunner("PadD", {*dout}, {*dinput}, {{"paddings", paddings}}); + NpuOpRunner("PadD", {tmp_dout}, {*dinput}, {{"paddings", paddings}}); runner.Run(stream); } }; diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index b97c310643..89a829f949 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -121,6 +121,13 @@ PADDLE_DEFINE_EXPORTED_string( "If proveided, it will be passed to aclInit()."); PADDLE_DEFINE_EXPORTED_int32(min_loss_scaling, 1, "set minmum loss scaling value!"); +PADDLE_DEFINE_EXPORTED_string( + npu_precision_mode, "", + "NPU operator precision mode, options are 'force_fp32', 'force_fp16', " + "'allow_fp32_to_fp16', 'must_keep_origin_dtype' and " + "'allow_mix_precision'. If you want to use the default mode (" + "allow_fp32_to_fp16), set this to empty string. For more details, " + "please refer to the documents"); #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py index 6faa77b460..7c8710fd42 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py @@ -90,6 +90,11 @@ class TestElementwiseSubOp(OpTest): # max_relative_error=0.006,) +class TestElementwiseSubOpInt32(TestElementwiseSubOp): + def init_dtype(self): + self.dtype = np.int32 + + class TestSubtractAPI(unittest.TestCase): def test_name(self): with paddle.static.program_guard(paddle.static.Program()): diff --git a/python/paddle/fluid/tests/unittests/npu/test_fill_any_like_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_fill_any_like_op_npu.py index a687509e6a..c3074db1aa 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_fill_any_like_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_fill_any_like_op_npu.py @@ -57,6 +57,12 @@ class TestFillAnyLikeNPUOpInt32(TestFillAnyLikeNPUOp): self.value = -1 +class TestFillAnyLikeNPUOpInt64(TestFillAnyLikeNPUOp): + def init(self): + self.dtype = np.int64 + self.value = -1 + + class TestFillAnyLikeNPUOpFloat32(TestFillAnyLikeNPUOp): def init(self): self.dtype = np.float32 diff --git a/python/paddle/fluid/tests/unittests/npu/test_slice_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_slice_op_npu.py index 5a38f14868..055c3015f8 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_slice_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_slice_op_npu.py @@ -301,5 +301,231 @@ class TestSliceNet(unittest.TestCase): self.assertTrue(np.allclose(npu_loss, cpu_loss)) +class TestSliceOpDecsDim(OpTest): + def setUp(self): + self.op_type = "slice" + self.set_npu() + self.init_dtype() + self.config() + self.set_inputs() + self.set_outputs() + self.set_attrs() + + def set_inputs(self): + self.inputs = {'Input': self.input} + + def set_outputs(self): + self.outputs = {'Out': self.out} + + def set_attrs(self): + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'infer_flags': self.infer_flags, + 'decrease_axis': self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0] + self.infer_flags = [1, 1, 1] + self.out = self.input[1, 0:3, 2:4, :] + + def init_dtype(self): + self.dtype = np.float32 + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + self.check_grad_with_place(self.place, ['Input'], 'Out') + + +class TestSliceOpDecsDimFp16(TestSliceOpDecsDim): + def init_dtype(self): + self.dtype = np.float16 + + +class TestSliceOpDecsDim2(TestSliceOpDecsDim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [1, 0, 2] + self.ends = [2, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [1, 1, 1] + self.out = self.input[1, 0, 2:4, :] + + +class TestSliceOpDecsDim3(TestSliceOpDecsDim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [-1, 0, 2] + self.ends = [1000000, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [1, 1, 1] + self.out = self.input[-1, 0, 2:4, :] + + +class TestSliceOpDecsDim4(TestSliceOpDecsDim): + def config(self): + self.input = np.random.random([3, 4, 5, 7]).astype(self.dtype) + self.starts = [0, 1, 2, 3] + self.ends = [1, 2, 3, 4] + self.axes = [0, 1, 2, 3] + self.decrease_axis = [0, 1, 2, 3] + self.infer_flags = [1, 1, 1] + self.out = self.input[0, 1, 2, 3:4] + + +class TestSliceOpDecsDim5(TestSliceOpDecsDim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [-1] + self.ends = [1000000] + self.axes = [3] + self.decrease_axis = [3] + self.infer_flags = [1, 1, 1] + self.out = self.input[:, :, :, -1] + + +class TestSliceOpDecsDim6(TestSliceOpDecsDim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [0, 1, 2, 3] + self.ends = [1, 2, 3, 4] + self.axes = [0, 1, 2, 3] + self.decrease_axis = [0, 1, 2, 3] + self.infer_flags = [1, 1, 1] + self.out = self.input[0, 1, 2, 3:4] + + +class TestSliceOpDecsDimStartsTensor(TestSliceOpDecsDim): + def set_inputs(self): + self.inputs = { + 'Input': self.input, + "StartsTensor": np.array( + self.starts, dtype='int32') + } + + def set_attrs(self): + self.attrs = { + 'axes': self.axes, + #'starts': self.starts, + 'ends': self.ends, + 'infer_flags': self.infer_flags, + 'decrease_axis': self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0] + self.infer_flags = [-1, -1, -1] + self.out = self.input[1, 0:3, 2:4, :] + + +class TestSliceOpDecsDimStartsTensorFP16(TestSliceOpDecsDimStartsTensor): + def init_dtype(self): + self.dtype = np.float16 + + +class TestSliceOpDecsDimStartsTensorStartsAndEndsTensor(TestSliceOpDecsDim): + def set_inputs(self): + self.inputs = { + 'Input': self.input, + "StartsTensor": np.array( + self.starts, dtype='int64'), + "EndsTensor": np.array( + self.ends, dtype='int32') + } + + def set_attrs(self): + self.attrs = { + 'axes': self.axes, + #'starts': self.starts, + #'ends': self.ends, + 'infer_flags': self.infer_flags, + 'decrease_axis': self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [1, 0, 2] + self.ends = [2, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [-1, -1, -1] + self.out = self.input[1, 0, 2:4, :] + + +class TestSliceOpDecsDimStartsTensorStartsAndEndsTensorFP16( + TestSliceOpDecsDimStartsTensorStartsAndEndsTensor): + def init_dtype(self): + self.dtype = np.float16 + + +class TestSliceOpDecsDimStartsListTensor(TestSliceOpDecsDim): + def set_inputs(self): + starts_tensor = [] + for index, ele in enumerate(self.starts): + starts_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} + + def set_attrs(self): + self.attrs = { + 'axes': self.axes, + 'starts': self.starts_infer, + 'ends': self.ends, + 'infer_flags': self.infer_flags, + 'decrease_axis': self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0] + self.infer_flags = [1, -1, 1] + self.out = self.input[1, 0:3, 2:4, :] + + self.starts_infer = [1, -1, 2] + + +class TestSliceOpDecsDimStartsListTensor2(TestSliceOpDecsDimStartsListTensor): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype) + self.starts = [-1] + self.ends = [1000000] + self.axes = [3] + self.decrease_axis = [3] + self.infer_flags = [-1] + self.out = self.input[:, :, :, -1] + + self.starts_infer = [-1] + + +class TestSliceOpDecsDimStartsListTensorFP16( + TestSliceOpDecsDimStartsListTensor): + def init_dtype(self): + self.dtype = np.float16 + + if __name__ == '__main__': unittest.main() -- GitLab