提交 0ed63b21 编写于 作者: N nhzlx

6. delete useless predictor id

test=develop
上级 1d5ef7c9
......@@ -99,10 +99,6 @@ struct Argument {
private: \
unique_ptr_t field__##_;
// Each predictor has an unique id.
// For now, this attr will help us to get the right
// trt_engine for each trt_engine_op for each predictor when using trt.
DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
// Model path
DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
// Model specified with program and parameters files.
......
......@@ -81,7 +81,6 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set(
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
pass->Set("predictor_id", new int(argument->predictor_id()));
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
}
......
......@@ -209,9 +209,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "parameters", params);
auto enable_int8 = Get<bool>("enable_int8");
int predictor_id = Get<int>("predictor_id");
auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id,
std::to_string(predictor_id));
std::to_string(0));
// Get "" when there is no cached calibration table data.
std::string calibration_data = GetTrtCalibTableData(
......@@ -221,9 +220,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key);
SetAttr(op_desc->Proto(), "engine_serialized_data", std::string(""));
SetAttr(op_desc->Proto(), "engine_serialized_data_path",
GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
engine_key));
std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator;
if (enable_int8 && calibration_data.size() != 0) {
......@@ -239,13 +235,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::string trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key);
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<tensorrt::TRTEngineManager>::Global().Create(
Get<int>("max_batch_size"), Get<int>("workspace_size"), enable_int8,
calibrator.get(), engine_key, Get<int>("gpu_device_id"));
if (trt_engine_serialized_data.size() == 0) {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
std::unique_ptr<tensorrt::TensorRTEngine> trt_engine(
new tensorrt::TensorRTEngine(
Get<int>("max_batch_size"), Get<int>("workspace_size"),
enable_int8, calibrator.get(), Get<int>("gpu_device_id")));
auto *scope = param_scope();
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
std::unordered_set<std::string> param_set(params.begin(), params.end());
......@@ -253,7 +249,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
.ConvertBlockToTRTEngine(
&block_desc_temp, *scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine);
param_set, output_mapping, trt_engine.get());
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
trt_engine_serialized_data =
std::string((const char *)serialized_engine_data->data(),
......@@ -263,11 +259,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
engine_key),
trt_engine_serialized_data);
} else {
LOG(INFO) << "Load TRT Engine from optimized serialized data : "
LOG(INFO) << "Load TRT Optimized Info from "
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
trt_engine->Deserialize(trt_engine_serialized_data);
}
SetAttr(op_desc->Proto(), "engine_serialized_data",
trt_engine_serialized_data);
}
......
......@@ -342,7 +342,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
config_.static_memory_optim_force_update_);
argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program
argument_.SetPredictorID(predictor_id_);
if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir());
} else {
......
......@@ -44,9 +44,7 @@ using framework::NaiveExecutor;
*/
class AnalysisPredictor : public PaddlePredictor {
public:
explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {
predictor_id_ = inference::GetUniqueId();
}
explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {}
~AnalysisPredictor();
bool Init(const std::shared_ptr<framework::Scope> &parent_scope,
......@@ -146,7 +144,6 @@ class AnalysisPredictor : public PaddlePredictor {
const size_t max_shape_collect_count_{1000};
int need_collect_var_shapes_{-1}; // -1 for default, 0 for false, 1 for true.
std::vector<std::map<std::string, std::vector<int>>> batch_var_shapes_;
int predictor_id_;
private:
// Some status here that help to determine the status inside the predictor.
......
......@@ -199,43 +199,6 @@ class TensorRTEngine {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class TRTEngineManager {
public:
bool HasEngine(const std::string& name) const {
if (engines_.count(name) == 0) return false;
return engines_.at(name).get() != nullptr;
}
// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}
// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, bool enable_int8,
TRTInt8Calibrator* calibrator,
const std::string& engine_name, int device_id = 0) {
std::unique_lock<std::mutex> lk(mut_);
auto* p = new TensorRTEngine(max_batch, max_workspace, enable_int8,
calibrator, device_id);
engines_[engine_name].reset(p);
return p;
}
void DeleteALL() {
for (auto& item : engines_) {
item.second.reset(nullptr);
}
}
private:
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
std::mutex mut_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -31,7 +31,8 @@ namespace inference {
namespace tensorrt {
namespace plugin {
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
class PluginFactoryTensorRT : public nvinfer1::IPluginFactory,
public DeleteHelper {
public:
// Deserialization method
PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
......
......@@ -24,6 +24,13 @@ namespace inference {
namespace tensorrt {
namespace plugin {
// Some trt base classes lack of the destructor.
// We use a assisted class to fix this.
struct DeleteHelper {
protected:
virtual ~DeleteHelper() {}
};
template <typename T>
inline void SerializeValue(void** buffer, T const& value);
......
......@@ -41,7 +41,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
private:
std::vector<std::string> input_names_;
std::unordered_set<std::string> param_names_;
mutable TensorRTEngine *trt_engine_;
mutable std::unique_ptr<TensorRTEngine> trt_engine_;
int max_batch_size_;
int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_;
......@@ -64,7 +64,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key");
engine_serialized_data_ = Attr<std::string>("engine_serialized_data");
trt_engine_ = nullptr;
auto params = Attr<std::vector<std::string>>("parameters");
for (const auto &param : params) {
......@@ -78,16 +77,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
}
// we will create an engine here.
if (!calibration_mode_) {
if (inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.HasEngine(engine_key_)) {
trt_engine_ = inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_key_);
}
}
}
protected:
......@@ -231,15 +220,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const {
if (trt_engine_ == nullptr) {
trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(max_batch_size_, workspace_size_, enable_int8_,
calibrator_.get(), engine_key_,
boost::get<platform::CUDAPlace>(dev_place).device);
PrepareTRTEngine(scope, trt_engine_);
if (trt_engine_.get() == nullptr) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
boost::get<platform::CUDAPlace>(dev_place).device));
if (engine_serialized_data_.size() > 0) {
trt_engine_->Deserialize(engine_serialized_data_);
} else {
PrepareTRTEngine(scope, trt_engine_.get());
}
}
return trt_engine_;
return trt_engine_.get();
}
void PrepareTRTEngine(const framework::Scope &scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册