提交 2070fb24 编写于 作者: N nhzlx

4. do the trt_engine optim during init.

add simple static mode loading
test=develop
上级 ecc12fb4
...@@ -99,6 +99,10 @@ struct Argument { ...@@ -99,6 +99,10 @@ struct Argument {
private: \ private: \
unique_ptr_t field__##_; unique_ptr_t field__##_;
// Each predictor has an unique id.
// For now, this attr will help us to get the right
// trt_engine for each trt_engine_op for each predictor when using trt.
DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
// Model path // Model path
DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string); DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
// Model specified with program and parameters files. // Model specified with program and parameters files.
......
...@@ -217,6 +217,35 @@ static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir, ...@@ -217,6 +217,35 @@ static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir,
return ""; return "";
} }
static std::string GetTrtEngineSerializedPath(const std::string &model_root,
const std::string &engine_key) {
return model_root + "/trt_serialized_" + engine_key;
}
static std::string GetTrtEngineSerializedData(
const std::string &model_opt_cache_dir, const std::string &engine_key) {
std::string trt_serialized_path =
GetTrtEngineSerializedPath(model_opt_cache_dir, engine_key);
if (FileExists(trt_serialized_path)) {
VLOG(3) << "Trt serialized file: " << trt_serialized_path
<< "is found here";
std::ifstream infile(trt_serialized_path, std::ios::in);
std::stringstream buffer;
buffer << infile.rdbuf();
std::string trt_engine_serialized_data(buffer.str());
return trt_engine_serialized_data;
}
return "";
}
static void SaveTrtEngineSerializedDataToFile(
const std::string &trt_serialized_path,
const std::string &engine_serialized_data) {
std::ofstream outfile(trt_serialized_path);
outfile << engine_serialized_data;
outfile.close();
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
...@@ -81,6 +81,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -81,6 +81,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set( pass->Set(
"model_opt_cache_dir", "model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
pass->Set("predictor_id", new int(argument->predictor_id()));
} }
pre_pass = pass_name; pre_pass = pass_name;
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
...@@ -83,7 +85,8 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -83,7 +85,8 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
} }
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs) { const std::set<std::string> &engine_outputs,
const std::string &predictor_id) {
std::string engine_hash_key = ""; std::string engine_hash_key = "";
for (auto name : engine_inputs) { for (auto name : engine_inputs) {
engine_hash_key += name; engine_hash_key += name;
...@@ -91,6 +94,7 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, ...@@ -91,6 +94,7 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
for (auto name : engine_outputs) { for (auto name : engine_outputs) {
engine_hash_key += name; engine_hash_key += name;
} }
engine_hash_key += predictor_id;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key)); auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key; return engine_key;
} }
...@@ -205,8 +209,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -205,8 +209,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "parameters", params); SetAttr(op_desc->Proto(), "parameters", params);
auto enable_int8 = Get<bool>("enable_int8"); auto enable_int8 = Get<bool>("enable_int8");
auto engine_key = int predictor_id = Get<int>("predictor_id");
GenerateEngineKey(input_names_with_id, output_names_with_id); auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id,
std::to_string(predictor_id));
// Get "" when there is no cached calibration table data. // Get "" when there is no cached calibration table data.
std::string calibration_data = GetTrtCalibTableData( std::string calibration_data = GetTrtCalibTableData(
...@@ -215,10 +220,53 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -215,10 +220,53 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "enable_int8", enable_int8); SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
SetAttr(op_desc->Proto(), "engine_serialized_data", std::string(""));
SetAttr(op_desc->Proto(), "engine_serialized_data_path",
GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
engine_key));
std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator;
if (enable_int8 && calibration_data.size() != 0) {
calibrator.reset(new tensorrt::TRTInt8Calibrator(calibration_data));
}
if (!(enable_int8 && calibration_data.size() == 0)) { // When in int8 mode and calibration_mode, the program just produce the
// calibration table data.
bool calibration_mode = (enable_int8 && calibration_data.size() == 0);
if (!calibration_mode) {
std::copy(params.begin(), params.end(), std::copy(params.begin(), params.end(),
std::back_inserter(*repetitive_params)); std::back_inserter(*repetitive_params));
std::string trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key);
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<tensorrt::TRTEngineManager>::Global().Create(
Get<int>("max_batch_size"), Get<int>("workspace_size"), enable_int8,
calibrator.get(), engine_key);
if (trt_engine_serialized_data.size() == 0) {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
auto *scope = param_scope();
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
std::unordered_set<std::string> param_set(params.begin(), params.end());
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(
&block_desc_temp, *scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine);
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
trt_engine_serialized_data =
std::string((const char *)serialized_engine_data->data(),
serialized_engine_data->size());
// SaveTrtEngineSerializedDataToFile(GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
// engine_key),
// trt_engine_serialized_data);
} else {
trt_engine->Deserialize(trt_engine_serialized_data);
}
SetAttr(op_desc->Proto(), "engine_serialized_data",
trt_engine_serialized_data);
} }
} }
......
...@@ -342,6 +342,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -342,6 +342,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
config_.static_memory_optim_force_update_); config_.static_memory_optim_force_update_);
argument_.SetModelFromMemory(config_.model_from_memory_); argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
argument_.SetPredictorID(predictor_id_);
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir()); argument_.SetModelDir(config_.model_dir());
} else { } else {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
...@@ -43,7 +44,9 @@ using framework::NaiveExecutor; ...@@ -43,7 +44,9 @@ using framework::NaiveExecutor;
*/ */
class AnalysisPredictor : public PaddlePredictor { class AnalysisPredictor : public PaddlePredictor {
public: public:
explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {} explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {
predictor_id_ = inference::GetUniqueId();
}
~AnalysisPredictor(); ~AnalysisPredictor();
bool Init(const std::shared_ptr<framework::Scope> &parent_scope, bool Init(const std::shared_ptr<framework::Scope> &parent_scope,
...@@ -143,6 +146,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -143,6 +146,7 @@ class AnalysisPredictor : public PaddlePredictor {
const size_t max_shape_collect_count_{1000}; const size_t max_shape_collect_count_{1000};
int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true. int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true.
std::vector<std::map<std::string, std::vector<int>>> batch_var_shapes_; std::vector<std::map<std::string, std::vector<int>>> batch_var_shapes_;
int predictor_id_;
private: private:
// Some status here that help to determine the status inside the predictor. // Some status here that help to determine the status inside the predictor.
......
...@@ -50,6 +50,11 @@ class Timer { ...@@ -50,6 +50,11 @@ class Timer {
} }
}; };
static int GetUniqueId() {
static int id = 0;
return id++;
}
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) { std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
......
...@@ -143,6 +143,7 @@ class OpConverter { ...@@ -143,6 +143,7 @@ class OpConverter {
} }
} }
// The scope here should be inited with the parameter vars.
void ConvertBlockToTRTEngine( void ConvertBlockToTRTEngine(
framework::BlockDesc* block_desc, const framework::Scope& scope, framework::BlockDesc* block_desc, const framework::Scope& scope,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
...@@ -151,18 +152,16 @@ class OpConverter { ...@@ -151,18 +152,16 @@ class OpConverter {
engine->InitNetwork(); engine->InitNetwork();
for (auto& input : inputs) { for (auto& input : inputs) {
if (parameters.count(input)) continue; if (parameters.count(input)) continue;
auto& t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, input);
auto t_shape = framework::vectorize(t.dims());
auto* var = block_desc->FindVar(input); auto* var = block_desc->FindVar(input);
PADDLE_ENFORCE(var, "no variable called %s", input); PADDLE_ENFORCE(var, "no variable called %s", input);
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input"); "TensorRT engine only takes LoDTensor as input");
auto var_shape = var->GetShape();
engine->DeclareInput( engine->DeclareInput(
input, FluidDataType2TRT( input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()), var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(t_shape)); Vec2TRT_Dims(var_shape));
} }
framework::proto::BlockDesc* block_proto = block_desc->Proto(); framework::proto::BlockDesc* block_proto = block_desc->Proto();
ConvertBlock(*block_proto, parameters, scope, engine); ConvertBlock(*block_proto, parameters, scope, engine);
......
...@@ -104,6 +104,34 @@ class TensorRTEngine { ...@@ -104,6 +104,34 @@ class TensorRTEngine {
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
nvinfer1::IHostMemory* Serialize() {
PADDLE_ENFORCE(infer_engine_ != nullptr,
"You should build engine first and then serialize");
ihost_memory_.reset(infer_engine_->serialize());
return ihost_memory_.get();
}
void Deserialize(const std::string& engine_serialized_data) {
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(
runtime->deserializeCudaEngine(engine_serialized_data.c_str(),
engine_serialized_data.size(), nullptr));
PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext());
}
void Deserialize(const nvinfer1::IHostMemory* engine_serialized_data) {
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data->data(), engine_serialized_data->size(),
nullptr));
PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext());
}
void SetRuntimeBatch(size_t batch_size); void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch(); int GetRuntimeBatch();
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
...@@ -154,11 +182,11 @@ class TensorRTEngine { ...@@ -154,11 +182,11 @@ class TensorRTEngine {
infer_ptr<nvinfer1::INetworkDefinition> infer_network_; infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_; infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
infer_ptr<nvinfer1::IExecutionContext> infer_context_; infer_ptr<nvinfer1::IExecutionContext> infer_context_;
infer_ptr<nvinfer1::IHostMemory> ihost_memory_;
}; // class TensorRTEngine }; // class TensorRTEngine
// Add an layer__ into engine__ with args ARGS. // Add an layer__ into engine__ with args ARGS.
// For example: // For example:
// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias)
// //
// Reference // Reference
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network
...@@ -170,6 +198,43 @@ class TensorRTEngine { ...@@ -170,6 +198,43 @@ class TensorRTEngine {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS); engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class TRTEngineManager {
public:
bool HasEngine(const std::string& name) const {
if (engines_.count(name) == 0) return false;
return engines_.at(name).get() != nullptr;
}
// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}
// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, bool enable_int8,
TRTInt8Calibrator* calibrator,
const std::string& engine_name) {
std::unique_lock<std::mutex> lk(mut_);
auto* p =
new TensorRTEngine(max_batch, max_workspace, enable_int8, calibrator);
engines_[engine_name].reset(p);
return p;
}
void DeleteALL() {
for (auto& item : engines_) {
item.second.reset(nullptr);
}
}
private:
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
std::mutex mut_;
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -191,9 +191,8 @@ TEST_F(TensorRTEngineTest, test_pool2d) { ...@@ -191,9 +191,8 @@ TEST_F(TensorRTEngineTest, test_pool2d) {
std::vector<void *> buffers(2); // TRT binded inputs std::vector<void *> buffers(2); // TRT binded inputs
nvinfer1::PoolingType pool_t = nvinfer1::PoolingType::kAVERAGE; nvinfer1::PoolingType pool_t = nvinfer1::PoolingType::kAVERAGE;
auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *x, pool_t,
*const_cast<nvinfer1::ITensor *>(x), nvinfer1::DimsHW{2, 2});
pool_t, nvinfer1::DimsHW{2, 2});
PADDLE_ENFORCE(pool_layer != nullptr); PADDLE_ENFORCE(pool_layer != nullptr);
pool_layer->setStride(nvinfer1::DimsHW{1, 1}); pool_layer->setStride(nvinfer1::DimsHW{1, 1});
......
...@@ -30,6 +30,9 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -30,6 +30,9 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Ys", "A list of outputs").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph."); AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("calibration_data", "the calibration data for int8"); AddAttr<std::string>("calibration_data", "the calibration data for int8");
AddAttr<std::string>(
"engine_serialized_data",
"the serialized data contains the all info of the ICUDAEngine");
AddAttr<std::string>( AddAttr<std::string>(
"engine_key", "engine_key",
"The engine_key here is used to distinguish different TRT Engines"); "The engine_key here is used to distinguish different TRT Engines");
......
...@@ -41,13 +41,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -41,13 +41,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
private: private:
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::unordered_set<std::string> param_names_; std::unordered_set<std::string> param_names_;
mutable std::unique_ptr<TensorRTEngine> trt_engine_; mutable TensorRTEngine *trt_engine_;
int max_batch_size_; int max_batch_size_;
int workspace_size_; int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_; std::unique_ptr<TRTInt8Calibrator> calibrator_;
bool enable_int8_; bool enable_int8_;
std::string calibration_data_; std::string calibration_data_;
std::string engine_key_; std::string engine_key_;
std::string engine_serialized_data_;
bool calibration_mode_; bool calibration_mode_;
public: public:
...@@ -62,6 +63,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -62,6 +63,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
enable_int8_ = Attr<bool>("enable_int8"); enable_int8_ = Attr<bool>("enable_int8");
calibration_data_ = Attr<std::string>("calibration_data"); calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key"); engine_key_ = Attr<std::string>("engine_key");
engine_serialized_data_ = Attr<std::string>("engine_serialized_data");
trt_engine_ = nullptr;
auto params = Attr<std::vector<std::string>>("parameters"); auto params = Attr<std::vector<std::string>>("parameters");
for (const auto &param : params) { for (const auto &param : params) {
...@@ -78,7 +81,12 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -78,7 +81,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
// we will create an engine here. // we will create an engine here.
if (!calibration_mode_) { if (!calibration_mode_) {
// trt_engine_.reset(); if (inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.HasEngine(engine_key_)) {
trt_engine_ = inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_key_);
}
} }
} }
...@@ -99,7 +107,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -99,7 +107,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
RunCalibration(scope, dev_place); RunCalibration(scope, dev_place);
return; return;
} }
auto trt_engine = GetEngine(scope, dev_place); auto *trt_engine = GetEngine(scope, dev_place);
RunTrt(scope, dev_place, trt_engine); RunTrt(scope, dev_place, trt_engine);
} }
...@@ -158,7 +166,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -158,7 +166,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
// auto *engine = trt_engine_.get();
PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs"); PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs");
std::vector<std::string> output_maps = std::vector<std::string> output_maps =
...@@ -192,8 +199,9 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -192,8 +199,9 @@ class TensorRTEngineOp : public framework::OperatorBase {
int output_index = 0; int output_index = 0;
VLOG(4) << "TensorRT Engine Op Outputs:"; VLOG(4) << "TensorRT Engine Op Outputs:";
for (const auto &y : Outputs("Ys")) { for (const auto &y : Outputs("Ys")) {
nvinfer1::ITensor *trt_t = engine->GetITensor(output_maps[output_index]); const int bind_index =
auto dims = trt_t->getDimensions(); engine->engine()->getBindingIndex(output_maps[output_index].c_str());
auto dims = engine->engine()->getBindingDimensions(bind_index);
// Use the output ITensor's dims to reshape the Fluid Tensor. // Use the output ITensor's dims to reshape the Fluid Tensor.
// The ITensor doesn't contain the batch size dim. // The ITensor doesn't contain the batch size dim.
std::vector<int> ddim; std::vector<int> ddim;
...@@ -206,8 +214,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -206,8 +214,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto *fluid_t = fluid_v->GetMutable<framework::LoDTensor>(); auto *fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
fluid_t->Resize(framework::make_ddim(ddim)); fluid_t->Resize(framework::make_ddim(ddim));
const int bind_index =
engine->engine()->getBindingIndex(output_maps[output_index].c_str());
PADDLE_ENFORCE(bind_index < num_bindings, PADDLE_ENFORCE(bind_index < num_bindings,
"The bind index should be less than num_bindings"); "The bind index should be less than num_bindings");
buffers[bind_index] = static_cast<void *>(fluid_t->mutable_data<float>( buffers[bind_index] = static_cast<void *>(fluid_t->mutable_data<float>(
...@@ -224,16 +230,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -224,16 +230,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope, TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
if (trt_engine_.get() == nullptr) { if (trt_engine_ == nullptr) {
trt_engine_.reset(new TensorRTEngine(max_batch_size_, workspace_size_, trt_engine_ =
enable_int8_, calibrator_.get())); inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
if (true) { .Create(max_batch_size_, workspace_size_, enable_int8_,
PrepareTRTEngine(scope, trt_engine_.get()); calibrator_.get(), engine_key_);
} else { PrepareTRTEngine(scope, trt_engine_);
// create static engine
}
} }
return trt_engine_.get(); return trt_engine_;
} }
void PrepareTRTEngine(const framework::Scope &scope, void PrepareTRTEngine(const framework::Scope &scope,
......
...@@ -107,6 +107,7 @@ TEST(TensorRTEngineOp, manual) { ...@@ -107,6 +107,7 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetAttr("output_name_mapping", engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"})); std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
LOG(INFO) << "create engine op"; LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
...@@ -202,6 +203,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -202,6 +203,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("output_name_mapping", engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"})); std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册