diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 0a9a86a62d3b5b6111407030f890f99f25493fa6..c3fc6667581acca6804d95fb57f157f2bb289cca 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2460,6 +2460,7 @@ USE_TRT_CONVERTER(conv2d_transpose); USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(where); +USE_TRT_CONVERTER(bitwise_not); USE_TRT_CONVERTER(one_hot); USE_TRT_CONVERTER(one_hot_v2); USE_TRT_CONVERTER(swish); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index ea51e0136192b8938764dc9b843085eabca9f329..a47267ac3a562c1093fe917830ece17484c9ebf7 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -32,6 +32,7 @@ list( shuffle_channel_op.cc fill_any_like_op.cc where_op.cc + bitwise_not_op.cc one_hot_op.cc swish_op.cc silu_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/bitwise_not_op.cc b/paddle/fluid/inference/tensorrt/convert/bitwise_not_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..08324bb5d3003e3e67f73a12a243a56e81c5c866 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/bitwise_not_op.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class BitwiseNotConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert bitwise_not op to tensorrt layer"; + + framework::OpDesc op_desc(op, nullptr); + nvinfer1::ILayer* layer = nullptr; + + auto* input_tensor = engine_->GetITensor(op_desc.Input("X")[0]); + nvinfer1::DataType data_type = input_tensor->getType(); + + // for bool type: use UnaryOperation::kNOT, for int type: !x = -x - 1 + if (data_type == nvinfer1::DataType::kBOOL) { + layer = TRT_ENGINE_ADD_LAYER( + engine_, Unary, *input_tensor, nvinfer1::UnaryOperation::kNOT); + } else { + nvinfer1::Dims input_dims = input_tensor->getDimensions(); + + // set up a elementwise -1 tensor, can not get the dims info for + // dynamic_shape so just let it broadcaste + nvinfer1::Dims neg_one_tensor_dims; + neg_one_tensor_dims.nbDims = input_dims.nbDims; + for (int i = 0; i < input_dims.nbDims; ++i) { + neg_one_tensor_dims.d[i] = 1; + } + + nvinfer1::Weights weights{nvinfer1::DataType::kINT32, new int(-1), 1}; + auto neg_one_tensor = + TRT_ENGINE_ADD_LAYER(engine_, Constant, neg_one_tensor_dims, weights) + ->getOutput(0); + + auto mul_neg_one = + TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input_tensor, + *neg_one_tensor, + nvinfer1::ElementWiseOperation::kPROD); + + layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *(mul_neg_one->getOutput(0)), + *neg_one_tensor, + nvinfer1::ElementWiseOperation::kSUM); + } + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "bitwise_not", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(bitwise_not, BitwiseNotConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 1d1448b4166af03056e61586aa42b4ad18dbc401..f4d45e77256bf6e9a6500b7fc9ea14a41dca719f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1954,6 +1954,21 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "bitwise_not") { +#if !IS_TRT_VERSION_GE(8400) + auto* block = desc.Block(); + auto x_var_name = desc.Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + auto dtype = x_var_desc->GetDataType(); + if (dtype == framework::proto::VarType::BOOL || + dtype == framework::proto::VarType::INT8 || + dtype == framework::proto::VarType::UINT8) { + VLOG(3) << "BOOL / INT8 / UINT8 type support requires TensorRT 8.4"; + return false; + } +#endif + } + if (op_type == "one_hot" || op_type == "one_hot_v2") { #if IS_TRT_VERSION_LT(8510) VLOG(3) << "one_hot/one_hot_v2 is not supported when TensorRT < 8.5.1"; @@ -2778,6 +2793,7 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "shuffle_channel", "where", + "bitwise_not", "one_hot", "one_hot_v2", "swish", @@ -2935,6 +2951,7 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "shuffle_channel", "where", + "bitwise_not", "one_hot", "one_hot_v2", "swish", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bitwise_not.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bitwise_not.py new file mode 100644 index 0000000000000000000000000000000000000000..d779e1fce556714ee75b3d17885f6470b091aaad --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_bitwise_not.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import Any, Dict, List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertActivationTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input1(dims, batch, attrs: List[Dict[str, Any]]): + if dims == 1: + return np.random.random([32]).astype(np.bool8) + elif dims == 2: + return np.random.random([3, 32]).astype(np.int8) + elif dims == 3: + return np.random.random([3, 32, 32]).astype(np.int32) + else: + return np.random.random([batch, 3, 32, 32]).astype(np.int64) + + for dims in [1, 2, 3, 4]: + for batch in [1, 4]: + self.dims = dims + dics = [{}] + + ops_config = [ + { + "op_type": 'bitwise_not', + "op_inputs": {"X": ["input_data"]}, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": dics[0], + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, dims, batch, dics) + ) + }, + outputs=["output_data"], + ) + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [1]} + self.dynamic_shape.max_input_shape = {"input_data": [64]} + self.dynamic_shape.opt_input_shape = {"input_data": [32]} + elif self.dims == 2: + self.dynamic_shape.min_input_shape = {"input_data": [1, 16]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 32]} + self.dynamic_shape.opt_input_shape = {"input_data": [3, 32]} + elif self.dims == 3: + self.dynamic_shape.min_input_shape = {"input_data": [1, 16, 16]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 32, 32]} + self.dynamic_shape.opt_input_shape = {"input_data": [3, 32, 32]} + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 16, 16] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 3, 32, 32] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 32, 32] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + ver = paddle_infer.get_trt_compile_version() + trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 + if trt_version >= 8400: + if self.dims == 1 and not dynamic_shape: + return 0, 3 + return 1, 2 + else: + if (self.dims == 1 and not dynamic_shape) or ( + program_config.inputs['input_data'].dtype + in ['bool', 'int8', 'uint8'] + ): + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()