未验证 提交 44044d80 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] Move down the transfer_layout (#52997)

* add tranfer_elim
* transfer layout elimination
上级 f2ed4011
......@@ -107,6 +107,7 @@ pass_library(preln_residual_bias_fuse_pass inference)
pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(transfer_layout_elim_pass inference)
pass_library(silu_fuse_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/transfer_layout_elim_pass.h"
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
// (D) means deleted nodes
// (G) means generated node
// var0 var0' var0 var0'
// | | | |
// transfer_layout0(D) transfer_layout0'(D) | |
// | | | |
// var1(D) var1'(D) -> | |
// \ / \ /
// op_node -> op_node
// | |
// | var2
// | |
// | transfer_layout(G)
// | |
// var2 var2'(var2 + suffix)(G)
// | |
// other ops other ops
// Put transfer_layout after op_node
// transfer_info is for case when we need know this transfer_layout info,
// nchw_nhwc or nhwc_nchw
void TransferLayoutElimPass::PutTranferlayoutAfterOp(
Node *op_node, ir::Graph *graph, std::string *transfer_info) const {
std::unordered_set<const Node *> remove_nodes;
// Ensure op_node has only one output!
int op_node_useful_output = 0;
Node *var2;
for (auto ele : op_node->outputs) {
if (ele->outputs.size() >= 1) {
op_node_useful_output++;
var2 = ele;
}
}
CHECK_EQ(op_node_useful_output == 1, true);
// group_norm has 3 inputs, but we do not need there is a transfer_layout
// before Bias and Scale so we extract useful_var1s from op_node->inputs.
std::vector<Node *> useful_var1s;
for (auto var1 : op_node->inputs) {
// if (var1->inputs.size() >= 1 &&
// var1->inputs[0]->Op()->Type() == "transfer_layout") {
// useful_var1s.push_back(var1);
// }
useful_var1s.push_back(var1);
}
CHECK_EQ(useful_var1s.size() >= 1L, true);
auto transfer_layout_opdesc = *useful_var1s[0]->inputs[0]->Op()->Proto();
auto block = useful_var1s[0]->inputs[0]->Op()->Block();
framework::OpDesc new_transfer_layout_desc(transfer_layout_opdesc, block);
new_transfer_layout_desc.SetInput("X", {var2->Name()});
// Do not use this line code, may result in failing SetShape in netron
// display.
// auto *var2_desc = block->Var(var2->Name());
auto *var2_desc = var2->Var();
auto var2_shape = var2_desc->GetShape();
CHECK_EQ(var2_shape.size() >= 4L, true);
auto new_var2_shape = var2_shape;
std::string suffix = "_nchw_to_nhwc";
auto dst_layout = static_cast<DataLayout>(
new_transfer_layout_desc.GetAttrIfExists<int>("dst_layout"));
auto src_layout = static_cast<DataLayout>(
new_transfer_layout_desc.GetAttrIfExists<int>("src_layout"));
if (dst_layout == DataLayout::NCHW && src_layout == DataLayout::NHWC) {
suffix = "_nhwc_to_nchw";
if (transfer_info) *transfer_info = "nhwc_nchw";
new_var2_shape[1] = var2_shape[2];
new_var2_shape[2] = var2_shape[3];
new_var2_shape[3] = var2_shape[1];
} else if (dst_layout == DataLayout::NHWC && src_layout == DataLayout::NCHW) {
suffix = "_nchw_to_nhwc";
if (transfer_info) *transfer_info = "nchw_nhwc";
new_var2_shape[1] = var2_shape[3];
new_var2_shape[2] = var2_shape[1];
new_var2_shape[3] = var2_shape[2];
}
var2_desc->SetShape(new_var2_shape);
std::string var2_dot_name = var2->Name() + suffix;
new_transfer_layout_desc.SetOutput("Out", {var2_dot_name});
new_transfer_layout_desc.Flush();
auto *var2_dot_desc = block->Var(var2_dot_name);
var2_dot_desc->SetPersistable(false);
// set var2_dot_desc be var2_shape
var2_dot_desc->SetShape(var2_shape);
var2_dot_desc->SetDataType(var2->Var()->GetDataType());
auto var2_dot = graph->CreateVarNode(var2_dot_desc);
auto *new_transfer_layout_node =
graph->CreateOpNode(&new_transfer_layout_desc);
for (auto other_op : var2->outputs) {
IR_NODE_UNLINK(var2, other_op);
other_op->Op()->RenameInput(var2->Name(), var2_dot_name);
IR_NODE_LINK_TO(var2_dot, other_op);
}
IR_NODE_LINK_TO(var2, new_transfer_layout_node);
IR_NODE_LINK_TO(new_transfer_layout_node, var2_dot);
for (auto var1 : useful_var1s) {
auto transfer_layout0_op = var1->inputs[0];
auto var0 = transfer_layout0_op->inputs[0];
IR_NODE_UNLINK(var0, transfer_layout0_op);
// IR_NODE_UNLINK(var1, op_node);
IR_NODE_LINK_TO(var0, op_node);
op_node->Op()->RenameInput(var1->Name(), var0->Name());
remove_nodes.emplace(transfer_layout0_op);
remove_nodes.emplace(var1);
}
GraphSafeRemoveNodes(graph, remove_nodes);
}
bool TransferLayoutElimPass::AllInputIsTransferlayout(
const ir::Node *op_node) const {
std::set<int> dst_layouts;
std::set<int> src_layouts;
auto *scope = param_scope();
for (auto var : op_node->inputs) {
// If this input is a 1D persistable tensor,we allow transfer_layout not
// appear before this var, but temporarily diasble this if.
if (var->Var()->Persistable() && 0) {
auto var_dims =
scope->FindVar(var->Name())->GetMutable<phi::DenseTensor>()->dims();
if (var_dims.size() == 1) {
continue;
}
}
if (var->inputs.size() != 1L) {
return false;
}
if (var->outputs.size() != 1L) {
return false;
}
if (var->inputs[0]->Name() != "transfer_layout") {
return false;
}
auto transfer_layout_desc = var->inputs[0]->Op();
dst_layouts.insert(
transfer_layout_desc->GetAttrIfExists<int>("dst_layout"));
src_layouts.insert(
transfer_layout_desc->GetAttrIfExists<int>("src_layout"));
}
// Make sure the dst_layout and src_layout attribute is same so that these
// transfer_layout can be moved down.
return dst_layouts.size() == 1 && src_layouts.size() == 1;
}
// (D) means deleted nodes
// (G) means generated node
// var0
// |
// transfer_layout0(D)
// |
// var1
// |
// transfer_layout1(D ,op_node)
// |
// var2
// | | |
// op0 op1 op2
void TransferLayoutElimPass::ElimTwoTranferlayout(Node *op_node,
ir::Graph *graph,
bool *modify) const {
std::unordered_set<const Node *> remove_nodes;
auto var1 = op_node->inputs[0];
auto transfer_layout0 = var1->inputs[0];
auto var0 = transfer_layout0->inputs[0];
auto var2 = op_node->outputs[0];
CHECK_EQ(transfer_layout0->Name() == "transfer_layout", true);
CHECK_EQ(op_node->Name() == "transfer_layout", true);
int dst0 = transfer_layout0->Op()->GetAttrIfExists<int>("dst_layout");
int src0 = transfer_layout0->Op()->GetAttrIfExists<int>("src_layout");
int dst1 = op_node->Op()->GetAttrIfExists<int>("dst_layout");
int src1 = op_node->Op()->GetAttrIfExists<int>("src_layout");
if (!(dst0 == src1 && dst1 == src0)) {
// We can not eliminate these two transfer_layout.
*modify = false;
return;
}
*modify = true;
remove_nodes.emplace(transfer_layout0);
remove_nodes.emplace(var1);
remove_nodes.emplace(op_node);
remove_nodes.emplace(var2);
for (auto next_op : var2->outputs) {
IR_NODE_LINK_TO(var0, next_op);
next_op->Op()->RenameInput(var2->Name(), var0->Name());
}
GraphSafeRemoveNodes(graph, remove_nodes);
}
void TransferLayoutElimPass::ApplyImpl(ir::Graph *graph) const {
const std::string pattern_name = "transfer_layout_elim_pass";
FusePassBase::Init(pattern_name, graph);
auto transfer_format = [&](std::string data_format) -> std::string {
if (data_format == "NCHW") {
return "NHWC";
} else if (data_format == "NHWC") {
return "NCHW";
}
return "";
};
while (true) {
auto op_node_sorted = framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(0));
bool modify = false;
for (auto *op_node : op_node_sorted) {
if (!op_node->IsOp()) continue;
// For these Ops, you can move down the transfer_layout without changing
// any attribute!
std::vector<std::string> act_like_ops = {
"elementwise_add",
"hard_swish",
"silu",
};
bool is_act_like_op =
find(act_like_ops.begin(), act_like_ops.end(), op_node->Name()) !=
act_like_ops.end();
// For these Ops, you can move down the transfer_layout, but MUST change
// the data_format attribute!
std::vector<std::string> pool_like_ops = {
// "pool2d",
// "group_norm",
};
bool is_pool_like_op =
find(pool_like_ops.begin(), pool_like_ops.end(), op_node->Name()) !=
pool_like_ops.end();
// For these Ops, you can move down the transfer_layout, but MUST change
// the axis attribute!
std::vector<std::string> concat_like_ops = {
"concat",
};
bool is_concat_like_op = find(concat_like_ops.begin(),
concat_like_ops.end(),
op_node->Name()) != concat_like_ops.end();
bool is_elim_op = op_node->Name() == "transfer_layout";
if (!(is_act_like_op || is_concat_like_op || is_pool_like_op ||
is_elim_op))
continue;
if (AllInputIsTransferlayout(op_node)) {
if (is_concat_like_op) {
std::string transfer_info;
PutTranferlayoutAfterOp(op_node, graph, &transfer_info);
int axis = op_node->Op()->GetAttrIfExists<int>("axis");
int modify_axis = axis;
if (transfer_info == "nhwc_nchw") {
if (axis == 1) {
modify_axis = 3;
} else if (axis == 2) {
modify_axis = 1;
} else if (axis == 3) {
modify_axis = 2;
}
} else if (transfer_info == "nchw_nhwc") {
if (axis == 1) {
modify_axis = 2;
} else if (axis == 2) {
modify_axis = 3;
} else if (axis == 3) {
modify_axis = 1;
}
}
op_node->Op()->SetAttr("axis", modify_axis);
modify = true;
break;
}
if (is_pool_like_op) {
PutTranferlayoutAfterOp(op_node, graph, nullptr);
op_node->Op()->SetAttr(
"data_format",
transfer_format(
op_node->Op()->GetAttrIfExists<std::string>("data_format")));
modify = true;
break;
}
if (is_act_like_op) {
PutTranferlayoutAfterOp(op_node, graph, nullptr);
modify = true;
break;
}
if (is_elim_op) {
ElimTwoTranferlayout(op_node, graph, &modify);
break;
}
}
}
if (!modify) break;
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(transfer_layout_elim_pass,
paddle::framework::ir::TransferLayoutElimPass);
// Add below for test_transfer_elim_pass passing.
REGISTER_PASS_CAPABILITY(transfer_layout_elim_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination());
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class TransferLayoutElimPass : public FusePassBase {
public:
virtual ~TransferLayoutElimPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
bool AllInputIsTransferlayout(const Node *op_node) const;
void PutTranferlayoutAfterOp(Node *op_node,
ir::Graph *graph,
std::string *transfer_info) const;
void ElimTwoTranferlayout(Node *op_node,
ir::Graph *graph,
bool *modify) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -264,6 +264,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif //
"transpose_flatten_concat_fuse_pass", //
"conv2d_fusion_layout_transfer_pass", //
"transfer_layout_elim_pass",
"auto_mixed_precision_pass", //
"inplace_op_var_pass", // should be the last pass.
});
......
......@@ -216,6 +216,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_inplace_op_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_transfer_layout_elim_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_simplify_with_basic_ops_pass_autoscan
PROPERTIES TIMEOUT 60)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import CutlassAutoScanTest, PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
os.environ['NVIDIA_TF32_OVERRIDE'] = '0'
class TestTransferElimPass0(PassAutoScanTest):
r"""input0 input1
| |
transfer_layout transfer_layout
| |
transfer_layout_out0 transfer_layout_out1
\ /
elementwise_add
|
elementwise_add_out
"""
def sample_predictor_configs(self, program_config):
# for gpu
config = self.create_inference_config(use_gpu=True)
yield config, ["elementwise_add", "transfer_layout"], (1e-4, 1e-5)
def is_program_valid(self, prog_config):
return True
def sample_program_config(self, draw):
transfer_layout0 = OpConfig(
"transfer_layout",
inputs={"X": ["input0"]},
outputs={"Out": ["transfer_layout_out0"]},
dst_layout=1,
src_layout=2,
)
transfer_layout1 = OpConfig(
"transfer_layout",
inputs={"X": ["input1"]},
outputs={"Out": ["transfer_layout_out1"]},
dst_layout=1,
src_layout=2,
)
add_op = OpConfig(
"elementwise_add",
inputs={
"X": ["transfer_layout_out0"],
"Y": ["transfer_layout_out1"],
},
outputs={"Out": ["elementwise_add_out"]},
axis=-1,
)
ops = [transfer_layout0, transfer_layout1, add_op]
x_shape = draw(
st.lists(
st.integers(min_value=10, max_value=100), min_size=4, max_size=4
)
)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input0": TensorConfig(shape=x_shape),
"input1": TensorConfig(shape=x_shape),
},
outputs=["elementwise_add_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=30,
passes=["transfer_layout_elim_pass"],
)
class TestTransferElimPass1(PassAutoScanTest):
r"""input0 input1
| |
transfer_layout transfer_layout
| |
transfer_layout_out0 transfer_layout_out1
\ /
elementwise_add
|
elementwise_add_out
|
transfer_layout
|
transfer_layout2
"""
def sample_predictor_configs(self, program_config):
# for gpu
config = self.create_inference_config(use_gpu=True)
yield config, ["elementwise_add"], (1e-4, 1e-5)
def is_program_valid(self, prog_config):
return True
def sample_program_config(self, draw):
transfer_layout0 = OpConfig(
"transfer_layout",
inputs={"X": ["input0"]},
outputs={"Out": ["transfer_layout_out0"]},
dst_layout=1,
src_layout=2,
)
transfer_layout1 = OpConfig(
"transfer_layout",
inputs={"X": ["input1"]},
outputs={"Out": ["transfer_layout_out1"]},
dst_layout=1,
src_layout=2,
)
add_op = OpConfig(
"elementwise_add",
inputs={
"X": ["transfer_layout_out0"],
"Y": ["transfer_layout_out1"],
},
outputs={"Out": ["elementwise_add_out"]},
axis=-1,
)
transfer_layout2 = OpConfig(
"transfer_layout",
inputs={"X": ["elementwise_add_out"]},
outputs={"Out": ["transfer_layout_out2"]},
dst_layout=2,
src_layout=1,
)
ops = [transfer_layout0, transfer_layout1, add_op, transfer_layout2]
x_shape = draw(
st.lists(
st.integers(min_value=10, max_value=100), min_size=4, max_size=4
)
)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input0": TensorConfig(shape=x_shape),
"input1": TensorConfig(shape=x_shape),
},
outputs=["transfer_layout_out2"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=30,
passes=["transfer_layout_elim_pass"],
)
class TestTransferElimPass2(PassAutoScanTest):
r"""input0 input1
| |
transfer_layout transfer_layout
| |
transfer_layout_out0 transfer_layout_out1
\ /
concat
|
concat_out
"""
def sample_predictor_configs(self, program_config):
# for gpu
config = self.create_inference_config(use_gpu=True)
yield config, ["concat", "transfer_layout"], (1e-4, 1e-5)
def is_program_valid(self, prog_config):
return True
def sample_program_config(self, draw):
# nhwc -> nchw
transfer_layout0 = OpConfig(
"transfer_layout",
inputs={"X": ["input0"]},
outputs={"Out": ["transfer_layout_out0"]},
dst_layout=1,
src_layout=2,
)
transfer_layout1 = OpConfig(
"transfer_layout",
inputs={"X": ["input1"]},
outputs={"Out": ["transfer_layout_out1"]},
dst_layout=1,
src_layout=2,
)
concat_op = OpConfig(
"concat",
inputs={"X": ["transfer_layout_out0", "transfer_layout_out1"]},
outputs={"Out": ["concat_out"]},
axis=1,
)
ops = [transfer_layout0, transfer_layout1, concat_op]
x_shape = draw(
st.lists(
st.integers(min_value=10, max_value=100), min_size=4, max_size=4
)
)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input0": TensorConfig(shape=x_shape),
"input1": TensorConfig(shape=x_shape),
},
outputs=["concat_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=30,
passes=["transfer_layout_elim_pass"],
)
class TestTransferElimPass3(CutlassAutoScanTest):
def sample_program_configs(self, *args, **kwargs):
def generate_input(input_shape):
return (np.random.random(input_shape) - 0.5).astype(np.float32)
# src_layout should be NCHW, because it is the model's input
for dst_layout, src_layout in [[1, 2]]:
for axis in [0, 1, 2, 3]:
ops_config = [
{
"op_type": "transfer_layout",
"op_inputs": {"X": ["input0"]},
"op_outputs": {"Out": ["transfer_layout_out0"]},
"op_attrs": {
"dst_layout": dst_layout,
"src_layout": src_layout,
},
},
{
"op_type": "transfer_layout",
"op_inputs": {"X": ["input1"]},
"op_outputs": {"Out": ["transfer_layout_out1"]},
"op_attrs": {
"dst_layout": dst_layout,
"src_layout": src_layout,
},
# nchw -> nhwc
},
{
"op_type": "concat",
"op_inputs": {
"X": [
"transfer_layout_out0",
"transfer_layout_out1",
]
},
"op_outputs": {"Out": ["concat_out0"]},
"op_attrs": {"axis": axis},
},
]
ops = self.generate_op_config(ops_config)
input_shape = [12, 13, 14, 15]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input0": TensorConfig(
data_gen=partial(generate_input, input_shape)
),
"input1": TensorConfig(
data_gen=partial(generate_input, input_shape)
),
},
outputs=["concat_out0"],
)
yield program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
config.enable_use_gpu(256, 0)
yield config, (1e-2, 1e-2)
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册