未验证 提交 97b76c94 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #15242 from NHZlX/trt_int8_ultimate_version

add trt int8 support
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include <set>
#include <vector>
namespace paddle {
......@@ -79,7 +80,7 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
}
std::unordered_set<Node *> visited;
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
std::set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_visited;
while (!to_visit.empty()) {
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
......@@ -130,6 +131,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode,
contrib::AnalysisConfig::Precision);
// Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
......
......@@ -36,6 +36,14 @@ void SetAttr<int>(framework::proto::OpDesc *op, const std::string &name,
attr->set_i(data);
}
template <>
void SetAttr<bool>(framework::proto::OpDesc *op, const std::string &name,
const bool &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(data);
}
template <>
void SetAttr<int64_t>(framework::proto::OpDesc *op, const std::string &name,
const int64_t &data) {
auto *attr = op->add_attrs();
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <sys/stat.h>
#include <cstdio>
#include <fstream>
#include <set>
#include <string>
#include <typeindex>
#include <unordered_map>
......@@ -29,9 +30,14 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#define GCC_ATTRIBUTE(attr__) ;
#define MKDIR(path) _mkdir(path)
#else
#include <unistd.h>
#define GCC_ATTRIBUTE(attr__) __attribute__((attr__));
#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH)
#endif
#define __SHOULD_USE_RESULT__ GCC_ATTRIBUTE(warn_unused_result)
......@@ -163,6 +169,54 @@ static bool PathExists(const std::string &path) {
return false;
}
static std::string GetDirRoot(const std::string &path) {
char sep = '/';
#ifdef _WIN32
sep = '\\';
#endif
size_t i = path.rfind(sep, path.length());
if (i != std::string::npos) {
return (path.substr(0, i));
}
return path;
}
static std::string GetOrCreateModelOptCacheDir(const std::string &model_root) {
std::string opt_cache_dir = model_root + "/_opt_cache/";
if (!PathExists(opt_cache_dir)) {
PADDLE_ENFORCE(MKDIR(opt_cache_dir.c_str()) != -1,
"Can not create optimize cache directory: %s, Make sure you "
"have permission to write",
opt_cache_dir);
}
return opt_cache_dir;
}
static std::string GetTrtCalibPath(const std::string &model_root,
const std::string &engine_key) {
return model_root + "/trt_calib_" + engine_key;
}
// If there is no calib table data file in model_opt_cache_dir, return "".
static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir,
const std::string &engine_key,
bool enable_int8) {
std::string trt_calib_table_path =
GetTrtCalibPath(model_opt_cache_dir, engine_key);
if (enable_int8 && FileExists(trt_calib_table_path)) {
VLOG(3) << "Calibration table file: " << trt_calib_table_path
<< "is found here";
std::ifstream infile(trt_calib_table_path, std::ios::in);
std::stringstream buffer;
buffer << infile.rdbuf();
std::string calibration_data(buffer.str());
return calibration_data;
}
return "";
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......
......@@ -67,6 +67,20 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size()));
pass->Set("min_subgraph_size",
new int(argument->tensorrt_min_subgraph_size()));
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
bool enable_int8 = argument->tensorrt_precision_mode() ==
contrib::AnalysisConfig::Precision::kInt8;
pass->Set("enable_int8", new bool(enable_int8));
std::string model_opt_cache_dir =
argument->Has("model_dir")
? argument->model_dir()
: GetDirRoot(argument->model_program_path());
pass->Set(
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
}
// graph_ = pass->Apply(std::move(graph_));
......@@ -91,11 +105,14 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
}
framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const {
std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
ProgramDesc desc(program);
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc;
desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
*graph = pass->Apply(std::unique_ptr<Graph>(the_graph));
......
......@@ -29,6 +29,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle {
namespace inference {
......@@ -42,8 +43,8 @@ class IRPassManager final {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
framework::proto::ProgramDesc AcquireProgram(
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const;
framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
ProgramDesc *program) const;
framework::ir::Graph &graph() const { return *graph_; }
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include <set>
#include <string>
#include <vector>
......@@ -67,12 +68,33 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
return graph;
}
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::set<std::string> &engine_outputs) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
Graph *graph) const {
auto *op_desc = node->Op();
auto &subgraph = *Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());
framework::ProgramDesc *program_desc =
Get<framework::ProgramDesc *>("program");
// Add new block for TensorRTEngineOP
const framework::BlockDesc &main_block =
program_desc->Block(framework::kRootBlockIndex);
// const framework::BlockDesc& main_block = program_desc->Block(0);
framework::BlockDesc *new_block = program_desc->AppendBlock(main_block);
// An fake block desc.
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
......@@ -82,13 +104,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
subgraph.size());
for (auto *node : subgraph) {
auto *new_block_op = new_block->AppendOp();
auto *op = block_desc.AppendOp();
*new_block_op->Proto() = *node->Op()->Proto();
*op->Proto() = *node->Op()->Proto();
}
// collect inputs
std::unordered_set<std::string> input_names;
std::unordered_set<std::string> input_names_with_id;
// Then, we will use the input_names_with_id and output_names_with_id to
// generate the eigine key.
// So, We use set instead of unordered_set here to ensure that the engine key
// is unique.
std::set<std::string> input_names;
std::set<std::string> input_names_with_id;
for (auto *x : node->inputs) {
input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id()));
......@@ -96,8 +123,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::unordered_set<std::string> output_names;
std::unordered_set<std::string> output_names_with_id;
std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
......@@ -182,7 +209,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// to Tensor.
std::vector<std::string> output_mapping;
for (auto name : output_names) {
// LOG(INFO) << name << " " << output_name_map.size();
PADDLE_ENFORCE(output_name_map.count(name) != 0);
output_mapping.push_back(output_name_map[name]);
}
......@@ -193,16 +219,29 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
*vars->Add() = *node->Var()->Proto();
}
}
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc");
PADDLE_ENFORCE(!output_mapping.empty());
// Set attrs
op_desc->SetBlockAttr("sub_block", new_block);
SetAttr(op_desc->Proto(), "subgraph",
block_desc.Proto()->SerializeAsString());
// Set attrs
SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes()));
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
auto enable_int8 = Get<bool>("enable_int8");
auto engine_key =
GenerateEngineKey(input_names_with_id, output_names_with_id);
std::string calibration_data = GetTrtCalibTableData(
Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
SetAttr(op_desc->Proto(), "engine_key", engine_key);
}
std::vector<std::string> ExtractParameters(
......
......@@ -31,7 +31,11 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
}
std::unique_ptr<Graph> graph(argument->main_graph_ptr());
framework::ProgramDesc desc(argument->main_program());
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
framework::ProgramDesc desc;
desc.CopyFrom(*argument->main_program().Proto());
pass->SetNotOwned("program", &desc);
auto thegraph = pass->Apply(std::move(graph));
thegraph.release(); // the argument still own the graph.
......
......@@ -102,6 +102,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
CP_MEMBER(tensorrt_workspace_size_);
CP_MEMBER(tensorrt_max_batchsize_);
CP_MEMBER(tensorrt_min_subgraph_size_);
CP_MEMBER(tensorrt_precision_mode_);
// MKLDNN releated.
CP_MEMBER(use_mkldnn_);
CP_MEMBER(mkldnn_enabled_op_types_);
......@@ -141,9 +142,9 @@ void contrib::AnalysisConfig::EnableMKLDNN() {
Update();
}
void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
int max_batch_size,
int min_subgraph_size) {
void contrib::AnalysisConfig::EnableTensorRtEngine(
int workspace_size, int max_batch_size, int min_subgraph_size,
contrib::AnalysisConfig::Precision precision_mode) {
#ifdef PADDLE_WITH_CUDA
if (!use_gpu()) {
LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
......@@ -154,6 +155,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
tensorrt_workspace_size_ = workspace_size;
tensorrt_max_batchsize_ = max_batch_size;
tensorrt_min_subgraph_size_ = min_subgraph_size;
tensorrt_precision_mode_ = precision_mode;
Update();
#else
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <glog/logging.h>
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <vector>
......@@ -25,6 +26,7 @@
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
......@@ -37,6 +39,8 @@
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif
DECLARE_bool(profile);
......@@ -44,6 +48,12 @@ DECLARE_bool(profile);
namespace paddle {
using contrib::AnalysisConfig;
using inference::Singleton;
#if PADDLE_WITH_TENSORRT
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
#endif
namespace {
bool IsPersistable(const framework::VarDesc *var) {
......@@ -339,6 +349,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
!config_.params_file().empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file().empty());
std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
argument_.SetModelProgramPath(config_.prog_file());
argument_.SetModelParamsPath(config_.params_file());
}
......@@ -349,6 +361,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
}
if (config_.use_mkldnn_) {
......@@ -569,7 +582,67 @@ bool AnalysisPredictor::LoadParameters() {
return true;
}
#if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE(config_.tensorrt_engine_enabled(),
"This func can be invoked only in trt mode");
auto &block = inference_program_->Block(0);
for (auto &op_desc : block.AllOps()) {
if (op_desc->Type() == "tensorrt_engine") {
std::string engine_name =
boost::get<std::string>(op_desc->GetAttr("engine_key"));
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) {
LOG(ERROR) << "You should run the predictor(with trt) on the real data "
"to generate calibration info";
return false;
}
TRTCalibratorEngine *calib_engine =
Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name);
LOG(INFO) << "Wait for calib threads done.";
calib_engine->calib_->waitAndSetDone();
LOG(INFO) << "Generating TRT Calibration table data, this may cost a lot "
"of time...";
calib_engine->thr_->join();
std::string calibration_table_data =
calib_engine->calib_->getCalibrationTableAsString();
if (calibration_table_data.empty()) {
LOG(ERROR) << "the calibration table is empty.";
return false;
}
std::string model_opt_cache_dir =
argument_.Has("model_dir")
? argument_.model_dir()
: inference::analysis::GetDirRoot(argument_.model_program_path());
std::string calibration_table_data_path =
inference::analysis::GetTrtCalibPath(
inference::analysis::GetOrCreateModelOptCacheDir(
model_opt_cache_dir),
engine_name);
std::ofstream ofile(calibration_table_data_path, std::ios::out);
LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file "
<< calibration_table_data_path;
ofile << calibration_table_data;
ofile.close();
}
}
// Free all calibrator resources.
Singleton<TRTCalibratorEngineManager>::Global().DeleteALL();
return true;
}
#endif
AnalysisPredictor::~AnalysisPredictor() {
#if PADDLE_WITH_TENSORRT
if (config_.tensorrt_engine_enabled() &&
config_.tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
Singleton<TRTCalibratorEngineManager>::Global().Has()) {
SaveTrtCalibToDisk();
}
#endif
if (FLAGS_profile) {
platform::DisableProfiler(platform::EventSortingKey::kTotal,
"./profile.log");
......
......@@ -97,6 +97,21 @@ class AnalysisPredictor : public PaddlePredictor {
void GetFetchOne(const framework::LoDTensor &fetchs,
PaddleTensor *output_data);
#if PADDLE_WITH_TENSORRT
// When we use Paddle-TRT INT8 engine, we need to generate calibration table
// data first,
// the calibration table contains the range for each op's input and output,
// this whole process can be divided into several steps:
//
// 1. Builds a 32-bit engine, runs it on the calibration set, and records a
// histogram for each
// tensor of the distribution of activation values.
// 2. Builds a calibration table from the histograms.
//
// After step 2, we need to store the calibration table on disk
bool SaveTrtCalibToDisk();
#endif
// Some more detailed tests, they are made the friends of the predictor, so that
// the all the details can be tested.
#if PADDLE_WITH_TESTING
......
......@@ -42,6 +42,10 @@ struct AnalysisConfig {
explicit AnalysisConfig(const std::string& model_dir);
explicit AnalysisConfig(const std::string& prog_file,
const std::string& params_file);
enum class Precision {
kFloat32 = 0,
kInt8,
};
/** Set model with a directory.
*/
......@@ -135,7 +139,8 @@ struct AnalysisConfig {
* subgraph is less than this, it will not transfer to TensorRT engine.
*/
void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1, int min_subgraph_size = 3);
int max_batch_size = 1, int min_subgraph_size = 3,
Precision precision = Precision::kFloat32);
/** A boolean state telling whether the TensorRT engine is used.
*/
bool tensorrt_engine_enabled() const { return use_tensorrt_; }
......@@ -229,6 +234,7 @@ struct AnalysisConfig {
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value.
int tensorrt_min_subgraph_size_{3};
Precision tensorrt_precision_mode_;
// memory reuse related.
bool enable_memory_optim_{false};
......
nv_library(tensorrt_engine SRCS engine.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
nv_library(tensorrt_engine SRCS engine.cc trt_int8_calibrator.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context)
nv_library(tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
......
......@@ -69,6 +69,13 @@ void TensorRTEngine::FreezeNetwork() {
// build engine.
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);
if (enable_int8_) {
infer_builder_->setInt8Mode(true);
PADDLE_ENFORCE(
calibrator_ != nullptr,
"The precision mode is 'INT8', the calibrator should not be nullptr");
infer_builder_->setInt8Calibrator(calibrator_);
}
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
......
......@@ -23,12 +23,14 @@ limitations under the License. */
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class TRTInt8Calibrator;
/*
* TensorRT Engine.
*
......@@ -55,13 +57,16 @@ class TensorRTEngine : public EngineBase {
};
TensorRTEngine(int max_batch, int max_workspace, cudaStream_t stream,
int device = 0,
int device = 0, bool enable_int8 = false,
TRTInt8Calibrator* calibrator = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
stream_(stream),
logger_(logger),
device_(device) {}
device_(device),
enable_int8_(enable_int8),
calibrator_(calibrator),
logger_(logger) {}
virtual ~TensorRTEngine();
......@@ -139,8 +144,8 @@ class TensorRTEngine : public EngineBase {
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
// into
// one conv, and then trigger bug. So, We should use strategy to avoid this
// into one conv, and then trigger bug. So, We should use strategy to avoid
// this
// optimization for the time being. This bug will be fixed in the future.
std::unordered_map<std::string /*name*/, int /*ITensor_quote_num*/>
itensor_quote_num;
......@@ -153,9 +158,14 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int max_workspace_;
cudaStream_t stream_;
// The specific GPU id that the TensorRTEngine bounded to.
int device_;
bool enable_int8_;
TRTInt8Calibrator* calibrator_;
// batch size of the current data, will be updated each Executation.
int batch_size_{-1};
cudaStream_t stream_;
nvinfer1::ILogger& logger_;
......@@ -165,8 +175,6 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
itensor_map_;
// The specific GPU id that the TensorRTEngine bounded to.
int device_;
std::vector<std::unique_ptr<plugin::PluginTensorRT>> owned_plugin_;
// TensorRT related internal members
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "glog/logging.h"
namespace paddle {
namespace inference {
namespace tensorrt {
// set the batch size before constructing the thread to execute engine
int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
TRTInt8Calibrator::TRTInt8Calibrator(
const std::unordered_map<std::string, size_t>& buffers, int batch_size,
std::string engine_name, const platform::Place place)
: batch_size_(batch_size), engine_name_(engine_name) {
int i = 0;
VLOG(4) << "Init a new calibrator: " << engine_name_;
for (const auto it : buffers) {
framework::Tensor temp_tensor;
std::string input_name = it.first;
int data_size = it.second;
int num_ele = data_size / sizeof(int16_t);
framework::DDim data_shape = framework::make_ddim({num_ele});
temp_tensor.Resize(data_shape);
data_tensors_.push_back(temp_tensor);
data_buffers_[input_name] = std::pair<void*, size_t>(
static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)), num_ele);
i += 1;
}
}
TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data)
: batch_size_(0),
calib_running_(false),
data_is_set_(false),
done_(true),
calibration_table_(calib_data) {}
void TRTInt8Calibrator::waitAndSetDone() {
std::unique_lock<std::mutex> lk(mut_);
while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk);
if (!done_) {
done_ = true;
cond_.notify_all();
}
}
// There might be more than one input for trt subgraph,
// So, we use a map to store input information.
bool TRTInt8Calibrator::setBatch(
const std::unordered_map<std::string, void*>& data) {
VLOG(3) << "set batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_);
// There is a producer and a consumer. The producer set the batch data and
// the consumer get the batch data. The size of the data pool is one.
// So, the producer has to wait for the consumer to finish processing before
// they can set the data.
while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
// The done_ is set to true using waitAndSetDone, When all calibration data
// are processed.
if (done_) return false;
// Sets the batch.
for (const auto& it : data) {
auto dataptr = data_buffers_.find(it.first);
if (dataptr == data_buffers_.end()) {
LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
<< "' does not match with the buffer names";
}
const auto& d = dataptr->second;
PADDLE_ENFORCE(
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice),
"Fail to cudaMemcpy %s for %s", engine_name_, it.first);
}
data_is_set_ = true;
cond_.notify_all();
return true;
}
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
int num_bindings) {
VLOG(4) << "get batch: " << engine_name_;
std::unique_lock<std::mutex> lk(mut_);
// The consumer has just finished processing a data.
// The producer can set the data again.
calib_running_ = false;
cond_.notify_all();
// As long as there is data in the pool, the consumer can get it.
while (!data_is_set_ && !done_) cond_.wait(lk);
if (done_) return false;
// Gets the batch
for (int i = 0; i < num_bindings; i++) {
auto it = data_buffers_.find(names[i]);
if (it == data_buffers_.end()) {
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
<< names[i] << "' at position " << i;
}
bindings[i] = it->second.first;
}
data_is_set_ = false;
calib_running_ = true;
VLOG(4) << "get batch done: " << engine_name_;
return true;
}
void TRTInt8Calibrator::setDone() {
std::unique_lock<std::mutex> lk(mut_);
done_ = true;
cond_.notify_all();
}
const void* TRTInt8Calibrator::readCalibrationCache(size_t& length) {
if (calibration_table_.empty()) return nullptr;
length = calibration_table_.size();
return calibration_table_.data();
}
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
std::size_t length) {
calibration_table_ = std::string((const char*)ptr, length);
VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr
<< " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() {
VLOG(4) << "Destroying calibrator for " << engine_name_;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <NvInfer.h>
#include <cuda_runtime_api.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class TensorRTEngine;
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
public:
TRTInt8Calibrator(const std::unordered_map<std::string, size_t>& buffers,
int batch_size, std::string engine_name,
const platform::Place place);
explicit TRTInt8Calibrator(const std::string& calibration_data);
~TRTInt8Calibrator();
int getBatchSize() const override;
bool getBatch(void* bindings[], const char* names[],
int num_bindings) override;
bool setBatch(const std::unordered_map<std::string, void*>& data);
void setDone();
void waitAndSetDone();
const void* readCalibrationCache(std::size_t& length) override;
void writeCalibrationCache(const void* ptr, std::size_t length) override;
const std::string& getCalibrationTableAsString() {
return calibration_table_;
}
private:
const int batch_size_;
bool calib_running_{true};
bool data_is_set_{false};
bool done_{false};
std::mutex mut_;
std::condition_variable cond_;
std::unordered_map<std::string, std::pair<void*, size_t>> data_buffers_;
std::vector<framework::Tensor> data_tensors_;
std::string engine_name_;
std::string calibration_table_;
};
class TRTCalibratorEngine {
public:
TRTCalibratorEngine() {}
std::unique_ptr<TRTInt8Calibrator> calib_;
std::unique_ptr<std::thread> thr_;
std::unique_ptr<TensorRTEngine> engine_;
};
/*
* Manager to control the TensorRT Int8 calibration creation and deltetion.
*/
class TRTCalibratorEngineManager {
public:
bool Has() const { return res_.size() > 0; }
bool Has(const std::string& name) const {
if (res_.count(name) == 0) return false;
return res_.at(name).get() != nullptr;
}
// Get Int8Calibrator via name
TRTCalibratorEngine* Get(const std::string& name) const {
return res_.at(name).get();
}
// Look up or create a calibrator.
TRTCalibratorEngine* LookupOrCreate(const std::string& engine_name) {
if (res_.count(engine_name) == 0) {
auto* p = new TRTCalibratorEngine;
res_[engine_name].reset(p);
}
return res_.at(engine_name).get();
}
// Create an Int8Calibrator
TRTCalibratorEngine* Create(const std::string& engine_name) {
auto* p = new TRTCalibratorEngine;
res_[engine_name].reset(p);
return p;
}
void DeleteALL() {
for (auto& item : res_) {
item.second.reset(nullptr);
}
}
private:
std::unordered_map<std::string, std::unique_ptr<TRTCalibratorEngine>> res_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -29,8 +29,14 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDuplicable();
AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("calibration_data", "the calibration data for int8");
AddAttr<std::string>(
"engine_key",
"The engine_key here is used to distinguish different TRT Engines");
AddAttr<int>("max_batch_size", "the maximum batch size.");
AddAttr<int>("workspace_size", "the workspace size.");
AddAttr<framework::BlockDesc *>("sub_block", "the trt block");
AddAttr<bool>("enable_int8", "whether swith to int8 mode");
AddComment("TensorRT engine operator.");
}
};
......@@ -47,6 +53,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference {
namespace ops = paddle::operators;
REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp,
ops::TensorRTEngineOpMaker);
ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker);
#endif // PADDLE_WITH_CUDA
......@@ -17,8 +17,10 @@
#ifdef PADDLE_WITH_CUDA
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
......@@ -62,6 +64,9 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
using inference::Singleton;
using inference::tensorrt::TensorRTEngine;
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
class TensorRTEngineOp : public framework::OperatorBase {
private:
......@@ -70,6 +75,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
mutable std::unique_ptr<TensorRTEngine> trt_engine_;
int max_batch_size_;
int workspace_size_;
std::unique_ptr<TRTInt8Calibrator> calibrator_;
bool enable_int8_;
std::string calibration_data_;
std::string engine_key_;
bool calibration_mode_;
public:
TensorRTEngineOp(const std::string &type,
......@@ -80,19 +90,96 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_ = Inputs("Xs");
max_batch_size_ = Attr<int>("max_batch_size");
workspace_size_ = Attr<int>("workspace_size");
enable_int8_ = Attr<bool>("enable_int8");
calibration_data_ = Attr<std::string>("calibration_data");
engine_key_ = Attr<std::string>("engine_key");
auto params = Attr<std::vector<std::string>>("parameters");
for (const auto &param : params) {
param_names_.insert(param);
}
// calibration_mode is ture represents we need to
// generate the calibration table data.
calibration_mode_ = (enable_int8_ && calibration_data_.size() == 0);
VLOG(4) << "calibration_mode: " << calibration_mode_;
if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
}
}
protected:
void RunNativeImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
auto *program = block->Program();
auto &current_scope = scope.NewScope();
auto ctx = executor.Prepare(*program, block->ID());
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
if (calibration_mode_ == true) {
RunCalibration(scope, dev_place);
return;
}
RunTrt(scope, dev_place);
}
void RunCalibration(const framework::Scope &scope,
const platform::Place &dev_place) const {
// This process will builds a 32-bit trt engine, runs it on the calibration
// set, and records a histogram for each
// tensor of the distribution of activation values.
LOG_FIRST_N(INFO, 1) << "The TRT engine: " << engine_key_
<< " is running calibration trt int8... ";
int runtime_batch = 1;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_key_)) {
TRTCalibratorEngine *calib_res =
Singleton<TRTCalibratorEngineManager>::Global().Create(engine_key_);
std::unordered_map<std::string, size_t> calib_buffers;
for (auto &x : input_names_) {
if (param_names_.count(x)) continue;
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
calib_buffers[x] = t.memory_size();
auto t_shape = framework::vectorize(t.dims());
runtime_batch = t_shape[0];
}
calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, engine_key_, dev_place));
calib_res->thr_.reset(new std::thread([&]() {
calib_res->engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, stream,
boost::get<platform::CUDAPlace>(dev_place).device, enable_int8_,
calib_res->calib_.get()));
VLOG(3) << "start the calib trt engine thread";
Prepare(scope, dev_place, calib_res->engine_.get());
}));
}
TRTInt8Calibrator *temp_calibrator =
Singleton<TRTCalibratorEngineManager>::Global()
.Get(engine_key_)
->calib_.get();
std::unordered_map<std::string, void *> calib_data;
for (auto &x : Inputs("Xs")) {
if (param_names_.count(x)) continue;
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
calib_data.emplace(x, t.data<void>());
}
temp_calibrator->setBatch(calib_data);
RunNativeImpl(scope, dev_place);
}
void RunTrt(const framework::Scope &scope,
const platform::Place &dev_place) const {
int runtime_batch = 1;
......@@ -101,9 +188,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
if (trt_engine_.get() == nullptr) {
trt_engine_.reset(new TensorRTEngine(
max_batch_size_, workspace_size_, stream,
boost::get<platform::CUDAPlace>(dev_place).device));
trt_engine_.reset(
new TensorRTEngine(max_batch_size_, workspace_size_, stream,
boost::get<platform::CUDAPlace>(dev_place).device,
enable_int8_, calibrator_.get()));
Prepare(scope, dev_place, trt_engine_.get());
}
......@@ -173,7 +261,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
void Prepare(const framework::Scope &scope, const platform::Place &dev_place,
TensorRTEngine *engine) const {
VLOG(4) << "Prepare engine";
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(Attr<std::string>("subgraph"));
......
......@@ -96,19 +96,20 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 1 << 20);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({}));
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
"output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
engine_op_desc.SetAttr("max_batch_size", static_cast<int>(2));
engine_op_desc.SetAttr("workspace_size", static_cast<int>(1 << 20));
engine_op_desc.SetAttr("parameters", std::vector<std::string>({}));
engine_op_desc.SetAttr("engine_key", std::string("a_engine"));
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
LOG(INFO) << "engine_op " << engine_op.get();
framework::Scope scope;
......@@ -190,20 +191,19 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 1 << 20);
SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
"output_name_mapping",
std::vector<std::string>({"z3"}));
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
engine_op_desc.SetAttr("max_batch_size", static_cast<int>(batch_size));
engine_op_desc.SetAttr("workspace_size", static_cast<int>(1 << 20));
engine_op_desc.SetAttr("parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
engine_op_desc.SetAttr("engine_key", std::string("b_engine"));
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
// Execute them.
engine_op->Run(scope, place);
......
......@@ -180,8 +180,14 @@ void BindNativePredictor(py::module *m) {
}
void BindAnalysisConfig(py::module *m) {
py::class_<AnalysisConfig>(*m, "AnalysisConfig")
.def(py::init<const AnalysisConfig &>())
py::class_<AnalysisConfig> analysis_config(*m, "AnalysisConfig");
py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision")
.value("Float32", AnalysisConfig::Precision::kFloat32)
.value("Int8", AnalysisConfig::Precision::kInt8)
.export_values();
analysis_config.def(py::init<const AnalysisConfig &>())
.def(py::init<const std::string &>())
.def(py::init<const std::string &, const std::string &>())
.def("set_model", (void (AnalysisConfig::*)(const std::string &)) &
......@@ -215,7 +221,8 @@ void BindAnalysisConfig(py::module *m) {
.def("specify_input_name", &AnalysisConfig::specify_input_name)
.def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine,
py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
py::arg("min_subgraph_size") = 3)
py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册