未验证 提交 7bddf2e8 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] mod for model bert (#36165)

* merge conflict of paddle_gtest_main.cc

* modify FLAGS_npu_precision_mode and default not to call aclSetCompileopt
上级 bec9fc9a
......@@ -166,9 +166,11 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel<float>,
REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel<int>,
ops::ElementwiseSubNPUKernel<float>,
ops::ElementwiseSubNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(elementwise_sub_grad,
ops::ElementwiseSubGradNPUKernel<int>,
ops::ElementwiseSubGradNPUKernel<float>,
ops::ElementwiseSubGradNPUKernel<plat::float16>);
......@@ -63,9 +63,12 @@ class FillAnyLikeNPUKernel : public framework::OpKernel<T> {
.stream();
auto shape = out->dims();
const auto& runner = NpuOpRunner("FillD", {tensor_tmp}, {*out},
{{"dims", framework::vectorize(shape)}});
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("Fill")
.AddInput(framework::vectorize(shape))
.AddInput(tensor_tmp)
.AddOutput(*out)
.Run(stream);
}
};
......@@ -75,5 +78,8 @@ class FillAnyLikeNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(fill_any_like, ops::FillAnyLikeNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::FillAnyLikeNPUKernel<int64_t>,
#endif
ops::FillAnyLikeNPUKernel<float>,
ops::FillAnyLikeNPUKernel<paddle::platform::float16>);
......@@ -26,6 +26,8 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
DECLARE_string(npu_precision_mode);
namespace paddle {
namespace operators {
......@@ -404,6 +406,12 @@ void NpuOpRunner::Run(aclrtStream stream) const {
VLOG(4) << "attr: " << attr_;
VLOG(4) << "stream: " << stream;
if (!FLAGS_npu_precision_mode.empty()) {
PADDLE_ENFORCE_NPU_SUCCESS(
aclSetCompileopt(ACL_PRECISION_MODE, FLAGS_npu_precision_mode.c_str()));
VLOG(4) << "set ACL_PRECISION_MODE: " << FLAGS_npu_precision_mode;
}
aclError ret = aclopCompileAndExecute(
op_type_.c_str(), input_descs_.size(), input_descs_.data(),
input_buffers_.data(), output_descs_.size(), output_descs_.data(),
......
......@@ -181,12 +181,37 @@ class SliceGradNPUKernel : public framework::OpKernel<T> {
paddings[i][1] = static_cast<int64_t>(in_dims[i] - size[i] - offsets[i]);
}
Tensor tmp_dout;
tmp_dout.ShareDataWith(*dout);
auto out_dims = dout->dims();
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == static_cast<size_t>(in_dims.size())) {
out_dims = framework::make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = framework::make_ddim(origin_out_shape);
}
tmp_dout.Resize(out_dims);
}
dinput->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
const auto& runner =
NpuOpRunner("PadD", {*dout}, {*dinput}, {{"paddings", paddings}});
NpuOpRunner("PadD", {tmp_dout}, {*dinput}, {{"paddings", paddings}});
runner.Run(stream);
}
};
......
......@@ -121,6 +121,13 @@ PADDLE_DEFINE_EXPORTED_string(
"If proveided, it will be passed to aclInit().");
PADDLE_DEFINE_EXPORTED_int32(min_loss_scaling, 1,
"set minmum loss scaling value!");
PADDLE_DEFINE_EXPORTED_string(
npu_precision_mode, "",
"NPU operator precision mode, options are 'force_fp32', 'force_fp16', "
"'allow_fp32_to_fp16', 'must_keep_origin_dtype' and "
"'allow_mix_precision'. If you want to use the default mode ("
"allow_fp32_to_fp16), set this to empty string. For more details, "
"please refer to the documents");
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
......@@ -90,6 +90,11 @@ class TestElementwiseSubOp(OpTest):
# max_relative_error=0.006,)
class TestElementwiseSubOpInt32(TestElementwiseSubOp):
def init_dtype(self):
self.dtype = np.int32
class TestSubtractAPI(unittest.TestCase):
def test_name(self):
with paddle.static.program_guard(paddle.static.Program()):
......
......@@ -57,6 +57,12 @@ class TestFillAnyLikeNPUOpInt32(TestFillAnyLikeNPUOp):
self.value = -1
class TestFillAnyLikeNPUOpInt64(TestFillAnyLikeNPUOp):
def init(self):
self.dtype = np.int64
self.value = -1
class TestFillAnyLikeNPUOpFloat32(TestFillAnyLikeNPUOp):
def init(self):
self.dtype = np.float32
......
......@@ -301,5 +301,231 @@ class TestSliceNet(unittest.TestCase):
self.assertTrue(np.allclose(npu_loss, cpu_loss))
class TestSliceOpDecsDim(OpTest):
def setUp(self):
self.op_type = "slice"
self.set_npu()
self.init_dtype()
self.config()
self.set_inputs()
self.set_outputs()
self.set_attrs()
def set_inputs(self):
self.inputs = {'Input': self.input}
def set_outputs(self):
self.outputs = {'Out': self.out}
def set_attrs(self):
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0]
self.infer_flags = [1, 1, 1]
self.out = self.input[1, 0:3, 2:4, :]
def init_dtype(self):
self.dtype = np.float32
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(self.place, ['Input'], 'Out')
class TestSliceOpDecsDimFp16(TestSliceOpDecsDim):
def init_dtype(self):
self.dtype = np.float16
class TestSliceOpDecsDim2(TestSliceOpDecsDim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [1, 0, 2]
self.ends = [2, 1, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.infer_flags = [1, 1, 1]
self.out = self.input[1, 0, 2:4, :]
class TestSliceOpDecsDim3(TestSliceOpDecsDim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [-1, 0, 2]
self.ends = [1000000, 1, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.infer_flags = [1, 1, 1]
self.out = self.input[-1, 0, 2:4, :]
class TestSliceOpDecsDim4(TestSliceOpDecsDim):
def config(self):
self.input = np.random.random([3, 4, 5, 7]).astype(self.dtype)
self.starts = [0, 1, 2, 3]
self.ends = [1, 2, 3, 4]
self.axes = [0, 1, 2, 3]
self.decrease_axis = [0, 1, 2, 3]
self.infer_flags = [1, 1, 1]
self.out = self.input[0, 1, 2, 3:4]
class TestSliceOpDecsDim5(TestSliceOpDecsDim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [-1]
self.ends = [1000000]
self.axes = [3]
self.decrease_axis = [3]
self.infer_flags = [1, 1, 1]
self.out = self.input[:, :, :, -1]
class TestSliceOpDecsDim6(TestSliceOpDecsDim):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [0, 1, 2, 3]
self.ends = [1, 2, 3, 4]
self.axes = [0, 1, 2, 3]
self.decrease_axis = [0, 1, 2, 3]
self.infer_flags = [1, 1, 1]
self.out = self.input[0, 1, 2, 3:4]
class TestSliceOpDecsDimStartsTensor(TestSliceOpDecsDim):
def set_inputs(self):
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype='int32')
}
def set_attrs(self):
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0]
self.infer_flags = [-1, -1, -1]
self.out = self.input[1, 0:3, 2:4, :]
class TestSliceOpDecsDimStartsTensorFP16(TestSliceOpDecsDimStartsTensor):
def init_dtype(self):
self.dtype = np.float16
class TestSliceOpDecsDimStartsTensorStartsAndEndsTensor(TestSliceOpDecsDim):
def set_inputs(self):
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype='int64'),
"EndsTensor": np.array(
self.ends, dtype='int32')
}
def set_attrs(self):
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
#'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [1, 0, 2]
self.ends = [2, 1, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0, 1]
self.infer_flags = [-1, -1, -1]
self.out = self.input[1, 0, 2:4, :]
class TestSliceOpDecsDimStartsTensorStartsAndEndsTensorFP16(
TestSliceOpDecsDimStartsTensorStartsAndEndsTensor):
def init_dtype(self):
self.dtype = np.float16
class TestSliceOpDecsDimStartsListTensor(TestSliceOpDecsDim):
def set_inputs(self):
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
def set_attrs(self):
self.attrs = {
'axes': self.axes,
'starts': self.starts_infer,
'ends': self.ends,
'infer_flags': self.infer_flags,
'decrease_axis': self.decrease_axis,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.decrease_axis = [0]
self.infer_flags = [1, -1, 1]
self.out = self.input[1, 0:3, 2:4, :]
self.starts_infer = [1, -1, 2]
class TestSliceOpDecsDimStartsListTensor2(TestSliceOpDecsDimStartsListTensor):
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [-1]
self.ends = [1000000]
self.axes = [3]
self.decrease_axis = [3]
self.infer_flags = [-1]
self.out = self.input[:, :, :, -1]
self.starts_infer = [-1]
class TestSliceOpDecsDimStartsListTensorFP16(
TestSliceOpDecsDimStartsListTensor):
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册