From 5efc4146d3d3db1a7789364d1d04d444cabf5368 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Fri, 20 May 2022 11:35:07 +0800 Subject: [PATCH] add arg_max tensorrt converter, fix identity_scale_op_clean_pass (#42850) --- .../ir/identity_scale_op_clean_pass.cc | 69 ++++++++------- .../ir_passes/tensorrt_subgraph_pass.cc | 5 ++ .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/api/paddle_pass_builder.cc | 3 +- .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/arg_max_op.cc | 73 ++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 12 +++ .../test_identity_scale_clean_pass.py | 64 ++++++++++++++ .../ir/inference/test_trt_convert_arg_max.py | 85 +++++++++++++++++++ .../ir/inference/test_trt_convert_scale.py | 2 +- 10 files changed, 284 insertions(+), 31 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/arg_max_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_identity_scale_clean_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py diff --git a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc index 290fbe3ea13..6b91ea4e360 100644 --- a/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_scale_op_clean_pass.cc @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/ir/identity_scale_op_clean_pass.h" - #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -29,55 +29,62 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { // -> // pre_op -> scale_out GraphPatternDetector detector; - auto pre_op = detector.mutable_pattern()->NewNode("pre_op")->assert_is_op(); - auto scale_in = detector.mutable_pattern() - ->NewNode("scale_in") - ->assert_is_op_input("scale") - ->AsIntermediate(); + auto scale_in = + detector.mutable_pattern() + ->NewNode("scale_in") + ->assert_is_op_input("scale") + ->assert_more([](Node* x) { return x->outputs.size() == 1UL; }); auto scale_op = detector.mutable_pattern() ->NewNode("scale_fuse") ->assert_is_op("scale") ->assert_op_attr("scale", 1.) ->assert_op_attr("bias", 0.); - auto scale_out = - detector.mutable_pattern() - ->NewNode("scale_out") - ->assert_is_op_output("scale") - // scale's output var should has only one consumer, or it can't be - // removed. - ->assert_more([](Node* x) { return x->outputs.size() == 1UL; }); + auto scale_out = detector.mutable_pattern() + ->NewNode("scale_out") + ->assert_is_op_output("scale"); - pre_op->LinksTo({scale_in}); scale_op->LinksFrom({scale_in}).LinksTo({scale_out}); + int found_subgraph_count = 0; GraphPatternDetector::handle_t handler = [&]( const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { Node* scale_op_var = subgraph.at(scale_op); Node* scale_in_var = subgraph.at(scale_in); Node* scale_out_var = subgraph.at(scale_out); - Node* pre_op_var = subgraph.at(pre_op); - // Link pre_op directly to scale_out const std::string scale_in_name = scale_in_var->Name(); const std::string scale_out_name = scale_out_var->Name(); // Remove links in graph GraphSafeRemoveNodes(graph, {scale_in_var, scale_op_var}); - // Modify proto message - auto* pre_op_desc = pre_op_var->Op(); - for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) { - auto* arguments = parameter.mutable_arguments(); - auto it = std::find(arguments->begin(), arguments->end(), scale_in_name); - PADDLE_ENFORCE_NE( - it, arguments->end(), - platform::errors::NotFound( - "Can not find input variable(%s) from scale op(%s).", - scale_in_name, pre_op_desc->Type())); - *it = scale_out_name; + // Modify pre_op_desc + // Link pre_op directly to scale_out + for (auto& node : graph->Nodes()) { + if (node->IsOp()) { + auto* op_desc = node->Op(); + auto out_vars_map = op_desc->Outputs(); + for (auto out_var_map : out_vars_map) { + auto names = out_var_map.second; + bool reset = false; + for (size_t i = 0; i < names.size(); i++) { + if (names[i] == scale_in_name) { + reset = true; + names[i] = scale_out_name; + break; + } + } + if (reset) { + op_desc->SetOutput(out_var_map.first, names); + op_desc->Flush(); + IR_NODE_LINK_TO(node, scale_out_var); + break; + } + } + } } - - IR_NODE_LINK_TO(pre_op_var, scale_out_var); + found_subgraph_count++; }; detector(graph, handler); + AddStatis(found_subgraph_count); } } // namespace ir @@ -86,3 +93,7 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(identity_scale_op_clean_pass, paddle::framework::ir::IdentityScaleOpCleanPass); +REGISTER_PASS_CAPABILITY(identity_scale_op_clean_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "scale", 0)); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index bc7dc9704ac..b73eb624db8 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -139,6 +139,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_idx(0); LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes"; + for (auto node : subgraph) { + if (node->NodeType() == Node::Type::kOperation) { + VLOG(5) << "trt subgraph has op: " << (node->Op()->Type()); + } + } for (auto *node : subgraph) { auto *new_block_op = new_block->AppendOp(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 13f81059df5..09a5bbddba8 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1782,6 +1782,7 @@ USE_TRT_CONVERTER(gather); USE_TRT_CONVERTER(anchor_generator); USE_TRT_CONVERTER(yolo_box); USE_TRT_CONVERTER(yolo_box_head); +USE_TRT_CONVERTER(arg_max); USE_TRT_CONVERTER(roi_align); USE_TRT_CONVERTER(affine_channel); USE_TRT_CONVERTER(multiclass_nms); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index fdb979283f7..f9ec41f6c83 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -82,7 +82,8 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ - "adaptive_pool2d_convert_global_pass", + "identity_scale_op_clean_pass", // + "adaptive_pool2d_convert_global_pass", // "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // "delete_fill_constant_op_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 22610ece34e..1910e2f6eb9 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -37,6 +37,7 @@ nv_library(tensorrt_converter anchor_generator_op.cc yolo_box_op.cc yolo_box_head_op.cc + arg_max_op.cc roi_align_op.cc affine_channel_op.cc multiclass_nms_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/arg_max_op.cc b/paddle/fluid/inference/tensorrt/convert/arg_max_op.cc new file mode 100644 index 00000000000..14975e48164 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/arg_max_op.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class ArgMaxOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid arg_max op to tensorrt topk layer"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + int rank = input_dims.nbDims; + int axis = op_desc.HasAttr("axis") + ? BOOST_GET_CONST(int64_t, op_desc.GetAttr("axis")) + : -1; + if (axis > 0) axis -= 1; + if (axis < 0) axis += rank; + auto* topk_layer = TRT_ENGINE_ADD_LAYER( + engine_, TopK, *input, nvinfer1::TopKOperation::kMAX, 1, 1 << axis); + + auto output_name = op_desc.Output("Out")[0]; + bool keepdims = BOOST_GET_CONST(bool, op_desc.GetAttr("keepdims")); + if (keepdims) { + RreplenishLayerAndOutput(topk_layer, "arg_max", + {output_name + "_value", output_name}, + test_mode); + } else { + auto squeeze_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *topk_layer->getOutput(1)); + auto dims = input_dims; + dims.nbDims -= 1; + for (int i = axis; i < dims.nbDims; i++) { + dims.d[i] = dims.d[i + 1]; + } + squeeze_layer->setReshapeDimensions(dims); + RreplenishLayerAndOutput(squeeze_layer, "arg_max", {output_name}, + test_mode); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(arg_max, ArgMaxOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index cbe151294db..690bc173c77 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -102,6 +102,7 @@ struct SimpleOpTypeSetTeller : public Teller { "gather_nd", "yolo_box", "yolo_box_head", + "arg_max", "roi_align", "affine_channel", "nearest_interp", @@ -169,6 +170,7 @@ struct SimpleOpTypeSetTeller : public Teller { "gather_nd", "yolo_box", "yolo_box_head", + "arg_max", "roi_align", "affine_channel", "nearest_interp", @@ -644,6 +646,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (!has_attrs) return false; } + if (op_type == "arg_max") { + if (with_dynamic_shape) return false; + int axis = desc.HasAttr("axis") + ? BOOST_GET_CONST(int64_t, desc.GetAttr("axis")) + : -1; + bool flatten = BOOST_GET_CONST(bool, desc.GetAttr("flatten")); + int dtype = BOOST_GET_CONST(int, desc.GetAttr("dtype")); + if (axis == 0 || flatten || dtype != 2) return false; + } + if (op_type == "affine_channel") { if (!desc.HasAttr("data_layout")) return false; auto data_layout = framework::StringToDataLayout( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_identity_scale_clean_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_identity_scale_clean_pass.py new file mode 100644 index 00000000000..8cacb6d29af --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_identity_scale_clean_pass.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import paddle.inference as paddle_infer +import unittest +import hypothesis.strategies as st + + +class TestIdentityScaleCleanPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=8, + workspace_size=0, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['relu'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + bias_after_scale = draw(st.booleans()) + n = draw(st.integers(min_value=1, max_value=4)) + c = draw(st.integers(min_value=1, max_value=20)) + h = draw(st.integers(min_value=1, max_value=20)) + w = draw(st.integers(min_value=1, max_value=20)) + + relu_op = OpConfig( + "relu", inputs={"X": ["relu_x"]}, outputs={"Out": ["relu_out"]}) + scale_op = OpConfig( + "scale", + inputs={"X": ["relu_out"]}, + outputs={"Out": ["scale_out"]}, + bias=0., + scale=1., + bias_after_scale=True) + program_config = ProgramConfig( + ops=[relu_op, scale_op], + weights={}, + inputs={"relu_x": TensorConfig(shape=[n, c, h, w])}, + outputs=["scale_out"]) + return program_config + + def test(self): + self.run_and_statis( + max_examples=25, passes=["identity_scale_op_clean_pass"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py new file mode 100644 index 00000000000..719e4488569 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_max.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import List + + +class TrtConvertArgMaxTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + input_shape = program_config.inputs["arg_max_input"].shape + axis = program_config.ops[0].attrs["axis"] + if axis < 0: + axis += len(input_shape) + if len(input_shape) <= axis or axis == 0: + return False + return True + + def sample_program_configs(self): + def generate_input(rank, batch): + dims = [batch] + for i in range(rank - 1): + dims.append((i + 1) * 8) + size = np.prod(dims) + return (np.arange(size) % 10 - 5).astype("float32").reshape(dims) + + for rank in [3, 4]: + for batch in [1, 4]: + for axis in [-1, 0, 1, 2, 3]: + for keepdims in [True, False]: + flatten = False + dtype = 2 + ops_config = [{ + "op_type": "arg_max", + "op_inputs": { + "X": ["arg_max_input"] + }, + "op_outputs": { + "Out": ["arg_max_out"] + }, + "op_attrs": { + "axis": axis, + "keepdims": keepdims, + "flatten": flatten, + "dtype": dtype + } + }] + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "arg_max_input": TensorConfig(data_gen=partial( + generate_input, rank, batch)) + }, + outputs=["arg_max_out"]) + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 1024000 + yield self.create_inference_config(), [1, 2], 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_scale.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_scale.py index d607a43739e..75783450e86 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_scale.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_scale.py @@ -42,7 +42,7 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): for num_input in [0, 1]: for dims in [1, 2, 3, 4]: for batch in [1, 2]: - for scale in [0.1, 1.0]: + for scale in [0.1, -1.0]: for bias in [0.0, 1.2]: for bias_after_scale in [False, True]: self.num_input = num_input -- GitLab