未验证 提交 017af746 编写于 作者: W Wilber 提交者: GitHub

[Inference] inplace all reshape op (#49146)

上级 8705a79d
......@@ -88,6 +88,7 @@ pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(inplace_op_var_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/inplace_op_var_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("inplace_op_var", graph);
int found_subgraph_count = 0;
MapToReshape(graph);
auto nodes = graph->Nodes();
auto is_valid_reshape = [](Node* node) {
// Some cases need to consider, please refer to
// https://github.com/PaddlePaddle/Paddle/pull/49146
if (node->IsOp() && node->Op()->Type() == "reshape2") {
auto x_name = node->Op()->Input("X").front();
for (auto* var_node : node->inputs) {
if (var_node->Name() == x_name) {
if (!var_node->Var()->Persistable() && var_node->outputs.size() == 1)
return true;
}
}
}
return false;
};
// Record all reshape2 op's input name and output name in block 0.
// If the name used in other block, we can not inplace reshape op.
std::unordered_set<std::string> var_names, deny_var_names;
for (auto* node : nodes) {
if (is_valid_reshape(node)) {
for (auto n : node->inputs) var_names.insert(n->Name());
for (auto n : node->outputs) var_names.insert(n->Name());
}
}
for (size_t i = 1; i < graph->SubGraphsSize(); ++i) {
auto sub_graph = graph->GetSubGraph(i);
for (auto* node : sub_graph->Nodes()) {
if (node->IsOp()) {
for (auto var_node : node->inputs) {
if (var_names.count(var_node->Name()))
deny_var_names.insert(var_node->Name());
}
for (auto var_node : node->outputs) {
if (var_names.count(var_node->Name()))
deny_var_names.insert(var_node->Name());
}
}
}
}
// inplace all reshape op.
auto topo_nodes = TopologySortOperations(*graph);
for (auto* node : topo_nodes) {
if (!is_valid_reshape(node)) continue;
auto* op_node = node->Op();
auto input_name = op_node->Input("X")[0];
auto output_name = op_node->Output("Out")[0];
if (deny_var_names.count(input_name) || deny_var_names.count(output_name)) {
continue;
}
++found_subgraph_count;
for (auto* out_var : node->outputs) {
if (out_var->Name() == output_name) {
out_var->RenameVar(input_name);
for (auto* next_op : out_var->outputs) {
next_op->Op()->RenameInput(output_name, input_name);
next_op->Op()->Flush();
}
}
}
op_node->RenameOutput(output_name, input_name);
op_node->Flush();
}
AddStatis(found_subgraph_count);
}
void InplaceOpVarPass::MapToReshape(ir::Graph* graph) const {
// flatten_contiguous_range op map to reshape.
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "flatten_contiguous_range") {
auto* op_node = node->Op();
auto start_axis = PADDLE_GET_CONST(int, op_node->GetAttr("start_axis"));
auto stop_axis = PADDLE_GET_CONST(int, op_node->GetAttr("stop_axis"));
auto input_name = op_node->Input("X")[0];
auto* block = op_node->Block();
auto input_shape = block->FindVar(input_name)->GetShape();
if (start_axis == 1 && stop_axis == 3 && input_shape.size() == 4 &&
input_shape[2] == 1 && input_shape[3] == 1) {
op_node->SetType("reshape2");
op_node->SetAttr("shape", std::vector<int>{0, -1});
op_node->Flush();
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_op_var_pass, paddle::framework::ir::InplaceOpVarPass);
REGISTER_PASS_CAPABILITY(inplace_op_var_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"reshape2", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class InplaceOpVarPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
virtual ~InplaceOpVarPass() = default;
void MapToReshape(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -188,7 +188,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass",
};
"inplace_op_var_pass"};
const std::vector<std::string> kTrtLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
......@@ -255,8 +255,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", //
"auto_mixed_precision_pass", //
"conv2d_fusion_layout_transfer_pass", //
"auto_mixed_precision_pass"
"auto_mixed_precision_pass", //
"inplace_op_var_pass", // should be the last pass.
});
use_gpu_ = true;
......
......@@ -59,6 +59,17 @@ class ReshapeOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"Output(Out) of ReshapeOp should not be null."));
if (ctx->IsRuntime()) {
auto *x_var =
PADDLE_GET(framework::Variable *, ctx->GetInputVarPtrs("X")[0]);
auto *out_var =
PADDLE_GET(framework::Variable *, ctx->GetOutputVarPtrs("Out")[0]);
// inplace, can not to run infer shape.
if (x_var == out_var) {
return;
}
}
if (ctx->HasInputs("ShapeTensor")) {
// top prority shape
auto ShapeTensor = ctx->Inputs("ShapeTensor");
......
......@@ -178,6 +178,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_inplace_op_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_simplify_with_basic_ops_pass_autoscan
PROPERTIES TIMEOUT 60)
......
......@@ -282,7 +282,6 @@ def create_fake_model(program_config):
var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR)
var_desc.set_dtype(convert_np_dtype_to_dtype_(tensor_config.dtype))
var_desc.set_shape(tensor_config.shape)
print(f"name: {name}; shape: {tensor_config.shape}")
var_desc.set_need_check_feed(True)
if tensor_config.lod is not None:
var_desc.set_lod_level(len(tensor_config.lod))
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.fluid.core as core
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestInplaceOpPass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_config(self, draw):
def generate_input():
return np.random.random(x_shape).astype(np.float32)
def generate_tmp1(val):
return np.array([val]).astype(np.int32)
def generate_tmp2(val):
return np.array([val]).astype(np.int32)
def generate_tmp3(val):
return np.array([val]).astype(np.int32)
def generate_shape(val):
return np.array(val).astype(np.int32)
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=10), min_size=4, max_size=4
)
)
shape = [0, -1, x_shape[-1]]
scale_op = OpConfig(
"scale",
inputs={"X": ["scale_in"]},
outputs={"Out": ["scale_out"]},
scale=1.3,
bias=0.1,
bias_after_scale=False,
)
test_case = draw(
st.sampled_from(
["simple_reshape", "shape_tensor1", "shape_tensor2"]
)
)
if test_case == "simple_reshape":
reshape_op = OpConfig(
"reshape2",
inputs={"X": ["scale_out"]},
outputs={
"Out": ["reshape_out"],
"XShape": ["reshape_xshape_out"],
},
shape=shape,
)
ops = [scale_op, reshape_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"scale_in": TensorConfig(data_gen=partial(generate_input)),
},
weights={},
outputs=["reshape_out"],
)
return program_config
elif test_case == "shape_tensor1":
shape = [-1, -1, x_shape[-1]]
reshape_op = OpConfig(
"reshape2",
inputs={
"X": ["scale_out"],
"ShapeTensor": ["tmp1", "tmp2", "tmp3"],
},
outputs={
"Out": ["reshape_out"],
"XShape": ["reshape_xshape_out"],
},
shape=shape,
)
ops = [scale_op, reshape_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"scale_in": TensorConfig(data_gen=partial(generate_input)),
"tmp1": TensorConfig(
data_gen=partial(generate_tmp1, x_shape[0])
),
"tmp2": TensorConfig(
data_gen=partial(generate_tmp2, x_shape[1] * x_shape[2])
),
"tmp3": TensorConfig(
data_gen=partial(generate_tmp3, x_shape[-1])
),
},
weights={},
outputs=["reshape_out"],
)
return program_config
else:
shape = [0, -1, x_shape[-1]]
reshape_op = OpConfig(
"reshape2",
inputs={"X": ["scale_out"], "Shape": ["shape"]},
outputs={
"Out": ["reshape_out"],
"XShape": ["reshape_xshape_out"],
},
shape=shape,
)
ops = [scale_op, reshape_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"scale_in": TensorConfig(data_gen=partial(generate_input)),
"shape": TensorConfig(
data_gen=partial(
generate_shape,
[x_shape[0], x_shape[1] * x_shape[2], x_shape[3]],
)
),
},
weights={},
outputs=["reshape_out"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
yield config, ['scale', 'reshape2'], (1e-5, 1e-5)
def add_ignore_pass_case(self):
pass
def test(self):
self.run_and_statis(
quant=False,
passes=["inplace_op_var_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册