提交 1d5ef7c9 编写于 作者: N nhzlx

5. add static trt load model

1). add static trt load model
2). fix bug: when device_id is not 0, the trt will have a bug
test=develop
上级 2070fb24
...@@ -82,6 +82,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -82,6 +82,7 @@ void IRPassManager::CreatePasses(Argument *argument,
"model_opt_cache_dir", "model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
pass->Set("predictor_id", new int(argument->predictor_id())); pass->Set("predictor_id", new int(argument->predictor_id()));
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
} }
pre_pass = pass_name; pre_pass = pass_name;
......
...@@ -242,7 +242,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -242,7 +242,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
tensorrt::TensorRTEngine *trt_engine = tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<tensorrt::TRTEngineManager>::Global().Create( inference::Singleton<tensorrt::TRTEngineManager>::Global().Create(
Get<int>("max_batch_size"), Get<int>("workspace_size"), enable_int8, Get<int>("max_batch_size"), Get<int>("workspace_size"), enable_int8,
calibrator.get(), engine_key); calibrator.get(), engine_key, Get<int>("gpu_device_id"));
if (trt_engine_serialized_data.size() == 0) { if (trt_engine_serialized_data.size() == 0) {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP " LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time."; "kernel etc). This process may cost a lot of time.";
...@@ -258,13 +258,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -258,13 +258,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine_serialized_data = trt_engine_serialized_data =
std::string((const char *)serialized_engine_data->data(), std::string((const char *)serialized_engine_data->data(),
serialized_engine_data->size()); serialized_engine_data->size());
// SaveTrtEngineSerializedDataToFile(GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"), SaveTrtEngineSerializedDataToFile(
// engine_key), GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
// trt_engine_serialized_data); engine_key),
trt_engine_serialized_data);
} else { } else {
LOG(INFO) << "Load TRT Engine from optimized serialized data : "
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
trt_engine->Deserialize(trt_engine_serialized_data); trt_engine->Deserialize(trt_engine_serialized_data);
} }
SetAttr(op_desc->Proto(), "engine_serialized_data", SetAttr(op_desc->Proto(), "engine_serialized_data",
trt_engine_serialized_data); trt_engine_serialized_data);
} }
......
...@@ -44,7 +44,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -44,7 +44,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
weight_tensor->Resize(Y_t->dims()); weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); TensorCopySync((*Y_t), cpu_place, weight_tensor.get());
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace()); auto* weight_data = weight_tensor->mutable_data<float>(cpu_place);
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
const int n_output = weight_tensor->dims()[0]; const int n_output = weight_tensor->dims()[0];
......
...@@ -153,7 +153,6 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -153,7 +153,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
if (CheckDims(dims_x, dims_y)) { if (CheckDims(dims_x, dims_y)) {
// The two input tensor should have the same dims // The two input tensor should have the same dims
VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer"; VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X), engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second); *const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
...@@ -166,7 +165,7 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -166,7 +165,7 @@ class ElementwiseTensorOpConverter : public OpConverter {
"ElementWisePluginLayer"; "ElementWisePluginLayer";
plugin::ElementWisePlugin* plugin = plugin::ElementWisePlugin* plugin =
new plugin::ElementWisePlugin(op_pair->second, dims_x, dims_y, axis); new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis);
plugin->AddInput(X); plugin->AddInput(X);
plugin->AddInput(Y); plugin->AddInput(Y);
nvinfer1::IPluginLayer* layer = engine_->AddPlugin( nvinfer1::IPluginLayer* layer = engine_->AddPlugin(
......
...@@ -85,10 +85,10 @@ class FcOpConverter : public OpConverter { ...@@ -85,10 +85,10 @@ class FcOpConverter : public OpConverter {
Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float)); Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float));
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)}; static_cast<size_t>(Y_t->numel())};
TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT,
static_cast<void*>(tmp->data<float>()), static_cast<void*>(tmp->data<float>()),
Y_t->memory_size() / sizeof(float)); static_cast<size_t>(Y_t->numel()));
weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]}); weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]});
tmp_weight.dims = weight.dims; tmp_weight.dims = weight.dims;
......
...@@ -43,23 +43,20 @@ class PReluOpConverter : public OpConverter { ...@@ -43,23 +43,20 @@ class PReluOpConverter : public OpConverter {
PADDLE_ENFORCE_NOT_NULL(alpha_var); PADDLE_ENFORCE_NOT_NULL(alpha_var);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>(); auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
platform::CUDAPlace place; platform::CPUPlace cpu_place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_device( std::unique_ptr<framework::LoDTensor> alpha_tensor_temp(
new framework::LoDTensor()); new framework::LoDTensor());
alpha_tensor_device->Resize(alpha_tensor->dims()); alpha_tensor_temp->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get()); TensorCopySync(*alpha_tensor, cpu_place, alpha_tensor_temp.get());
float* alpha_data = alpha_tensor_device->mutable_data<float>(place); float* alpha_data = alpha_tensor_temp->mutable_data<float>(cpu_place);
// Transform alpha to TensorRTEngine::Weight plugin::PReluPlugin* plugin =
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT, new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode);
static_cast<void*>(alpha_data),
alpha_tensor_device->numel());
plugin::PReluPlugin* plugin = new plugin::PReluPlugin(alpha_rt, mode);
nvinfer1::IPluginLayer* layer = nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin); engine_->AddPlugin(&input, input_num, plugin);
// keep alpha tensor to avoid release it's memory // keep alpha tensor to avoid release it's memory
engine_->weight_map[op_desc.Input("Alpha")[0]] = engine_->weight_map[op_desc.Input("Alpha")[0]] =
std::move(alpha_tensor_device); std::move(alpha_tensor_temp);
std::string layer_name = "prelu (Output: "; std::string layer_name = "prelu (Output: ";
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
......
...@@ -79,7 +79,8 @@ class TRTConvertValidation { ...@@ -79,7 +79,8 @@ class TRTConvertValidation {
if_add_batch_(if_add_batch), if_add_batch_(if_add_batch),
max_batch_size_(max_batch_size) { max_batch_size_(max_batch_size) {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0); PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new TensorRTEngine(max_batch_size, workspace_size)); engine_.reset(
new TensorRTEngine(max_batch_size, workspace_size, false, nullptr, 0));
engine_->InitNetwork(); engine_->InitNetwork();
} }
...@@ -114,13 +115,12 @@ class TRTConvertValidation { ...@@ -114,13 +115,12 @@ class TRTConvertValidation {
} }
void DeclVar(const std::string& name, const std::vector<int> dim_vec) { void DeclVar(const std::string& name, const std::vector<int> dim_vec) {
platform::CUDAPlace place; platform::CUDADeviceContext ctx(place_);
platform::CUDADeviceContext ctx(place);
auto* x = scope_.Var(name); auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>(); auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec)); x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place, ctx); RandomizeTensor(x_tensor, place_, ctx);
} }
// Declare a variable in a fluid Scope. // Declare a variable in a fluid Scope.
void DeclVar(const std::string& name, const nvinfer1::Dims& dims, void DeclVar(const std::string& name, const nvinfer1::Dims& dims,
...@@ -155,9 +155,8 @@ class TRTConvertValidation { ...@@ -155,9 +155,8 @@ class TRTConvertValidation {
std::unordered_set<std::string> neglected_output = {}) { std::unordered_set<std::string> neglected_output = {}) {
// Execute Fluid Op // Execute Fluid Op
PADDLE_ENFORCE_LE(batch_size, max_batch_size_); PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
platform::CUDAPlace place; platform::CUDADeviceContext ctx(place_);
platform::CUDADeviceContext ctx(place); op_->Run(scope_, place_);
op_->Run(scope_, place);
std::vector<std::string> input_output_names; std::vector<std::string> input_output_names;
...@@ -188,7 +187,7 @@ class TRTConvertValidation { ...@@ -188,7 +187,7 @@ class TRTConvertValidation {
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
const int bind_index = engine_->engine()->getBindingIndex(name.c_str()); const int bind_index = engine_->engine()->getBindingIndex(name.c_str());
buffers[bind_index] = buffers[bind_index] =
static_cast<void*>(tensor->mutable_data<float>(place)); static_cast<void*>(tensor->mutable_data<float>(place_));
} }
// Execute TRT. // Execute TRT.
...@@ -220,6 +219,7 @@ class TRTConvertValidation { ...@@ -220,6 +219,7 @@ class TRTConvertValidation {
framework::Scope& scope() { return scope_; } framework::Scope& scope() { return scope_; }
private: private:
platform::CUDAPlace place_;
std::unique_ptr<TensorRTEngine> engine_; std::unique_ptr<TensorRTEngine> engine_;
cudaStream_t stream_; cudaStream_t stream_;
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
......
...@@ -34,6 +34,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) { ...@@ -34,6 +34,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers, void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
cudaStream_t stream) { cudaStream_t stream) {
freshDeviceId();
batch_size_ = batch_size; batch_size_ = batch_size;
infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr); infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
...@@ -41,6 +42,7 @@ void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers, ...@@ -41,6 +42,7 @@ void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
} }
void TensorRTEngine::FreezeNetwork() { void TensorRTEngine::FreezeNetwork() {
freshDeviceId();
VLOG(3) << "TRT to freeze network"; VLOG(3) << "TRT to freeze network";
PADDLE_ENFORCE(infer_builder_ != nullptr, PADDLE_ENFORCE(infer_builder_ != nullptr,
"Call InitNetwork first to initialize network."); "Call InitNetwork first to initialize network.");
...@@ -140,6 +142,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( ...@@ -140,6 +142,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin); return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin);
} }
void TensorRTEngine::freshDeviceId() {
int count;
cudaGetDeviceCount(&count);
PADDLE_ENFORCE_LT(device_id_, count);
cudaSetDevice(device_id_);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
...@@ -59,12 +60,13 @@ class TensorRTEngine { ...@@ -59,12 +60,13 @@ class TensorRTEngine {
}; };
TensorRTEngine(int max_batch, int max_workspace, bool enable_int8 = false, TensorRTEngine(int max_batch, int max_workspace, bool enable_int8 = false,
TRTInt8Calibrator* calibrator = nullptr, TRTInt8Calibrator* calibrator = nullptr, int device_id = 0,
nvinfer1::ILogger& logger = NaiveLogger::Global()) nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch), : max_batch_(max_batch),
max_workspace_(max_workspace), max_workspace_(max_workspace),
enable_int8_(enable_int8), enable_int8_(enable_int8),
calibrator_(calibrator), calibrator_(calibrator),
device_id_(device_id),
logger_(logger) {} logger_(logger) {}
~TensorRTEngine() {} ~TensorRTEngine() {}
...@@ -78,6 +80,7 @@ class TensorRTEngine { ...@@ -78,6 +80,7 @@ class TensorRTEngine {
// Initialize the inference network, so that TensorRT layers can add to this // Initialize the inference network, so that TensorRT layers can add to this
// network. // network.
void InitNetwork() { void InitNetwork() {
freshDeviceId();
infer_builder_.reset(createInferBuilder(&logger_)); infer_builder_.reset(createInferBuilder(&logger_));
infer_network_.reset(infer_builder_->createNetwork()); infer_network_.reset(infer_builder_->createNetwork());
} }
...@@ -113,20 +116,11 @@ class TensorRTEngine { ...@@ -113,20 +116,11 @@ class TensorRTEngine {
} }
void Deserialize(const std::string& engine_serialized_data) { void Deserialize(const std::string& engine_serialized_data) {
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_)); freshDeviceId();
infer_engine_.reset(
runtime->deserializeCudaEngine(engine_serialized_data.c_str(),
engine_serialized_data.size(), nullptr));
PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext());
}
void Deserialize(const nvinfer1::IHostMemory* engine_serialized_data) {
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_)); infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(runtime->deserializeCudaEngine( infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data->data(), engine_serialized_data->size(), engine_serialized_data.c_str(), engine_serialized_data.size(),
nullptr)); &inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
PADDLE_ENFORCE(infer_engine_ != nullptr, PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!"); "build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext()); infer_context_.reset(infer_engine_->createExecutionContext());
...@@ -134,6 +128,7 @@ class TensorRTEngine { ...@@ -134,6 +128,7 @@ class TensorRTEngine {
void SetRuntimeBatch(size_t batch_size); void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch(); int GetRuntimeBatch();
int GetDeviceId() { return device_id_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*); int num_inputs, plugin::PluginTensorRT*);
...@@ -146,6 +141,11 @@ class TensorRTEngine { ...@@ -146,6 +141,11 @@ class TensorRTEngine {
weight_map; weight_map;
private: private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void freshDeviceId();
// the max batch size // the max batch size
int max_batch_; int max_batch_;
// the runtime batch size // the runtime batch size
...@@ -158,6 +158,7 @@ class TensorRTEngine { ...@@ -158,6 +158,7 @@ class TensorRTEngine {
// batch size of the current data, will be updated each Executation. // batch size of the current data, will be updated each Executation.
int batch_size_{-1}; int batch_size_{-1};
int device_id_;
nvinfer1::ILogger& logger_; nvinfer1::ILogger& logger_;
// max data size for the buffers. // max data size for the buffers.
...@@ -216,10 +217,10 @@ class TRTEngineManager { ...@@ -216,10 +217,10 @@ class TRTEngineManager {
// Create or get an engine called `name` // Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, bool enable_int8, TensorRTEngine* Create(int max_batch, int max_workspace, bool enable_int8,
TRTInt8Calibrator* calibrator, TRTInt8Calibrator* calibrator,
const std::string& engine_name) { const std::string& engine_name, int device_id = 0) {
std::unique_lock<std::mutex> lk(mut_); std::unique_lock<std::mutex> lk(mut_);
auto* p = auto* p = new TensorRTEngine(max_batch, max_workspace, enable_int8,
new TensorRTEngine(max_batch, max_workspace, enable_int8, calibrator); calibrator, device_id);
engines_[engine_name].reset(p); engines_[engine_name].reset(p);
return p; return p;
} }
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <NvInfer.h> #include <NvInfer.h>
#include <cuda.h> #include <cuda.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -74,6 +77,32 @@ class NaiveLogger : public nvinfer1::ILogger { ...@@ -74,6 +77,32 @@ class NaiveLogger : public nvinfer1::ILogger {
~NaiveLogger() override {} ~NaiveLogger() override {}
}; };
class NaiveProfiler : public nvinfer1::IProfiler {
public:
typedef std::pair<std::string, float> Record;
std::vector<Record> mProfile;
virtual void reportLayerTime(const char* layerName, float ms) {
auto record =
std::find_if(mProfile.begin(), mProfile.end(),
[&](const Record& r) { return r.first == layerName; });
if (record == mProfile.end())
mProfile.push_back(std::make_pair(layerName, ms));
else
record->second += ms;
}
void printLayerTimes() {
float totalTime = 0;
for (size_t i = 0; i < mProfile.size(); i++) {
printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(),
mProfile[i].second);
totalTime += mProfile[i].second;
}
printf("Time over all layers: %4.3f\n", totalTime);
}
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
nv_library(tensorrt_plugin nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc
avg_pool_op_plugin.cu avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine prelu) DEPS enforce tensorrt_engine prelu)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
namespace paddle { namespace paddle {
...@@ -20,6 +21,12 @@ namespace inference { ...@@ -20,6 +21,12 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
AvgPoolPlugin* CreateAvgPoolPluginDeserialize(const void* buffer,
size_t length) {
return new AvgPoolPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("avg_pool_plugin", CreateAvgPoolPluginDeserialize);
nvinfer1::Dims AvgPoolPlugin::getOutputDimensions( nvinfer1::Dims AvgPoolPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* inputDims, int nbInputs) { int index, const nvinfer1::Dims* inputDims, int nbInputs) {
assert(nbInputs == 1); assert(nbInputs == 1);
......
...@@ -33,24 +33,27 @@ class AvgPoolPlugin : public PluginTensorRT { ...@@ -33,24 +33,27 @@ class AvgPoolPlugin : public PluginTensorRT {
protected: protected:
size_t getSerializationSize() override { size_t getSerializationSize() override {
return SerializedSize(ceil_mode_) + SerializedSize(ksize_) + return SerializedSize(getPluginType()) + SerializedSize(ceil_mode_) +
SerializedSize(strides_) + SerializedSize(paddings_) + SerializedSize(ksize_) + SerializedSize(strides_) +
SerializedSize(input_shape_) + getBaseSerializationSize(); SerializedSize(paddings_) + SerializedSize(input_shape_) +
SerializedSize(output_shape_) + getBaseSerializationSize();
} }
// TRT will call this func when we need to serialize the configuration of // TRT will call this func when we need to serialize the configuration of
// tensorrt. // tensorrt.
// It should not be called by users.
void serialize(void *buffer) override { void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, ceil_mode_); SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, ksize_); SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_); SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_); SerializeValue(&buffer, paddings_);
SerializeValue(&buffer, input_shape_); SerializeValue(&buffer, input_shape_);
SerializeValue(&buffer, output_shape_);
} }
public: public:
AvgPoolPlugin() {}
AvgPoolPlugin(bool ceil_mode, std::vector<int> ksize, AvgPoolPlugin(bool ceil_mode, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings, std::vector<int> strides, std::vector<int> paddings,
std::vector<int> input_shape) std::vector<int> input_shape)
...@@ -89,6 +92,7 @@ class AvgPoolPlugin : public PluginTensorRT { ...@@ -89,6 +92,7 @@ class AvgPoolPlugin : public PluginTensorRT {
DeserializeValue(&serialData, &serialLength, &strides_); DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_); DeserializeValue(&serialData, &serialLength, &paddings_);
DeserializeValue(&serialData, &serialLength, &input_shape_); DeserializeValue(&serialData, &serialLength, &input_shape_);
DeserializeValue(&serialData, &serialLength, &output_shape_);
} }
AvgPoolPlugin *clone() const override { AvgPoolPlugin *clone() const override {
...@@ -96,7 +100,7 @@ class AvgPoolPlugin : public PluginTensorRT { ...@@ -96,7 +100,7 @@ class AvgPoolPlugin : public PluginTensorRT {
input_shape_); input_shape_);
} }
const char *getPluginType() const override { return "avg_pool"; } const char *getPluginType() const override { return "avg_pool_plugin"; }
int getNbOutputs() const override { return 1; } int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override; int nbInputDims) override;
......
...@@ -14,12 +14,19 @@ limitations under the License. */ ...@@ -14,12 +14,19 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
ElementWisePlugin* CreateElementWisePluginDeserialize(const void* buffer,
size_t length) {
return new ElementWisePlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("elementwise_plugin", CreateElementWisePluginDeserialize);
namespace details { namespace details {
template <typename T> template <typename T>
...@@ -119,10 +126,10 @@ int ElementWisePlugin::enqueue(int batch_size, const void* const* inputs, ...@@ -119,10 +126,10 @@ int ElementWisePlugin::enqueue(int batch_size, const void* const* inputs,
const float* y = reinterpret_cast<const float*>(inputs[1]); const float* y = reinterpret_cast<const float*>(inputs[1]);
float* out = reinterpret_cast<float*>(outputs[0]); float* out = reinterpret_cast<float*>(outputs[0]);
if (type_ == nvinfer1::ElementWiseOperation::kSUM) { if (type_ == "add") {
details::ElementWise(details::Add<float>(), x, y, out, batch_size, details::ElementWise(details::Add<float>(), x, y, out, batch_size,
prev_size_, midd_size_, post_size_, stream); prev_size_, midd_size_, post_size_, stream);
} else if (type_ == nvinfer1::ElementWiseOperation::kPROD) { } else if (type_ == "mul") {
details::ElementWise(details::Mul<float>(), x, y, out, batch_size, details::ElementWise(details::Mul<float>(), x, y, out, batch_size,
prev_size_, midd_size_, post_size_, stream); prev_size_, midd_size_, post_size_, stream);
} else { } else {
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...@@ -24,9 +25,8 @@ namespace plugin { ...@@ -24,9 +25,8 @@ namespace plugin {
class ElementWisePlugin : public PluginTensorRT { class ElementWisePlugin : public PluginTensorRT {
public: public:
ElementWisePlugin(nvinfer1::ElementWiseOperation type, ElementWisePlugin(std::string type, nvinfer1::Dims const &dims_x,
nvinfer1::Dims const &dims_x, nvinfer1::Dims const &dims_y, nvinfer1::Dims const &dims_y, int axis)
int axis)
: type_(type), : type_(type),
dims_x_(dims_x), dims_x_(dims_x),
dims_y_(dims_y), dims_y_(dims_y),
...@@ -37,6 +37,9 @@ class ElementWisePlugin : public PluginTensorRT { ...@@ -37,6 +37,9 @@ class ElementWisePlugin : public PluginTensorRT {
ElementWisePlugin(void const *serial_data, size_t serial_length) { ElementWisePlugin(void const *serial_data, size_t serial_length) {
deserializeBase(serial_data, serial_length); deserializeBase(serial_data, serial_length);
const char *elementwise_type;
DeserializeValue(&serial_data, &serial_length, &elementwise_type);
type_ = std::string(elementwise_type);
DeserializeValue(&serial_data, &serial_length, &axis_); DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &dims_x_); DeserializeValue(&serial_data, &serial_length, &dims_x_);
DeserializeValue(&serial_data, &serial_length, &dims_y_); DeserializeValue(&serial_data, &serial_length, &dims_y_);
...@@ -47,7 +50,7 @@ class ElementWisePlugin : public PluginTensorRT { ...@@ -47,7 +50,7 @@ class ElementWisePlugin : public PluginTensorRT {
return nullptr; return nullptr;
} }
const char *getPluginType() const override { return "elementwise"; } const char *getPluginType() const override { return "elementwise_plugin"; }
nvinfer1::Dims getOutputDimensions(int index, nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *input_dims, const nvinfer1::Dims *input_dims,
...@@ -61,18 +64,21 @@ class ElementWisePlugin : public PluginTensorRT { ...@@ -61,18 +64,21 @@ class ElementWisePlugin : public PluginTensorRT {
protected: protected:
size_t getSerializationSize() override { size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(dims_x_) + return SerializedSize(getPluginType()) + SerializedSize(axis_) +
SerializedSize(dims_y_) + getBaseSerializationSize(); SerializedSize(dims_x_) + SerializedSize(dims_y_) +
getBaseSerializationSize();
} }
void serialize(void *buffer) override { void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, type_.c_str());
SerializeValue(&buffer, axis_); SerializeValue(&buffer, axis_);
SerializeValue(&buffer, dims_x_); SerializeValue(&buffer, dims_x_);
SerializeValue(&buffer, dims_y_); SerializeValue(&buffer, dims_y_);
} }
nvinfer1::ElementWiseOperation type_; std::string type_;
nvinfer1::Dims dims_x_; nvinfer1::Dims dims_x_;
nvinfer1::Dims dims_y_; nvinfer1::Dims dims_y_;
int axis_; int axis_;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/prelu.h" #include "paddle/fluid/operators/math/prelu.h"
namespace paddle { namespace paddle {
...@@ -24,6 +25,17 @@ namespace inference { ...@@ -24,6 +25,17 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
PReluPlugin *CreatePreluPluginDeserialize(const void *buffer, size_t length) {
return new PReluPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("prelu_plugin", CreatePreluPluginDeserialize);
int PReluPlugin::initialize() {
cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
cudaMemcpyHostToDevice);
}
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims, const nvinfer1::Dims *inputDims,
int nbInputs) { int nbInputs) {
...@@ -39,7 +51,8 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -39,7 +51,8 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
// input dims is CHW. // input dims is CHW.
const auto &input_dims = this->getInputDims(0); const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]); const float *input = reinterpret_cast<const float *>(inputs[0]);
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values); // const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
const float *alpha = p_gpu_weight_;
float *output = reinterpret_cast<float **>(outputs)[0]; float *output = reinterpret_cast<float **>(outputs)[0];
std::vector<int> input_shape; std::vector<int> input_shape;
......
...@@ -14,7 +14,12 @@ ...@@ -14,7 +14,12 @@
#pragma once #pragma once
#include <algorithm>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
...@@ -24,39 +29,51 @@ namespace tensorrt { ...@@ -24,39 +29,51 @@ namespace tensorrt {
namespace plugin { namespace plugin {
class PReluPlugin : public PluginTensorRT { class PReluPlugin : public PluginTensorRT {
TensorRTEngine::Weight alpha_; std::vector<float> weight_;
float *p_gpu_weight_;
std::string mode_; std::string mode_;
protected: protected:
size_t getSerializationSize() override { size_t getSerializationSize() override {
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_); return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
return 0; SerializedSize(weight_) + SerializedSize(getPluginType());
} }
// TRT will call this func when we need to serialize the configuration of // TRT will call this func when we need to serialize the configuration of
// tensorrt. // tensorrt.
// It should not be called by users. // It should not be called by users.
void serialize(void *buffer) override { void serialize(void *buffer) override {
// serializeBase(buffer); SerializeValue(&buffer, getPluginType());
// SerializeValue(&buffer, alpha_); serializeBase(buffer);
// SerializeValue(&buffer, mode_); SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str());
} }
public: public:
PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode) PReluPlugin(const float *weight, const int weight_num,
: alpha_(alpha), mode_(mode) {} std::string const &mode)
: mode_(mode) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}
// It was used for tensorrt deserialization. // It was used for tensorrt deserialization.
// It should not be called by users. // It should not be called by users.
PReluPlugin(void const *serialData, size_t serialLength) { PReluPlugin(void const *serialData, size_t serialLength) {
// deserializeBase(serialData, serialLength); deserializeBase(serialData, serialLength);
// DeserializeValue(&serialData, &serialLength, &alpha_); DeserializeValue(&serialData, &serialLength, &weight_);
// DeserializeValue(&serialData, &serialLength, &mode_); const char *prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode);
} }
~PReluPlugin() { cudaFree(p_gpu_weight_); }
int initialize() override;
PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); } PReluPlugin *clone() const override {
return new PReluPlugin(weight_.data(), weight_.size(), mode_);
}
const char *getPluginType() const override { return "prelu"; } const char *getPluginType() const override { return "prelu_plugin"; }
int getNbOutputs() const override { return 1; } int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override; int nbInputDims) override;
......
...@@ -15,12 +15,18 @@ ...@@ -15,12 +15,18 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <algorithm> #include <algorithm>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) {
return new SplitPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize);
// copied from operators::math::SplitFunctor // copied from operators::math::SplitFunctor
template <typename T> template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int in_row,
......
...@@ -25,6 +25,7 @@ namespace plugin { ...@@ -25,6 +25,7 @@ namespace plugin {
class SplitPlugin : public PluginTensorRT { class SplitPlugin : public PluginTensorRT {
public: public:
SplitPlugin() {}
SplitPlugin(int axis, std::vector<int> const &output_lengths) SplitPlugin(int axis, std::vector<int> const &output_lengths)
: axis_(axis), same_shape_(true), output_length_(output_lengths) {} : axis_(axis), same_shape_(true), output_length_(output_lengths) {}
...@@ -38,7 +39,7 @@ class SplitPlugin : public PluginTensorRT { ...@@ -38,7 +39,7 @@ class SplitPlugin : public PluginTensorRT {
return new SplitPlugin(axis_, output_length_); return new SplitPlugin(axis_, output_length_);
} }
const char *getPluginType() const override { return "split"; } const char *getPluginType() const override { return "split_plugin"; }
int getNbOutputs() const override { return output_length_.size(); } int getNbOutputs() const override { return output_length_.size(); }
nvinfer1::Dims getOutputDimensions(int index, nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *input_dims, const nvinfer1::Dims *input_dims,
...@@ -50,11 +51,12 @@ class SplitPlugin : public PluginTensorRT { ...@@ -50,11 +51,12 @@ class SplitPlugin : public PluginTensorRT {
protected: protected:
size_t getSerializationSize() override { size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(output_length_) + return SerializedSize(getPluginType()) + SerializedSize(axis_) +
getBaseSerializationSize(); SerializedSize(output_length_) + getBaseSerializationSize();
} }
void serialize(void *buffer) override { void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer); serializeBase(buffer);
SerializeValue(&buffer, axis_); SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_length_); SerializeValue(&buffer, output_length_);
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -30,6 +30,13 @@ namespace inference { ...@@ -30,6 +30,13 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
class PluginTensorRT;
typedef std::function<PluginTensorRT*(const void*, size_t)>
PluginDeserializeFunc;
typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
class PluginTensorRT : public nvinfer1::IPluginExt { class PluginTensorRT : public nvinfer1::IPluginExt {
public: public:
PluginTensorRT() {} PluginTensorRT() {}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
const void* serial_data,
size_t serial_length) {
const char* plugin_type;
DeserializeValue(&serial_data, &serial_length, &plugin_type);
PADDLE_ENFORCE(Has(plugin_type),
"trt plugin type %s does not exists, check it.", plugin_type);
auto plugin = plugin_registry_[plugin_type](serial_data, serial_length);
owned_plugins_.emplace_back(plugin);
return plugin;
}
bool PluginFactoryTensorRT::RegisterPlugin(
const std::string& op_name, PluginDeserializeFunc deserialize_func) {
if (Has(op_name)) return false;
auto ret = plugin_registry_.emplace(op_name, deserialize_func);
return ret.second;
}
void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); }
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <cstring>
#include <list>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
public:
// Deserialization method
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
size_t serial_length) override;
bool RegisterPlugin(const std::string& op_name,
PluginDeserializeFunc deserialize_func);
bool Has(const std::string& op_name) {
return plugin_registry_.find(op_name) != plugin_registry_.end();
}
void DestroyPlugins();
protected:
std::unordered_map<std::string, PluginDeserializeFunc> plugin_registry_;
std::list<std::unique_ptr<PluginTensorRT>> owned_plugins_;
};
class TrtPluginRegistrar {
public:
TrtPluginRegistrar(const std::string& name,
PluginDeserializeFunc deserialize_func) {
inference::Singleton<PluginFactoryTensorRT>::Global().RegisterPlugin(
name, deserialize_func);
}
};
#define REGISTER_TRT_PLUGIN(name, deserialize_func) \
REGISTER_TRT_PLUGIN_UNIQ(__COUNTER__, name, deserialize_func)
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func) \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrar \
trt_plugin_registrar##ctr __attribute__((unused)) = \
paddle::inference::tensorrt::plugin::TrtPluginRegistrar( \
name, deserialize_func)
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cstring> #include <cstring>
#include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
......
...@@ -134,9 +134,10 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -134,9 +134,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_res->calib_.reset(new TRTInt8Calibrator( calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, engine_key_, dev_place)); calib_buffers, runtime_batch, engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() { calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset( calib_res->engine_.reset(new TensorRTEngine(
new TensorRTEngine(max_batch_size_, workspace_size_, enable_int8_, max_batch_size_, workspace_size_, enable_int8_,
calib_res->calib_.get())); calib_res->calib_.get(),
boost::get<platform::CUDAPlace>(dev_place).device));
VLOG(3) << "start the calib trt engine thread"; VLOG(3) << "start the calib trt engine thread";
PrepareTRTEngine(scope, calib_res->engine_.get()); PrepareTRTEngine(scope, calib_res->engine_.get());
})); }));
...@@ -234,7 +235,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -234,7 +235,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine_ = trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(max_batch_size_, workspace_size_, enable_int8_, .Create(max_batch_size_, workspace_size_, enable_int8_,
calibrator_.get(), engine_key_); calibrator_.get(), engine_key_,
boost::get<platform::CUDAPlace>(dev_place).device);
PrepareTRTEngine(scope, trt_engine_); PrepareTRTEngine(scope, trt_engine_);
} }
return trt_engine_; return trt_engine_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册