diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 47a3e46d076c318cc709d1403c22cb4920a8648b..38137d18db6bffbdcf93b05d23461fa08c82627d 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -25,18 +25,19 @@ class Graph; void IsTestPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it " "for activations and pooling."; - auto op_list = {"pool2d", "sigmoid", "logsigmoid", - "softshrink", "exp", "brelu", - "pow", "leaky_relu", "stanh", - "relu", "tanh", "tanh_shrink", - "sqrt", "abs", "ceil", - "elu", "floor", "cos", - "sin", "round", "reciprocal", - "hard_shrink", "hard_sigmoid", "relu6", - "soft_relu", "swish", "thresholded_relu", - "log", "square", "softplus", - "softsign", "silu", "mish", - "gumbel_softmax"}; + auto op_list = {"pool2d", "sigmoid", "logsigmoid", + "softshrink", "exp", "brelu", + "pow", "leaky_relu", "stanh", + "relu", "tanh", "tanh_shrink", + "sqrt", "abs", "ceil", + "elu", "floor", "cos", + "sin", "round", "reciprocal", + "hard_shrink", "hard_sigmoid", "relu6", + "soft_relu", "swish", "thresholded_relu", + "log", "square", "softplus", + "softsign", "silu", "gumbel_softmax", + "mish", "celu", "tanhshrink", + "logsigmoid"}; for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 280427cb4c8f38ac0013fceabda4132c85bb4ab4..345e8e64011cd9ba238fa493093744365f9a6bf9 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2318,11 +2318,14 @@ USE_TRT_CONVERTER(sum) USE_TRT_CONVERTER(shape) USE_TRT_CONVERTER(fill_constant) USE_TRT_CONVERTER(fused_token_prune) +USE_TRT_CONVERTER(celu) USE_TRT_CONVERTER(layernorm_shift_partition) USE_TRT_CONVERTER(preln_layernorm_shift_partition) USE_TRT_CONVERTER(merge_layernorm) USE_TRT_CONVERTER(generic_plugin_creater) USE_TRT_CONVERTER(custom_plugin_creater) +USE_TRT_CONVERTER(tanh_shrink) +USE_TRT_CONVERTER(logsigmoid) USE_TRT_CONVERTER(lookup_table) USE_TRT_CONVERTER(expand_v2) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index a6c0a42de37abeea4de2d6b1228d90e49d18bb61..72e2e1c7f0c9511067795ef77a3fe5073027fd88 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -76,7 +76,10 @@ list( shape_op.cc fill_constant_op.cc fused_token_prune_op.cc + celu_op.cc layernorm_shift_partition_op.cc + tanhshrink_op.cc + logsigmoid_op.cc preln_layernorm_shift_partition_op.cc merge_layernorm_op.cc generic_and_custom_plugin_creater.cc diff --git a/paddle/fluid/inference/tensorrt/convert/celu_op.cc b/paddle/fluid/inference/tensorrt/convert/celu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c357b95a2b0c86451c640d1eeb76954225689ce6 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/celu_op.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace nvinfer1 { +class ILayer; +} // namespace nvinfer1 +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class CeluOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert fluid celu op to tensorrt layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE_EQ(input_num, + 1, + platform::errors::InvalidArgument( + "The input X's size must equal to 1 in TRT celu op." + " But received X's size %d.", + input_num)); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE_EQ( + output_num, + 1UL, + platform::errors::InvalidArgument( + "The output Out's size must equal to 1 in TRT celu op. " + "But received Out's size %u.", + output_num)); + // Get attrs + float alpha = PADDLE_GET_CONST(float, op_desc.GetAttr("alpha")); + + nvinfer1::ILayer* layer = nullptr; + + int32_t rank = input->getDimensions().nbDims; + nvinfer1::Dims constant_shape; + constant_shape.nbDims = rank; + std::fill(constant_shape.d, constant_shape.d + rank, 1); + std::vector weight_alpha_data{alpha}; + std::vector weight_zero_data{0.f}; + std::vector weight_one_data{1.f}; + auto* alpha_data = + AddConstantLayer(weight_alpha_data.data(), constant_shape); + auto* constant_zero_data = + AddConstantLayer(weight_zero_data.data(), constant_shape); + auto* constant_one_data = + AddConstantLayer(weight_one_data.data(), constant_shape); + + auto* input_div_with_alpha = Div(input, alpha_data); + auto* input_exp = TRT_ENGINE_ADD_LAYER( + engine_, Unary, *input_div_with_alpha, nvinfer1::UnaryOperation::kEXP); + auto* input_sub_with_one = Sub(input_exp->getOutput(0), constant_one_data); + auto* input_prod_with_alpha = Prod(input_sub_with_one, alpha_data); + auto* min_input = Min(input_prod_with_alpha, constant_zero_data); + auto* relu = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *input, nvinfer1::ActivationType::kRELU); + layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *relu->getOutput(0), + *min_input, + nvinfer1::ElementWiseOperation::kSUM); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "celu", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(celu, CeluOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/logsigmoid_op.cc b/paddle/fluid/inference/tensorrt/convert/logsigmoid_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..74f8d21eab702c0cc9efc1be6e04190873030e92 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/logsigmoid_op.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace nvinfer1 { +class ILayer; +} // namespace nvinfer1 +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class LogSigmoidOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert fluid LogSigmoid op to tensorrt layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE_EQ( + input_num, + 1, + platform::errors::InvalidArgument( + "The input X's size must equal to 1 in TRT LogSigmoid op." + " But received X's size %d.", + input_num)); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE_EQ( + output_num, + 1UL, + platform::errors::InvalidArgument( + "The output Out's size must equal to 1 in TRT LogSigmoid op. " + "But received Out's size %u.", + output_num)); + + nvinfer1::ILayer* layer = nullptr; + auto* sigmoid = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *input, nvinfer1::ActivationType::kSIGMOID); + layer = TRT_ENGINE_ADD_LAYER(engine_, + Unary, + *(sigmoid->getOutput(0)), + nvinfer1::UnaryOperation::kLOG); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "logsigmoid", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(logsigmoid, LogSigmoidOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/tanhshrink_op.cc b/paddle/fluid/inference/tensorrt/convert/tanhshrink_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f31b2a50655e7d344753c80e7586b7f55fefa236 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/tanhshrink_op.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace nvinfer1 { +class ILayer; +} // namespace nvinfer1 +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class TanhshrinkOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert fluid Tanhshrink op to tensorrt layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE_EQ( + input_num, + 1, + platform::errors::InvalidArgument( + "The input X's size must equal to 1 in TRT Tanhshrink op." + " But received X's size %d.", + input_num)); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE_EQ( + output_num, + 1UL, + platform::errors::InvalidArgument( + "The output Out's size must equal to 1 in TRT Tanhshrink op. " + "But received Out's size %u.", + output_num)); + + nvinfer1::ILayer* layer = nullptr; + auto* tanh = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *input, nvinfer1::ActivationType::kTANH); + layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input, + *(tanh->getOutput(0)), + nvinfer1::ElementWiseOperation::kSUB); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "tanh_shrink", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(tanh_shrink, TanhshrinkOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_celu_op.cc b/paddle/fluid/inference/tensorrt/convert/test_celu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..73b89ebb758077d465734f4c6b791650b84d168c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_celu_op.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(celu_op, test_celu) { + std::unordered_set parameters; + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("celu_input", nvinfer1::Dims3(3, 2, 2)); + validator.DeclOutputVar("celu_out", nvinfer1::Dims3(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("celu"); + desc.SetInput("X", {"celu_input"}); + desc.SetOutput("Out", {"celu_out"}); + + desc.SetAttr("alpha", 2.0f); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(celu); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 3e6f5779c6fa87e79f93998817df51748f1e32b5..6680b16a356b9900a20961c155f3aaafa054e8d5 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -77,16 +77,17 @@ struct SimpleOpTypeSetTeller : public Teller { desc.HasAttr("skip_quant")) return false; std::unordered_set act_op_list = { - "relu", "relu6", "sigmoid", - "elu", "selu", "softsign", - "softplus", "stanh", "thresholded_relu", - "exp", "log", "sqrt", - "abs", "sin", "cos", - "tan", "tanh", "sinh", - "cosh", "asin", "acos", - "atan", "asinh", "atanh", - "ceil", "floor", "erf", - "silu"}; + "relu", "relu6", "sigmoid", + "elu", "selu", "softsign", + "softplus", "stanh", "thresholded_relu", + "exp", "log", "sqrt", + "abs", "sin", "cos", + "tan", "tanh", "sinh", + "cosh", "asin", "acos", + "atan", "asinh", "atanh", + "ceil", "floor", "erf", + "silu", "celu", "tanh_shrink", + "logsigmoid"}; if (act_op_list.find(op_type) != act_op_list.end()) { auto* block = desc.Block(); if (block == nullptr) { @@ -2212,6 +2213,7 @@ struct SimpleOpTypeSetTeller : public Teller { "shuffle_channel", "swish", "silu", + "celu", "split", "instance_norm", "gelu", @@ -2268,6 +2270,8 @@ struct SimpleOpTypeSetTeller : public Teller { "squeeze2", "unsqueeze2", "layernorm_shift_partition", + "tanh_shrink", + "logsigmoid", "preln_layernorm_shift_partition", "lookup_table", "lookup_table_v2", @@ -2330,6 +2334,7 @@ struct SimpleOpTypeSetTeller : public Teller { "shuffle_channel", "swish", "silu", + "celu", "split", "instance_norm", "gelu", @@ -2387,6 +2392,8 @@ struct SimpleOpTypeSetTeller : public Teller { "unsqueeze2", "fused_token_prune", "layernorm_shift_partition", + "tanh_shrink", + "logsigmoid", "preln_layernorm_shift_partition", "merge_layernorm", "lookup_table", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_celu.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_celu.py new file mode 100644 index 0000000000000000000000000000000000000000..58a8fecfdeb9609c8b031833fc36ea6153804f3b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_celu.py @@ -0,0 +1,140 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import List, Dict, Any +import unittest + + +class TrtConvertCeluTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(dims, attrs: List[Dict[str, Any]]): + if dims == 1: + return np.ones([3]).astype(np.float32) + elif dims == 2: + return np.ones([3, 64]).astype(np.float32) + elif dims == 3: + return np.ones([3, 64, 64]).astype(np.float32) + else: + return np.ones([1, 3, 64, 64]).astype(np.float32) + + for dims in [1, 2, 3, 4]: + for alpha in [1.0, 2.0, 3.0]: + self.dims = dims + + dics = [{"alpha": alpha}] + + ops_config = [ + { + "op_type": "celu", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": dics[0], + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, dims, dics) + ) + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [1]} + self.dynamic_shape.max_input_shape = {"input_data": [128]} + self.dynamic_shape.opt_input_shape = {"input_data": [64]} + elif self.dims == 2: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} + elif self.dims == 3: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} + self.dynamic_shape.max_input_shape = { + "input_data": [10, 64, 64] + } + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 3, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if self.dims == 1: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), (1e-3, 1e-3) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), (1e-3, 1e-3) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_logsigmoid.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_logsigmoid.py new file mode 100755 index 0000000000000000000000000000000000000000..f9ba462e669cd0a26686d468fe8a6a8fcf4df19b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_logsigmoid.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import List, Dict, Any +import unittest + + +class TrtConvertLogSigmoidTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(dims, attrs: List[Dict[str, Any]]): + if dims == 1: + return np.ones([3]).astype(np.float32) + elif dims == 2: + return np.ones([3, 64]).astype(np.float32) + elif dims == 3: + return np.ones([3, 64, 64]).astype(np.float32) + else: + return np.ones([1, 3, 64, 64]).astype(np.float32) + + for dims in [1, 2, 3, 4]: + self.dims = dims + + ops_config = [ + { + "op_type": "logsigmoid", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": {}, + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, dims, {}) + ) + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [1]} + self.dynamic_shape.max_input_shape = {"input_data": [128]} + self.dynamic_shape.opt_input_shape = {"input_data": [64]} + elif self.dims == 2: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} + elif self.dims == 3: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} + self.dynamic_shape.max_input_shape = { + "input_data": [10, 64, 64] + } + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 3, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if self.dims == 1: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), (1e-3, 1e-3) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), (1e-3, 1e-3) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_silu.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_silu.py index 73df326fb01ad2ef9d6490686126783ab9fab0ce..9ed3720a79f48e070a92eba0d9f73dd2ed98376a 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_silu.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_silu.py @@ -21,7 +21,7 @@ from typing import Any, Dict, List import unittest -class TrtConvertSwishTest(TrtLayerAutoScanTest): +class TrtConvertSiluTest(TrtLayerAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: return True @@ -37,33 +37,32 @@ class TrtConvertSwishTest(TrtLayerAutoScanTest): return np.ones([1, 3, 64, 64]).astype(np.float32) for dims in [1, 2, 3, 4]: - for beta in [1.0, 2.0, 3.0]: - self.dims = dims - - ops_config = [ - { - "op_type": "silu", - "op_inputs": { - "X": ["input_data"], - }, - "op_outputs": {"Out": ["output_data"]}, - "op_attrs": {}, - } - ] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "input_data": TensorConfig( - data_gen=partial(generate_input1, dims, {}) - ) - }, - outputs=["output_data"], - ) + self.dims = dims - yield program_config + ops_config = [ + { + "op_type": "silu", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": {}, + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, dims, {}) + ) + }, + outputs=["output_data"], + ) + + yield program_config def sample_predictor_configs( self, program_config diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_tanhshrink.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_tanhshrink.py new file mode 100755 index 0000000000000000000000000000000000000000..9fbdbb89f2398e6a141f2bd76aadddc1c3b6e1e4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_tanhshrink.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import List, Dict, Any +import unittest + + +class TrtConvertTanhshrinkTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(dims, attrs: List[Dict[str, Any]]): + if dims == 1: + return np.ones([3]).astype(np.float32) + elif dims == 2: + return np.ones([3, 64]).astype(np.float32) + elif dims == 3: + return np.ones([3, 64, 64]).astype(np.float32) + else: + return np.ones([1, 3, 64, 64]).astype(np.float32) + + for dims in [1, 2, 3, 4]: + self.dims = dims + + ops_config = [ + { + "op_type": "tanh_shrink", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": {}, + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, dims, {}) + ) + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [1]} + self.dynamic_shape.max_input_shape = {"input_data": [128]} + self.dynamic_shape.opt_input_shape = {"input_data": [64]} + elif self.dims == 2: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64]} + elif self.dims == 3: + self.dynamic_shape.min_input_shape = {"input_data": [1, 32, 32]} + self.dynamic_shape.max_input_shape = { + "input_data": [10, 64, 64] + } + self.dynamic_shape.opt_input_shape = {"input_data": [3, 64, 64]} + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 32, 32] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 3, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if self.dims == 1: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), (1e-3, 1e-3) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), (1e-3, 1e-3) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()