未验证 提交 ae60105d 编写于 作者: Y Yuanle Liu 提交者: GitHub

process multiple conv2d_fusion shares weight (#51068)

上级 4652bee4
......@@ -142,6 +142,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
auto iter = op_nodes.cbegin();
auto *block_desc = (*iter)->Op()->Block();
// Process multiple conv2d_fusion shares weight.
std::unordered_set<std::string> weights_shape_nhwc;
// Used to control the insertion of transfer_layout op.
std::unordered_set<ir::Node *> vars_shape_nhwc;
// Only support conv2d_fusion now.
......@@ -157,6 +161,9 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
constexpr int NHWC_ALIGNMENT = 8;
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) {
if (weights_shape_nhwc.count(filter_name)) {
continue;
}
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
......@@ -206,27 +213,28 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
// transfer weights
auto filter_names = op_desc->Input("Filter");
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
auto *filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
phi::DenseTensor temp_tensor = *filter_tensor;
filter_tensor->clear();
framework::TransDataLayout(phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC,
phi::CPUPlace{},
temp_tensor,
filter_tensor);
}
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) {
if (std::find(filter_names.cbegin(),
filter_names.cend(),
in_var_node->Var()->Name()) != filter_names.cend()) {
auto from_shape = in_var_node->Var()->GetShape();
in_var_node->Var()->SetShape(
{from_shape[0], from_shape[2], from_shape[3], from_shape[1]});
if (weights_shape_nhwc.count(filter_name) == 0) {
weights_shape_nhwc.insert(filter_name);
auto *filter_var = scope->FindLocalVar(filter_name);
auto *filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
phi::DenseTensor temp_tensor;
framework::TransDataLayout(phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC,
phi::CPUPlace{},
*filter_tensor,
&temp_tensor);
*filter_tensor = temp_tensor;
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable() &&
in_var_node->Var()->Name() == filter_name) {
auto from_shape = in_var_node->Var()->GetShape();
in_var_node->Var()->SetShape(
{from_shape[0], from_shape[2], from_shape[3], from_shape[1]});
}
}
}
}
......
......@@ -270,10 +270,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
// TODO(liuyuanle): rewrite this pass with new logic
// "conv2d_fusion_layout_transfer_pass", //
"auto_mixed_precision_pass", //
"inplace_op_var_pass", // should be the last pass.
"conv2d_fusion_layout_transfer_pass", //
"auto_mixed_precision_pass", //
"inplace_op_var_pass", // should be the last pass.
});
use_gpu_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册