未验证 提交 ae60105d 编写于 作者: Y Yuanle Liu 提交者: GitHub

process multiple conv2d_fusion shares weight (#51068)

上级 4652bee4
...@@ -142,6 +142,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -142,6 +142,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
auto iter = op_nodes.cbegin(); auto iter = op_nodes.cbegin();
auto *block_desc = (*iter)->Op()->Block(); auto *block_desc = (*iter)->Op()->Block();
// Process multiple conv2d_fusion shares weight.
std::unordered_set<std::string> weights_shape_nhwc;
// Used to control the insertion of transfer_layout op.
std::unordered_set<ir::Node *> vars_shape_nhwc; std::unordered_set<ir::Node *> vars_shape_nhwc;
// Only support conv2d_fusion now. // Only support conv2d_fusion now.
...@@ -157,6 +161,9 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -157,6 +161,9 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
constexpr int NHWC_ALIGNMENT = 8; constexpr int NHWC_ALIGNMENT = 8;
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) { for (const auto &filter_name : filter_names) {
if (weights_shape_nhwc.count(filter_name)) {
continue;
}
auto *filter_var = scope->FindLocalVar(filter_name); auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>(); const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
CHECK_EQ(filter_tensor.dims().size() == 4UL, true); CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
...@@ -206,27 +213,28 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -206,27 +213,28 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
// transfer weights // transfer weights
auto filter_names = op_desc->Input("Filter"); auto filter_names = op_desc->Input("Filter");
for (const auto &filter_name : filter_names) { for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name); if (weights_shape_nhwc.count(filter_name) == 0) {
auto *filter_tensor = filter_var->GetMutable<phi::DenseTensor>(); weights_shape_nhwc.insert(filter_name);
phi::DenseTensor temp_tensor = *filter_tensor; auto *filter_var = scope->FindLocalVar(filter_name);
filter_tensor->clear(); auto *filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
phi::DenseTensor temp_tensor;
framework::TransDataLayout(phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC, framework::TransDataLayout(phi::DataLayout::kNCHW,
phi::CPUPlace{}, phi::DataLayout::kNHWC,
temp_tensor, phi::CPUPlace{},
filter_tensor); *filter_tensor,
} &temp_tensor);
auto op_inputs = op_node->inputs; *filter_tensor = temp_tensor;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true); auto op_inputs = op_node->inputs;
if (in_var_node->Var()->Persistable()) { for (auto *in_var_node : op_inputs) {
if (std::find(filter_names.cbegin(), CHECK_EQ(in_var_node->IsVar(), true);
filter_names.cend(), if (in_var_node->Var()->Persistable() &&
in_var_node->Var()->Name()) != filter_names.cend()) { in_var_node->Var()->Name() == filter_name) {
auto from_shape = in_var_node->Var()->GetShape(); auto from_shape = in_var_node->Var()->GetShape();
in_var_node->Var()->SetShape( in_var_node->Var()->SetShape(
{from_shape[0], from_shape[2], from_shape[3], from_shape[1]}); {from_shape[0], from_shape[2], from_shape[3], from_shape[1]});
}
} }
} }
} }
......
...@@ -270,10 +270,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -270,10 +270,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
// TODO(liuyuanle): rewrite this pass with new logic "conv2d_fusion_layout_transfer_pass", //
// "conv2d_fusion_layout_transfer_pass", // "auto_mixed_precision_pass", //
"auto_mixed_precision_pass", // "inplace_op_var_pass", // should be the last pass.
"inplace_op_var_pass", // should be the last pass.
}); });
use_gpu_ = true; use_gpu_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册