未验证 提交 e118edd3 编写于 作者: F furnace 提交者: GitHub

[NPU] fix conv2d and top_k_v2 fp16 (#41409)

[NPU] fix conv2d and top_k_v2 fp16
上级 7ee9ba2f
......@@ -20,6 +20,29 @@ namespace operators {
using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext;
static void CastToFP16(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& in,
Tensor* out) {
out->mutable_data<paddle::platform::float16>(ctx.GetPlace());
NpuOpRunner runner;
runner.SetType("Cast")
.AddInput(in)
.AddOutput(*out)
.AddAttr("dst_type", ACL_FLOAT16)
.Run(stream);
}
static void CastToFP32(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& in,
Tensor* out) {
out->mutable_data<float>(ctx.GetPlace());
NpuOpRunner runner;
runner.SetType("Cast")
.AddInput(in)
.AddOutput(*out)
.AddAttr("dst_type", ACL_FLOAT)
.Run(stream);
}
template <typename T>
class DepthwiseConvNPUKernel : public framework::OpKernel<T> {
......@@ -356,18 +379,33 @@ class NPUConvGradOpKernel : public framework::OpKernel<T> {
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (filter_grad) {
filter_grad->mutable_data<float>(ctx.GetPlace());
filter_grad->mutable_data<T>(ctx.GetPlace());
std::vector<int> filter_shape_vec = phi::vectorize<int>(filter->dims());
Tensor filter_grad_fp32(experimental::DataType::FLOAT32);
filter_grad_fp32.Resize(filter_grad->dims());
if (framework::TransToProtoVarType(input->dtype()) ==
framework::proto::VarType::FP16) {
CastToFP32(ctx, stream, *filter_grad, &filter_grad_fp32);
} else {
filter_grad_fp32.ShareDataWith(*filter_grad);
}
const auto& runner = NpuOpRunner(
"Conv2DBackpropFilterD", {input_tensor, output_grad_tensor},
{*filter_grad}, {{"filter_size", filter_shape_vec},
{"strides", strides_vec},
{"pads", paddings},
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
{filter_grad_fp32}, {{"filter_size", filter_shape_vec},
{"strides", strides_vec},
{"pads", paddings},
{"dilations", dilations_vec},
{"groups", groups},
{"data_format", data_format}});
runner.Run(stream);
if (framework::TransToProtoVarType(input->dtype()) ==
framework::proto::VarType::FP16) {
CastToFP16(ctx, stream, filter_grad_fp32, filter_grad);
}
}
if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace());
......
......@@ -89,7 +89,9 @@ class TopkV2NPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(top_k_v2, ops::TopkV2NPUKernel<float>,
ops::TopkV2NPUKernel<plat::float16>,
ops::TopkV2NPUKernel<double>,
ops::TopkV2NPUKernel<int32_t>,
ops::TopkV2NPUKernel<int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册