From ae60105d3076a063ba4958bb51ea695da0ccef62 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 2 Mar 2023 11:02:09 +0800 Subject: [PATCH] process multiple conv2d_fusion shares weight (#51068) --- .../ir/conv2d_fusion_layout_transfer_pass.cc | 50 +++++++++++-------- .../inference/api/paddle_pass_builder.cc | 7 ++- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc index 4058547e70c..31041754b0b 100644 --- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc +++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc @@ -142,6 +142,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { auto iter = op_nodes.cbegin(); auto *block_desc = (*iter)->Op()->Block(); + // Process multiple conv2d_fusion shares weight. + std::unordered_set weights_shape_nhwc; + + // Used to control the insertion of transfer_layout op. std::unordered_set vars_shape_nhwc; // Only support conv2d_fusion now. @@ -157,6 +161,9 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { constexpr int NHWC_ALIGNMENT = 8; // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. for (const auto &filter_name : filter_names) { + if (weights_shape_nhwc.count(filter_name)) { + continue; + } auto *filter_var = scope->FindLocalVar(filter_name); const auto &filter_tensor = filter_var->Get(); CHECK_EQ(filter_tensor.dims().size() == 4UL, true); @@ -206,27 +213,28 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { // transfer weights auto filter_names = op_desc->Input("Filter"); for (const auto &filter_name : filter_names) { - auto *filter_var = scope->FindLocalVar(filter_name); - auto *filter_tensor = filter_var->GetMutable(); - phi::DenseTensor temp_tensor = *filter_tensor; - filter_tensor->clear(); - - framework::TransDataLayout(phi::DataLayout::kNCHW, - phi::DataLayout::kNHWC, - phi::CPUPlace{}, - temp_tensor, - filter_tensor); - } - auto op_inputs = op_node->inputs; - for (auto *in_var_node : op_inputs) { - CHECK_EQ(in_var_node->IsVar(), true); - if (in_var_node->Var()->Persistable()) { - if (std::find(filter_names.cbegin(), - filter_names.cend(), - in_var_node->Var()->Name()) != filter_names.cend()) { - auto from_shape = in_var_node->Var()->GetShape(); - in_var_node->Var()->SetShape( - {from_shape[0], from_shape[2], from_shape[3], from_shape[1]}); + if (weights_shape_nhwc.count(filter_name) == 0) { + weights_shape_nhwc.insert(filter_name); + auto *filter_var = scope->FindLocalVar(filter_name); + auto *filter_tensor = filter_var->GetMutable(); + phi::DenseTensor temp_tensor; + + framework::TransDataLayout(phi::DataLayout::kNCHW, + phi::DataLayout::kNHWC, + phi::CPUPlace{}, + *filter_tensor, + &temp_tensor); + *filter_tensor = temp_tensor; + + auto op_inputs = op_node->inputs; + for (auto *in_var_node : op_inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (in_var_node->Var()->Persistable() && + in_var_node->Var()->Name() == filter_name) { + auto from_shape = in_var_node->Var()->GetShape(); + in_var_node->Var()->SetShape( + {from_shape[0], from_shape[2], from_shape[3], from_shape[1]}); + } } } } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index fa422470975..23fdaf3ddff 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -270,10 +270,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", // - // TODO(liuyuanle): rewrite this pass with new logic - // "conv2d_fusion_layout_transfer_pass", // - "auto_mixed_precision_pass", // - "inplace_op_var_pass", // should be the last pass. + "conv2d_fusion_layout_transfer_pass", // + "auto_mixed_precision_pass", // + "inplace_op_var_pass", // should be the last pass. }); use_gpu_ = true; -- GitLab