diff --git a/paddle/fluid/operators/conv_op_npu.cc b/paddle/fluid/operators/conv_op_npu.cc index 86a6ec2c3a1603c64c14d03ffdbd9821f5719657..3ace825e7b80df6032183505c048a6c0e796aaca 100644 --- a/paddle/fluid/operators/conv_op_npu.cc +++ b/paddle/fluid/operators/conv_op_npu.cc @@ -20,6 +20,29 @@ namespace operators { using Tensor = framework::Tensor; using NPUDeviceContext = platform::NPUDeviceContext; +static void CastToFP16(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& in, + Tensor* out) { + out->mutable_data(ctx.GetPlace()); + NpuOpRunner runner; + runner.SetType("Cast") + .AddInput(in) + .AddOutput(*out) + .AddAttr("dst_type", ACL_FLOAT16) + .Run(stream); +} + +static void CastToFP32(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& in, + Tensor* out) { + out->mutable_data(ctx.GetPlace()); + NpuOpRunner runner; + runner.SetType("Cast") + .AddInput(in) + .AddOutput(*out) + .AddAttr("dst_type", ACL_FLOAT) + .Run(stream); +} template class DepthwiseConvNPUKernel : public framework::OpKernel { @@ -356,18 +379,33 @@ class NPUConvGradOpKernel : public framework::OpKernel { auto stream = ctx.template device_context().stream(); if (filter_grad) { - filter_grad->mutable_data(ctx.GetPlace()); + filter_grad->mutable_data(ctx.GetPlace()); std::vector filter_shape_vec = phi::vectorize(filter->dims()); + Tensor filter_grad_fp32(experimental::DataType::FLOAT32); + filter_grad_fp32.Resize(filter_grad->dims()); + + if (framework::TransToProtoVarType(input->dtype()) == + framework::proto::VarType::FP16) { + CastToFP32(ctx, stream, *filter_grad, &filter_grad_fp32); + } else { + filter_grad_fp32.ShareDataWith(*filter_grad); + } + const auto& runner = NpuOpRunner( "Conv2DBackpropFilterD", {input_tensor, output_grad_tensor}, - {*filter_grad}, {{"filter_size", filter_shape_vec}, - {"strides", strides_vec}, - {"pads", paddings}, - {"dilations", dilations_vec}, - {"groups", groups}, - {"data_format", data_format}}); + {filter_grad_fp32}, {{"filter_size", filter_shape_vec}, + {"strides", strides_vec}, + {"pads", paddings}, + {"dilations", dilations_vec}, + {"groups", groups}, + {"data_format", data_format}}); runner.Run(stream); + + if (framework::TransToProtoVarType(input->dtype()) == + framework::proto::VarType::FP16) { + CastToFP16(ctx, stream, filter_grad_fp32, filter_grad); + } } if (input_grad) { input_grad->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc index dff5c2d3f39378486bb5d2f8010d005d57b20550..04e4d88b008e0c293cf206c63badb12121cca30a 100644 --- a/paddle/fluid/operators/top_k_v2_op_npu.cc +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -89,7 +89,9 @@ class TopkV2NPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL(top_k_v2, ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel, ops::TopkV2NPUKernel, ops::TopkV2NPUKernel, ops::TopkV2NPUKernel);