提交 5863c861 编写于 作者: N nhzlx 提交者: ceci3

6. delete useless predictor id

test=develop
上级 f3d164fa
...@@ -99,10 +99,6 @@ struct Argument { ...@@ -99,10 +99,6 @@ struct Argument {
private: \ private: \
unique_ptr_t field__##_; unique_ptr_t field__##_;
// Each predictor has an unique id.
// For now, this attr will help us to get the right
// trt_engine for each trt_engine_op for each predictor when using trt.
DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
// Model path // Model path
DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string); DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
// Model specified with program and parameters files. // Model specified with program and parameters files.
......
...@@ -81,7 +81,6 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -81,7 +81,6 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set( pass->Set(
"model_opt_cache_dir", "model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
pass->Set("predictor_id", new int(argument->predictor_id()));
pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
} }
......
...@@ -209,9 +209,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -209,9 +209,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "parameters", params); SetAttr(op_desc->Proto(), "parameters", params);
auto enable_int8 = Get<bool>("enable_int8"); auto enable_int8 = Get<bool>("enable_int8");
int predictor_id = Get<int>("predictor_id");
auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id, auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id,
std::to_string(predictor_id)); std::to_string(0));
// Get "" when there is no cached calibration table data. // Get "" when there is no cached calibration table data.
std::string calibration_data = GetTrtCalibTableData( std::string calibration_data = GetTrtCalibTableData(
...@@ -221,9 +220,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -221,9 +220,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "enable_int8", enable_int8); SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
SetAttr(op_desc->Proto(), "engine_serialized_data", std::string("")); SetAttr(op_desc->Proto(), "engine_serialized_data", std::string(""));
SetAttr(op_desc->Proto(), "engine_serialized_data_path",
GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
engine_key));
std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator; std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator;
if (enable_int8 && calibration_data.size() != 0) { if (enable_int8 && calibration_data.size() != 0) {
...@@ -239,13 +235,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -239,13 +235,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::string trt_engine_serialized_data = GetTrtEngineSerializedData( std::string trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key); Get<std::string>("model_opt_cache_dir"), engine_key);
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<tensorrt::TRTEngineManager>::Global().Create(
Get<int>("max_batch_size"), Get<int>("workspace_size"), enable_int8,
calibrator.get(), engine_key, Get<int>("gpu_device_id"));
if (trt_engine_serialized_data.size() == 0) { if (trt_engine_serialized_data.size() == 0) {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP " LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time."; "kernel etc). This process may cost a lot of time.";
std::unique_ptr<tensorrt::TensorRTEngine> trt_engine(
new tensorrt::TensorRTEngine(
Get<int>("max_batch_size"), Get<int>("workspace_size"),
enable_int8, calibrator.get(), Get<int>("gpu_device_id")));
auto *scope = param_scope(); auto *scope = param_scope();
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto()); framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
std::unordered_set<std::string> param_set(params.begin(), params.end()); std::unordered_set<std::string> param_set(params.begin(), params.end());
...@@ -253,7 +249,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -253,7 +249,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
.ConvertBlockToTRTEngine( .ConvertBlockToTRTEngine(
&block_desc_temp, *scope, &block_desc_temp, *scope,
std::vector<std::string>(input_names.begin(), input_names.end()), std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine); param_set, output_mapping, trt_engine.get());
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize(); nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
trt_engine_serialized_data = trt_engine_serialized_data =
std::string((const char *)serialized_engine_data->data(), std::string((const char *)serialized_engine_data->data(),
...@@ -263,11 +259,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -263,11 +259,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
engine_key), engine_key),
trt_engine_serialized_data); trt_engine_serialized_data);
} else { } else {
LOG(INFO) << "Load TRT Engine from optimized serialized data : " LOG(INFO) << "Load TRT Optimized Info from "
<< GetTrtEngineSerializedPath( << GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key); Get<std::string>("model_opt_cache_dir"), engine_key);
trt_engine->Deserialize(trt_engine_serialized_data);
} }
SetAttr(op_desc->Proto(), "engine_serialized_data", SetAttr(op_desc->Proto(), "engine_serialized_data",
trt_engine_serialized_data); trt_engine_serialized_data);
} }
......
...@@ -345,7 +345,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -345,7 +345,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
config_.static_memory_optim_force_update_); config_.static_memory_optim_force_update_);
argument_.SetModelFromMemory(config_.model_from_memory_); argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
argument_.SetPredictorID(predictor_id_);
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir()); argument_.SetModelDir(config_.model_dir());
} else { } else {
......
...@@ -44,9 +44,7 @@ using framework::NaiveExecutor; ...@@ -44,9 +44,7 @@ using framework::NaiveExecutor;
*/ */
class AnalysisPredictor : public PaddlePredictor { class AnalysisPredictor : public PaddlePredictor {
public: public:
explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) { explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {}
predictor_id_ = inference::GetUniqueId();
}
~AnalysisPredictor(); ~AnalysisPredictor();
bool Init(const std::shared_ptr<framework::Scope> &parent_scope, bool Init(const std::shared_ptr<framework::Scope> &parent_scope,
...@@ -146,7 +144,6 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -146,7 +144,6 @@ class AnalysisPredictor : public PaddlePredictor {
const size_t max_shape_collect_count_{1000}; const size_t max_shape_collect_count_{1000};
int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true. int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true.
std::vector<std::map<std::string, std::vector<int>>> batch_var_shapes_; std::vector<std::map<std::string, std::vector<int>>> batch_var_shapes_;
int predictor_id_;
private: private:
// Some status here that help to determine the status inside the predictor. // Some status here that help to determine the status inside the predictor.
......
...@@ -199,43 +199,6 @@ class TensorRTEngine { ...@@ -199,43 +199,6 @@ class TensorRTEngine {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS); engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class TRTEngineManager {
public:
bool HasEngine(const std::string& name) const {
if (engines_.count(name) == 0) return false;
return engines_.at(name).get() != nullptr;
}
// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}
// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, bool enable_int8,
TRTInt8Calibrator* calibrator,
const std::string& engine_name, int device_id = 0) {
std::unique_lock<std::mutex> lk(mut_);
auto* p = new TensorRTEngine(max_batch, max_workspace, enable_int8,
calibrator, device_id);
engines_[engine_name].reset(p);
return p;
}
void DeleteALL() {
for (auto& item : engines_) {
item.second.reset(nullptr);
}
}
private:
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
std::mutex mut_;
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,8 @@ namespace inference { ...@@ -31,7 +31,8 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory,
public DeleteHelper {
public: public:
// Deserialization method // Deserialization method
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
......
...@@ -24,6 +24,13 @@ namespace inference { ...@@ -24,6 +24,13 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
// Some trt base classes lack of the destructor.
// We use a assisted class to fix this.
struct DeleteHelper {
protected:
virtual ~DeleteHelper() {}
};
template <typename T> template <typename T>
inline void SerializeValue(void** buffer, T const& value); inline void SerializeValue(void** buffer, T const& value);
......
...@@ -41,7 +41,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -41,7 +41,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
private: private:
std::vector<std::string> input_names_; std::vector<std::string> input_names_;
std::unordered_set<std::string> param_names_; std::unordered_set<std::string> param_names_;
mutable TensorRTEngine *trt_engine_; mutable std::unique_ptr<TensorRTEngine> trt_engine_;
int max_batch_size_; int max_batch_size_;
int workspace_size_; int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_; std::unique_ptr<TRTInt8Calibrator> calibrator_;
...@@ -64,7 +64,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -64,7 +64,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
calibration_data_ = Attr<std::string>("calibration_data"); calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key"); engine_key_ = Attr<std::string>("engine_key");
engine_serialized_data_ = Attr<std::string>("engine_serialized_data"); engine_serialized_data_ = Attr<std::string>("engine_serialized_data");
trt_engine_ = nullptr;
auto params = Attr<std::vector<std::string>>("parameters"); auto params = Attr<std::vector<std::string>>("parameters");
for (const auto &param : params) { for (const auto &param : params) {
...@@ -78,16 +77,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -78,16 +77,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (enable_int8_ && calibration_data_.size()) { if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
} }
// we will create an engine here.
if (!calibration_mode_) {
if (inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.HasEngine(engine_key_)) {
trt_engine_ = inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_key_);
}
}
} }
protected: protected:
...@@ -231,15 +220,17 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -231,15 +220,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope, TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
if (trt_engine_ == nullptr) { if (trt_engine_.get() == nullptr) {
trt_engine_ = trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global() max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
.Create(max_batch_size_, workspace_size_, enable_int8_, boost::get<platform::CUDAPlace>(dev_place).device));
calibrator_.get(), engine_key_, if (engine_serialized_data_.size() > 0) {
boost::get<platform::CUDAPlace>(dev_place).device); trt_engine_->Deserialize(engine_serialized_data_);
PrepareTRTEngine(scope, trt_engine_); } else {
PrepareTRTEngine(scope, trt_engine_.get());
}
} }
return trt_engine_; return trt_engine_.get();
} }
void PrepareTRTEngine(const framework::Scope &scope, void PrepareTRTEngine(const framework::Scope &scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册