提交 92cf4a4c 编写于 作者: N nhzlx

fix comments

test=develop
上级 36abc964
......@@ -105,7 +105,6 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(model_path, ModelPath, std::string);
// The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <sys/stat.h>
#include <cstdio>
#include <fstream>
#include <set>
#include <string>
#include <typeindex>
#include <unordered_map>
......@@ -29,9 +30,14 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define GCC_ATTRIBUTE(attr__) ;
#define MKDIR(path) _mkdir(path)
#else
#include <unistd.h>
#define GCC_ATTRIBUTE(attr__) __attribute__((attr__));
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
#define __SHOULD_USE_RESULT__ GCC_ATTRIBUTE(warn_unused_result)
......@@ -163,7 +169,7 @@ static bool PathExists(const std::string &path) {
return false;
}
static std::string GetDirRoot(const std::string path) {
static std::string GetDirRoot(const std::string &path) {
char sep = '/';
#ifdef _WIN32
......@@ -177,11 +183,40 @@ static std::string GetDirRoot(const std::string path) {
return path;
}
static std::string GetOrCreateModelOptCacheDir(const std::string &model_root) {
std::string opt_cache_dir = model_root + "/_opt_cache/";
if (!PathExists(opt_cache_dir)) {
PADDLE_ENFORCE(MKDIR(opt_cache_dir.c_str()) != -1,
"Can not create optimize cache directory: %s, Make sure you "
"have permission to write",
opt_cache_dir);
}
return opt_cache_dir;
}
static std::string GetTrtCalibPath(const std::string &model_root,
const std::string &engine_key) {
return model_root + "/trt_calib_" + engine_key;
}
// If there is no calib table data file in model_opt_cache_dir, return "".
static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir,
const std::string &engine_key,
bool enable_int8) {
std::string trt_calib_table_path =
GetTrtCalibPath(model_opt_cache_dir, engine_key);
if (enable_int8 && FileExists(trt_calib_table_path)) {
VLOG(3) << "Calibration table file: " << trt_calib_table_path
<< "is found here";
std::ifstream infile(trt_calib_table_path, std::ios::in);
std::stringstream buffer;
buffer << infile.rdbuf();
std::string calibration_data(buffer.str());
return calibration_data;
}
return "";
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......
......@@ -72,14 +72,17 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *(
const_cast<framework::ProgramDesc *>(&argument->main_program())));
bool enable_int8 = false;
if (argument->tensorrt_precision_mode() ==
contrib::AnalysisConfig::Precision::kInt8) {
enable_int8 = true;
}
bool enable_int8 = argument->tensorrt_precision_mode() ==
contrib::AnalysisConfig::Precision::kInt8;
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("model_dir", new std::string(argument->model_path()));
std::string model_opt_cache_dir =
argument->Has("model_dir")
? argument->model_dir()
: GetDirRoot(argument->model_program_path());
pass->Set(
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
}
// graph_ = pass->Apply(std::move(graph_));
......
......@@ -29,6 +29,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle {
namespace inference {
......
......@@ -68,6 +68,19 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
return graph;
}
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
Graph *graph) const {
auto *op_desc = node->Op();
......@@ -97,7 +110,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
*op->Proto() = *node->Op()->Proto();
}
// collect inputs
// Then, we will use the input_names_with_id and output_names_with_id to
// generate the eigine key.
// So, We use set instead of unordered_set here to ensure that the engine key
// is unique.
std::set<std::string> input_names;
std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) {
......@@ -217,30 +233,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
auto enable_int8 = Get<bool>("enable_int8");
SetAttr(op_desc->Proto(), "calibration_data", std::string(""));
auto engine_key =
GenerateEngineKey(input_names_with_id, output_names_with_id);
// we use the subgraph's inputs and outputs to generate the engine key.
std::string engine_hash_key = "";
for (auto name : input_names_with_id) {
engine_hash_key += name;
}
for (auto name : output_names_with_id) {
engine_hash_key += name;
}
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
std::string calibration_data = GetTrtCalibTableData(
Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
auto trt_calib_file =
GetTrtCalibPath(Get<std::string>("model_dir"), engine_key);
VLOG(3) << "engine key: " << engine_key;
if (enable_int8 && FileExists(trt_calib_file)) {
VLOG(3) << "Calibration table file: " << trt_calib_file << "is found here";
std::ifstream infile(trt_calib_file, std::ios::in);
std::stringstream buffer;
buffer << infile.rdbuf();
std::string calibration_data(buffer.str());
SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
}
SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key);
}
......
......@@ -40,6 +40,7 @@
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif
DECLARE_bool(profile);
......@@ -341,7 +342,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// Analyze inference_program
if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir());
argument_.SetModelPath(config_.model_dir());
} else {
PADDLE_ENFORCE(
!config_.params_file().empty(),
......@@ -349,7 +349,6 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
PADDLE_ENFORCE(!config_.prog_file().empty());
std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
argument_.SetModelPath(dir);
argument_.SetModelProgramPath(config_.prog_file());
argument_.SetModelParamsPath(config_.params_file());
}
......@@ -599,7 +598,8 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name);
LOG(INFO) << "Wait for calib threads done.";
calib_engine->calib_->waitAndSetDone();
LOG(INFO) << "Finish wait.";
LOG(INFO) << "Generating TRT Calibration table data, this may cost a lot "
"of time...";
calib_engine->thr_->join();
std::string calibration_table_data =
calib_engine->calib_->getCalibrationTableAsString();
......@@ -609,9 +609,16 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
return false;
}
std::string model_opt_cache_dir =
argument_.Has("model_dir")
? argument_.model_dir()
: inference::analysis::GetDirRoot(argument_.model_program_path());
std::string calibration_table_data_path =
inference::analysis::GetTrtCalibPath(argument_.model_path(),
engine_name);
inference::analysis::GetTrtCalibPath(
inference::analysis::GetOrCreateModelOptCacheDir(
model_opt_cache_dir),
engine_name);
std::ofstream ofile(calibration_table_data_path, std::ios::out);
LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file "
......
......@@ -133,7 +133,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
// This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each
// tensor of the distribution of activation values.
LOG(INFO) << "Running calibration trt int8 ...";
LOG_FIRST_N(INFO, 1) << "The TRT engine: " << engine_key_
<< " is running calibration trt int8... ";
int runtime_batch = 1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册