未验证 提交 e118edd3 编写于 作者: F furnace 提交者: GitHub

[NPU] fix conv2d and top_k_v2 fp16 (#41409)

[NPU] fix conv2d and top_k_v2 fp16
上级 7ee9ba2f
...@@ -20,6 +20,29 @@ namespace operators { ...@@ -20,6 +20,29 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext; using NPUDeviceContext = platform::NPUDeviceContext;
static void CastToFP16(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& in,
Tensor* out) {
out->mutable_data<paddle::platform::float16>(ctx.GetPlace());
NpuOpRunner runner;
runner.SetType("Cast")
.AddInput(in)
.AddOutput(*out)
.AddAttr("dst_type", ACL_FLOAT16)
.Run(stream);
}
static void CastToFP32(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& in,
Tensor* out) {
out->mutable_data<float>(ctx.GetPlace());
NpuOpRunner runner;
runner.SetType("Cast")
.AddInput(in)
.AddOutput(*out)
.AddAttr("dst_type", ACL_FLOAT)
.Run(stream);
}
template <typename T> template <typename T>
class DepthwiseConvNPUKernel : public framework::OpKernel<T> { class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
...@@ -356,18 +379,33 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> { ...@@ -356,18 +379,33 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
auto stream = ctx.template device_context<NPUDeviceContext>().stream(); auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (filter_grad) { if (filter_grad) {
filter_grad->mutable_data<float>(ctx.GetPlace()); filter_grad->mutable_data<T>(ctx.GetPlace());
std::vector<int> filter_shape_vec = phi::vectorize<int>(filter->dims()); std::vector<int> filter_shape_vec = phi::vectorize<int>(filter->dims());
Tensor filter_grad_fp32(experimental::DataType::FLOAT32);
filter_grad_fp32.Resize(filter_grad->dims());
if (framework::TransToProtoVarType(input->dtype()) ==
framework::proto::VarType::FP16) {
CastToFP32(ctx, stream, *filter_grad, &filter_grad_fp32);
} else {
filter_grad_fp32.ShareDataWith(*filter_grad);
}
const auto& runner = NpuOpRunner( const auto& runner = NpuOpRunner(
"Conv2DBackpropFilterD", {input_tensor, output_grad_tensor}, "Conv2DBackpropFilterD", {input_tensor, output_grad_tensor},
{*filter_grad}, {{"filter_size", filter_shape_vec}, {filter_grad_fp32}, {{"filter_size", filter_shape_vec},
{"strides", strides_vec}, {"strides", strides_vec},
{"pads", paddings}, {"pads", paddings},
{"dilations", dilations_vec}, {"dilations", dilations_vec},
{"groups", groups}, {"groups", groups},
{"data_format", data_format}}); {"data_format", data_format}});
runner.Run(stream); runner.Run(stream);
if (framework::TransToProtoVarType(input->dtype()) ==
framework::proto::VarType::FP16) {
CastToFP16(ctx, stream, filter_grad_fp32, filter_grad);
}
} }
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace()); input_grad->mutable_data<T>(ctx.GetPlace());
......
...@@ -89,7 +89,9 @@ class TopkV2NPUKernel : public framework::OpKernel<T> { ...@@ -89,7 +89,9 @@ class TopkV2NPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(top_k_v2, ops::TopkV2NPUKernel<float>, REGISTER_OP_NPU_KERNEL(top_k_v2, ops::TopkV2NPUKernel<float>,
ops::TopkV2NPUKernel<plat::float16>,
ops::TopkV2NPUKernel<double>, ops::TopkV2NPUKernel<double>,
ops::TopkV2NPUKernel<int32_t>, ops::TopkV2NPUKernel<int32_t>,
ops::TopkV2NPUKernel<int64_t>); ops::TopkV2NPUKernel<int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册