From e118edd3c0f1570d61156ea7a17cbfd9c7545bef Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Mon, 18 Apr 2022 11:16:22 +0800 Subject: [PATCH] [NPU] fix conv2d and top_k_v2 fp16 (#41409) [NPU] fix conv2d and top_k_v2 fp16 --- paddle/fluid/operators/conv_op_npu.cc | 52 ++++++++++++++++++++--- paddle/fluid/operators/top_k_v2_op_npu.cc | 2 + 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/conv_op_npu.cc b/paddle/fluid/operators/conv_op_npu.cc index 86a6ec2c3a..3ace825e7b 100644 --- a/paddle/fluid/operators/conv_op_npu.cc +++ b/paddle/fluid/operators/conv_op_npu.cc @@ -20,6 +20,29 @@ namespace operators { using Tensor = framework::Tensor; using NPUDeviceContext = platform::NPUDeviceContext; +static void CastToFP16(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& in, + Tensor* out) { + out->mutable_data(ctx.GetPlace()); + NpuOpRunner runner; + runner.SetType("Cast") + .AddInput(in) + .AddOutput(*out) + .AddAttr("dst_type", ACL_FLOAT16) + .Run(stream); +} + +static void CastToFP32(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& in, + Tensor* out) { + out->mutable_data(ctx.GetPlace()); + NpuOpRunner runner; + runner.SetType("Cast") + .AddInput(in) + .AddOutput(*out) + .AddAttr("dst_type", ACL_FLOAT) + .Run(stream); +} template class DepthwiseConvNPUKernel : public framework::OpKernel { @@ -356,18 +379,33 @@ class NPUConvGradOpKernel : public framework::OpKernel { auto stream = ctx.template device_context().stream(); if (filter_grad) { - filter_grad->mutable_data(ctx.GetPlace()); + filter_grad->mutable_data(ctx.GetPlace()); std::vector filter_shape_vec = phi::vectorize(filter->dims()); + Tensor filter_grad_fp32(experimental::DataType::FLOAT32); + filter_grad_fp32.Resize(filter_grad->dims()); + + if (framework::TransToProtoVarType(input->dtype()) == + framework::proto::VarType::FP16) { + CastToFP32(ctx, stream, *filter_grad, &filter_grad_fp32); + } else { + filter_grad_fp32.ShareDataWith(*filter_grad); + } + const auto& runner = NpuOpRunner( "Conv2DBackpropFilterD", {input_tensor, output_grad_tensor}, - {*filter_grad}, {{"filter_size", filter_shape_vec}, - {"strides", strides_vec}, - {"pads", paddings}, - {"dilations", dilations_vec}, - {"groups", groups}, - {"data_format", data_format}}); + {filter_grad_fp32}, {{"filter_size", filter_shape_vec}, + {"strides", strides_vec}, + {"pads", paddings}, + {"dilations", dilations_vec}, + {"groups", groups}, + {"data_format", data_format}}); runner.Run(stream); + + if (framework::TransToProtoVarType(input->dtype()) == + framework::proto::VarType::FP16) { + CastToFP16(ctx, stream, filter_grad_fp32, filter_grad); + } } if (input_grad) { input_grad->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc index dff5c2d3f3..04e4d88b00 100644 --- a/paddle/fluid/operators/top_k_v2_op_npu.cc +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -89,7 +89,9 @@ class TopkV2NPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL(top_k_v2, ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel, ops::TopkV2NPUKernel, ops::TopkV2NPUKernel, ops::TopkV2NPUKernel); -- GitLab