提交 312fe0ec 编写于 作者: N nhzlx

add trt int8 calibration support

fix comments

test=develop
上级 c1264e99
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
...@@ -128,7 +129,7 @@ struct Argument { ...@@ -128,7 +129,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int); DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int); DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode,
std::string); contrib::AnalysisConfig::Precision);
// The program transformed by IR analysis phase. // The program transformed by IR analysis phase.
DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram, DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram,
......
...@@ -36,6 +36,14 @@ void SetAttr<int>(framework::proto::OpDesc *op, const std::string &name, ...@@ -36,6 +36,14 @@ void SetAttr<int>(framework::proto::OpDesc *op, const std::string &name,
attr->set_i(data); attr->set_i(data);
} }
template <> template <>
void SetAttr<bool>(framework::proto::OpDesc *op, const std::string &name,
const bool &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(data);
}
template <>
void SetAttr<int64_t>(framework::proto::OpDesc *op, const std::string &name, void SetAttr<int64_t>(framework::proto::OpDesc *op, const std::string &name,
const int64_t &data) { const int64_t &data) {
auto *attr = op->add_attrs(); auto *attr = op->add_attrs();
......
...@@ -156,7 +156,7 @@ static bool PathExists(const std::string &path) { ...@@ -156,7 +156,7 @@ static bool PathExists(const std::string &path) {
return false; return false;
} }
static std::string SplitPath(const std::string path) { static std::string GetDirRoot(const std::string path) {
char sep = '/'; char sep = '/';
#ifdef _WIN32 #ifdef _WIN32
...@@ -167,10 +167,14 @@ static std::string SplitPath(const std::string path) { ...@@ -167,10 +167,14 @@ static std::string SplitPath(const std::string path) {
if (i != std::string::npos) { if (i != std::string::npos) {
return (path.substr(0, i)); return (path.substr(0, i));
} }
return path; return path;
} }
static std::string GetTrtCalibPath(const std::string &model_root,
const std::string &engine_key) {
return model_root + "/trt_calib_" + engine_key;
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
...@@ -71,13 +71,17 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -71,13 +71,17 @@ void IRPassManager::CreatePasses(Argument *argument,
"program", "program",
new framework::ProgramDesc *( new framework::ProgramDesc *(
const_cast<framework::ProgramDesc *>(&argument->main_program()))); const_cast<framework::ProgramDesc *>(&argument->main_program())));
pass->Set("precision_mode",
new std::string(argument->tensorrt_precision_mode())); bool enable_int8 = false;
if (argument->tensorrt_precision_mode() ==
contrib::AnalysisConfig::Precision::kInt8)
enable_int8 = true;
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("model_dir", new std::string(argument->model_path())); pass->Set("model_dir", new std::string(argument->model_path()));
} }
// graph_ = pass->Apply(std::move(graph_)); // graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name; pre_pass = pass_name;
passes_.emplace_back(std::move(pass)); passes_.emplace_back(std::move(pass));
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -93,8 +94,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -93,8 +94,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
} }
// collect inputs // collect inputs
std::unordered_set<std::string> input_names; std::set<std::string> input_names;
std::unordered_set<std::string> input_names_with_id; std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) { for (auto *x : node->inputs) {
input_names.insert(x->Name()); input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id())); input_names_with_id.insert(x->Name() + std::to_string(x->id()));
...@@ -102,8 +103,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -102,8 +103,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
op_desc->SetInput( op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end())); "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::unordered_set<std::string> output_names; std::set<std::string> output_names;
std::unordered_set<std::string> output_names_with_id; std::set<std::string> output_names_with_id;
for (auto *x : node->outputs) { for (auto *x : node->outputs) {
output_names.insert(x->Name()); output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id())); output_names_with_id.insert(x->Name() + std::to_string(x->id()));
...@@ -203,28 +204,40 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -203,28 +204,40 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
"the block has no var-desc"); "the block has no var-desc");
PADDLE_ENFORCE(!output_mapping.empty()); PADDLE_ENFORCE(!output_mapping.empty());
op_desc->SetBlockAttr("sub_block", new_block); op_desc->SetBlockAttr("sub_block", new_block);
// Set attrs
SetAttr(op_desc->Proto(), "subgraph", SetAttr(op_desc->Proto(), "subgraph",
block_desc.Proto()->SerializeAsString()); block_desc.Proto()->SerializeAsString());
// Set attrs
SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size")); SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size")); SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes())); SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes()));
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
std::string engine_key = std::to_string( auto enable_int8 = Get<bool>("enable_int8");
std::hash<std::string>()(block_desc.Proto()->SerializeAsString()));
std::string precision_mode = Get<std::string>("precision_mode");
SetAttr(op_desc->Proto(), "calibration_data", std::string("")); SetAttr(op_desc->Proto(), "calibration_data", std::string(""));
std::string trt_calib_file =
Get<std::string>("model_dir") + "/trt_calib_" + engine_key; // we use the subgraph's inputs and outputs to generate the engine key.
if (precision_mode == "INT8" && FileExists(trt_calib_file)) { std::string engine_hash_key = "";
for (auto name : input_names_with_id) {
engine_hash_key += name;
}
for (auto name : output_names_with_id) {
engine_hash_key += name;
}
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
auto trt_calib_file =
GetTrtCalibPath(Get<std::string>("model_dir"), engine_key);
VLOG(3) << "engine key: " << engine_key;
if (enable_int8 && FileExists(trt_calib_file)) {
VLOG(3) << "Calibration table file: " << trt_calib_file << "is found here";
std::ifstream infile(trt_calib_file, std::ios::in); std::ifstream infile(trt_calib_file, std::ios::in);
std::stringstream buffer; std::stringstream buffer;
buffer << infile.rdbuf(); buffer << infile.rdbuf();
std::string calibration_data(buffer.str()); std::string calibration_data(buffer.str());
SetAttr(op_desc->Proto(), "calibration_data", calibration_data); SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
} }
SetAttr(op_desc->Proto(), "precision_mode", precision_mode); SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key); SetAttr(op_desc->Proto(), "engine_key", engine_key);
} }
......
...@@ -122,13 +122,13 @@ void contrib::AnalysisConfig::EnableMKLDNN() { ...@@ -122,13 +122,13 @@ void contrib::AnalysisConfig::EnableMKLDNN() {
#endif #endif
} }
void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, void contrib::AnalysisConfig::EnableTensorRtEngine(
int max_batch_size, int workspace_size, int max_batch_size, int min_subgraph_size,
int min_subgraph_size, contrib::AnalysisConfig::Precision precision_mode) {
std::string precision_mode) {
use_tensorrt_ = true; use_tensorrt_ = true;
tensorrt_workspace_size_ = workspace_size; tensorrt_workspace_size_ = workspace_size;
tensorrt_max_batchsize_ = max_batch_size; tensorrt_max_batchsize_ = max_batch_size;
tensorrt_min_subgraph_size_ = min_subgraph_size;
tensorrt_precision_mode_ = precision_mode; tensorrt_precision_mode_ = precision_mode;
Update(); Update();
} }
...@@ -149,7 +149,7 @@ void contrib::AnalysisConfig::Update() { ...@@ -149,7 +149,7 @@ void contrib::AnalysisConfig::Update() {
<< "TensorRT engine is not available when EnableGpu() not actived."; << "TensorRT engine is not available when EnableGpu() not actived.";
} else { } else {
// Append after the infer_clean pass. // Append after the infer_clean pass.
pass_builder()->InsertPass(1, "tensorrt_subgraph_pass"); pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
} }
} }
...@@ -180,7 +180,7 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() { ...@@ -180,7 +180,7 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() {
ss << use_tensorrt_; ss << use_tensorrt_;
ss << tensorrt_workspace_size_; ss << tensorrt_workspace_size_;
ss << tensorrt_max_batchsize_; ss << tensorrt_max_batchsize_;
ss << tensorrt_precision_mode_; ss << tensorrt_min_subgraph_size_;
ss << use_mkldnn_; ss << use_mkldnn_;
ss << enable_ir_optim_; ss << enable_ir_optim_;
......
...@@ -30,9 +30,9 @@ ...@@ -30,9 +30,9 @@
#include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h"
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif #endif
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
...@@ -46,8 +46,8 @@ namespace paddle { ...@@ -46,8 +46,8 @@ namespace paddle {
using contrib::AnalysisConfig; using contrib::AnalysisConfig;
using inference::Singleton; using inference::Singleton;
using inference::tensorrt::TRTInt8Calibrator; using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorRes; using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorResManager; using inference::tensorrt::TRTCalibratorEngineManager;
namespace { namespace {
bool IsPersistable(const framework::VarDesc *var) { bool IsPersistable(const framework::VarDesc *var) {
...@@ -334,7 +334,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -334,7 +334,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
!config_.params_file().empty(), !config_.params_file().empty(),
"Either model_dir or (param_file, prog_file) should be set."); "Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file().empty()); PADDLE_ENFORCE(!config_.prog_file().empty());
std::string dir = inference::analysis::SplitPath(config_.prog_file()); std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
argument_.SetModelPath(dir); argument_.SetModelPath(dir);
argument_.SetModelProgramPath(config_.prog_file()); argument_.SetModelProgramPath(config_.prog_file());
...@@ -562,6 +562,7 @@ bool AnalysisPredictor::LoadParameters() { ...@@ -562,6 +562,7 @@ bool AnalysisPredictor::LoadParameters() {
return true; return true;
} }
#if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() { bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE(config_.tensorrt_engine_enabled(), PADDLE_ENFORCE(config_.tensorrt_engine_enabled(),
"This func can be invoked only in trt mode"); "This func can be invoked only in trt mode");
...@@ -570,44 +571,50 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() { ...@@ -570,44 +571,50 @@ bool AnalysisPredictor::SaveTrtCalibToDisk() {
if (op_desc->Type() == "tensorrt_engine") { if (op_desc->Type() == "tensorrt_engine") {
std::string engine_name = std::string engine_name =
boost::get<std::string>(op_desc->GetAttr("engine_key")); boost::get<std::string>(op_desc->GetAttr("engine_key"));
if (!Singleton<TRTCalibratorResManager>::Global().Has(engine_name)) { if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) {
LOG(ERROR) << "You should run the predictor(with trt) on the real data " LOG(ERROR) << "You should run the predictor(with trt) on the real data "
"to generate calibration info"; "to generate calibration info";
return false; return false;
} }
TRTCalibratorRes *calib_res = TRTCalibratorEngine *calib_engine =
Singleton<TRTCalibratorResManager>::Global().Get(engine_name); Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name);
LOG(INFO) << "Wait for calib threads done."; LOG(INFO) << "Wait for calib threads done.";
calib_res->calib_->waitAndSetDone(); calib_engine->calib_->waitAndSetDone();
LOG(INFO) << "Finish wait."; LOG(INFO) << "Finish wait.";
calib_res->thr_->join(); calib_engine->thr_->join();
std::string calibration_data = std::string calibration_table_data =
calib_res->calib_->getCalibrationTableAsString(); calib_engine->calib_->getCalibrationTableAsString();
if (calibration_data.size() == 0) { if (calibration_table_data.empty()) {
LOG(ERROR) << "the calibration table is empty."; LOG(ERROR) << "the calibration table is empty.";
return false; return false;
} }
std::string calibration_data_path =
argument_.model_path() + "/trt_calib_" + engine_name; std::string calibration_table_data_path =
std::ofstream ofile(calibration_data_path, std::ios::out); inference::analysis::GetTrtCalibPath(argument_.model_path(),
LOG(INFO) << "Write Paddle-TRT INT8 calibration data to file " engine_name);
<< calibration_data_path;
ofile << calibration_data; std::ofstream ofile(calibration_table_data_path, std::ios::out);
LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file "
<< calibration_table_data_path;
ofile << calibration_table_data;
ofile.close(); ofile.close();
} }
} }
// Free all calibrator resources. // Free all calibrator resources.
Singleton<TRTCalibratorResManager>::Global().DeleteALL(); Singleton<TRTCalibratorEngineManager>::Global().DeleteALL();
return true; return true;
} }
#endif
AnalysisPredictor::~AnalysisPredictor() { AnalysisPredictor::~AnalysisPredictor() {
#if PADDLE_WITH_TENSORRT
if (config_.tensorrt_engine_enabled() && if (config_.tensorrt_engine_enabled() &&
config_.tensorrt_precision_mode_ == "INT8" && config_.tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
Singleton<TRTCalibratorResManager>::Global().Has()) { Singleton<TRTCalibratorEngineManager>::Global().Has()) {
SaveTrtCalibToDisk(); SaveTrtCalibToDisk();
} }
#endif
if (FLAGS_profile) { if (FLAGS_profile) {
platform::DisableProfiler(platform::EventSortingKey::kTotal, platform::DisableProfiler(platform::EventSortingKey::kTotal,
"./profile.log"); "./profile.log");
......
...@@ -91,7 +91,20 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -91,7 +91,20 @@ class AnalysisPredictor : public PaddlePredictor {
void GetFetchOne(const framework::LoDTensor &fetchs, void GetFetchOne(const framework::LoDTensor &fetchs,
PaddleTensor *output_data); PaddleTensor *output_data);
#if PADDLE_WITH_TENSORRT
// When we use Paddle-TRT INT8 engine, we need to generate calibration table
// data first,
// the calibration table contains the range for each op's input and output,
// this whole process can be divided into several steps:
//
// 1. Builds a 32-bit engine, runs it on the calibration set, and records a
// histogram for each
// tensor of the distribution of activation values.
// 2. Builds a calibration table from the histograms.
//
// After step 2, we need to store the calibration table on disk
bool SaveTrtCalibToDisk(); bool SaveTrtCalibToDisk();
#endif
~AnalysisPredictor(); ~AnalysisPredictor();
......
...@@ -42,6 +42,10 @@ struct AnalysisConfig { ...@@ -42,6 +42,10 @@ struct AnalysisConfig {
explicit AnalysisConfig(const std::string& model_dir); explicit AnalysisConfig(const std::string& model_dir);
explicit AnalysisConfig(const std::string& prog_file, explicit AnalysisConfig(const std::string& prog_file,
const std::string& params_file); const std::string& params_file);
enum class Precision {
kFloat32 = 0,
kInt8,
};
/** Set model with a directory. /** Set model with a directory.
*/ */
...@@ -136,7 +140,7 @@ struct AnalysisConfig { ...@@ -136,7 +140,7 @@ struct AnalysisConfig {
*/ */
void EnableTensorRtEngine(int workspace_size = 1 << 20, void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1, int min_subgraph_size = 3, int max_batch_size = 1, int min_subgraph_size = 3,
std::string precision = "FP32"); Precision precision = Precision::kFloat32);
/** A boolean state telling whether the TensorRT engine is used. /** A boolean state telling whether the TensorRT engine is used.
*/ */
bool tensorrt_engine_enabled() const { return use_tensorrt_; } bool tensorrt_engine_enabled() const { return use_tensorrt_; }
...@@ -232,7 +236,7 @@ struct AnalysisConfig { ...@@ -232,7 +236,7 @@ struct AnalysisConfig {
// We set this variable to control the minimum number of nodes in the // We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value. // subgraph, 3 as default value.
int tensorrt_min_subgraph_size_{3}; int tensorrt_min_subgraph_size_{3};
std::string tensorrt_precision_mode_; Precision tensorrt_precision_mode_;
bool use_mkldnn_{false}; bool use_mkldnn_{false};
std::unordered_set<std::string> mkldnn_enabled_op_types_; std::unordered_set<std::string> mkldnn_enabled_op_types_;
......
...@@ -70,7 +70,7 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -70,7 +70,7 @@ void TensorRTEngine::FreezeNetwork() {
// build engine. // build engine.
infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_); infer_builder_->setMaxWorkspaceSize(max_workspace_);
if (precision_mode_ == "INT8") { if (enable_int8_) {
infer_builder_->setInt8Mode(true); infer_builder_->setInt8Mode(true);
PADDLE_ENFORCE( PADDLE_ENFORCE(
calibrator_ != nullptr, calibrator_ != nullptr,
......
...@@ -58,14 +58,14 @@ class TensorRTEngine : public EngineBase { ...@@ -58,14 +58,14 @@ class TensorRTEngine : public EngineBase {
TensorRTEngine(int max_batch, int max_workspace, TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr, int device = 0, cudaStream_t* stream = nullptr, int device = 0,
std::string precision_mode = "FP32", bool enable_int8 = "false",
TRTInt8Calibrator* calibrator = nullptr, TRTInt8Calibrator* calibrator = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global()) nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch), : max_batch_(max_batch),
max_workspace_(max_workspace), max_workspace_(max_workspace),
stream_(stream ? stream : &default_stream_), stream_(stream ? stream : &default_stream_),
device_(device), device_(device),
precision_mode_(precision_mode), enable_int8_(enable_int8),
calibrator_(calibrator), calibrator_(calibrator),
logger_(logger) { logger_(logger) {
freshDeviceId(); freshDeviceId();
...@@ -168,7 +168,7 @@ class TensorRTEngine : public EngineBase { ...@@ -168,7 +168,7 @@ class TensorRTEngine : public EngineBase {
// The specific GPU id that the TensorRTEngine bounded to. // The specific GPU id that the TensorRTEngine bounded to.
int device_; int device_;
std::string precision_mode_; bool enable_int8_;
TRTInt8Calibrator* calibrator_; TRTInt8Calibrator* calibrator_;
// batch size of the current data, will be updated each Executation. // batch size of the current data, will be updated each Executation.
int batch_size_{-1}; int batch_size_{-1};
......
...@@ -25,11 +25,7 @@ int TRTInt8Calibrator::getBatchSize() const { return batch_size_; } ...@@ -25,11 +25,7 @@ int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
TRTInt8Calibrator::TRTInt8Calibrator( TRTInt8Calibrator::TRTInt8Calibrator(
const std::unordered_map<std::string, size_t>& buffers, int batch_size, const std::unordered_map<std::string, size_t>& buffers, int batch_size,
std::string engine_name, const platform::Place place) std::string engine_name, const platform::Place place)
: batch_size_(batch_size), : batch_size_(batch_size), engine_name_(engine_name) {
calib_running_(true),
data_is_set_(false),
done_(false),
engine_name_(engine_name) {
int i = 0; int i = 0;
VLOG(4) << "Init a new calibrator: " << engine_name_; VLOG(4) << "Init a new calibrator: " << engine_name_;
for (const auto it : buffers) { for (const auto it : buffers) {
...@@ -62,28 +58,32 @@ void TRTInt8Calibrator::waitAndSetDone() { ...@@ -62,28 +58,32 @@ void TRTInt8Calibrator::waitAndSetDone() {
} }
} }
// There might be more than one input for trt subgraph,
// So, we use a map to store input information.
bool TRTInt8Calibrator::setBatch( bool TRTInt8Calibrator::setBatch(
const std::unordered_map<std::string, void*>& data) { const std::unordered_map<std::string, void*>& data) {
VLOG(3) << "set batch: " << engine_name_; VLOG(3) << "set batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_); std::unique_lock<std::mutex> lk(mut_);
// There is a producer and a consumer. The producer set the batch data and
// the consumer get the batch data. The size of the data pool is one.
// So, the producer has to wait for the consumer to finish processing before
// they can set the data.
while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk); while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
// The done_ is set to true using waitAndSetDone, When all calibration data
// are processed.
if (done_) return false; if (done_) return false;
// Sets the batch. // Sets the batch.
for (const auto it : data) { for (const auto& it : data) {
auto dataptr = data_buffers_.find(it.first); auto dataptr = data_buffers_.find(it.first);
if (dataptr == data_buffers_.end()) { if (dataptr == data_buffers_.end()) {
LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
<< "' does not match with the buffer names"; << "' does not match with the buffer names";
} }
const auto& d = dataptr->second; const auto& d = dataptr->second;
auto status = PADDLE_ENFORCE(
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice); cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice),
if (status != cudaSuccess) { "Fail to cudaMemcpy %s for %s", engine_name_, it.first);
LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
<< "' failed with " << status;
}
} }
data_is_set_ = true; data_is_set_ = true;
...@@ -95,9 +95,12 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, ...@@ -95,9 +95,12 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
int num_bindings) { int num_bindings) {
VLOG(4) << "get batch: " << engine_name_; VLOG(4) << "get batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_); std::unique_lock<std::mutex> lk(mut_);
// The consumer has just finished processing a data.
// The producer can set the data again.
calib_running_ = false; calib_running_ = false;
cond_.notify_all(); cond_.notify_all();
// As long as there is data in the pool, the consumer can get it.
while (!data_is_set_ && !done_) cond_.wait(lk); while (!data_is_set_ && !done_) cond_.wait(lk);
if (done_) return false; if (done_) return false;
...@@ -123,7 +126,7 @@ void TRTInt8Calibrator::setDone() { ...@@ -123,7 +126,7 @@ void TRTInt8Calibrator::setDone() {
cond_.notify_all(); cond_.notify_all();
} }
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { const void* TRTInt8Calibrator::readCalibrationCache(size_t& length) {
if (calibration_table_.empty()) return nullptr; if (calibration_table_.empty()) return nullptr;
length = calibration_table_.size(); length = calibration_table_.size();
return calibration_table_.data(); return calibration_table_.data();
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "NvInfer.h" #include <NvInfer.h>
#include "cuda_runtime_api.h" #include <cuda_runtime_api.h>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -60,9 +60,9 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { ...@@ -60,9 +60,9 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
private: private:
const int batch_size_; const int batch_size_;
bool calib_running_; bool calib_running_{true};
bool data_is_set_; bool data_is_set_{false};
bool done_; bool done_{false};
std::mutex mut_; std::mutex mut_;
std::condition_variable cond_; std::condition_variable cond_;
...@@ -74,9 +74,9 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { ...@@ -74,9 +74,9 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
std::string calibration_table_; std::string calibration_table_;
}; };
class TRTCalibratorRes { class TRTCalibratorEngine {
public: public:
TRTCalibratorRes() {} TRTCalibratorEngine() {}
std::unique_ptr<TRTInt8Calibrator> calib_; std::unique_ptr<TRTInt8Calibrator> calib_;
std::unique_ptr<std::thread> thr_; std::unique_ptr<std::thread> thr_;
std::unique_ptr<TensorRTEngine> engine_; std::unique_ptr<TensorRTEngine> engine_;
...@@ -84,7 +84,7 @@ class TRTCalibratorRes { ...@@ -84,7 +84,7 @@ class TRTCalibratorRes {
/* /*
* Manager to control the TensorRT Int8 calibration creation and deltetion. * Manager to control the TensorRT Int8 calibration creation and deltetion.
*/ */
class TRTCalibratorResManager { class TRTCalibratorEngineManager {
public: public:
bool Has() const { return res_.size() > 0; } bool Has() const { return res_.size() > 0; }
bool Has(const std::string& name) const { bool Has(const std::string& name) const {
...@@ -93,22 +93,22 @@ class TRTCalibratorResManager { ...@@ -93,22 +93,22 @@ class TRTCalibratorResManager {
} }
// Get Int8Calibrator via name // Get Int8Calibrator via name
TRTCalibratorRes* Get(const std::string& name) const { TRTCalibratorEngine* Get(const std::string& name) const {
return res_.at(name).get(); return res_.at(name).get();
} }
// Look up or create a calibrator. // Look up or create a calibrator.
TRTCalibratorRes* LookupOrCreate(const std::string& engine_name) { TRTCalibratorEngine* LookupOrCreate(const std::string& engine_name) {
if (res_.count(engine_name) == 0) { if (res_.count(engine_name) == 0) {
auto* p = new TRTCalibratorRes(); auto* p = new TRTCalibratorEngine;
res_[engine_name].reset(p); res_[engine_name].reset(p);
} }
return res_.at(engine_name).get(); return res_.at(engine_name).get();
} }
// Create an Int8Calibrator // Create an Int8Calibrator
TRTCalibratorRes* Create(const std::string& engine_name) { TRTCalibratorEngine* Create(const std::string& engine_name) {
auto* p = new TRTCalibratorRes(); auto* p = new TRTCalibratorEngine;
res_[engine_name].reset(p); res_[engine_name].reset(p);
return p; return p;
} }
...@@ -120,7 +120,7 @@ class TRTCalibratorResManager { ...@@ -120,7 +120,7 @@ class TRTCalibratorResManager {
} }
private: private:
std::unordered_map<std::string, std::unique_ptr<TRTCalibratorRes>> res_; std::unordered_map<std::string, std::unique_ptr<TRTCalibratorEngine>> res_;
}; };
} // namespace tensorrt } // namespace tensorrt
......
...@@ -36,8 +36,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,8 +36,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("max_batch_size", "the maximum batch size."); AddAttr<int>("max_batch_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the workspace size."); AddAttr<int>("workspace_size", "the workspace size.");
AddAttr<framework::BlockDesc *>("sub_block", "the trt block"); AddAttr<framework::BlockDesc *>("sub_block", "the trt block");
AddAttr<std::string>("precision_mode", AddAttr<bool>("enable_int8", "whether swith to int8 mode");
"the precision mode: 'FP32', 'INT8' ");
AddComment("TensorRT engine operator."); AddComment("TensorRT engine operator.");
} }
}; };
......
...@@ -65,8 +65,8 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) { ...@@ -65,8 +65,8 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
using inference::Singleton; using inference::Singleton;
using inference::tensorrt::TensorRTEngine; using inference::tensorrt::TensorRTEngine;
using inference::tensorrt::TRTInt8Calibrator; using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorRes; using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorResManager; using inference::tensorrt::TRTCalibratorEngineManager;
class TensorRTEngineOp : public framework::OperatorBase { class TensorRTEngineOp : public framework::OperatorBase {
private: private:
...@@ -76,7 +76,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -76,7 +76,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
int max_batch_size_; int max_batch_size_;
int workspace_size_; int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_; std::unique_ptr<TRTInt8Calibrator> calibrator_;
std::string precision_mode_; bool enable_int8_;
std::string calibration_data_; std::string calibration_data_;
std::string engine_key_; std::string engine_key_;
bool calibration_mode_; bool calibration_mode_;
...@@ -90,7 +90,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -90,7 +90,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_ = Inputs("Xs"); input_names_ = Inputs("Xs");
max_batch_size_ = Attr<int>("max_batch_size"); max_batch_size_ = Attr<int>("max_batch_size");
workspace_size_ = Attr<int>("workspace_size"); workspace_size_ = Attr<int>("workspace_size");
precision_mode_ = Attr<std::string>("precision_mode"); enable_int8_ = Attr<bool>("enable_int8");
calibration_data_ = Attr<std::string>("calibration_data"); calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key"); engine_key_ = Attr<std::string>("engine_key");
...@@ -98,17 +98,19 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -98,17 +98,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
for (const auto &param : params) { for (const auto &param : params) {
param_names_.insert(param); param_names_.insert(param);
} }
calibration_mode_ = // calibration_mode is ture represents we need to
(precision_mode_ == "INT8" && calibration_data_.size() == 0); // generate the calibration table data.
calibration_mode_ = (enable_int8_ && calibration_data_.size() == 0);
if (precision_mode_ == "INT8" && calibration_data_.size()) { VLOG(4) << "calibration_mode: " << calibration_mode_;
if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
} }
} }
protected: protected:
void RunNative(const framework::Scope &scope, void RunNativeImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
auto *program = block->Program(); auto *program = block->Program();
...@@ -128,12 +130,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -128,12 +130,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
void RunCalibration(const framework::Scope &scope, void RunCalibration(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
// Create calibrator here. // This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each
// tensor of the distribution of activation values.
LOG(INFO) << "Running calibration trt int8 ..."; LOG(INFO) << "Running calibration trt int8 ...";
int runtime_batch = 1; int runtime_batch = 1;
if (!Singleton<TRTCalibratorResManager>::Global().Has(engine_key_)) { if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_key_)) {
TRTCalibratorRes *calib_res = TRTCalibratorEngine *calib_res =
Singleton<TRTCalibratorResManager>::Global().Create(engine_key_); Singleton<TRTCalibratorEngineManager>::Global().Create(engine_key_);
std::unordered_map<std::string, size_t> calib_buffers; std::unordered_map<std::string, size_t> calib_buffers;
for (auto &x : input_names_) { for (auto &x : input_names_) {
if (param_names_.count(x)) continue; if (param_names_.count(x)) continue;
...@@ -148,7 +152,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -148,7 +152,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_res->thr_.reset(new std::thread([&]() { calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine( calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, nullptr, max_batch_size_, workspace_size_, nullptr,
boost::get<platform::CUDAPlace>(dev_place).device, precision_mode_, boost::get<platform::CUDAPlace>(dev_place).device, enable_int8_,
calib_res->calib_.get())); calib_res->calib_.get()));
VLOG(3) << "start the calib trt engine thread"; VLOG(3) << "start the calib trt engine thread";
Prepare(scope, dev_place, calib_res->engine_.get()); Prepare(scope, dev_place, calib_res->engine_.get());
...@@ -156,7 +160,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -156,7 +160,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
TRTInt8Calibrator *temp_calibrator = TRTInt8Calibrator *temp_calibrator =
Singleton<TRTCalibratorResManager>::Global() Singleton<TRTCalibratorEngineManager>::Global()
.Get(engine_key_) .Get(engine_key_)
->calib_.get(); ->calib_.get();
std::unordered_map<std::string, void *> calib_data; std::unordered_map<std::string, void *> calib_data;
...@@ -168,7 +172,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -168,7 +172,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
calib_data.emplace(x, t.data<void>()); calib_data.emplace(x, t.data<void>());
} }
temp_calibrator->setBatch(calib_data); temp_calibrator->setBatch(calib_data);
RunNative(scope, dev_place); RunNativeImpl(scope, dev_place);
} }
void RunTrt(const framework::Scope &scope, void RunTrt(const framework::Scope &scope,
...@@ -178,7 +182,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -178,7 +182,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine_.reset( trt_engine_.reset(
new TensorRTEngine(max_batch_size_, workspace_size_, nullptr, new TensorRTEngine(max_batch_size_, workspace_size_, nullptr,
boost::get<platform::CUDAPlace>(dev_place).device, boost::get<platform::CUDAPlace>(dev_place).device,
precision_mode_, calibrator_.get())); enable_int8_, calibrator_.get()));
Prepare(scope, dev_place, trt_engine_.get()); Prepare(scope, dev_place, trt_engine_.get());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册