提交 1d5ef7c9 编写于 作者: N nhzlx

5. add static trt load model

1). add static trt load model
2). fix bug: when device_id is not 0, the trt will have a bug
test=develop
上级 2070fb24
......@@ -82,6 +82,7 @@ void IRPassManager::CreatePasses(Argument *argument,
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
pass->Set("predictor_id", new int(argument->predictor_id()));
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
}
pre_pass = pass_name;
......
......@@ -242,7 +242,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<tensorrt::TRTEngineManager>::Global().Create(
Get<int>("max_batch_size"), Get<int>("workspace_size"), enable_int8,
calibrator.get(), engine_key);
calibrator.get(), engine_key, Get<int>("gpu_device_id"));
if (trt_engine_serialized_data.size() == 0) {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
......@@ -258,13 +258,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine_serialized_data =
std::string((const char *)serialized_engine_data->data(),
serialized_engine_data->size());
// SaveTrtEngineSerializedDataToFile(GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
// engine_key),
// trt_engine_serialized_data);
SaveTrtEngineSerializedDataToFile(
GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
engine_key),
trt_engine_serialized_data);
} else {
LOG(INFO) << "Load TRT Engine from optimized serialized data : "
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
trt_engine->Deserialize(trt_engine_serialized_data);
}
SetAttr(op_desc->Proto(), "engine_serialized_data",
trt_engine_serialized_data);
}
......
......@@ -44,7 +44,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor.get());
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace());
auto* weight_data = weight_tensor->mutable_data<float>(cpu_place);
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
const int n_output = weight_tensor->dims()[0];
......
......@@ -153,7 +153,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
if (CheckDims(dims_x, dims_y)) {
// The two input tensor should have the same dims
VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
nvinfer1::IElementWiseLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *const_cast<nvinfer1::ITensor*>(X),
*const_cast<nvinfer1::ITensor*>(Y), op_pair->second);
......@@ -166,7 +165,7 @@ class ElementwiseTensorOpConverter : public OpConverter {
"ElementWisePluginLayer";
plugin::ElementWisePlugin* plugin =
new plugin::ElementWisePlugin(op_pair->second, dims_x, dims_y, axis);
new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis);
plugin->AddInput(X);
plugin->AddInput(Y);
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(
......
......@@ -85,10 +85,10 @@ class FcOpConverter : public OpConverter {
Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float));
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
Y_t->memory_size() / sizeof(float)};
static_cast<size_t>(Y_t->numel())};
TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT,
static_cast<void*>(tmp->data<float>()),
Y_t->memory_size() / sizeof(float));
static_cast<size_t>(Y_t->numel()));
weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]});
tmp_weight.dims = weight.dims;
......
......@@ -43,23 +43,20 @@ class PReluOpConverter : public OpConverter {
PADDLE_ENFORCE_NOT_NULL(alpha_var);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
platform::CUDAPlace place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_device(
platform::CPUPlace cpu_place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_temp(
new framework::LoDTensor());
alpha_tensor_device->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get());
float* alpha_data = alpha_tensor_device->mutable_data<float>(place);
alpha_tensor_temp->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, cpu_place, alpha_tensor_temp.get());
float* alpha_data = alpha_tensor_temp->mutable_data<float>(cpu_place);
// Transform alpha to TensorRTEngine::Weight
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data),
alpha_tensor_device->numel());
plugin::PReluPlugin* plugin = new plugin::PReluPlugin(alpha_rt, mode);
plugin::PReluPlugin* plugin =
new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
// keep alpha tensor to avoid release it's memory
engine_->weight_map[op_desc.Input("Alpha")[0]] =
std::move(alpha_tensor_device);
std::move(alpha_tensor_temp);
std::string layer_name = "prelu (Output: ";
auto output_name = op_desc.Output("Out")[0];
......
......@@ -79,7 +79,8 @@ class TRTConvertValidation {
if_add_batch_(if_add_batch),
max_batch_size_(max_batch_size) {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new TensorRTEngine(max_batch_size, workspace_size));
engine_.reset(
new TensorRTEngine(max_batch_size, workspace_size, false, nullptr, 0));
engine_->InitNetwork();
}
......@@ -114,13 +115,12 @@ class TRTConvertValidation {
}
void DeclVar(const std::string& name, const std::vector<int> dim_vec) {
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
platform::CUDADeviceContext ctx(place_);
auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place, ctx);
RandomizeTensor(x_tensor, place_, ctx);
}
// Declare a variable in a fluid Scope.
void DeclVar(const std::string& name, const nvinfer1::Dims& dims,
......@@ -155,9 +155,8 @@ class TRTConvertValidation {
std::unordered_set<std::string> neglected_output = {}) {
// Execute Fluid Op
PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
op_->Run(scope_, place);
platform::CUDADeviceContext ctx(place_);
op_->Run(scope_, place_);
std::vector<std::string> input_output_names;
......@@ -188,7 +187,7 @@ class TRTConvertValidation {
auto* tensor = var->GetMutable<framework::LoDTensor>();
const int bind_index = engine_->engine()->getBindingIndex(name.c_str());
buffers[bind_index] =
static_cast<void*>(tensor->mutable_data<float>(place));
static_cast<void*>(tensor->mutable_data<float>(place_));
}
// Execute TRT.
......@@ -220,6 +219,7 @@ class TRTConvertValidation {
framework::Scope& scope() { return scope_; }
private:
platform::CUDAPlace place_;
std::unique_ptr<TensorRTEngine> engine_;
cudaStream_t stream_;
std::unique_ptr<framework::OperatorBase> op_;
......
......@@ -34,6 +34,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
cudaStream_t stream) {
freshDeviceId();
batch_size_ = batch_size;
infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr);
cudaStreamSynchronize(stream);
......@@ -41,6 +42,7 @@ void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
}
void TensorRTEngine::FreezeNetwork() {
freshDeviceId();
VLOG(3) << "TRT to freeze network";
PADDLE_ENFORCE(infer_builder_ != nullptr,
"Call InitNetwork first to initialize network.");
......@@ -140,6 +142,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin);
}
void TensorRTEngine::freshDeviceId() {
int count;
cudaGetDeviceCount(&count);
PADDLE_ENFORCE_LT(device_id_, count);
cudaSetDevice(device_id_);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
......@@ -59,12 +60,13 @@ class TensorRTEngine {
};
TensorRTEngine(int max_batch, int max_workspace, bool enable_int8 = false,
TRTInt8Calibrator* calibrator = nullptr,
TRTInt8Calibrator* calibrator = nullptr, int device_id = 0,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
enable_int8_(enable_int8),
calibrator_(calibrator),
device_id_(device_id),
logger_(logger) {}
~TensorRTEngine() {}
......@@ -78,6 +80,7 @@ class TensorRTEngine {
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork() {
freshDeviceId();
infer_builder_.reset(createInferBuilder(&logger_));
infer_network_.reset(infer_builder_->createNetwork());
}
......@@ -113,20 +116,11 @@ class TensorRTEngine {
}
void Deserialize(const std::string& engine_serialized_data) {
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(
runtime->deserializeCudaEngine(engine_serialized_data.c_str(),
engine_serialized_data.size(), nullptr));
PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext());
}
void Deserialize(const nvinfer1::IHostMemory* engine_serialized_data) {
freshDeviceId();
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data->data(), engine_serialized_data->size(),
nullptr));
engine_serialized_data.c_str(), engine_serialized_data.size(),
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext());
......@@ -134,6 +128,7 @@ class TensorRTEngine {
void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch();
int GetDeviceId() { return device_id_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);
......@@ -146,6 +141,11 @@ class TensorRTEngine {
weight_map;
private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
// freshDeviceId().
void freshDeviceId();
// the max batch size
int max_batch_;
// the runtime batch size
......@@ -158,6 +158,7 @@ class TensorRTEngine {
// batch size of the current data, will be updated each Executation.
int batch_size_{-1};
int device_id_;
nvinfer1::ILogger& logger_;
// max data size for the buffers.
......@@ -216,10 +217,10 @@ class TRTEngineManager {
// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, bool enable_int8,
TRTInt8Calibrator* calibrator,
const std::string& engine_name) {
const std::string& engine_name, int device_id = 0) {
std::unique_lock<std::mutex> lk(mut_);
auto* p =
new TensorRTEngine(max_batch, max_workspace, enable_int8, calibrator);
auto* p = new TensorRTEngine(max_batch, max_workspace, enable_int8,
calibrator, device_id);
engines_[engine_name].reset(p);
return p;
}
......
......@@ -17,6 +17,9 @@
#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -74,6 +77,32 @@ class NaiveLogger : public nvinfer1::ILogger {
~NaiveLogger() override {}
};
class NaiveProfiler : public nvinfer1::IProfiler {
public:
typedef std::pair<std::string, float> Record;
std::vector<Record> mProfile;
virtual void reportLayerTime(const char* layerName, float ms) {
auto record =
std::find_if(mProfile.begin(), mProfile.end(),
[&](const Record& r) { return r.first == layerName; });
if (record == mProfile.end())
mProfile.push_back(std::make_pair(layerName, ms));
else
record->second += ms;
}
void printLayerTimes() {
float totalTime = 0;
for (size_t i = 0; i < mProfile.size(); i++) {
printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(),
mProfile[i].second);
totalTime += mProfile[i].second;
}
printf("Time over all layers: %4.3f\n", totalTime);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc
avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine prelu)
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle {
......@@ -20,6 +21,12 @@ namespace inference {
namespace tensorrt {
namespace plugin {
AvgPoolPlugin* CreateAvgPoolPluginDeserialize(const void* buffer,
size_t length) {
return new AvgPoolPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("avg_pool_plugin", CreateAvgPoolPluginDeserialize);
nvinfer1::Dims AvgPoolPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* inputDims, int nbInputs) {
assert(nbInputs == 1);
......
......@@ -33,24 +33,27 @@ class AvgPoolPlugin : public PluginTensorRT {
protected:
size_t getSerializationSize() override {
return SerializedSize(ceil_mode_) + SerializedSize(ksize_) +
SerializedSize(strides_) + SerializedSize(paddings_) +
SerializedSize(input_shape_) + getBaseSerializationSize();
return SerializedSize(getPluginType()) + SerializedSize(ceil_mode_) +
SerializedSize(ksize_) + SerializedSize(strides_) +
SerializedSize(paddings_) + SerializedSize(input_shape_) +
SerializedSize(output_shape_) + getBaseSerializationSize();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_);
SerializeValue(&buffer, input_shape_);
SerializeValue(&buffer, output_shape_);
}
public:
AvgPoolPlugin() {}
AvgPoolPlugin(bool ceil_mode, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
std::vector<int> input_shape)
......@@ -89,6 +92,7 @@ class AvgPoolPlugin : public PluginTensorRT {
DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_);
DeserializeValue(&serialData, &serialLength, &input_shape_);
DeserializeValue(&serialData, &serialLength, &output_shape_);
}
AvgPoolPlugin *clone() const override {
......@@ -96,7 +100,7 @@ class AvgPoolPlugin : public PluginTensorRT {
input_shape_);
}
const char *getPluginType() const override { return "avg_pool"; }
const char *getPluginType() const override { return "avg_pool_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
......
......@@ -14,12 +14,19 @@ limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
ElementWisePlugin* CreateElementWisePluginDeserialize(const void* buffer,
size_t length) {
return new ElementWisePlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("elementwise_plugin", CreateElementWisePluginDeserialize);
namespace details {
template <typename T>
......@@ -119,10 +126,10 @@ int ElementWisePlugin::enqueue(int batch_size, const void* const* inputs,
const float* y = reinterpret_cast<const float*>(inputs[1]);
float* out = reinterpret_cast<float*>(outputs[0]);
if (type_ == nvinfer1::ElementWiseOperation::kSUM) {
if (type_ == "add") {
details::ElementWise(details::Add<float>(), x, y, out, batch_size,
prev_size_, midd_size_, post_size_, stream);
} else if (type_ == nvinfer1::ElementWiseOperation::kPROD) {
} else if (type_ == "mul") {
details::ElementWise(details::Mul<float>(), x, y, out, batch_size,
prev_size_, midd_size_, post_size_, stream);
} else {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
......@@ -24,9 +25,8 @@ namespace plugin {
class ElementWisePlugin : public PluginTensorRT {
public:
ElementWisePlugin(nvinfer1::ElementWiseOperation type,
nvinfer1::Dims const &dims_x, nvinfer1::Dims const &dims_y,
int axis)
ElementWisePlugin(std::string type, nvinfer1::Dims const &dims_x,
nvinfer1::Dims const &dims_y, int axis)
: type_(type),
dims_x_(dims_x),
dims_y_(dims_y),
......@@ -37,6 +37,9 @@ class ElementWisePlugin : public PluginTensorRT {
ElementWisePlugin(void const *serial_data, size_t serial_length) {
deserializeBase(serial_data, serial_length);
const char *elementwise_type;
DeserializeValue(&serial_data, &serial_length, &elementwise_type);
type_ = std::string(elementwise_type);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &dims_x_);
DeserializeValue(&serial_data, &serial_length, &dims_y_);
......@@ -47,7 +50,7 @@ class ElementWisePlugin : public PluginTensorRT {
return nullptr;
}
const char *getPluginType() const override { return "elementwise"; }
const char *getPluginType() const override { return "elementwise_plugin"; }
nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *input_dims,
......@@ -61,18 +64,21 @@ class ElementWisePlugin : public PluginTensorRT {
protected:
size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(dims_x_) +
SerializedSize(dims_y_) + getBaseSerializationSize();
return SerializedSize(getPluginType()) + SerializedSize(axis_) +
SerializedSize(dims_x_) + SerializedSize(dims_y_) +
getBaseSerializationSize();
}
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, type_.c_str());
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, dims_x_);
SerializeValue(&buffer, dims_y_);
}
nvinfer1::ElementWiseOperation type_;
std::string type_;
nvinfer1::Dims dims_x_;
nvinfer1::Dims dims_y_;
int axis_;
......
......@@ -17,6 +17,7 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/prelu.h"
namespace paddle {
......@@ -24,6 +25,17 @@ namespace inference {
namespace tensorrt {
namespace plugin {
PReluPlugin *CreatePreluPluginDeserialize(const void *buffer, size_t length) {
return new PReluPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("prelu_plugin", CreatePreluPluginDeserialize);
int PReluPlugin::initialize() {
cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
cudaMemcpyHostToDevice);
}
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
......@@ -39,7 +51,8 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
// const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
const float *alpha = p_gpu_weight_;
float *output = reinterpret_cast<float **>(outputs)[0];
std::vector<int> input_shape;
......
......@@ -14,7 +14,12 @@
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
......@@ -24,39 +29,51 @@ namespace tensorrt {
namespace plugin {
class PReluPlugin : public PluginTensorRT {
TensorRTEngine::Weight alpha_;
std::vector<float> weight_;
float *p_gpu_weight_;
std::string mode_;
protected:
size_t getSerializationSize() override {
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_);
return 0;
return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
SerializedSize(weight_) + SerializedSize(getPluginType());
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
// serializeBase(buffer);
// SerializeValue(&buffer, alpha_);
// SerializeValue(&buffer, mode_);
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str());
}
public:
PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode)
: alpha_(alpha), mode_(mode) {}
PReluPlugin(const float *weight, const int weight_num,
std::string const &mode)
: mode_(mode) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}
// It was used for tensorrt deserialization.
// It should not be called by users.
PReluPlugin(void const *serialData, size_t serialLength) {
// deserializeBase(serialData, serialLength);
// DeserializeValue(&serialData, &serialLength, &alpha_);
// DeserializeValue(&serialData, &serialLength, &mode_);
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &weight_);
const char *prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode);
}
~PReluPlugin() { cudaFree(p_gpu_weight_); }
int initialize() override;
PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); }
PReluPlugin *clone() const override {
return new PReluPlugin(weight_.data(), weight_.size(), mode_);
}
const char *getPluginType() const override { return "prelu"; }
const char *getPluginType() const override { return "prelu_plugin"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
......
......@@ -15,12 +15,18 @@
#include <cuda_fp16.h>
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) {
return new SplitPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize);
// copied from operators::math::SplitFunctor
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
......
......@@ -25,6 +25,7 @@ namespace plugin {
class SplitPlugin : public PluginTensorRT {
public:
SplitPlugin() {}
SplitPlugin(int axis, std::vector<int> const &output_lengths)
: axis_(axis), same_shape_(true), output_length_(output_lengths) {}
......@@ -38,7 +39,7 @@ class SplitPlugin : public PluginTensorRT {
return new SplitPlugin(axis_, output_length_);
}
const char *getPluginType() const override { return "split"; }
const char *getPluginType() const override { return "split_plugin"; }
int getNbOutputs() const override { return output_length_.size(); }
nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *input_dims,
......@@ -50,11 +51,12 @@ class SplitPlugin : public PluginTensorRT {
protected:
size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(output_length_) +
getBaseSerializationSize();
return SerializedSize(getPluginType()) + SerializedSize(axis_) +
SerializedSize(output_length_) + getBaseSerializationSize();
}
void serialize(void *buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_length_);
......
......@@ -19,7 +19,7 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -30,6 +30,13 @@ namespace inference {
namespace tensorrt {
namespace plugin {
class PluginTensorRT;
typedef std::function<PluginTensorRT*(const void*, size_t)>
PluginDeserializeFunc;
typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
class PluginTensorRT : public nvinfer1::IPluginExt {
public:
PluginTensorRT() {}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
const void* serial_data,
size_t serial_length) {
const char* plugin_type;
DeserializeValue(&serial_data, &serial_length, &plugin_type);
PADDLE_ENFORCE(Has(plugin_type),
"trt plugin type %s does not exists, check it.", plugin_type);
auto plugin = plugin_registry_[plugin_type](serial_data, serial_length);
owned_plugins_.emplace_back(plugin);
return plugin;
}
bool PluginFactoryTensorRT::RegisterPlugin(
const std::string& op_name, PluginDeserializeFunc deserialize_func) {
if (Has(op_name)) return false;
auto ret = plugin_registry_.emplace(op_name, deserialize_func);
return ret.second;
}
void PluginFactoryTensorRT::DestroyPlugins() { owned_plugins_.clear(); }
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInfer.h>
#include <cstring>
#include <list>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
public:
// Deserialization method
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
size_t serial_length) override;
bool RegisterPlugin(const std::string& op_name,
PluginDeserializeFunc deserialize_func);
bool Has(const std::string& op_name) {
return plugin_registry_.find(op_name) != plugin_registry_.end();
}
void DestroyPlugins();
protected:
std::unordered_map<std::string, PluginDeserializeFunc> plugin_registry_;
std::list<std::unique_ptr<PluginTensorRT>> owned_plugins_;
};
class TrtPluginRegistrar {
public:
TrtPluginRegistrar(const std::string& name,
PluginDeserializeFunc deserialize_func) {
inference::Singleton<PluginFactoryTensorRT>::Global().RegisterPlugin(
name, deserialize_func);
}
};
#define REGISTER_TRT_PLUGIN(name, deserialize_func) \
REGISTER_TRT_PLUGIN_UNIQ(__COUNTER__, name, deserialize_func)
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func) \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrar \
trt_plugin_registrar##ctr __attribute__((unused)) = \
paddle::inference::tensorrt::plugin::TrtPluginRegistrar( \
name, deserialize_func)
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -13,8 +13,8 @@
// limitations under the License.
#pragma once
#include <cstring>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
......
......@@ -134,9 +134,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(
new TensorRTEngine(max_batch_size_, workspace_size_, enable_int8_,
calib_res->calib_.get()));
calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_,
calib_res->calib_.get(),
boost::get<platform::CUDAPlace>(dev_place).device));
VLOG(3) << "start the calib trt engine thread";
PrepareTRTEngine(scope, calib_res->engine_.get());
}));
......@@ -234,7 +235,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(max_batch_size_, workspace_size_, enable_int8_,
calibrator_.get(), engine_key_);
calibrator_.get(), engine_key_,
boost::get<platform::CUDAPlace>(dev_place).device);
PrepareTRTEngine(scope, trt_engine_);
}
return trt_engine_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册