From f3d164faad585bc7eeff582cba6b035d054e16c7 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Fri, 22 Feb 2019 06:54:58 +0000 Subject: [PATCH] 5. add static trt load model 1). add static trt load model 2). fix bug: when device_id is not 0, the trt will have a bug test=develop --- .../inference/analysis/ir_pass_manager.cc | 1 + .../ir_passes/tensorrt_subgraph_pass.cc | 13 ++-- .../inference/tensorrt/convert/conv2d_op.cc | 2 +- .../tensorrt/convert/elementwise_op.cc | 3 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 4 +- .../inference/tensorrt/convert/prelu_op.cc | 19 ++--- .../inference/tensorrt/convert/ut_helper.h | 16 ++-- paddle/fluid/inference/tensorrt/engine.cc | 9 +++ paddle/fluid/inference/tensorrt/engine.h | 33 ++++---- paddle/fluid/inference/tensorrt/helper.h | 29 +++++++ .../inference/tensorrt/plugin/CMakeLists.txt | 3 +- .../tensorrt/plugin/avg_pool_op_plugin.cu | 7 ++ .../tensorrt/plugin/avg_pool_op_plugin.h | 14 ++-- .../tensorrt/plugin/elementwise_op_plugin.cu | 11 ++- .../tensorrt/plugin/elementwise_op_plugin.h | 20 +++-- .../tensorrt/plugin/prelu_op_plugin.cu | 15 +++- .../tensorrt/plugin/prelu_op_plugin.h | 43 +++++++---- .../tensorrt/plugin/split_op_plugin.cu | 6 ++ .../tensorrt/plugin/split_op_plugin.h | 8 +- .../inference/tensorrt/plugin/trt_plugin.h | 9 ++- .../tensorrt/plugin/trt_plugin_factory.cc | 48 ++++++++++++ .../tensorrt/plugin/trt_plugin_factory.h | 76 +++++++++++++++++++ .../{serialize.h => trt_plugin_utils.h} | 2 +- .../operators/tensorrt/tensorrt_engine_op.h | 10 ++- 24 files changed, 318 insertions(+), 83 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h rename paddle/fluid/inference/tensorrt/plugin/{serialize.h => trt_plugin_utils.h} (99%) diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 768dd00bc..3e5525b1e 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -82,6 +82,7 @@ void IRPassManager::CreatePasses(Argument *argument, "model_opt_cache_dir", new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); pass->Set("predictor_id", new int(argument->predictor_id())); + pass->Set("gpu_device_id", new int(argument->gpu_device_id())); } pre_pass = pass_name; diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 7f564f321..6f23330d6 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -242,7 +242,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( tensorrt::TensorRTEngine *trt_engine = inference::Singleton::Global().Create( Get("max_batch_size"), Get("workspace_size"), enable_int8, - calibrator.get(), engine_key); + calibrator.get(), engine_key, Get("gpu_device_id")); if (trt_engine_serialized_data.size() == 0) { LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP " "kernel etc). This process may cost a lot of time."; @@ -258,13 +258,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp( trt_engine_serialized_data = std::string((const char *)serialized_engine_data->data(), serialized_engine_data->size()); - // SaveTrtEngineSerializedDataToFile(GetTrtEngineSerializedPath(Get("model_opt_cache_dir"), - // engine_key), - // trt_engine_serialized_data); + SaveTrtEngineSerializedDataToFile( + GetTrtEngineSerializedPath(Get("model_opt_cache_dir"), + engine_key), + trt_engine_serialized_data); } else { + LOG(INFO) << "Load TRT Engine from optimized serialized data : " + << GetTrtEngineSerializedPath( + Get("model_opt_cache_dir"), engine_key); trt_engine->Deserialize(trt_engine_serialized_data); } - SetAttr(op_desc->Proto(), "engine_serialized_data", trt_engine_serialized_data); } diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index ae1849f43..39a99a21e 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -44,7 +44,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, weight_tensor->Resize(Y_t->dims()); TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); - auto* weight_data = weight_tensor->mutable_data(platform::CPUPlace()); + auto* weight_data = weight_tensor->mutable_data(cpu_place); PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); const int n_output = weight_tensor->dims()[0]; diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 79362f967..0c5a1a6ef 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -153,7 +153,6 @@ class ElementwiseTensorOpConverter : public OpConverter { if (CheckDims(dims_x, dims_y)) { // The two input tensor should have the same dims VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer"; - nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, ElementWise, *const_cast(X), *const_cast(Y), op_pair->second); @@ -166,7 +165,7 @@ class ElementwiseTensorOpConverter : public OpConverter { "ElementWisePluginLayer"; plugin::ElementWisePlugin* plugin = - new plugin::ElementWisePlugin(op_pair->second, dims_x, dims_y, axis); + new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis); plugin->AddInput(X); plugin->AddInput(Y); nvinfer1::IPluginLayer* layer = engine_->AddPlugin( diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index eef4fab4e..42dcd68e4 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -85,10 +85,10 @@ class FcOpConverter : public OpConverter { Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float)); TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), - Y_t->memory_size() / sizeof(float)}; + static_cast(Y_t->numel())}; TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT, static_cast(tmp->data()), - Y_t->memory_size() / sizeof(float)); + static_cast(Y_t->numel())); weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]}); tmp_weight.dims = weight.dims; diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index dbdff85dd..2ae804106 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -43,23 +43,20 @@ class PReluOpConverter : public OpConverter { PADDLE_ENFORCE_NOT_NULL(alpha_var); auto* alpha_tensor = alpha_var->GetMutable(); - platform::CUDAPlace place; - std::unique_ptr alpha_tensor_device( + platform::CPUPlace cpu_place; + std::unique_ptr alpha_tensor_temp( new framework::LoDTensor()); - alpha_tensor_device->Resize(alpha_tensor->dims()); - TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get()); - float* alpha_data = alpha_tensor_device->mutable_data(place); + alpha_tensor_temp->Resize(alpha_tensor->dims()); + TensorCopySync(*alpha_tensor, cpu_place, alpha_tensor_temp.get()); + float* alpha_data = alpha_tensor_temp->mutable_data(cpu_place); - // Transform alpha to TensorRTEngine::Weight - TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT, - static_cast(alpha_data), - alpha_tensor_device->numel()); - plugin::PReluPlugin* plugin = new plugin::PReluPlugin(alpha_rt, mode); + plugin::PReluPlugin* plugin = + new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode); nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, input_num, plugin); // keep alpha tensor to avoid release it's memory engine_->weight_map[op_desc.Input("Alpha")[0]] = - std::move(alpha_tensor_device); + std::move(alpha_tensor_temp); std::string layer_name = "prelu (Output: "; auto output_name = op_desc.Output("Out")[0]; diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index c02a6d8da..d7cca0e45 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -79,7 +79,8 @@ class TRTConvertValidation { if_add_batch_(if_add_batch), max_batch_size_(max_batch_size) { PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); - engine_.reset(new TensorRTEngine(max_batch_size, workspace_size)); + engine_.reset( + new TensorRTEngine(max_batch_size, workspace_size, false, nullptr, 0)); engine_->InitNetwork(); } @@ -114,13 +115,12 @@ class TRTConvertValidation { } void DeclVar(const std::string& name, const std::vector dim_vec) { - platform::CUDAPlace place; - platform::CUDADeviceContext ctx(place); + platform::CUDADeviceContext ctx(place_); auto* x = scope_.Var(name); auto* x_tensor = x->GetMutable(); x_tensor->Resize(framework::make_ddim(dim_vec)); - RandomizeTensor(x_tensor, place, ctx); + RandomizeTensor(x_tensor, place_, ctx); } // Declare a variable in a fluid Scope. void DeclVar(const std::string& name, const nvinfer1::Dims& dims, @@ -155,9 +155,8 @@ class TRTConvertValidation { std::unordered_set neglected_output = {}) { // Execute Fluid Op PADDLE_ENFORCE_LE(batch_size, max_batch_size_); - platform::CUDAPlace place; - platform::CUDADeviceContext ctx(place); - op_->Run(scope_, place); + platform::CUDADeviceContext ctx(place_); + op_->Run(scope_, place_); std::vector input_output_names; @@ -188,7 +187,7 @@ class TRTConvertValidation { auto* tensor = var->GetMutable(); const int bind_index = engine_->engine()->getBindingIndex(name.c_str()); buffers[bind_index] = - static_cast(tensor->mutable_data(place)); + static_cast(tensor->mutable_data(place_)); } // Execute TRT. @@ -220,6 +219,7 @@ class TRTConvertValidation { framework::Scope& scope() { return scope_; } private: + platform::CUDAPlace place_; std::unique_ptr engine_; cudaStream_t stream_; std::unique_ptr op_; diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 805f047c9..fddf5f11c 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -34,6 +34,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) { void TensorRTEngine::Execute(int batch_size, std::vector *buffers, cudaStream_t stream) { + freshDeviceId(); batch_size_ = batch_size; infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr); cudaStreamSynchronize(stream); @@ -41,6 +42,7 @@ void TensorRTEngine::Execute(int batch_size, std::vector *buffers, } void TensorRTEngine::FreezeNetwork() { + freshDeviceId(); VLOG(3) << "TRT to freeze network"; PADDLE_ENFORCE(infer_builder_ != nullptr, "Call InitNetwork first to initialize network."); @@ -140,6 +142,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin); } +void TensorRTEngine::freshDeviceId() { + int count; + cudaGetDeviceCount(&count); + PADDLE_ENFORCE_LT(device_id_, count); + cudaSetDevice(device_id_); +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index cc378f4ab..6abc9a1f0 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" #include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #include "paddle/fluid/inference/utils/singleton.h" @@ -59,12 +60,13 @@ class TensorRTEngine { }; TensorRTEngine(int max_batch, int max_workspace, bool enable_int8 = false, - TRTInt8Calibrator* calibrator = nullptr, + TRTInt8Calibrator* calibrator = nullptr, int device_id = 0, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), enable_int8_(enable_int8), calibrator_(calibrator), + device_id_(device_id), logger_(logger) {} ~TensorRTEngine() {} @@ -78,6 +80,7 @@ class TensorRTEngine { // Initialize the inference network, so that TensorRT layers can add to this // network. void InitNetwork() { + freshDeviceId(); infer_builder_.reset(createInferBuilder(&logger_)); infer_network_.reset(infer_builder_->createNetwork()); } @@ -113,20 +116,11 @@ class TensorRTEngine { } void Deserialize(const std::string& engine_serialized_data) { - infer_ptr runtime(createInferRuntime(&logger_)); - infer_engine_.reset( - runtime->deserializeCudaEngine(engine_serialized_data.c_str(), - engine_serialized_data.size(), nullptr)); - PADDLE_ENFORCE(infer_engine_ != nullptr, - "build cuda engine failed when deserialize engine info.!"); - infer_context_.reset(infer_engine_->createExecutionContext()); - } - - void Deserialize(const nvinfer1::IHostMemory* engine_serialized_data) { + freshDeviceId(); infer_ptr runtime(createInferRuntime(&logger_)); infer_engine_.reset(runtime->deserializeCudaEngine( - engine_serialized_data->data(), engine_serialized_data->size(), - nullptr)); + engine_serialized_data.c_str(), engine_serialized_data.size(), + &inference::Singleton::Global())); PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed when deserialize engine info.!"); infer_context_.reset(infer_engine_->createExecutionContext()); @@ -134,6 +128,7 @@ class TensorRTEngine { void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); + int GetDeviceId() { return device_id_; } nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, int num_inputs, plugin::PluginTensorRT*); @@ -146,6 +141,11 @@ class TensorRTEngine { weight_map; private: + // Each ICudaEngine object is bound to a specific GPU when it is instantiated, + // ensure that the thread is associated with the correct device by calling + // freshDeviceId(). + void freshDeviceId(); + // the max batch size int max_batch_; // the runtime batch size @@ -158,6 +158,7 @@ class TensorRTEngine { // batch size of the current data, will be updated each Executation. int batch_size_{-1}; + int device_id_; nvinfer1::ILogger& logger_; // max data size for the buffers. @@ -216,10 +217,10 @@ class TRTEngineManager { // Create or get an engine called `name` TensorRTEngine* Create(int max_batch, int max_workspace, bool enable_int8, TRTInt8Calibrator* calibrator, - const std::string& engine_name) { + const std::string& engine_name, int device_id = 0) { std::unique_lock lk(mut_); - auto* p = - new TensorRTEngine(max_batch, max_workspace, enable_int8, calibrator); + auto* p = new TensorRTEngine(max_batch, max_workspace, enable_int8, + calibrator, device_id); engines_[engine_name].reset(p); return p; } diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index fc7ca7714..010942a06 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -17,6 +17,9 @@ #include #include #include +#include +#include +#include #include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/enforce.h" @@ -74,6 +77,32 @@ class NaiveLogger : public nvinfer1::ILogger { ~NaiveLogger() override {} }; +class NaiveProfiler : public nvinfer1::IProfiler { + public: + typedef std::pair Record; + std::vector mProfile; + + virtual void reportLayerTime(const char* layerName, float ms) { + auto record = + std::find_if(mProfile.begin(), mProfile.end(), + [&](const Record& r) { return r.first == layerName; }); + if (record == mProfile.end()) + mProfile.push_back(std::make_pair(layerName, ms)); + else + record->second += ms; + } + + void printLayerTimes() { + float totalTime = 0; + for (size_t i = 0; i < mProfile.size(); i++) { + printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), + mProfile[i].second); + totalTime += mProfile[i].second; + } + printf("Time over all layers: %4.3f\n", totalTime); + } +}; + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 95443e813..709aa103d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,4 +1,5 @@ nv_library(tensorrt_plugin - SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu + SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu + prelu_op_plugin.cu trt_plugin_factory.cc avg_pool_op_plugin.cu DEPS enforce tensorrt_engine prelu) diff --git a/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu index 5d747af8c..f27a83816 100644 --- a/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" #include "paddle/fluid/operators/math/pooling.h" namespace paddle { @@ -20,6 +21,12 @@ namespace inference { namespace tensorrt { namespace plugin { +AvgPoolPlugin* CreateAvgPoolPluginDeserialize(const void* buffer, + size_t length) { + return new AvgPoolPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("avg_pool_plugin", CreateAvgPoolPluginDeserialize); + nvinfer1::Dims AvgPoolPlugin::getOutputDimensions( int index, const nvinfer1::Dims* inputDims, int nbInputs) { assert(nbInputs == 1); diff --git a/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h index b5e4ece0f..a7c0aa579 100644 --- a/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h @@ -33,24 +33,27 @@ class AvgPoolPlugin : public PluginTensorRT { protected: size_t getSerializationSize() override { - return SerializedSize(ceil_mode_) + SerializedSize(ksize_) + - SerializedSize(strides_) + SerializedSize(paddings_) + - SerializedSize(input_shape_) + getBaseSerializationSize(); + return SerializedSize(getPluginType()) + SerializedSize(ceil_mode_) + + SerializedSize(ksize_) + SerializedSize(strides_) + + SerializedSize(paddings_) + SerializedSize(input_shape_) + + SerializedSize(output_shape_) + getBaseSerializationSize(); } // TRT will call this func when we need to serialize the configuration of // tensorrt. - // It should not be called by users. void serialize(void *buffer) override { + SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, ceil_mode_); SerializeValue(&buffer, ksize_); SerializeValue(&buffer, strides_); SerializeValue(&buffer, paddings_); SerializeValue(&buffer, input_shape_); + SerializeValue(&buffer, output_shape_); } public: + AvgPoolPlugin() {} AvgPoolPlugin(bool ceil_mode, std::vector ksize, std::vector strides, std::vector paddings, std::vector input_shape) @@ -89,6 +92,7 @@ class AvgPoolPlugin : public PluginTensorRT { DeserializeValue(&serialData, &serialLength, &strides_); DeserializeValue(&serialData, &serialLength, &paddings_); DeserializeValue(&serialData, &serialLength, &input_shape_); + DeserializeValue(&serialData, &serialLength, &output_shape_); } AvgPoolPlugin *clone() const override { @@ -96,7 +100,7 @@ class AvgPoolPlugin : public PluginTensorRT { input_shape_); } - const char *getPluginType() const override { return "avg_pool"; } + const char *getPluginType() const override { return "avg_pool_plugin"; } int getNbOutputs() const override { return 1; } nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims) override; diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu index 9cd9026b7..9aed3ddab 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu @@ -14,12 +14,19 @@ limitations under the License. */ #include #include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { +ElementWisePlugin* CreateElementWisePluginDeserialize(const void* buffer, + size_t length) { + return new ElementWisePlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("elementwise_plugin", CreateElementWisePluginDeserialize); + namespace details { template @@ -119,10 +126,10 @@ int ElementWisePlugin::enqueue(int batch_size, const void* const* inputs, const float* y = reinterpret_cast(inputs[1]); float* out = reinterpret_cast(outputs[0]); - if (type_ == nvinfer1::ElementWiseOperation::kSUM) { + if (type_ == "add") { details::ElementWise(details::Add(), x, y, out, batch_size, prev_size_, midd_size_, post_size_, stream); - } else if (type_ == nvinfer1::ElementWiseOperation::kPROD) { + } else if (type_ == "mul") { details::ElementWise(details::Mul(), x, y, out, batch_size, prev_size_, midd_size_, post_size_, stream); } else { diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h index 9c461f7a5..3b040f14c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" @@ -24,9 +25,8 @@ namespace plugin { class ElementWisePlugin : public PluginTensorRT { public: - ElementWisePlugin(nvinfer1::ElementWiseOperation type, - nvinfer1::Dims const &dims_x, nvinfer1::Dims const &dims_y, - int axis) + ElementWisePlugin(std::string type, nvinfer1::Dims const &dims_x, + nvinfer1::Dims const &dims_y, int axis) : type_(type), dims_x_(dims_x), dims_y_(dims_y), @@ -37,6 +37,9 @@ class ElementWisePlugin : public PluginTensorRT { ElementWisePlugin(void const *serial_data, size_t serial_length) { deserializeBase(serial_data, serial_length); + const char *elementwise_type; + DeserializeValue(&serial_data, &serial_length, &elementwise_type); + type_ = std::string(elementwise_type); DeserializeValue(&serial_data, &serial_length, &axis_); DeserializeValue(&serial_data, &serial_length, &dims_x_); DeserializeValue(&serial_data, &serial_length, &dims_y_); @@ -47,7 +50,7 @@ class ElementWisePlugin : public PluginTensorRT { return nullptr; } - const char *getPluginType() const override { return "elementwise"; } + const char *getPluginType() const override { return "elementwise_plugin"; } nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *input_dims, @@ -61,18 +64,21 @@ class ElementWisePlugin : public PluginTensorRT { protected: size_t getSerializationSize() override { - return SerializedSize(axis_) + SerializedSize(dims_x_) + - SerializedSize(dims_y_) + getBaseSerializationSize(); + return SerializedSize(getPluginType()) + SerializedSize(axis_) + + SerializedSize(dims_x_) + SerializedSize(dims_y_) + + getBaseSerializationSize(); } void serialize(void *buffer) override { + SerializeValue(&buffer, getPluginType()); serializeBase(buffer); + SerializeValue(&buffer, type_.c_str()); SerializeValue(&buffer, axis_); SerializeValue(&buffer, dims_x_); SerializeValue(&buffer, dims_y_); } - nvinfer1::ElementWiseOperation type_; + std::string type_; nvinfer1::Dims dims_x_; nvinfer1::Dims dims_y_; int axis_; diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index 3075e87ea..b8a044fe9 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -17,6 +17,7 @@ #include #include "glog/logging.h" #include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" #include "paddle/fluid/operators/math/prelu.h" namespace paddle { @@ -24,6 +25,17 @@ namespace inference { namespace tensorrt { namespace plugin { +PReluPlugin *CreatePreluPluginDeserialize(const void *buffer, size_t length) { + return new PReluPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("prelu_plugin", CreatePreluPluginDeserialize); + +int PReluPlugin::initialize() { + cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size()); + cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float), + cudaMemcpyHostToDevice); +} + nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, const nvinfer1::Dims *inputDims, int nbInputs) { @@ -39,7 +51,8 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, // input dims is CHW. const auto &input_dims = this->getInputDims(0); const float *input = reinterpret_cast(inputs[0]); - const float *alpha = reinterpret_cast(alpha_.get().values); + // const float *alpha = reinterpret_cast(alpha_.get().values); + const float *alpha = p_gpu_weight_; float *output = reinterpret_cast(outputs)[0]; std::vector input_shape; diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h index 0db56a310..a96649503 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -14,7 +14,12 @@ #pragma once +#include #include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" + #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" @@ -24,39 +29,51 @@ namespace tensorrt { namespace plugin { class PReluPlugin : public PluginTensorRT { - TensorRTEngine::Weight alpha_; + std::vector weight_; + float *p_gpu_weight_; std::string mode_; protected: size_t getSerializationSize() override { - // return getBaseSerializationSize(alpha_) + SerializedSize(mode_); - return 0; + return getBaseSerializationSize() + SerializedSize(mode_.c_str()) + + SerializedSize(weight_) + SerializedSize(getPluginType()); } // TRT will call this func when we need to serialize the configuration of // tensorrt. // It should not be called by users. void serialize(void *buffer) override { - // serializeBase(buffer); - // SerializeValue(&buffer, alpha_); - // SerializeValue(&buffer, mode_); + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, weight_); + SerializeValue(&buffer, mode_.c_str()); } public: - PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode) - : alpha_(alpha), mode_(mode) {} + PReluPlugin(const float *weight, const int weight_num, + std::string const &mode) + : mode_(mode) { + weight_.resize(weight_num); + std::copy(weight, weight + weight_num, weight_.data()); + } // It was used for tensorrt deserialization. // It should not be called by users. PReluPlugin(void const *serialData, size_t serialLength) { - // deserializeBase(serialData, serialLength); - // DeserializeValue(&serialData, &serialLength, &alpha_); - // DeserializeValue(&serialData, &serialLength, &mode_); + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &weight_); + const char *prelu_mode; + DeserializeValue(&serialData, &serialLength, &prelu_mode); + mode_ = std::string(prelu_mode); } + ~PReluPlugin() { cudaFree(p_gpu_weight_); } + int initialize() override; - PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); } + PReluPlugin *clone() const override { + return new PReluPlugin(weight_.data(), weight_.size(), mode_); + } - const char *getPluginType() const override { return "prelu"; } + const char *getPluginType() const override { return "prelu_plugin"; } int getNbOutputs() const override { return 1; } nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims) override; diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu index de61ace59..b5503c3b9 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -15,12 +15,18 @@ #include #include #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { +SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) { + return new SplitPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize); + // copied from operators::math::SplitFunctor template __global__ void SplitKernel(const T* input_data, const int in_row, diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h index 6f028d3d7..16553d44a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -25,6 +25,7 @@ namespace plugin { class SplitPlugin : public PluginTensorRT { public: + SplitPlugin() {} SplitPlugin(int axis, std::vector const &output_lengths) : axis_(axis), same_shape_(true), output_length_(output_lengths) {} @@ -38,7 +39,7 @@ class SplitPlugin : public PluginTensorRT { return new SplitPlugin(axis_, output_length_); } - const char *getPluginType() const override { return "split"; } + const char *getPluginType() const override { return "split_plugin"; } int getNbOutputs() const override { return output_length_.size(); } nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *input_dims, @@ -50,11 +51,12 @@ class SplitPlugin : public PluginTensorRT { protected: size_t getSerializationSize() override { - return SerializedSize(axis_) + SerializedSize(output_length_) + - getBaseSerializationSize(); + return SerializedSize(getPluginType()) + SerializedSize(axis_) + + SerializedSize(output_length_) + getBaseSerializationSize(); } void serialize(void *buffer) override { + SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, axis_); SerializeValue(&buffer, output_length_); diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index 86084829e..735504136 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -19,7 +19,7 @@ #include #include -#include "paddle/fluid/inference/tensorrt/plugin/serialize.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler.h" @@ -30,6 +30,13 @@ namespace inference { namespace tensorrt { namespace plugin { +class PluginTensorRT; + +typedef std::function + PluginDeserializeFunc; + +typedef std::function PluginConstructFunc; + class PluginTensorRT : public nvinfer1::IPluginExt { public: PluginTensorRT() {} diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc new file mode 100644 index 000000000..3c20b6d1e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, + const void* serial_data, + size_t serial_length) { + const char* plugin_type; + DeserializeValue(&serial_data, &serial_length, &plugin_type); + + PADDLE_ENFORCE(Has(plugin_type), + "trt plugin type %s does not exists, check it.", plugin_type); + auto plugin = plugin_registry_[plugin_type](serial_data, serial_length); + owned_plugins_.emplace_back(plugin); + + return plugin; +} + +bool PluginFactoryTensorRT::RegisterPlugin( + const std::string& op_name, PluginDeserializeFunc deserialize_func) { + if (Has(op_name)) return false; + auto ret = plugin_registry_.emplace(op_name, deserialize_func); + return ret.second; +} + +void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); } + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h new file mode 100644 index 000000000..03992f88b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h @@ -0,0 +1,76 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" +#include "paddle/fluid/inference/utils/singleton.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { + public: + // Deserialization method + PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, + size_t serial_length) override; + + bool RegisterPlugin(const std::string& op_name, + PluginDeserializeFunc deserialize_func); + + bool Has(const std::string& op_name) { + return plugin_registry_.find(op_name) != plugin_registry_.end(); + } + + void DestroyPlugins(); + + protected: + std::unordered_map plugin_registry_; + + std::list> owned_plugins_; +}; + +class TrtPluginRegistrar { + public: + TrtPluginRegistrar(const std::string& name, + PluginDeserializeFunc deserialize_func) { + inference::Singleton::Global().RegisterPlugin( + name, deserialize_func); + } +}; + +#define REGISTER_TRT_PLUGIN(name, deserialize_func) \ + REGISTER_TRT_PLUGIN_UNIQ(__COUNTER__, name, deserialize_func) + +#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func) \ + static paddle::inference::tensorrt::plugin::TrtPluginRegistrar \ + trt_plugin_registrar##ctr __attribute__((unused)) = \ + paddle::inference::tensorrt::plugin::TrtPluginRegistrar( \ + name, deserialize_func) + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/serialize.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h similarity index 99% rename from paddle/fluid/inference/tensorrt/plugin/serialize.h rename to paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h index ce859f16f..55ca681c7 100644 --- a/paddle/fluid/inference/tensorrt/plugin/serialize.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h @@ -13,8 +13,8 @@ // limitations under the License. #pragma once - #include +#include #include #include #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index ab6f403ce..cb6412115 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -134,9 +134,10 @@ class TensorRTEngineOp : public framework::OperatorBase { calib_res->calib_.reset(new TRTInt8Calibrator( calib_buffers, runtime_batch, engine_key_, dev_place)); calib_res->thr_.reset(new std::thread([&]() { - calib_res->engine_.reset( - new TensorRTEngine(max_batch_size_, workspace_size_, enable_int8_, - calib_res->calib_.get())); + calib_res->engine_.reset(new TensorRTEngine( + max_batch_size_, workspace_size_, enable_int8_, + calib_res->calib_.get(), + boost::get(dev_place).device)); VLOG(3) << "start the calib trt engine thread"; PrepareTRTEngine(scope, calib_res->engine_.get()); })); @@ -234,7 +235,8 @@ class TensorRTEngineOp : public framework::OperatorBase { trt_engine_ = inference::Singleton::Global() .Create(max_batch_size_, workspace_size_, enable_int8_, - calibrator_.get(), engine_key_); + calibrator_.get(), engine_key_, + boost::get(dev_place).device); PrepareTRTEngine(scope, trt_engine_); } return trt_engine_; -- GitLab