diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 2d8980b1d15d89cdf9c243a57188a0acb354940d..b06ff63a741c48be3719f069d5d45d957195f979 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -104,6 +104,7 @@ struct Argument { DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); + DECL_ARGUMENT_FIELD(model_path, ModelPath, std::string); // The overall graph to work on. DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); @@ -126,6 +127,8 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int); DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int); DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int); + DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, + std::string); // The program transformed by IR analysis phase. DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram, diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 269a0da9f9378601373e42d741f519843b111ec6..5df3aacc3f27a998984be4679a4ef531c501dd66 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -156,6 +156,21 @@ static bool PathExists(const std::string &path) { return false; } +static std::string SplitPath(const std::string path) { + char sep = '/'; + +#ifdef _WIN32 + sep = '\\'; +#endif + + size_t i = path.rfind(sep, path.length()); + if (i != std::string::npos) { + return (path.substr(0, i)); + } + + return path; +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index e37fea38bcb2b1f514347ecbfe7072abb6f07455..a99605577434c5d8a0f5a67d6fd0515d9a024c3e 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -67,9 +67,17 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size())); pass->Set("min_subgraph_size", new int(argument->tensorrt_min_subgraph_size())); + pass->Set( + "program", + new framework::ProgramDesc *( + const_cast(&argument->main_program()))); + pass->Set("precision_mode", + new std::string(argument->tensorrt_precision_mode())); + pass->Set("model_dir", new std::string(argument->model_path())); } // graph_ = pass->Apply(std::move(graph_)); + pre_pass = pass_name; passes_.emplace_back(std::move(pass)); @@ -94,7 +102,8 @@ framework::proto::ProgramDesc IRPassManager::AcquireProgram( auto pass = framework::ir::PassRegistry::Instance().Get("graph_to_program_pass"); - ProgramDesc desc(program); + ProgramDesc desc; + desc.CopyFrom(*const_cast(program).Proto()); pass->SetNotOwned("program", &desc); auto *the_graph = graph->release(); *graph = pass->Apply(std::unique_ptr(the_graph)); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index bc06e78ae6997b0d4d0456c15d6e4158efdad300..634c5ead0a30512699c3b2eaa225e06bce66b7ff 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -72,13 +72,23 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, auto &subgraph = *Agent(node).subgraph(); PADDLE_ENFORCE(!subgraph.empty()); + framework::ProgramDesc *program_desc = + Get("program"); + // Add new block for TensorRTEngineOP + const framework::BlockDesc &main_block = + program_desc->Block(framework::kRootBlockIndex); + // const framework::BlockDesc& main_block = program_desc->Block(0); + framework::BlockDesc *new_block = program_desc->AppendBlock(main_block); + // An fake block desc. framework::proto::BlockDesc block_proto; framework::BlockDesc block_desc(nullptr, &block_proto); block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_idx(0); for (auto *node : subgraph) { + auto *new_block_op = new_block->AppendOp(); auto *op = block_desc.AppendOp(); + *new_block_op->Proto() = *node->Op()->Proto(); *op->Proto() = *node->Op()->Proto(); } @@ -178,7 +188,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, // to Tensor. std::vector output_mapping; for (auto name : output_names) { - // LOG(INFO) << name << " " << output_name_map.size(); PADDLE_ENFORCE(output_name_map.count(name) != 0); output_mapping.push_back(output_name_map[name]); } @@ -189,9 +198,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, *vars->Add() = *node->Var()->Proto(); } } + PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(), "the block has no var-desc"); PADDLE_ENFORCE(!output_mapping.empty()); + op_desc->SetBlockAttr("sub_block", new_block); // Set attrs SetAttr(op_desc->Proto(), "subgraph", block_desc.Proto()->SerializeAsString()); @@ -199,6 +210,22 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, SetAttr(op_desc->Proto(), "workspace_size", Get("workspace_size")); SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes())); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping); + + std::string engine_key = std::to_string( + std::hash()(block_desc.Proto()->SerializeAsString())); + std::string precision_mode = Get("precision_mode"); + SetAttr(op_desc->Proto(), "calibration_data", std::string("")); + std::string trt_calib_file = + Get("model_dir") + "/trt_calib_" + engine_key; + if (precision_mode == "INT8" && FileExists(trt_calib_file)) { + std::ifstream infile(trt_calib_file, std::ios::in); + std::stringstream buffer; + buffer << infile.rdbuf(); + std::string calibration_data(buffer.str()); + SetAttr(op_desc->Proto(), "calibration_data", calibration_data); + } + SetAttr(op_desc->Proto(), "precision_mode", precision_mode); + SetAttr(op_desc->Proto(), "engine_key", engine_key); } std::vector ExtractParameters( diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 211c691504de2c0bd8ff50f34b92cbc01397d5c9..399db291fd1d1695e164b6d9985f1be8ab4551ef 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -86,6 +86,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { CP_MEMBER(tensorrt_workspace_size_); CP_MEMBER(tensorrt_max_batchsize_); CP_MEMBER(tensorrt_min_subgraph_size_); + CP_MEMBER(tensorrt_precision_mode_); // MKLDNN releated. CP_MEMBER(use_mkldnn_); CP_MEMBER(mkldnn_enabled_op_types_); @@ -123,10 +124,13 @@ void contrib::AnalysisConfig::EnableMKLDNN() { void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, int max_batch_size, - int min_subgraph_size) { + int min_subgraph_size, + std::string precision_mode) { use_tensorrt_ = true; tensorrt_workspace_size_ = workspace_size; tensorrt_max_batchsize_ = max_batch_size; + tensorrt_precision_mode_ = precision_mode; + Update(); } void contrib::AnalysisConfig::Update() { @@ -176,6 +180,7 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() { ss << use_tensorrt_; ss << tensorrt_workspace_size_; ss << tensorrt_max_batchsize_; + ss << tensorrt_precision_mode_; ss << use_mkldnn_; ss << enable_ir_optim_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 585634fae9c85f77cc77d774ac166891014a025c..75c62bb98cb7c09eb8ad81942a6e7e110f04abd6 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/inference/api/analysis_predictor.h" #include #include +#include #include #include #include @@ -30,6 +31,8 @@ #if PADDLE_WITH_TENSORRT #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #endif +#include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/cpu_helper.h" @@ -41,6 +44,10 @@ DECLARE_bool(profile); namespace paddle { using contrib::AnalysisConfig; +using inference::Singleton; +using inference::tensorrt::TRTInt8Calibrator; +using inference::tensorrt::TRTCalibratorRes; +using inference::tensorrt::TRTCalibratorResManager; namespace { bool IsPersistable(const framework::VarDesc *var) { @@ -321,11 +328,15 @@ void AnalysisPredictor::OptimizeInferenceProgram() { // Analyze inference_program if (!config_.model_dir().empty()) { argument_.SetModelDir(config_.model_dir()); + argument_.SetModelPath(config_.model_dir()); } else { PADDLE_ENFORCE( !config_.params_file().empty(), "Either model_dir or (param_file, prog_file) should be set."); PADDLE_ENFORCE(!config_.prog_file().empty()); + std::string dir = inference::analysis::SplitPath(config_.prog_file()); + + argument_.SetModelPath(dir); argument_.SetModelProgramPath(config_.prog_file()); argument_.SetModelParamsPath(config_.params_file()); } @@ -335,6 +346,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_); argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); + argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); } if (config_.use_mkldnn_) { @@ -550,7 +562,52 @@ bool AnalysisPredictor::LoadParameters() { return true; } +bool AnalysisPredictor::SaveTrtCalibToDisk() { + PADDLE_ENFORCE(config_.tensorrt_engine_enabled(), + "This func can be invoked only in trt mode"); + auto &block = inference_program_->Block(0); + for (auto &op_desc : block.AllOps()) { + if (op_desc->Type() == "tensorrt_engine") { + std::string engine_name = + boost::get(op_desc->GetAttr("engine_key")); + if (!Singleton::Global().Has(engine_name)) { + LOG(ERROR) << "You should run the predictor(with trt) on the real data " + "to generate calibration info"; + return false; + } + TRTCalibratorRes *calib_res = + Singleton::Global().Get(engine_name); + LOG(INFO) << "Wait for calib threads done."; + calib_res->calib_->waitAndSetDone(); + LOG(INFO) << "Finish wait."; + calib_res->thr_->join(); + std::string calibration_data = + calib_res->calib_->getCalibrationTableAsString(); + + if (calibration_data.size() == 0) { + LOG(ERROR) << "the calibration table is empty."; + return false; + } + std::string calibration_data_path = + argument_.model_path() + "/trt_calib_" + engine_name; + std::ofstream ofile(calibration_data_path, std::ios::out); + LOG(INFO) << "Write Paddle-TRT INT8 calibration data to file " + << calibration_data_path; + ofile << calibration_data; + ofile.close(); + } + } + // Free all calibrator resources. + Singleton::Global().DeleteALL(); + return true; +} + AnalysisPredictor::~AnalysisPredictor() { + if (config_.tensorrt_engine_enabled() && + config_.tensorrt_precision_mode_ == "INT8" && + Singleton::Global().Has()) { + SaveTrtCalibToDisk(); + } if (FLAGS_profile) { platform::DisableProfiler(platform::EventSortingKey::kTotal, "./profile.log"); diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index a6e126c5d533f4299ccc3deed7d116cabc71f75b..cec36a0d3a92defac82b982ac688db4a4080d552 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -90,6 +90,9 @@ class AnalysisPredictor : public PaddlePredictor { template void GetFetchOne(const framework::LoDTensor &fetchs, PaddleTensor *output_data); + + bool SaveTrtCalibToDisk(); + ~AnalysisPredictor(); // Some more detailed tests, they are made the friends of the predictor, so that diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index ae6ac69854d91d44567ccd985791de5fd2b16f26..14b16d08b3444c020b920feb0e5d9cc79e45d43f 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -135,7 +135,8 @@ struct AnalysisConfig { * subgraph is less than this, it will not transfer to TensorRT engine. */ void EnableTensorRtEngine(int workspace_size = 1 << 20, - int max_batch_size = 1, int min_subgraph_size = 3); + int max_batch_size = 1, int min_subgraph_size = 3, + std::string precision = "FP32"); /** A boolean state telling whether the TensorRT engine is used. */ bool tensorrt_engine_enabled() const { return use_tensorrt_; } @@ -231,6 +232,7 @@ struct AnalysisConfig { // We set this variable to control the minimum number of nodes in the // subgraph, 3 as default value. int tensorrt_min_subgraph_size_{3}; + std::string tensorrt_precision_mode_; bool use_mkldnn_{false}; std::unordered_set mkldnn_enabled_op_types_; diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index 9afeafd176c70bc03166ec7732ae5e2faf67ea54..f4977d08c4d051b8a528e122c47948c3c81d153c 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,4 @@ -nv_library(tensorrt_engine SRCS engine.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context) +nv_library(tensorrt_engine SRCS engine.cc trt_int8_calibrator.cc DEPS ${GLOB_OPERATOR_DEPS} framework_proto device_context) nv_library(tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index f739752cbc44805cb0fb3246385609cf16ba744a..43f99df4637275a2df85f3ce89318a43f4d0c9fa 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -70,6 +70,13 @@ void TensorRTEngine::FreezeNetwork() { // build engine. infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxWorkspaceSize(max_workspace_); + if (precision_mode_ == "INT8") { + infer_builder_->setInt8Mode(true); + PADDLE_ENFORCE( + calibrator_ != nullptr, + "The precision mode is 'INT8', the calibrator should not be nullptr"); + infer_builder_->setInt8Calibrator(calibrator_); + } infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index f5b2c28ba9e6fefc1d6c14640d696c3bf3ac8249..9aed374dce4c5cd455151a0f1089a8cf30db16b9 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -23,12 +23,14 @@ limitations under the License. */ #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { namespace inference { namespace tensorrt { +class TRTInt8Calibrator; /* * TensorRT Engine. * @@ -56,12 +58,16 @@ class TensorRTEngine : public EngineBase { TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream = nullptr, int device = 0, + std::string precision_mode = "FP32", + TRTInt8Calibrator* calibrator = nullptr, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), stream_(stream ? stream : &default_stream_), - logger_(logger), - device_(device) { + device_(device), + precision_mode_(precision_mode), + calibrator_(calibrator), + logger_(logger) { freshDeviceId(); cudaStreamCreate(stream_); } @@ -142,8 +148,8 @@ class TensorRTEngine : public EngineBase { // In the normal case, the paddle-trt exists bug when runing the googlenet. // When there are more than two convolutions of 1 * 1 with the same input, the // paddle-tensorrt will do the merging optimization, which fuse those conv - // into - // one conv, and then trigger bug. So, We should use strategy to avoid this + // into one conv, and then trigger bug. So, We should use strategy to avoid + // this // optimization for the time being. This bug will be fixed in the future. std::unordered_map itensor_quote_num; @@ -156,11 +162,16 @@ class TensorRTEngine : public EngineBase { // the max memory size the engine uses int max_workspace_; - // batch size of the current data, will be updated each Executation. - int batch_size_{-1}; cudaStream_t* stream_; // If stream_ is not set from outside, hold its own stream. cudaStream_t default_stream_; + // The specific GPU id that the TensorRTEngine bounded to. + int device_; + + std::string precision_mode_; + TRTInt8Calibrator* calibrator_; + // batch size of the current data, will be updated each Executation. + int batch_size_{-1}; nvinfer1::ILogger& logger_; std::vector buffers_; @@ -169,8 +180,6 @@ class TensorRTEngine : public EngineBase { std::unordered_map itensor_map_; - // The specific GPU id that the TensorRTEngine bounded to. - int device_; std::vector> owned_plugin_; // TensorRT related internal members @@ -208,38 +217,6 @@ class TensorRTEngine : public EngineBase { #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ engine__->network()->add##layer__(ARGS); -/* - * Helper to control the TensorRT engine's creation and deletion. - */ -class TRT_EngineManager { - public: - bool HasEngine(const std::string& name) const { - return engines_.count(name) != 0; - } - - // Get an engine called `name`. - TensorRTEngine* Get(const std::string& name) const { - return engines_.at(name).get(); - } - - // Create or get an engine called `name` - TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream, - const std::string& name, int gpu_device = 0) { - auto* p = new TensorRTEngine(max_batch, max_workspace, stream, gpu_device); - engines_[name].reset(p); - return p; - } - - void DeleteALl() { - for (auto& item : engines_) { - item.second.reset(nullptr); - } - } - - private: - std::unordered_map> engines_; -}; - } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc new file mode 100644 index 0000000000000000000000000000000000000000..f935620020406d1b31e0a2a9addbfabf7c596f87 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc @@ -0,0 +1,144 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" +#include "glog/logging.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +// set the batch size before constructing the thread to execute engine +int TRTInt8Calibrator::getBatchSize() const { return batch_size_; } + +TRTInt8Calibrator::TRTInt8Calibrator( + const std::unordered_map& buffers, int batch_size, + std::string engine_name, const platform::Place place) + : batch_size_(batch_size), + calib_running_(true), + data_is_set_(false), + done_(false), + engine_name_(engine_name) { + int i = 0; + VLOG(4) << "Init a new calibrator: " << engine_name_; + for (const auto it : buffers) { + framework::Tensor temp_tensor; + std::string input_name = it.first; + int data_size = it.second; + int num_ele = data_size / sizeof(int16_t); + framework::DDim data_shape = framework::make_ddim({num_ele}); + temp_tensor.Resize(data_shape); + data_tensors_.push_back(temp_tensor); + data_buffers_[input_name] = std::pair( + static_cast(temp_tensor.mutable_data(place)), num_ele); + i += 1; + } +} + +TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data) + : batch_size_(0), + calib_running_(false), + data_is_set_(false), + done_(true), + calibration_table_(calib_data) {} + +void TRTInt8Calibrator::waitAndSetDone() { + std::unique_lock lk(mut_); + while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk); + if (!done_) { + done_ = true; + cond_.notify_all(); + } +} + +bool TRTInt8Calibrator::setBatch( + const std::unordered_map& data) { + VLOG(3) << "set batch: " << engine_name_; + std::unique_lock lk(mut_); + while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk); + if (done_) return false; + + // Sets the batch. + for (const auto it : data) { + auto dataptr = data_buffers_.find(it.first); + if (dataptr == data_buffers_.end()) { + LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first + << "' does not match with the buffer names"; + } + + const auto& d = dataptr->second; + auto status = + cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice); + if (status != cudaSuccess) { + LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first + << "' failed with " << status; + } + } + + data_is_set_ = true; + cond_.notify_all(); + return true; +} + +bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, + int num_bindings) { + VLOG(4) << "get batch: " << engine_name_; + std::unique_lock lk(mut_); + calib_running_ = false; + cond_.notify_all(); + + while (!data_is_set_ && !done_) cond_.wait(lk); + if (done_) return false; + + // Gets the batch + for (int i = 0; i < num_bindings; i++) { + auto it = data_buffers_.find(names[i]); + if (it == data_buffers_.end()) { + LOG(FATAL) << "Calibration engine asked for unknown tensor name '" + << names[i] << "' at position " << i; + } + bindings[i] = it->second.first; + } + + data_is_set_ = false; + calib_running_ = true; + VLOG(4) << "get batch done: " << engine_name_; + return true; +} + +void TRTInt8Calibrator::setDone() { + std::unique_lock lk(mut_); + done_ = true; + cond_.notify_all(); +} + +const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); +} + +void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, + std::size_t length) { + calibration_table_ = std::string((const char*)ptr, length); + VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr + << " length=" << length; +} +TRTInt8Calibrator::~TRTInt8Calibrator() { + VLOG(4) << "Destroying calibrator for " << engine_name_; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h new file mode 100644 index 0000000000000000000000000000000000000000..81ba9c7032c2022c6fc3c4421692564d7f6b8740 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h @@ -0,0 +1,128 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class TensorRTEngine; + +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { + public: + TRTInt8Calibrator(const std::unordered_map& buffers, + int batch_size, std::string engine_name, + const platform::Place place); + + explicit TRTInt8Calibrator(const std::string& calibration_data); + ~TRTInt8Calibrator(); + + int getBatchSize() const override; + + bool getBatch(void* bindings[], const char* names[], + int num_bindings) override; + + bool setBatch(const std::unordered_map& data); + void setDone(); + void waitAndSetDone(); + + const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; + const std::string& getCalibrationTableAsString() { + return calibration_table_; + } + + private: + const int batch_size_; + + bool calib_running_; + bool data_is_set_; + bool done_; + + std::mutex mut_; + std::condition_variable cond_; + + std::unordered_map> data_buffers_; + std::vector data_tensors_; + + std::string engine_name_; + std::string calibration_table_; +}; + +class TRTCalibratorRes { + public: + TRTCalibratorRes() {} + std::unique_ptr calib_; + std::unique_ptr thr_; + std::unique_ptr engine_; +}; +/* + * Manager to control the TensorRT Int8 calibration creation and deltetion. + */ +class TRTCalibratorResManager { + public: + bool Has() const { return res_.size() > 0; } + bool Has(const std::string& name) const { + if (res_.count(name) == 0) return false; + return res_.at(name).get() != nullptr; + } + + // Get Int8Calibrator via name + TRTCalibratorRes* Get(const std::string& name) const { + return res_.at(name).get(); + } + + // Look up or create a calibrator. + TRTCalibratorRes* LookupOrCreate(const std::string& engine_name) { + if (res_.count(engine_name) == 0) { + auto* p = new TRTCalibratorRes(); + res_[engine_name].reset(p); + } + return res_.at(engine_name).get(); + } + + // Create an Int8Calibrator + TRTCalibratorRes* Create(const std::string& engine_name) { + auto* p = new TRTCalibratorRes(); + res_[engine_name].reset(p); + return p; + } + + void DeleteALL() { + for (auto& item : res_) { + item.second.reset(nullptr); + } + } + + private: + std::unordered_map> res_; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc index b993c55fad13e892efd51648b78704bec83bf2b4..ed177eb18f827efdcafb669ce5c5a93bbfaebacc 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc @@ -29,8 +29,15 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Xs", "A list of inputs.").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph."); + AddAttr("calibration_data", "the calibration data for int8"); + AddAttr( + "engine_key", + "The engine_key here is used to distinguish different TRT Engines"); AddAttr("max_batch_size", "the maximum batch size."); AddAttr("workspace_size", "the workspace size."); + AddAttr("sub_block", "the trt block"); + AddAttr("precision_mode", + "the precision mode: 'FP32', 'INT8' "); AddComment("TensorRT engine operator."); } }; @@ -47,6 +54,6 @@ class TensorRTEngineInferVarType : public framework::VarTypeInference { namespace ops = paddle::operators; REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, - ops::TensorRTEngineOpMaker); + ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); #endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 88c4f508474e66953b79fb92ff1eb0b53a539f07..57747faec853ba5405c5984b99e0307a9217901f 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -17,8 +17,10 @@ #ifdef PADDLE_WITH_CUDA #include +#include #include +#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/inference/analysis/helper.h" @@ -62,6 +64,9 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { using inference::Singleton; using inference::tensorrt::TensorRTEngine; +using inference::tensorrt::TRTInt8Calibrator; +using inference::tensorrt::TRTCalibratorRes; +using inference::tensorrt::TRTCalibratorResManager; class TensorRTEngineOp : public framework::OperatorBase { private: @@ -70,6 +75,11 @@ class TensorRTEngineOp : public framework::OperatorBase { mutable std::unique_ptr trt_engine_; int max_batch_size_; int workspace_size_; + std::unique_ptr calibrator_; + std::string precision_mode_; + std::string calibration_data_; + std::string engine_key_; + bool calibration_mode_; public: TensorRTEngineOp(const std::string &type, @@ -80,26 +90,95 @@ class TensorRTEngineOp : public framework::OperatorBase { input_names_ = Inputs("Xs"); max_batch_size_ = Attr("max_batch_size"); workspace_size_ = Attr("workspace_size"); + precision_mode_ = Attr("precision_mode"); + calibration_data_ = Attr("calibration_data"); + engine_key_ = Attr("engine_key"); auto params = Attr>("parameters"); for (const auto ¶m : params) { param_names_.insert(param); } + calibration_mode_ = + (precision_mode_ == "INT8" && calibration_data_.size() == 0); + + if (precision_mode_ == "INT8" && calibration_data_.size()) { + calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); + } } protected: + void RunNative(const framework::Scope &scope, + const platform::Place &dev_place) const { + framework::Executor executor(dev_place); + auto *block = Attr("sub_block"); + auto *program = block->Program(); + auto *scope_ptr = const_cast(&scope); + auto ctx = executor.Prepare(*program, block->ID()); + executor.RunPreparedContext(ctx.get(), scope_ptr, false, true, true); + } + void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { + if (calibration_mode_ == true) { + RunCalibration(scope, dev_place); + return; + } RunTrt(scope, dev_place); } + void RunCalibration(const framework::Scope &scope, + const platform::Place &dev_place) const { + // Create calibrator here. + LOG(INFO) << "Running calibration trt int8 ..."; + int runtime_batch = 1; + if (!Singleton::Global().Has(engine_key_)) { + TRTCalibratorRes *calib_res = + Singleton::Global().Create(engine_key_); + std::unordered_map calib_buffers; + for (auto &x : input_names_) { + if (param_names_.count(x)) continue; + auto &t = + inference::analysis::GetFromScope(scope, x); + calib_buffers[x] = t.memory_size(); + auto t_shape = framework::vectorize(t.dims()); + runtime_batch = t_shape[0]; + } + calib_res->calib_.reset(new TRTInt8Calibrator( + calib_buffers, runtime_batch, engine_key_, dev_place)); + calib_res->thr_.reset(new std::thread([&]() { + calib_res->engine_.reset(new TensorRTEngine( + max_batch_size_, workspace_size_, nullptr, + boost::get(dev_place).device, precision_mode_, + calib_res->calib_.get())); + VLOG(3) << "start the calib trt engine thread"; + Prepare(scope, dev_place, calib_res->engine_.get()); + })); + } + + TRTInt8Calibrator *temp_calibrator = + Singleton::Global() + .Get(engine_key_) + ->calib_.get(); + std::unordered_map calib_data; + + for (auto &x : Inputs("Xs")) { + if (param_names_.count(x)) continue; + auto &t = + inference::analysis::GetFromScope(scope, x); + calib_data.emplace(x, t.data()); + } + temp_calibrator->setBatch(calib_data); + RunNative(scope, dev_place); + } + void RunTrt(const framework::Scope &scope, const platform::Place &dev_place) const { int runtime_batch = 1; if (trt_engine_.get() == nullptr) { - trt_engine_.reset(new TensorRTEngine( - max_batch_size_, workspace_size_, nullptr, - boost::get(dev_place).device)); + trt_engine_.reset( + new TensorRTEngine(max_batch_size_, workspace_size_, nullptr, + boost::get(dev_place).device, + precision_mode_, calibrator_.get())); Prepare(scope, dev_place, trt_engine_.get()); } @@ -168,7 +247,8 @@ class TensorRTEngineOp : public framework::OperatorBase { void Prepare(const framework::Scope &scope, const platform::Place &dev_place, TensorRTEngine *engine) const { - VLOG(4) << "Prepare engine"; + LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP " + "kernel etc). This process may cost a lot of time."; framework::proto::BlockDesc block_desc; block_desc.ParseFromString(Attr("subgraph"));