diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 969166a31ab8aab6fdd0110fca2b901146d7bef1..c54e805e2669fb583460d3eb900d77cab0357b3e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1880,6 +1880,9 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { auto reshape1_op = pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2"); + reshape1_op->assert_more([&](Node *x) { + return boost::get>(x->Op()->GetAttr("shape")).size() == 5; + }); auto reshape1_out = pattern->NewNode(reshape1_out_repr()) ->assert_is_op_output("reshape2", "Out") diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7650b2e90a0bd4d89f5fd3f71830cf5ca612a55d..a5e8821c1a0cd7340fe47e2db5b9643473d9d58a 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -968,6 +968,8 @@ USE_TRT_CONVERTER(split); USE_TRT_CONVERTER(prelu); USE_TRT_CONVERTER(conv2d_transpose); USE_TRT_CONVERTER(leaky_relu); +USE_TRT_CONVERTER(shuffle_channel); +USE_TRT_CONVERTER(swish); #endif #if PADDLE_WITH_ANAKIN diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index bc2c0914728f30fe45dc4ece6477d03a244e8b40..539f8f06023666512b8049bfbffa16049610817e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -74,6 +74,7 @@ const std::vector kTRTSubgraphPasses({ "infer_clean_graph_pass", // "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // + "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // "delete_quant_dequant_op_pass", // // "fc_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 854007ce801e4ccc853d6186df2651e95ff4fa5d..b63b75f78901d3f3df38aea911417b697f540dd4 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -3,6 +3,7 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc +shuffle_channel_op.cc swish_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -42,3 +43,9 @@ nv_test(test_op_converter SRCS test_op_converter.cc DEPS # prelu_op) #nv_test(test_trt_leaky_relu_op SRCS test_leaky_relu_op.cc leaky_relu_op.cc # DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op) + +#nv_test(test_shuffle_channel_op SRCS test_shuffle_channel_op.cc shuffle_channel_op.cc +# DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine shuffle_channel_op) + +#nv_test(test_swish_op SRCS test_swish_op.cc swish_op.cc +# DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op tensorrt_plugin) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 5c2454fa9a35eb7b70b11750592f012ed4ff690a..18de448690534656cdfe851c74a2b390264b1b6b 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -42,11 +42,20 @@ class ActivationOpConverter : public OpConverter { nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, Activation, *const_cast(input_tensor), op_pair->second); + +#if IS_TRT_VERSION_GE(5130) + // max(alpha, min(beta, x)) + if (op_type_ == "relu6") { + layer->setAlpha(0.); + layer->setBeta(6.); + } +#endif + auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode); if (op_desc.HasAttr("out_scale")) { -#if IS_TRT_VERSION_GE(5000) +#if IS_TRT_VERSION_GE(5130) float out_scale = boost::get(op_desc.GetAttr("out_scale")); engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); #endif @@ -63,6 +72,9 @@ const std::unordered_map {"relu", nvinfer1::ActivationType::kRELU}, {"sigmoid", nvinfer1::ActivationType::kSIGMOID}, {"tanh", nvinfer1::ActivationType::kTANH}, +#if IS_TRT_VERSION_GE(5130) + {"relu6", nvinfer1::ActivationType::kCLIP}, +#endif }; class ReluOpConverter : public ActivationOpConverter { @@ -80,6 +92,11 @@ class TanhOpConverter : public ActivationOpConverter { TanhOpConverter() { op_type_ = "tanh"; } }; +class Relu6OpConverter : public ActivationOpConverter { + public: + Relu6OpConverter() { op_type_ = "relu6"; } +}; + } // namespace tensorrt } // namespace inference } // namespace paddle @@ -87,3 +104,4 @@ class TanhOpConverter : public ActivationOpConverter { REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); REGISTER_TRT_OP_CONVERTER(sigmoid, SigmoidOpConverter); REGISTER_TRT_OP_CONVERTER(tanh, TanhOpConverter); +REGISTER_TRT_OP_CONVERTER(relu6, Relu6OpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc index 71177e5e66dcc52afc8bc4f4a6ade802c0f136a7..510b622f46fed13bd0fd07d37d15fb7047864fea 100644 --- a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc @@ -31,6 +31,20 @@ class DropoutOpConverter : public OpConverter { auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); float dropout_prob = boost::get(op_desc.GetAttr("dropout_prob")); + std::string downgrade_in_infer = ""; + if (op_desc.HasAttr("dropout_implementation")) { + downgrade_in_infer = + boost::get(op_desc.GetAttr("dropout_implementation")); + } + + if (!downgrade_in_infer.empty() && + downgrade_in_infer == "upscale_in_train") { + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input1); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "dropout", {output_name}, test_mode); + return; + } + platform::CPUPlace cpu_place; std::unique_ptr weight_tensor( new framework::LoDTensor()); diff --git a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc index 7753fda06cfb3cacc75c008efb5c4b16f7def0f9..2a46938cb1090a4a0dc0f5d1c3d588097d281552 100644 --- a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc @@ -35,7 +35,14 @@ class LeakyReluOpConverter : public OpConverter { PADDLE_ENFORCE(output_num == 1); // Get attrs float alpha = boost::get(op_desc.GetAttr("alpha")); + nvinfer1::ILayer* output_layer = nullptr; +#if IS_TRT_VERSION_GE(5100) + nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *input, nvinfer1::ActivationType::kLEAKY_RELU); + layer->setAlpha(alpha); + output_layer = layer; +#else platform::CPUPlace place; std::unique_ptr alpha_tensor( new framework::LoDTensor()); @@ -65,7 +72,7 @@ class LeakyReluOpConverter : public OpConverter { nvinfer1::ScaleMode::kUNIFORM, shift.get(), sub_scale.get(), power.get()); PADDLE_ENFORCE(nullptr != scale_relu_layer); - auto* output_layer = + output_layer = TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *(scale_layer->getOutput(0)), *(scale_relu_layer->getOutput(0)), nvinfer1::ElementWiseOperation::kSUM); @@ -75,7 +82,7 @@ class LeakyReluOpConverter : public OpConverter { PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) == engine_->weight_map.end()); engine_->weight_map[alpha_name] = std::move(alpha_tensor); - +#endif auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(output_layer, "leaky_relu", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f891e0f9f4e6731199e4a6884ec74a1265b3fef --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/shuffle_channel_op.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * ConcatOp + */ +class ShuffleChannelOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + PADDLE_ENFORCE(input_dims.nbDims == 3); + int c = input_dims.d[0]; + int h = input_dims.d[1]; + int w = input_dims.d[2]; + int group = boost::get(op_desc.GetAttr("group")); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + nvinfer1::Dims4 reshape_dim(group, c / group, h, w); + layer->setReshapeDimensions(reshape_dim); + layer->setSecondTranspose({1, 0, 2, 3}); + auto* output = layer->getOutput(0); + + auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *output); + nvinfer1::DimsCHW reshape_dim2(c, h, w); + reshape_layer->setReshapeDimensions(reshape_dim2); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(reshape_layer, "concat", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/swish_op.cc b/paddle/fluid/inference/tensorrt/convert/swish_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..42f2008afa16c305561db9b27f472819fe4cec17 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/swish_op.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SwishOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid swish op to tensorrt layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE(input_num == 1); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE(output_num == 1); + // Get attrs + float beta = boost::get(op_desc.GetAttr("beta")); + + plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta); + + nvinfer1::IPluginLayer* layer = + engine_->AddPlugin(&input, input_num, plugin); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "swish", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(swish, SwishOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index dd3dfb0bc7b609e28462954835a0d40e0a63b6cd..f2dc5ba1c7c2c832e0239f6a30760c354aaf4699 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -46,6 +46,8 @@ TEST(SigmoidOpConverter, main) { test_activation("sigmoid"); } TEST(TanhOpConverter, main) { test_activation("tanh"); } +TEST(Relu6OpConverter, main) { test_activation("relu6"); } + } // namespace tensorrt } // namespace inference } // namespace paddle @@ -53,3 +55,4 @@ TEST(TanhOpConverter, main) { test_activation("tanh"); } USE_OP(relu); USE_OP(sigmoid); USE_OP(tanh); +USE_OP(relu6); diff --git a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc index 6b8e621b702d977f5868766a6eafb98c8522c3cd..81e905b975327125fddc8a33d871cc97290e4ac1 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc @@ -34,6 +34,7 @@ TEST(DropoutOpConverter, main) { framework::OpDesc desc; int is_test = 1; float dropout_prob = 0.4; + std::string dropout_implementation = "upscale_in_train"; desc.SetType("dropout"); desc.SetInput("X", {"dropout-X"}); @@ -42,6 +43,8 @@ TEST(DropoutOpConverter, main) { desc.SetAttr("is_test", is_test); desc.SetAttr("dropout_prob", dropout_prob); + desc.SetAttr("dropout_implementation", dropout_implementation); + LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); LOG(INFO) << "execute"; diff --git a/paddle/fluid/inference/tensorrt/convert/test_shuffle_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/test_shuffle_channel_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3cc5273734e02ecc4ed6453e6cd47052463c8b2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_shuffle_channel_op.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(leaky_relu_op, test_leaky_relu) { + std::unordered_set parameters; + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("sc_input", nvinfer1::DimsCHW(4, 2, 2)); + validator.DeclOutputVar("sc_out", nvinfer1::DimsCHW(4, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("shuffle_channel"); + desc.SetInput("X", {"sc_input"}); + desc.SetOutput("Out", {"sc_out"}); + int group = 2; + desc.SetAttr("group", group); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +// USE_OP(leaky_relu); +USE_OP(shuffle_channel); diff --git a/paddle/fluid/inference/tensorrt/convert/test_swish_op.cc b/paddle/fluid/inference/tensorrt/convert/test_swish_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c15c79bb13fad4233775482dc1b8b4841e61a23a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_swish_op.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(swish_op, test_swish) { + std::unordered_set parameters; + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("sw_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclOutputVar("sw_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("swish"); + desc.SetInput("X", {"sw_input"}); + desc.SetOutput("Out", {"sw_out"}); + + desc.SetAttr("beta", 2.0f); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(swish); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 170ca40d659efad226cef44c89b5491f81abedec..292f5e1d4b928e81bb1a3020ae212791ac60d45b 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -20,7 +20,11 @@ namespace tensorrt { // Just tell by the op_types. struct SimpleOpTypeSetTeller : public Teller { - SimpleOpTypeSetTeller() {} + SimpleOpTypeSetTeller() { +#if IS_TRT_VERSION_GE(5130) + teller_set.insert("relu6"); +#endif + } bool operator()(const std::string& op_type, const framework::OpDesc& desc) override { @@ -28,11 +32,27 @@ struct SimpleOpTypeSetTeller : public Teller { } private: - std::unordered_set teller_set{ - {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", - "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", - "elementwise_add", "elementwise_mul", "dropout", "prelu", - "conv2d_transpose", "leaky_relu", "fc"}}; + std::unordered_set teller_set{{"mul", + "conv2d", + "pool2d", + "relu", + "softmax", + "sigmoid", + "depthwise_conv2d", + "batch_norm", + "concat", + "tanh", + "pad", + "elementwise_add", + "elementwise_mul", + "dropout", + "prelu", + "conv2d_transpose", + "leaky_relu", + "fc", + "shuffle_channel", + "swish", + "split"}}; }; bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 3363d77af84f767a83ea6695a4423af71f34256c..7ff1d4746a1817493774d653982b345cf6948f74 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -18,6 +18,7 @@ #include #include #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 709aa103d1b6681221328b180d65e90f08d3368e..d01c5c823b51d204f1e507b55edb127737a18be4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,5 +1,5 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu trt_plugin_factory.cc - avg_pool_op_plugin.cu + avg_pool_op_plugin.cu swish_op_plugin.cu DEPS enforce tensorrt_engine prelu) diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index b8a044fe99b91893c8c9ef661b4f46ebaa6db8c7..84f938eeb5fa50421a819978cd84c968919c96b3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -34,6 +34,7 @@ int PReluPlugin::initialize() { cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size()); cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float), cudaMemcpyHostToDevice); + return 0; } nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index b5503c3b95ee2429dd865fd6de416a04aafbccf0..49420a7a667309eaf8a97c480ab361cec4c29e3e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -27,50 +27,20 @@ SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) { } REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize); -// copied from operators::math::SplitFunctor template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int* out_cols, - int out_cols_size, T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = out_cols[0]; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = out_cols[curr_segment + 1]; - while (curr_col_offset <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - curr_col_offset = out_cols[curr_segment + 1]; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - T* output_ptr = outputs_data[curr_segment]; - if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * segment_width + local_col] = - input_data[tid_y * in_col + tid_x]; - } - } -} - -template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, - T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int split = tid_x / fixed_out_col; - int in_offset = tid_x - split * fixed_out_col; - T* output_ptr = outputs_data[split]; - if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * fixed_out_col + in_offset] = - input_data[tid_y * in_col + tid_x]; +__device__ int upper_bound(T const* vals, int n, T const& key) { + int i = 0; + while (n > 0) { + int m = n / 2; + int j = i + m; + if (!(key < vals[j])) { + i = j + 1; + n -= m + 1; + } else { + n = m; } } + return i; } nvinfer1::Dims SplitPlugin::getOutputDimensions( @@ -101,80 +71,61 @@ int SplitPlugin::initialize() { if (output_length_[i] != output_length_[0]) { same_shape_ = false; } - segment_offsets.push_back(segment_offsets.back() + - output_length_[i] * inner_cols_); + segment_offsets.push_back(segment_offsets.back() + output_length_[i]); } - inner_cols_ *= dims.d[axis_]; + axis_shape_ = dims.d[axis_]; d_segment_offsets_ = segment_offsets; segment_offsets_ = std::move(segment_offsets); d_output_ptrs_.resize(this->getNbOutputs(), nullptr); return 0; } +// The following part of the code refers to onnx-tensorrt +// https://github.com/onnx/onnx-tensorrt/blob/master/Split.cu template -inline void Split(cudaStream_t stream, const bool same_shape, - const int outer_rows, const int inner_cols, - const std::vector& segment_offsets, - const int* d_segment_offsets, const T* input, T** outputs) { - const int kThreadsPerBlock = 1024; - const int kMaxBlocks = 65535; - int block_cols = kThreadsPerBlock; - if (inner_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((inner_cols + 31) >> 5) << 5; - } - int block_rows = kThreadsPerBlock / block_cols; - dim3 block_size = dim3(block_cols, block_rows, 1); - - int grid_cols = - std::min((inner_cols + block_cols - 1) / block_cols, kMaxBlocks); - int grid_rows = - std::min(kMaxBlocks / grid_cols, std::max(outer_rows / block_rows, 1)); - dim3 grid_size = dim3(grid_cols, grid_rows, 1); - - if (same_shape) { - SplitKernel<<>>( - input, outer_rows, inner_cols, segment_offsets[1], outputs); - } else { - SplitKernel<<>>( - input, outer_rows, inner_cols, d_segment_offsets, - static_cast(segment_offsets.size()), outputs); +__global__ void split_kernel(int nsegment, + int const* __restrict__ segment_offsets, + T const* __restrict__ idata, T* const* odatas, + int inner_cols, int axis_shape, int outer_rows) { + int x0 = threadIdx.x + blockIdx.x * blockDim.x; + int src_y0 = threadIdx.y + blockIdx.y * blockDim.y; + int z0 = threadIdx.z + blockIdx.z * blockDim.z; + for (int z = z0; z < outer_rows; z += blockDim.z * gridDim.z) { + for (int src_y = src_y0; src_y < axis_shape; + src_y += blockDim.y * gridDim.y) { + for (int x = x0; x < inner_cols; x += blockDim.x * gridDim.x) { + int segment = upper_bound(segment_offsets, nsegment, src_y) - 1; + int dst_y = src_y - segment_offsets[segment]; + int dst_ny = segment_offsets[segment + 1] - segment_offsets[segment]; + odatas[segment][x + inner_cols * (dst_y + dst_ny * z)] = + idata[x + inner_cols * (src_y + axis_shape * z)]; + } + } } } int SplitPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) { + const int* d_segment_offsets_ptr = + thrust::raw_pointer_cast(&d_segment_offsets_[0]); float const* input_ptr = reinterpret_cast(inputs[0]); - if (((batchSize == 1 && axis_ == 0) || axis_ == -1) && - this->getNbOutputs() < 10) { - float** output_ptrs = reinterpret_cast(outputs); - int data_type_size = (this->getDataType() == nvinfer1::DataType::kFLOAT) - ? sizeof(float) - : sizeof(__half); - for (int i = 0; i < this->getNbOutputs(); ++i) { - PADDLE_ENFORCE( - cudaMemcpyAsync( - output_ptrs[i], input_ptr + segment_offsets_[i], - (segment_offsets_[i + 1] - segment_offsets_[i]) * data_type_size, - cudaMemcpyDeviceToDevice, stream) == cudaSuccess); - } - } else { - outer_rows_ *= batchSize; - const int* d_segment_offsets_ptr = - thrust::raw_pointer_cast(&d_segment_offsets_[0]); - float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]); - PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, outputs, - this->getNbOutputs() * sizeof(float*), - cudaMemcpyHostToDevice, - stream) == cudaSuccess); - if (this->getDataType() == nvinfer1::DataType::kFLOAT) { - Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_, - d_segment_offsets_ptr, input_ptr, output_ptrs); - } else { - Split(stream, same_shape_, outer_rows_, inner_cols_, segment_offsets_, - d_segment_offsets_ptr, (__half*)input_ptr, // NOLINT - (__half**)output_ptrs); // NOLINT - } - } + float* const* h_odatas = reinterpret_cast(outputs); + float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]); + PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, h_odatas, + d_output_ptrs_.size() * sizeof(float*), + cudaMemcpyHostToDevice, + stream) == cudaSuccess); + + int outer_rows = outer_rows_ * batchSize; + + dim3 block(32, 16); + dim3 grid(std::min((inner_cols_ - 1) / block.x + 1, 65535u), + std::min((axis_shape_ - 1) / block.y + 1, 65535u), + std::min((outer_rows_ - 1) / block.z + 1, 65535u)); + + split_kernel<<>>( + d_segment_offsets_.size(), d_segment_offsets_ptr, input_ptr, output_ptrs, + inner_cols_, axis_shape_, outer_rows); return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index cbb72590567a35bee29387d4c00518b437913508..b2a7bc3bdaa2543e83ab024548c3c10ffd7212be 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -66,6 +66,7 @@ class SplitPlugin : public PluginTensorRT { int axis_; int outer_rows_; int inner_cols_; + int axis_shape_; bool same_shape_; std::vector output_length_; std::vector segment_offsets_; diff --git a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..864ca5f080f95d56191b0e9895654068edb8d0ee --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.cu @@ -0,0 +1,76 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +SwishPlugin *CreateSwishPluginDeserialize(const void *buffer, size_t length) { + return new SwishPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("swish_plugin", CreateSwishPluginDeserialize); + +int SwishPlugin::initialize() { return 0; } + +nvinfer1::Dims SwishPlugin::getOutputDimensions(int index, + const nvinfer1::Dims *inputDims, + int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const &input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} +__global__ void swish_kernel(int num, const float *input, float *output, + float beta) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { +#if __CUDA_ARCH__ >= 350 + output[index] = + __ldg(input + index) / (1.0f + expf(-beta * __ldg(input + index))); +#else + output[index] = input[index] / (1.0f + expf(-beta * input[index])); +#endif + } +} + +int SwishPlugin::enqueue(int batch_size, const void *const *inputs, + void **outputs, void *workspace, cudaStream_t stream) { + // input dims is CHW. + const auto &input_dims = this->getInputDims(0); + const float *input = reinterpret_cast(inputs[0]); + float *output = reinterpret_cast(outputs)[0]; + int num = batch_size; + for (int i = 0; i < input_dims.nbDims; i++) { + num *= input_dims.d[i]; + } + int threads = 1024; + int blocks = (num + threads - 1) / threads; + swish_kernel<<>>(num, input, output, beta_); + + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..6c3cd038884bf6482edd49fe27901888b2e93bdd --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h @@ -0,0 +1,72 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class SwishPlugin : public PluginTensorRT { + private: + float beta_; + + protected: + size_t getSerializationSize() override { + return getBaseSerializationSize() + SerializedSize(beta_); + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + void serialize(void *buffer) override { + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, beta_); + } + + public: + explicit SwishPlugin(const float beta) : beta_(beta) {} + + // It was used for tensorrt deserialization. + // It should not be called by users. + SwishPlugin(void const *serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &beta_); + } + ~SwishPlugin() {} + int initialize() override; + + SwishPlugin *clone() const override { return new SwishPlugin(beta_); } + + const char *getPluginType() const override { return "swish_plugin"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + int nbInputDims) override; + int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle