未验证 提交 37455714 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle inference] Add conv_fusion_fp16 (#44435)

* convfusionfp16

* convfusionfp16

* convfusionfp16
上级 0243c6ca
......@@ -284,6 +284,27 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
return;
}
// conv_weight fp32 --> fp16
auto* conv_weight_tensor =
scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto tensor_type = conv_weight_tensor->dtype();
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
framework::Tensor weight_float_tensor;
weight_float_tensor.set_type(paddle::experimental::DataType::FLOAT32);
weight_float_tensor.Resize(conv_weight_tensor->dims());
auto* weight_float_data =
weight_float_tensor.mutable_data<float>(platform::CPUPlace());
auto* data =
conv_weight_tensor->mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < conv_weight_tensor->numel(); i++) {
weight_float_data[i] = static_cast<float>(data[i]);
}
conv_weight_tensor->clear();
paddle::framework::TensorCopySync(
weight_float_tensor, platform::CPUPlace(), conv_weight_tensor);
}
// Get batch norm bias
auto* bn_bias_tensor =
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
......@@ -319,6 +340,43 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
epsilon,
conv_type());
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
{
framework::Tensor weight_float16_tensor;
weight_float16_tensor.set_type(paddle::experimental::DataType::FLOAT16);
weight_float16_tensor.Resize(conv_weight_tensor->dims());
auto* weight_float16_data =
weight_float16_tensor.mutable_data<float16>(platform::CPUPlace());
auto* data =
conv_weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < conv_weight_tensor->numel(); i++) {
weight_float16_data[i] = static_cast<float16>(data[i]);
}
conv_weight_tensor->clear();
paddle::framework::TensorCopySync(
weight_float16_tensor, platform::CPUPlace(), conv_weight_tensor);
}
{
framework::Tensor eltwise_y_in_float16_tensor;
eltwise_y_in_float16_tensor.set_type(
paddle::experimental::DataType::FLOAT16);
eltwise_y_in_float16_tensor.Resize(eltwise_y_in_tensor->dims());
auto* eltwise_y_in_float16_data =
eltwise_y_in_float16_tensor.mutable_data<float16>(
platform::CPUPlace());
auto* data =
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) {
eltwise_y_in_float16_data[i] = static_cast<float16>(data[i]);
}
eltwise_y_in_tensor->clear();
paddle::framework::TensorCopySync(eltwise_y_in_float16_tensor,
platform::CPUPlace(),
eltwise_y_in_tensor);
}
}
// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
if (fuse_option == FUSE_MKLDNN) {
......
......@@ -154,9 +154,12 @@ const std::vector<std::string> kLiteSubgraphPasses({
// support fp16/bf16 precision, temporarily use low precision pass to prevent
// running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{
// "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass",
};
"conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass"};
const std::vector<std::string> kTrtLowerPrecisionPasses{
// "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册