diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 048424e306ee00b045fdc3ad52f606a575a8320b..26bca9b1e54ecb854b944e7b1311d4877f2c6796 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -141,6 +141,10 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("optim_input_shape", new std::map>( argument->optim_input_shape())); + bool with_dynamic_shape = argument->max_input_shape().size() > 0 && + argument->min_input_shape().size() > 0 && + argument->optim_input_shape().size() > 0; + pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); pass->Set("trt_disabled_ops", new std::vector( argument->tensorrt_disabled_ops())); pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla())); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 535f082dccd279ada2da8729dfd25275ebaf94aa..791cd15169a7bfad91d3fc1e6b52da810281346d 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -40,6 +40,7 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( auto use_calib_mode = Get("use_calib_mode"); bool no_calib_int8 = enable_int8 && !(use_calib_mode); auto trt_disabled_ops = Get>("trt_disabled_ops"); + auto with_dynamic_shape = Get("with_dynamic_shape"); auto teller = [&](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), @@ -48,8 +49,8 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( << " is diabled by config in TensorRT"; return false; } - return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(), - no_calib_int8); + return tensorrt::OpTeller::Global().Tell(node, no_calib_int8, + with_dynamic_shape); }; framework::ir::SubGraphFuser fuser( diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 536f0b7407fc3239d37a8df21a8820f2052478e6..81c68a65576ca4cbfd58e915cae58f796115d1fe 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1151,6 +1151,8 @@ USE_TRT_CONVERTER(elementwise_mul_tensor); USE_TRT_CONVERTER(elementwise_max_tensor); USE_TRT_CONVERTER(elementwise_min_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor); +USE_TRT_CONVERTER(transpose); +USE_TRT_CONVERTER(flatten); USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index f80b2274d411393340c654975d197c1da85ab174..26d6b9c9015c2e2926b0b43e91e19b583bb7f6e8 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -3,50 +3,9 @@ nv_library(tensorrt_converter SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc - shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc + shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter) - -# TODO(xingzhaolong): fix the the following ci ut error. - -#nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) -#nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine mul_op) -#nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine mul_op) -#nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op) -#nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine conv_op conv_transpose_op) -#nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op tensorrt_plugin) -#nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin -# elementwise_add_op elementwise_mul_op) -#nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine softmax_op) -#nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine batch_norm_op) -#nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine concat_op) -#nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine dropout_op) -#nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine pad_op) -#nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin -# split_op concat_op) -#nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin -# prelu_op) -#nv_test(test_trt_leaky_relu_op SRCS test_leaky_relu_op.cc leaky_relu_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op) - -#nv_test(test_shuffle_channel_op SRCS test_shuffle_channel_op.cc shuffle_channel_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine shuffle_channel_op) - -#nv_test(test_swish_op SRCS test_swish_op.cc swish_op.cc -# DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op tensorrt_plugin) diff --git a/paddle/fluid/inference/tensorrt/convert/flatten_op.cc b/paddle/fluid/inference/tensorrt/convert/flatten_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..03a1c1672469eca959dc08800b248f96ef165b13 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/flatten_op.cc @@ -0,0 +1,62 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * FlattenOp, only support static shape mode currently. + */ +class FlattenOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + int dims = input->getDimensions().nbDims; + + int dim_prod = 1; + for (int i = 0; i < dims; i++) { + int dim_i = input->getDimensions().d[i]; + PADDLE_ENFORCE_GT( + dim_i, 0, platform::errors::InvalidArgument( + "flatten input dim should be > 0, but got %d.", dim_i)); + dim_prod *= dim_i; + } + nvinfer1::Dims flatten_dim; + flatten_dim.nbDims = 1; + flatten_dim.d[0] = dim_prod; + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + layer->setReshapeDimensions(flatten_dim); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "flatten", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(flatten, FlattenOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 4a386ac1d81c538f29e47296afce3127d0395465..8de16df0a2f610b30da389bc73e122074d66471e 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -109,7 +109,18 @@ class OpConverter { it, platform::errors::Unimplemented("no OpConverter for optype [%s]", op_desc.Type())); } - + if (op_desc.Type() == "transpose2") { + it = Registry::Global().Lookup("transpose"); + PADDLE_ENFORCE_NOT_NULL( + it, platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } + if (op_desc.Type() == "flatten2") { + it = Registry::Global().Lookup("flatten"); + PADDLE_ENFORCE_NOT_NULL( + it, platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); + } if (!it) { it = Registry::Global().Lookup(op_desc.Type()); } diff --git a/paddle/fluid/inference/tensorrt/convert/transpose_op.cc b/paddle/fluid/inference/tensorrt/convert/transpose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c6f2d0174eac83c5f8530e019ebd9e239f41f87d --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/transpose_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * TransposeOp + */ +class TransposeOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + int dims = input->getDimensions().nbDims; + std::vector axis = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("axis")); + if (!engine_->with_dynamic_shape()) { + for (size_t i = 1; i < axis.size(); i++) { + axis[i]--; + } + } + + nvinfer1::Permutation perm; + for (int i = 0; i < dims; i++) { + int j = engine_->with_dynamic_shape() ? i : i + 1; + perm.order[i] = axis[j]; + } + + // Permutation is valid if it has nbDims unique values from range [0, + // nbDims-1] + auto is_valid_permutation = [&](int dims, + const nvinfer1::Permutation& permutation) { + std::bitset found; + for (int i = 0; i < dims; ++i) { + const int x = permutation.order[i]; + if ((x < 0) || (x >= dims) || found[x]) + return false; // Out of bounds or duplicate + found.set(x); + } + return true; + }; + + PADDLE_ENFORCE_EQ(is_valid_permutation(dims, perm), true, + platform::errors::InvalidArgument( + "Invalid permutation dimensions for trt transpose op " + "converter: duplicate or out of bound.")); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + layer->setFirstTranspose(perm); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "transpose", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(transpose, TransposeOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 821fdeddc98531bf4ae805cf2a1521644ed2cdc0..d0c9d01872ced546a065a20f490fa65bf4f0b3d9 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -103,11 +103,17 @@ struct SimpleOpTypeSetTeller : public Teller { "layer_norm", "scale", "stack", + "transpose2", + "transpose", + "flatten2", + "flatten", }; }; -bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, - bool use_no_calib_int8) { +bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, + bool with_dynamic_shape) { + const std::string op_type = node->Op()->Type(); + const framework::OpDesc desc = *node->Op(); // do not support the op which is labeled the `skip_quant` if ((desc.HasAttr("namescope") && BOOST_GET_CONST(std::string, desc.GetAttr("op_namescope")) == @@ -144,6 +150,26 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, } } } + if (op_type == "transpose2" || op_type == "transpose") { + if (!desc.HasAttr("axis")) { + return false; + } else { + std::vector axis = + BOOST_GET_CONST(std::vector, desc.GetAttr("axis")); + if (!with_dynamic_shape && axis[0] != 0) return false; + if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false; + } + } + if (op_type == "flatten2" || op_type == "flatten") { + // flatten doesn't support dynamic shape currently + if (!desc.HasAttr("axis")) { + return false; + } else { + if (with_dynamic_shape) return false; + int axis = BOOST_GET_CONST(int, desc.GetAttr("axis")); + if (axis != 1) return false; + } + } if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } return false; diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 9113525a5c94fda633f08188687e822822bb7bce..0a0cbeae51b021430301fb03528031b18ff7b31d 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -17,7 +17,7 @@ #include #include #include - +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/inference/tensorrt/engine.h" @@ -65,8 +65,8 @@ class OpTeller { return *x; } - bool Tell(const std::string& op_type, const framework::OpDesc& desc, - bool use_no_calib_int8 = false); + bool Tell(const framework::ir::Node* node, bool use_no_calib_int8 = false, + bool with_dynamic_shape = false); private: OpTeller(); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py index e4a7305f70faf7677c1390b335682f1b8b9dc536..2c77ce1723129471ae71dbef3e9acb69699ea0df 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -287,6 +287,59 @@ class TensorRTSubgraphPassInstanceNormTest(InferencePassTest): PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) +class TensorRTSubgraphPassTransposeTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + transpose_out = self.append_transpose(data) + out = fluid.layers.batch_norm(transpose_out, is_test=True) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassTransposeTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def append_transpose(self, data): + return fluid.layers.transpose(data, [0, 3, 1, 2]) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassFlattenTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + flatten_out = self.append_flatten(data) + reshape_out = fluid.layers.reshape(flatten_out, [-1, 0, 1, 1]) + out = fluid.layers.batch_norm(reshape_out, is_test=True) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassFlattenTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def append_flatten(self, data): + return fluid.layers.flatten(data, axis=1) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + class TensorRTSubgraphPassLayerNormTest(InferencePassTest): def setUp(self): self.set_params() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py index 4661333ffeca10b7026c68a47b44fc3be83ff093..b15035c3c4dbad5d2b81880a478fcbf92ba6f097 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_transpose_flatten_concat_fuse_pass.py @@ -27,14 +27,15 @@ class TransposeFlattenConcatFusePassTRTTest(InferencePassTest): name="data1", shape=[8, 32, 128], dtype="float32") data2 = fluid.data( name="data2", shape=[8, 32, 128], dtype="float32") - trans1 = fluid.layers.transpose(data1, perm=[2, 1, 0]) - trans2 = fluid.layers.transpose(data2, perm=[2, 1, 0]) + trans1 = fluid.layers.transpose(data1, perm=[0, 2, 1]) + trans2 = fluid.layers.transpose(data2, perm=[0, 2, 1]) flatt1 = fluid.layers.flatten(trans1) flatt2 = fluid.layers.flatten(trans2) - concat_out = fluid.layers.concat([flatt1, flatt2]) + concat_out = fluid.layers.concat([flatt1, flatt2], axis=1) # There is no parameters for above structure. # Hence, append a batch_norm to avoid failure caused by load_combined. - out = fluid.layers.batch_norm(concat_out, is_test=True) + reshape_out = fluid.layers.reshape(concat_out, [-1, 0, 1, 1]) + out = fluid.layers.batch_norm(reshape_out, is_test=True) self.feeds = { "data1": np.random.random([8, 32, 128]).astype("float32"), @@ -42,7 +43,7 @@ class TransposeFlattenConcatFusePassTRTTest(InferencePassTest): } self.enable_trt = True self.trt_parameters = TransposeFlattenConcatFusePassTRTTest.TensorRTParam( - 1 << 20, 8, 3, AnalysisConfig.Precision.Float32, False, False) + 1 << 20, 8, 0, AnalysisConfig.Precision.Float32, False, False) self.fetch_list = [out] def test_check_output(self):