From d38fd6a0fcd754907ff17fe896651c5274c7f672 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Tue, 13 Nov 2018 08:23:26 +0000 Subject: [PATCH] add plugin support and offer an simple split sample --- paddle/fluid/inference/analysis/analyzer.cc | 2 +- .../api/api_tensorrt_subgraph_engine.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 7 +- .../inference/tensorrt/convert/split_op.cc | 73 +++++++++++++++ .../tensorrt/convert/test_split_op.cc | 53 +++++++++++ paddle/fluid/inference/tensorrt/engine.cc | 6 ++ paddle/fluid/inference/tensorrt/engine.h | 5 + .../inference/tensorrt/plugin/CMakeLists.txt | 3 +- .../tensorrt/plugin/plugin_factory.cc | 64 ------------- .../tensorrt/plugin/plugin_factory.h | 91 ------------------- .../inference/tensorrt/plugin/plugin_utils.cc | 37 -------- .../inference/tensorrt/plugin/plugin_utils.h | 34 ------- .../plugin/{serialize.hpp => serialize.h} | 0 .../tensorrt/plugin/split_op_plugin.cu | 70 ++++---------- .../tensorrt/plugin/split_op_plugin.h | 61 ++++++++----- .../inference/tensorrt/plugin/trt_plugin.cc | 4 +- .../inference/tensorrt/plugin/trt_plugin.h | 8 +- 17 files changed, 208 insertions(+), 311 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/split_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/test_split_op.cc delete mode 100644 paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc delete mode 100644 paddle/fluid/inference/tensorrt/plugin/plugin_factory.h delete mode 100644 paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc delete mode 100644 paddle/fluid/inference/tensorrt/plugin/plugin_utils.h rename paddle/fluid/inference/tensorrt/plugin/{serialize.hpp => serialize.h} (100%) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index a3440cfc7..cd6636a7e 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -71,7 +71,7 @@ class DfgPassManagerImpl final : public DfgPassManager { std::unordered_set teller_set( {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", - "elementwise_add", "dropout"}); + "elementwise_add", "dropout", "split"}); if (!node->IsFunction()) return false; const auto* func = static_cast(node); diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 94b393349..eceab6e2b 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -186,3 +186,4 @@ USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(pad); +USE_TRT_CONVERTER(split); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index e34d5db6b..ed4c398ce 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,8 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc +pad_op.cc split_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) - nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) +nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin +split_op concat_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc new file mode 100644 index 000000000..60d07859f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * SplitOp. + */ +class SplitOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(40) << "convert a fluid split op to tensorrt split layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + int input_num = op_desc.Input("X").size(); + size_t output_num = op_desc.Output("Out").size(); + + PADDLE_ENFORCE(input_num == 1); + int axis = boost::get(op_desc.GetAttr("axis")); + std::vector output_lengths = + boost::get>(op_desc.GetAttr("sections")); + PADDLE_ENFORCE(axis != 0); + if (axis < 0) { + axis += input_dims.nbDims; + } else { + axis -= 1; + } + + PADDLE_ENFORCE(output_lengths.size() == output_num); + + SplitPlugin* plugin = new SplitPlugin(axis, output_lengths); + nvinfer1::IPluginLayer* layer = + engine_->addPlugin(&input, input_num, plugin); + + std::string layer_name = "split (Output: "; + for (size_t i = 0; i < output_num; i++) { + auto output_name = op_desc.Output("Out")[i]; + layer->getOutput(i)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(i)); + layer_name += output_name; + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } + layer->setName((layer_name + ")").c_str()); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(split, SplitOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_split_op.cc b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc new file mode 100644 index 000000000..f81d01155 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(split_op, test) { + std::unordered_set parameters({""}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2)); + validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("split"); + desc.SetInput("X", {"split_input"}); + desc.SetOutput("Out", {"split_out1", "split_out2"}); + + int num = 0; + int axis = 1; + std::vector output_lengths = {2, 1}; + desc.SetAttr("axis", axis); + desc.SetAttr("num", num); + desc.SetAttr("sections", output_lengths); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(split); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 9e0f95844..426bf169b 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -254,6 +254,12 @@ void TensorRTEngine::freshDeviceId() { cudaSetDevice(device_); } +nvinfer1::IPluginLayer *TensorRTEngine::addPlugin( + nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) { + owned_plugin_.emplace_back(plugin); + return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin); +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 828181200..216606a29 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { @@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase { void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); int GetDevice() { return device_; } + nvinfer1::IPluginLayer* addPlugin(nvinfer1::ITensor* const* inputs, + int nbInputs, PluginTensorRT*); // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. @@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase { std::unordered_map buffer_sizes_; std::unordered_map itensor_map_; + // The specific GPU id that the TensorRTEngine bounded to. int device_; + std::vector> owned_plugin_; // TensorRT related internal members template diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 1b91c864c..71b7a5516 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,2 +1 @@ -nv_library(tensorrt_plugin SRCS plugin_factory.cc plugin_utils.cc -trt_plugin.cc split_op_plugin.cu DEPS enforce) +nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc b/paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc deleted file mode 100644 index 5ebcd4461..000000000 --- a/paddle/fluid/inference/tensorrt/plugin/plugin_factory.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/inference/tensorrt/plugin/plugin_factory.h" - -namespace paddle { -namespace inference { -namespace tensorrt { - -PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, - const void* serial_data, - size_t serial_length) { - size_t parsed_byte = 0; - std::string encoded_op_name = - ExtractOpName(serial_data, serial_length, &parsed_byte); - - if (!IsPlugin(encoded_op_name)) { - return nullptr; - } - - auto plugin_ptr = - plugin_registry_[encoded_op_name].first(serial_data, serial_length); - owned_plugins_.emplace_back(plugin_ptr); - - return plugin_ptr; -} - -PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( - const std::string& op_name) { - if (!IsPlugin(op_name)) return nullptr; - - auto plugin_ptr = plugin_registry_[op_name].second(); - owned_plugins_.emplace_back(plugin_ptr); - - return plugin_ptr; -} - -bool PluginFactoryTensorRT::RegisterPlugin( - const std::string& op_name, PluginDeserializeFunc deserialize_func, - PluginConstructFunc construct_func) { - if (IsPlugin(op_name)) return false; - - auto ret = plugin_registry_.emplace( - op_name, std::make_pair(deserialize_func, construct_func)); - - return ret.second; -} - -void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); } - -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_factory.h b/paddle/fluid/inference/tensorrt/plugin/plugin_factory.h deleted file mode 100644 index 00435766f..000000000 --- a/paddle/fluid/inference/tensorrt/plugin/plugin_factory.h +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "NvInfer.h" -#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h" -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace inference { -namespace tensorrt { - -class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { - public: - static PluginFactoryTensorRT* GetInstance() { - static PluginFactoryTensorRT* factory_instance = - new PluginFactoryTensorRT(); - return factory_instance; - } - - // Deserialization method - PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, - size_t serial_length) override; - - // Plugin construction, PluginFactoryTensorRT owns the plugin. - PluginTensorRT* CreatePlugin(const std::string& op_name); - - bool RegisterPlugin(const std::string& op_name, - PluginDeserializeFunc deserialize_func, - PluginConstructFunc construct_func); - - bool IsPlugin(const std::string& op_name) { - return plugin_registry_.find(op_name) != plugin_registry_.end(); - } - - size_t CountOwnedPlugins() { return owned_plugins_.size(); } - - void DestroyPlugins(); - - protected: - std::unordered_map> - plugin_registry_; - std::vector> owned_plugins_; -}; - -class TrtPluginRegistrar { - public: - TrtPluginRegistrar(const std::string& name, - PluginDeserializeFunc deserialize_func, - PluginConstructFunc construct_func) { - auto factory = PluginFactoryTensorRT::GetInstance(); - // platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func, - // construct_func), "Falied to register plugin [%s]", name); - // platform::PADDLE_ENFORCE(factory->RegisterPlugin(name, deserialize_func, - // construct_func)); - factory->RegisterPlugin(name, deserialize_func, construct_func); - } -}; - -#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ - REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \ - construct_func) -#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \ - construct_func) \ - REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) -#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ - static ::paddle::inference::tensorrt::TrtPluginRegistrar \ - trt_plugin_registrar##ctr __attribute__((unused)) = \ - ::paddle::inference::tensorrt::TrtPluginRegistrar( \ - name, deserialize_func, construct_func) - -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc b/paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc deleted file mode 100644 index 2cc4162aa..000000000 --- a/paddle/fluid/inference/tensorrt/plugin/plugin_utils.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h" -#include - -namespace paddle { -namespace inference { -namespace tensorrt { - -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental) { - size_t op_name_char_count = *static_cast(serial_data); - *incremental = sizeof(size_t) + op_name_char_count; - - assert(serial_length >= *incremental); - - const char* buffer = static_cast(serial_data) + sizeof(size_t); - std::string op_name(buffer, op_name_char_count); - - return op_name; -} - -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/plugin_utils.h b/paddle/fluid/inference/tensorrt/plugin/plugin_utils.h deleted file mode 100644 index fb6608c12..000000000 --- a/paddle/fluid/inference/tensorrt/plugin/plugin_utils.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "NvInfer.h" -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { - -typedef std::function - PluginDeserializeFunc; -typedef std::function PluginConstructFunc; - -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental); - -} // namespace tensorrt -} // namespace inference -} // namespze paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/serialize.hpp b/paddle/fluid/inference/tensorrt/plugin/serialize.h similarity index 100% rename from paddle/fluid/inference/tensorrt/plugin/serialize.hpp rename to paddle/fluid/inference/tensorrt/plugin/serialize.h diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index 044c229b5..ed43c4d43 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" @@ -19,8 +20,6 @@ namespace paddle { namespace inference { namespace tensorrt { -SplitPlugin* CreateSplitPlugin() { return new SplitPlugin(); }; - nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, const nvinfer1::Dims* inputDims, int nbInputs) { @@ -28,15 +27,16 @@ nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, assert(index < this->getNbOutputs()); nvinfer1::Dims const& input_dims = inputDims[0]; nvinfer1::Dims output_dims = input_dims; - output_dims.d[axis_] = output_lenght_.at(index); + output_dims.d[axis_] = output_length_.at(index); return output_dims; } int SplitPlugin::initialize() { std::vector segment_offsets(1, 0); for (int i = 0; i < this->getNbOutputs(); ++i) { - segment_offsets.push_back(segment_offsets.back() + output_lenght_[i]); + segment_offsets.push_back(segment_offsets.back() + output_length_[i]); } + segment_offsets_ = segment_offsets; d_segment_offsets_ = segment_offsets; nvinfer1::Dims dims = this->getInputDims(0); nx_ = 1; @@ -51,60 +51,30 @@ int SplitPlugin::initialize() { return 0; } -template -__device__ int upper_bound(T const* vals, int n, T const& key) { - int i = 0; - while (n > 0) { - int m = n / 2; - int j = i + m; - if (!(key < vals[j])) { - i = j + 1; - n -= m + 1; - } else { - n = m; - } - } - return i; -} - -template -__global__ void split_kernel(int nsegment, - int const* __restrict__ segment_offsets, - T const* __restrict__ idata, T* const* odatas, - int nx, int srcny_, int nz) { - int x0 = threadIdx.x + blockIdx.x * blockDim.x; - int src_y0 = threadIdx.y + blockIdx.y * blockDim.y; - int z0 = threadIdx.z + blockIdx.z * blockDim.z; - for (int z = z0; z < nz; z += blockDim.z * gridDim.z) { - for (int src_y = src_y0; src_y < srcny_; src_y += blockDim.y * gridDim.y) { - for (int x = x0; x < nx; x += blockDim.x * gridDim.x) { - int segment = upper_bound(segment_offsets, nsegment, src_y) - 1; - int dst_y = src_y - segment_offsets[segment]; - int dstny_ = segment_offsets[segment + 1] - segment_offsets[segment]; - odatas[segment][x + nx * (dst_y + dstny_ * z)] = - idata[x + nx * (src_y + srcny_ * z)]; - } - } - } -} - int SplitPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) { auto const& input_dims = this->getInputDims(0); + int input_size = 0; int const* d_segment_offsets_ptr = thrust::raw_pointer_cast(&d_segment_offsets_[0]); float const* idata = reinterpret_cast(inputs[0]); float** odatas = reinterpret_cast(outputs); - int nz = nz_ * batchSize; - dim3 block(32, 16); - dim3 grid(std::min((nx_ - 1) / block.x + 1, 65535u), - std::min((ny_ - 1) / block.y + 1, 65535u), - std::min((nz_ - 1) / block.z + 1, 65535u)); - - split_kernel<<>>(d_segment_offsets_.size(), - d_segment_offsets_ptr, idata, odatas, - nx_, ny_, nz); + // kernel impl here. + int inputBatchOffset = nx_ * ny_ * nz_; + for (size_t i = 0; i < this->getNbOutputs(); i++) { + for (size_t j = 0; j < batchSize; j++) { + cudaMemcpyAsync( + odatas[i] + + j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * + sizeof(float), + inputs[0] + + (inputBatchOffset * j + segment_offsets_[i] * nx_) * + sizeof(float), + (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float), + cudaMemcpyDeviceToDevice, stream); + } + } return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index 406c822bb..59be60911 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -1,8 +1,21 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" namespace paddle { namespace inference { @@ -10,53 +23,55 @@ namespace tensorrt { class SplitPlugin : public PluginTensorRT { int axis_; - std::vector output_lenght_; + std::vector output_length_; int nx_, ny_, nz_; thrust::device_vector d_segment_offsets_; + std::vector segment_offsets_; protected: virtual size_t getSerializationSize() override { - return serialized_size(axis_) + serialized_size(output_lenght_) - + getBaseSerializationSize(); + return serialized_size(axis_) + serialized_size(output_length_) + + getBaseSerializationSize(); } virtual void serialize(void *buffer) override { serializeBase(buffer); serialize_value(&buffer, axis_); - serialize_value(&buffer, output_lenght_); + serialize_value(&buffer, output_length_); } public: - Split() {} - SplitPlugin(void const* serialData, size_t serialLength) { + SplitPlugin(int axis, std::vector const &output_lengths) + : axis_(axis), output_length_(output_lengths) { + assert(axis <= nvinfer1::Dims::MAX_DIMS); + } + + SplitPlugin(void const *serialData, size_t serialLength) { deserializeBase(serialData, serialLength); deserialize_value(&serialData, &serialLength, &axis_); - deserialize_value(&serialData, &serialLength, &output_lenght_); + deserialize_value(&serialData, &serialLength, &output_length_); } - SplitPlugin* clone() const override { - return new SplitPlugin(axis_, output_lenght_); + SplitPlugin *clone() const override { + return new SplitPlugin(axis_, output_length_); } - virtual const char* getPluginType() const override { return "split"; } - virtual int getNbOutputs() const override { return output_lenght_.size(); } + virtual const char *getPluginType() const override { return "split"; } + virtual int getNbOutputs() const override { return output_length_.size(); } virtual nvinfer1::Dims getOutputDimensions(int index, - const nvinfer1::Dims *inputs, int nbInputDims) override; + const nvinfer1::Dims *inputs, + int nbInputDims) override; virtual int initialize() override; - virtual int enqueue(int batchSize, - const void *const *inputs, void **outputs, + virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override; - void setAxis(int axis) { - axis_ = axis; - } + void setAxis(int axis) { axis_ = axis; } - void setOutputLengths(const std::vector & output_lengths) { + void setOutputLengths(const std::vector &output_lengths) { output_length_ = output_lengths; } - }; -} // tensorrt -} // inference -} // paddle +} // tensorrt +} // inference +} // paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc index 4eff6665d..975a5ed16 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" -#include "paddle/fluid/inference/tensorrt/plugin/plugin_utils.h" namespace paddle { namespace inference { @@ -41,8 +40,7 @@ size_t PluginTensorRT::getBaseSerializationSize() { bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const { - return ((type == nvinfer1::DataType::kFLOAT || - type == nvinfer1::DataType::kHALF) && + return ((type == nvinfer1::DataType::kFLOAT) && (format == nvinfer1::PluginFormat::kNCHW)); } diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index 8168646bd..44869b390 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -14,14 +14,14 @@ #pragma once -#include #include #include #include #include #include +#include "NvInfer.h" -#include "paddle/fluid/inference/tensorrt/plugin/serialize.hpp" +#include "paddle/fluid/inference/tensorrt/plugin/serialize.h" namespace paddle { namespace inference { @@ -53,8 +53,8 @@ class PluginTensorRT : public nvinfer1::IPluginExt { nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override; - virtual void serialize(void* buffer) override; - virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) = 0; + virtual size_t getSerializationSize() = 0; protected: void deserializeBase(void const*& serialData, size_t& serialLength); -- GitLab