提交 92cf4a4c 编写于 作者: N nhzlx

fix comments

test=develop
上级 36abc964
...@@ -105,7 +105,6 @@ struct Argument { ...@@ -105,7 +105,6 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string); DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(model_path, ModelPath, std::string);
// The overall graph to work on. // The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <sys/stat.h> #include <sys/stat.h>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <set>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include <unordered_map> #include <unordered_map>
...@@ -29,9 +30,14 @@ limitations under the License. */ ...@@ -29,9 +30,14 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#ifdef _WIN32 #ifdef _WIN32
#include <direct.h>
#include <io.h>
#define GCC_ATTRIBUTE(attr__) ; #define GCC_ATTRIBUTE(attr__) ;
#define MKDIR(path) _mkdir(path)
#else #else
#include <unistd.h>
#define GCC_ATTRIBUTE(attr__) __attribute__((attr__)); #define GCC_ATTRIBUTE(attr__) __attribute__((attr__));
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif #endif
#define __SHOULD_USE_RESULT__ GCC_ATTRIBUTE(warn_unused_result) #define __SHOULD_USE_RESULT__ GCC_ATTRIBUTE(warn_unused_result)
...@@ -163,7 +169,7 @@ static bool PathExists(const std::string &path) { ...@@ -163,7 +169,7 @@ static bool PathExists(const std::string &path) {
return false; return false;
} }
static std::string GetDirRoot(const std::string path) { static std::string GetDirRoot(const std::string &path) {
char sep = '/'; char sep = '/';
#ifdef _WIN32 #ifdef _WIN32
...@@ -177,11 +183,40 @@ static std::string GetDirRoot(const std::string path) { ...@@ -177,11 +183,40 @@ static std::string GetDirRoot(const std::string path) {
return path; return path;
} }
static std::string GetOrCreateModelOptCacheDir(const std::string &model_root) {
std::string opt_cache_dir = model_root + "/_opt_cache/";
if (!PathExists(opt_cache_dir)) {
PADDLE_ENFORCE(MKDIR(opt_cache_dir.c_str()) != -1,
"Can not create optimize cache directory: %s, Make sure you "
"have permission to write",
opt_cache_dir);
}
return opt_cache_dir;
}
static std::string GetTrtCalibPath(const std::string &model_root, static std::string GetTrtCalibPath(const std::string &model_root,
const std::string &engine_key) { const std::string &engine_key) {
return model_root + "/trt_calib_" + engine_key; return model_root + "/trt_calib_" + engine_key;
} }
// If there is no calib table data file in model_opt_cache_dir, return "".
static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir,
const std::string &engine_key,
bool enable_int8) {
std::string trt_calib_table_path =
GetTrtCalibPath(model_opt_cache_dir, engine_key);
if (enable_int8 && FileExists(trt_calib_table_path)) {
VLOG(3) << "Calibration table file: " << trt_calib_table_path
<< "is found here";
std::ifstream infile(trt_calib_table_path, std::ios::in);
std::stringstream buffer;
buffer << infile.rdbuf();
std::string calibration_data(buffer.str());
return calibration_data;
}
return "";
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
...@@ -72,14 +72,17 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -72,14 +72,17 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *( new framework::ProgramDesc *(
const_cast<framework::ProgramDesc *>(&argument->main_program()))); const_cast<framework::ProgramDesc *>(&argument->main_program())));
bool enable_int8 = false; bool enable_int8 = argument->tensorrt_precision_mode() ==
if (argument->tensorrt_precision_mode() == contrib::AnalysisConfig::Precision::kInt8;
contrib::AnalysisConfig::Precision::kInt8) {
enable_int8 = true;
}
pass->Set("enable_int8", new bool(enable_int8)); pass->Set("enable_int8", new bool(enable_int8));
pass->Set("model_dir", new std::string(argument->model_path())); std::string model_opt_cache_dir =
argument->Has("model_dir")
? argument->model_dir()
: GetDirRoot(argument->model_program_path());
pass->Set(
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
} }
// graph_ = pass->Apply(std::move(graph_)); // graph_ = pass->Apply(std::move(graph_));
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -68,6 +68,19 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -68,6 +68,19 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
return graph; return graph;
} }
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
Graph *graph) const { Graph *graph) const {
auto *op_desc = node->Op(); auto *op_desc = node->Op();
...@@ -97,7 +110,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -97,7 +110,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
*op->Proto() = *node->Op()->Proto(); *op->Proto() = *node->Op()->Proto();
} }
// collect inputs // Then, we will use the input_names_with_id and output_names_with_id to
// generate the eigine key.
// So, We use set instead of unordered_set here to ensure that the engine key
// is unique.
std::set<std::string> input_names; std::set<std::string> input_names;
std::set<std::string> input_names_with_id; std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) { for (auto *x : node->inputs) {
...@@ -217,30 +233,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -217,30 +233,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
auto enable_int8 = Get<bool>("enable_int8"); auto enable_int8 = Get<bool>("enable_int8");
SetAttr(op_desc->Proto(), "calibration_data", std::string("")); auto engine_key =
GenerateEngineKey(input_names_with_id, output_names_with_id);
// we use the subgraph's inputs and outputs to generate the engine key. std::string calibration_data = GetTrtCalibTableData(
std::string engine_hash_key = ""; Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
for (auto name : input_names_with_id) { SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
engine_hash_key += name;
}
for (auto name : output_names_with_id) {
engine_hash_key += name;
}
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
auto trt_calib_file =
GetTrtCalibPath(Get<std::string>("model_dir"), engine_key);
VLOG(3) << "engine key: " << engine_key;
if (enable_int8 && FileExists(trt_calib_file)) {
VLOG(3) << "Calibration table file: " << trt_calib_file << "is found here";
std::ifstream infile(trt_calib_file, std::ios::in);
std::stringstream buffer;
buffer << infile.rdbuf();
std::string calibration_data(buffer.str());
SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
}
SetAttr(op_desc->Proto(), "enable_int8", enable_int8); SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
} }
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif #endif
DECLARE_bool(profile); DECLARE_bool(profile);
...@@ -341,7 +342,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -341,7 +342,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// Analyze inference_program // Analyze inference_program
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir()); argument_.SetModelDir(config_.model_dir());
argument_.SetModelPath(config_.model_dir());
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
!config_.params_file().empty(), !config_.params_file().empty(),
...@@ -349,7 +349,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -349,7 +349,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
PADDLE_ENFORCE(!config_.prog_file().empty()); PADDLE_ENFORCE(!config_.prog_file().empty());
std::string dir = inference::analysis::GetDirRoot(config_.prog_file()); std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
argument_.SetModelPath(dir);
argument_.SetModelProgramPath(config_.prog_file()); argument_.SetModelProgramPath(config_.prog_file());
argument_.SetModelParamsPath(config_.params_file()); argument_.SetModelParamsPath(config_.params_file());
} }
...@@ -599,7 +598,8 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() { ...@@ -599,7 +598,8 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name); Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name);
LOG(INFO) << "Wait for calib threads done."; LOG(INFO) << "Wait for calib threads done.";
calib_engine->calib_->waitAndSetDone(); calib_engine->calib_->waitAndSetDone();
LOG(INFO) << "Finish wait."; LOG(INFO) << "Generating TRT Calibration table data, this may cost a lot "
"of time...";
calib_engine->thr_->join(); calib_engine->thr_->join();
std::string calibration_table_data = std::string calibration_table_data =
calib_engine->calib_->getCalibrationTableAsString(); calib_engine->calib_->getCalibrationTableAsString();
...@@ -609,9 +609,16 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() { ...@@ -609,9 +609,16 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
return false; return false;
} }
std::string model_opt_cache_dir =
argument_.Has("model_dir")
? argument_.model_dir()
: inference::analysis::GetDirRoot(argument_.model_program_path());
std::string calibration_table_data_path = std::string calibration_table_data_path =
inference::analysis::GetTrtCalibPath(argument_.model_path(), inference::analysis::GetTrtCalibPath(
engine_name); inference::analysis::GetOrCreateModelOptCacheDir(
model_opt_cache_dir),
engine_name);
std::ofstream ofile(calibration_table_data_path, std::ios::out); std::ofstream ofile(calibration_table_data_path, std::ios::out);
LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file " LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file "
......
...@@ -133,7 +133,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -133,7 +133,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
// This process will builds a 32-bit trt engine, runs it on the calibration // This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each // set, and records a histogram for each
// tensor of the distribution of activation values. // tensor of the distribution of activation values.
LOG(INFO) << "Running calibration trt int8 ..."; LOG_FIRST_N(INFO, 1) << "The TRT engine: " << engine_key_
<< " is running calibration trt int8... ";
int runtime_batch = 1; int runtime_batch = 1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册