From 1dbbe20ec38d208035bfb037496672957d732c2b Mon Sep 17 00:00:00 2001 From: ccrrong <101700995+ccrrong@users.noreply.github.com> Date: Wed, 29 Jun 2022 14:00:06 +0800 Subject: [PATCH] add equal trt converter (#43461) * add comparisons trt converter --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/equal_op.cc | 94 ++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 21 +++ .../ir/inference/test_trt_convert_equal.py | 166 ++++++++++++++++++ 5 files changed, 283 insertions(+) create mode 100644 paddle/fluid/inference/tensorrt/convert/equal_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_equal.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7bdd1c957b7..8a2083ea226 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2078,6 +2078,7 @@ USE_TRT_CONVERTER(transformer_input_convert) USE_TRT_CONVERTER(cast) USE_TRT_CONVERTER(recover_padding) USE_TRT_CONVERTER(remove_padding) +USE_TRT_CONVERTER(equal); USE_TRT_CONVERTER(top_k) USE_TRT_CONVERTER(top_k_v2) USE_TRT_CONVERTER(squeeze2) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 4e728dc74f7..c999c009605 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -62,6 +62,7 @@ list( transformer_input_convert_op.cc cast_op.cc remove_padding_op.cc + equal_op.cc recover_padding_op.cc preln_residual_bias.cc c_allreduce_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/equal_op.cc b/paddle/fluid/inference/tensorrt/convert/equal_op.cc new file mode 100644 index 00000000000..2e29c0f7007 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/equal_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class EqualOpConverter : public OpConverter { + public: + EqualOpConverter() {} + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { +#if IS_TRT_VERSION_GE(8000) + framework::OpDesc op_desc(op, nullptr); + nvinfer1::ILayer* layer = nullptr; + + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + auto* Y = engine_->GetITensor(op_desc.Input("Y").front()); + nvinfer1::Dims dims_x = X->getDimensions(); + nvinfer1::Dims dims_y = Y->getDimensions(); + + int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); + if (axis < 0) { + axis = std::abs(dims_x.nbDims - dims_y.nbDims); + } + auto output_name = op_desc.Output("Out")[0]; + nvinfer1::IShuffleLayer* expand_layer = nullptr; + if (dims_x.nbDims > dims_y.nbDims) { + nvinfer1::Dims expand_shape; + expand_shape.nbDims = dims_x.nbDims; + for (int i = 0; i < expand_shape.nbDims; i++) { + expand_shape.d[i] = 1; + } + for (int i = 0; i < dims_y.nbDims; i++) { + expand_shape.d[i + axis] = dims_y.d[i]; + } + expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *Y); + expand_layer->setReshapeDimensions(expand_shape); + Y = expand_layer->getOutput(0); + } else if (dims_x.nbDims < dims_y.nbDims) { + nvinfer1::Dims expand_shape; + expand_shape.nbDims = dims_y.nbDims; + for (int i = 0; i < expand_shape.nbDims; i++) { + expand_shape.d[i] = 1; + } + for (int i = 0; i < dims_x.nbDims; i++) { + expand_shape.d[i + axis] = dims_x.d[i]; + } + expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); + expand_layer->setReshapeDimensions(expand_shape); + X = expand_layer->getOutput(0); + } + + layer = TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kEQUAL); + RreplenishLayerAndOutput(layer, "equal", {output_name}, test_mode); +#else + PADDLE_THROW( + platform::errors::Fatal("ElementWise Equal Operation is only supported " + "on TRT 8 or higher version.")); +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(equal, EqualOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 7465c24e310..1ee748afe50 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -110,6 +110,7 @@ struct SimpleOpTypeSetTeller : public Teller { "elementwise_mul", "elementwise_div", "elementwise_pow", + "equal", "dropout", "prelu", "conv2d_transpose", @@ -213,6 +214,7 @@ struct SimpleOpTypeSetTeller : public Teller { "elementwise_mul", "elementwise_div", "elementwise_pow", + "equal", "dropout", "prelu", "conv2d_transpose", @@ -2049,6 +2051,25 @@ bool OpTeller::Tell(const framework::ir::Node* node, } #endif + if (op_type == "equal") { +#if !IS_TRT_VERSION_GE(8000) + VLOG(3) << "compare is not supported when TensorRT < 8.0"; + return false; +#else + int axis = BOOST_GET_CONST(int, desc.GetAttr("axis")); + if (axis == 0) { + return false; + } + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } +#endif + } + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_equal.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_equal.py new file mode 100644 index 00000000000..285ac3f2202 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_equal.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + if attrs[0]['axis'] == 0: + return false + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8415: + return False + return True + + def sample_program_configs(self): + + def generate_input(shape): + return np.random.random(shape).astype(np.float32) + + for batch in [1, 2, 4]: + for shape in [[batch, 1], [batch, 1, 32], [batch, 1, 16, 32]]: + for axis in [-1 if len(shape) == 1 else 1]: + self.dims = len(shape) + dics = [{"axis": axis}, {"in_dtype": 0, "out_dtype": 5}] + ops_config = [{ + "op_type": "equal", + "op_inputs": { + "X": ["input_data1"], + "Y": ["input_data2"] + }, + "op_outputs": { + "Out": ["compare_output_data"] + }, + "op_attrs": dics[0] + }, { + "op_type": "cast", + "op_inputs": { + "X": ["compare_output_data"] + }, + "op_outputs": { + "Out": ["output_data"] + }, + "op_attrs": dics[1] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data1": + TensorConfig( + data_gen=partial(generate_input, shape)), + "input_data2": + TensorConfig( + data_gen=partial(generate_input, shape)) + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + # The input.dims[1] must be equal to the weight's length. + if self.dims == 2: + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 1], + "input_data2": [1, 1] + } + self.dynamic_shape.max_input_shape = { + "input_data1": [4, 1], + "input_data2": [4, 1] + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [2, 1], + "input_data2": [2, 1] + } + elif self.dims == 3: + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 1, 4], + "input_data2": [1, 1, 4] + } + self.dynamic_shape.max_input_shape = { + "input_data1": [4, 1, 256], + "input_data2": [1, 1, 256] + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [2, 1, 16], + "input_data2": [2, 1, 16] + } + elif self.dims == 4: + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 1, 4, 4], + "input_data2": [1, 1, 4, 4] + } + self.dynamic_shape.max_input_shape = { + "input_data1": [4, 1, 128, 256], + "input_data2": [4, 1, 128, 256] + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [2, 1, 32, 16], + "input_data2": [2, 1, 32, 16] + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if self.dims == 1: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab