未验证 提交 dd67d44a 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Paddle-TRT] : (Part1) Dynamic shape support (#22868)

* change the ci trt from version 5. to 6.0

* paddle-trt dynamic shape support init

* conv+bias or conv+bn dynamic shape support
test=develop

* modity trt engine opconvert
test=develop

* fix ci error
test=develop
上级 62fd3209
......@@ -137,8 +137,8 @@ RUN curl -s -q https://glide.sh/get | sh
RUN wget -q https://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \
tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \
cp -rf /usr/local/TensorRT/include /usr && \
cp -rf /usr/local/TensorRT/lib /usr
cp -rf /usr/local/TensorRT/include/* /usr/include/ && \
cp -rf /usr/local/TensorRT/lib/* /usr/lib/
# git credential to skip password typing
RUN git config --global credential.helper store
......
......@@ -59,6 +59,7 @@ struct Argument {
using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
using fusion_statis_t = std::unordered_map<std::string, int>;
using input_shape_t = std::map<std::string, std::vector<int>>;
bool Has(const std::string& key) const { return valid_fields_.count(key); }
// If we set the model using config.SetModelBuffer,
......@@ -174,6 +175,12 @@ struct Argument {
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
// usually use for trt dynamic shape.
DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
......
......@@ -123,6 +123,13 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
pass->Set("use_static_engine", new bool(use_static_engine));
pass->Set("model_from_memory", new bool(argument->model_from_memory()));
pass->Set("max_input_shape", new std::map<std::string, std::vector<int>>(
argument->max_input_shape()));
pass->Set("min_input_shape", new std::map<std::string, std::vector<int>>(
argument->min_input_shape()));
pass->Set("optim_input_shape",
new std::map<std::string, std::vector<int>>(
argument->optim_input_shape()));
}
if (pass_name == "ngraph_subgraph_pass") {
pass->Set("program",
......
......@@ -166,6 +166,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
auto min_input_shape =
Get<std::map<std::string, std::vector<int>>>("min_input_shape");
auto max_input_shape =
Get<std::map<std::string, std::vector<int>>>("max_input_shape");
auto opt_input_shape =
Get<std::map<std::string, std::vector<int>>>("optim_input_shape");
// The following procedure is used to rename all the intermediate
// variables and the output variables of the subgraph.
......@@ -263,11 +269,33 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::copy(params_not_shared.begin(), params_not_shared.end(),
std::back_inserter(*repetitive_params));
// Check trt version for dynamic shape input.
if (min_input_shape.size() > 0 && TRT_VERSION < 6000) {
std::cout << "hello";
LOG_FIRST_N(WARNING, 1) << "You are using the dynamic size input mode of "
"Paddle-TRT, but we found that the version of "
"the TensorRT is less than 6.0, so we use the "
"static shape mode instead.";
min_input_shape = {};
max_input_shape = {};
opt_input_shape = {};
}
if (min_input_shape.size() > 0 && TRT_VERSION > 6000) {
LOG_FIRST_N(WARNING, 1)
<< "The Paddle lib links the " << TRT_VERSION / 1000.
<< " version TensorRT, "
<< "make sure the runtime TensorRT you are using is no less than this "
"version, otherwise, there might be Segfault!";
}
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key + std::to_string(predictor_id),
Get<int>("max_batch_size"), Get<int>("workspace_size"),
precision_mode, calibrator.get(), Get<int>("gpu_device_id"));
precision_mode, calibrator.get(), Get<int>("gpu_device_id"),
min_input_shape, max_input_shape, opt_input_shape);
bool need_serialize = (use_static_engine && !load_from_memory);
if (need_serialize) {
......
......@@ -125,6 +125,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Quantization related.
CP_MEMBER(use_mkldnn_quantizer_);
CP_MEMBER(mkldnn_quantizer_config_);
CP_MEMBER(min_input_shape_);
CP_MEMBER(max_input_shape_);
CP_MEMBER(optim_input_shape_);
CP_MEMBER(use_lite_);
CP_MEMBER(lite_precision_mode_);
......@@ -223,7 +226,10 @@ MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
void AnalysisConfig::EnableTensorRtEngine(
int workspace_size, int max_batch_size, int min_subgraph_size,
AnalysisConfig::Precision precision_mode, bool use_static,
bool use_calib_mode) {
bool use_calib_mode,
std::map<std::string, std::vector<int>> min_input_shape,
std::map<std::string, std::vector<int>> max_input_shape,
std::map<std::string, std::vector<int>> optim_input_shape) {
#ifdef PADDLE_WITH_CUDA
if (!use_gpu()) {
LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
......@@ -237,6 +243,9 @@ void AnalysisConfig::EnableTensorRtEngine(
tensorrt_precision_mode_ = precision_mode;
trt_use_static_engine_ = use_static;
trt_use_calib_mode_ = use_calib_mode;
min_input_shape_ = min_input_shape;
max_input_shape_ = max_input_shape;
optim_input_shape_ = optim_input_shape;
Update();
#else
......
......@@ -425,6 +425,9 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
}
if (config_.lite_engine_enabled()) {
......
......@@ -160,11 +160,13 @@ struct AnalysisConfig {
* @param min_subgrpah_size the minimum TensorRT subgraph size needed, if a
* subgraph is less than this, it will not transfer to TensorRT engine.
*/
void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1, int min_subgraph_size = 3,
Precision precision = Precision::kFloat32,
bool use_static = false,
bool use_calib_mode = true);
void EnableTensorRtEngine(
int workspace_size = 1 << 20, int max_batch_size = 1,
int min_subgraph_size = 3, Precision precision = Precision::kFloat32,
bool use_static = false, bool use_calib_mode = true,
std::map<std::string, std::vector<int>> min_input_shape = {},
std::map<std::string, std::vector<int>> max_input_shape = {},
std::map<std::string, std::vector<int>> optim_input_shape = {});
/** A boolean state telling whether the TensorRT engine is used.
*/
bool tensorrt_engine_enabled() const { return use_tensorrt_; }
......@@ -348,6 +350,9 @@ struct AnalysisConfig {
std::string serialized_info_cache_;
mutable std::unique_ptr<PassStrategy> pass_builder_;
std::map<std::string, std::vector<int>> min_input_shape_;
std::map<std::string, std::vector<int>> max_input_shape_;
std::map<std::string, std::vector<int>> optim_input_shape_;
bool use_lite_{false};
std::vector<std::string> lite_passes_filter_;
......
......@@ -49,8 +49,12 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto* X = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims dims_x = X->getDimensions();
PADDLE_ENFORCE(dims_x.nbDims >= 3, "x dims experts 3, but %d is given.",
dims_x.nbDims);
std::vector<int> no_batch_dims;
int start_index = 0;
if (engine_->with_dynamic_shape()) start_index = 1;
for (; start_index < dims_x.nbDims; start_index++)
no_batch_dims.push_back(dims_x.d[start_index]);
auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v);
......@@ -62,23 +66,23 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
std::vector<int> dims_y = framework::vectorize<int>(Y_t->dims());
if (static_cast<int>(dims_y.size()) == dims_x.nbDims + 1) {
if (dims_y.size() == no_batch_dims.size() + 1) {
if (dims_y[0] == 1) dims_y.erase(dims_y.begin());
}
if (static_cast<int>(dims_y.size()) == 1 && dims_y[0] == dims_x.d[0]) {
if (dims_y.size() == 1 && dims_y[0] == no_batch_dims[0]) {
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
} else if (static_cast<int>(dims_y.size()) == dims_x.nbDims &&
dims_y[0] == dims_x.d[0]) {
} else if (dims_y.size() == no_batch_dims.size() &&
dims_y[0] == no_batch_dims[0]) {
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
for (int i = 1; i < dims_x.nbDims; i++) {
if (dims_y[i] != dims_x.d[i]) {
for (size_t i = 1; i < no_batch_dims.size(); i++) {
if (dims_y[i] != no_batch_dims[i]) {
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
break;
}
}
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
for (int i = 1; i < dims_x.nbDims; i++) {
for (size_t i = 1; i < no_batch_dims.size(); i++) {
if (dims_y[i] != 1)
PADDLE_THROW(
"TensorRT unsupported weight shape for Elementwise op!");
......
......@@ -23,52 +23,13 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using FluidDT = framework::proto::VarType_Type;
using TRT_DT = nvinfer1::DataType;
namespace { // NOLINT
TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) {
case FluidDT::VarType_Type_FP32:
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
return TRT_DT::kINT32;
default:
return TRT_DT::kINT32;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown fluid datatype in TRT op converter"));
return TRT_DT::kINT32;
}
nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t>& shape,
std::string input) {
PADDLE_ENFORCE_GT(shape.size(), 1UL,
platform::errors::InvalidArgument(
"TensorRT's tensor input requires at least 2 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
PADDLE_ENFORCE_LE(shape.size(), 4UL,
platform::errors::InvalidArgument(
"TensorRT's tensor input requires at most 4 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
if (shape.size() == 4UL)
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
else if (shape.size() == 3UL)
return nvinfer1::Dims2(shape[1], shape[2]);
return nvinfer1::DimsCHW(shape[1], 1, 1);
}
} // namespace // NOLINT
/*
* Convert Op from Fluid to TensorRT Engine.
*/
......@@ -167,11 +128,37 @@ class OpConverter {
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input");
auto var_shape = var->GetShape();
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var_shape, input));
if (engine->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
auto min_input_shape = engine->min_input_shape()[input];
auto max_input_shape = engine->max_input_shape()[input];
auto optim_input_shape = engine->optim_input_shape()[input];
size_t ranks = min_input_shape.size();
std::vector<int64_t> input_shape;
input_shape.push_back(-1);
for (size_t i = 1; i < ranks; i++) {
if (min_input_shape[i] != max_input_shape[i]) {
input_shape.push_back(-1);
} else {
input_shape.push_back(min_input_shape[i]);
// the i dimension should be same.
PADDLE_ENFORCE_EQ(min_input_shape[i], optim_input_shape[i],
platform::errors::InvalidArgument(
"The dim (%d) of the min_input_shape and "
"optim_input_shape should be same."));
}
}
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(input_shape, input, true));
#endif
} else {
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var_shape, input));
}
}
framework::proto::BlockDesc* block_proto = block_desc->Proto();
ConvertBlock(*block_proto, parameters, scope, engine);
......
......@@ -28,23 +28,35 @@ namespace tensorrt {
int TensorRTEngine::runtime_batch_ = 1;
void TensorRTEngine::Build(const DescType &paddle_model) {
PADDLE_ENFORCE(false, "not implemented");
void TensorRTEngine::InitNetwork() {
freshDeviceId();
infer_builder_.reset(createInferBuilder(&logger_));
if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
infer_networkv2_.reset(infer_builder_->createNetworkV2(
1U << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
infer_builder_config_.reset(infer_builder_->createBuilderConfig());
infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_;
optim_profile_.reset(infer_builder_->createOptimizationProfile());
#endif
} else {
infer_network_.reset(infer_builder_->createNetwork());
}
}
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
cudaStream_t stream) {
freshDeviceId();
const std::thread::id tid = std::this_thread::get_id();
batch_size_ = batch_size;
if (infer_context_.find(tid) == infer_context_.end()) {
std::unique_lock<std::mutex> lock(mutex_);
PADDLE_ENFORCE_NOT_NULL(
infer_engine_,
"You should build engine first and then set the context.");
infer_context_[tid].reset(infer_engine_->createExecutionContext());
auto infer_context = context();
if (!with_dynamic_shape()) {
infer_context->enqueue(batch_size, buffers->data(), stream, nullptr);
} else {
#if IS_TRT_VERSION_GE(6000)
infer_context->enqueueV2(buffers->data(), stream, nullptr);
#endif
}
infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr);
SetRuntimeBatch(batch_size);
}
......@@ -53,8 +65,9 @@ void TensorRTEngine::FreezeNetwork() {
VLOG(3) << "TRT to freeze network";
PADDLE_ENFORCE(infer_builder_ != nullptr,
"Call InitNetwork first to initialize network.");
PADDLE_ENFORCE(infer_network_ != nullptr,
"Call InitNetwork first to initialize network.");
PADDLE_ENFORCE_EQ(network() != nullptr, true,
platform::errors::InvalidArgument(
"Call InitNetwork first to initialize network."));
// build engine.
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);
......@@ -66,6 +79,8 @@ void TensorRTEngine::FreezeNetwork() {
if (!support_fp16) {
LOG(INFO) << "You specify FP16 mode, but the hardware do not support "
"FP16 speed up, use FP32 instead.";
} else {
LOG(INFO) << "Run Paddle-TRT FP16 mode";
}
}
#else
......@@ -92,14 +107,14 @@ void TensorRTEngine::FreezeNetwork() {
}
std::unordered_set<nvinfer1::ITensor *> all_t;
for (int i = 0; i < infer_network_->getNbLayers(); i++) {
auto layer = infer_network_->getLayer(i);
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
for (int j = 0; j < layer->getNbOutputs(); j++) {
all_t.insert(layer->getOutput(j));
}
}
for (int i = 0; i < infer_network_->getNbInputs(); i++) {
all_t.insert(infer_network_->getInput(i));
for (int i = 0; i < network()->getNbInputs(); i++) {
all_t.insert(network()->getInput(i));
}
for (auto &t : all_t) {
......@@ -110,14 +125,14 @@ void TensorRTEngine::FreezeNetwork() {
}
}
std::unordered_set<std::string> all_out_t_name;
for (int i = 0; i < infer_network_->getNbOutputs(); i++) {
auto *temp = infer_network_->getOutput(i);
for (int i = 0; i < network()->getNbOutputs(); i++) {
auto *temp = network()->getOutput(i);
temp->setDynamicRange(-1, 1);
all_out_t_name.insert(temp->getName());
}
for (int i = 0; i < infer_network_->getNbLayers(); i++) {
auto layer = infer_network_->getLayer(i);
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
for (int j = 0; j < layer->getNbOutputs(); j++) {
auto *temp_out = layer->getOutput(j);
if (std::find(all_out_t_name.begin(), all_out_t_name.end(),
......@@ -127,26 +142,41 @@ void TensorRTEngine::FreezeNetwork() {
}
}
}
#endif
}
}
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
for (auto &input : min_input_shape_) {
optim_profile_->setDimensions(
input.first.c_str(), nvinfer1::OptProfileSelector::kMIN,
Vec2TRT_Dims(input.second, input.first, true));
optim_profile_->setDimensions(
input.first.c_str(), nvinfer1::OptProfileSelector::kMAX,
Vec2TRT_Dims(max_input_shape_[input.first], input.first, true));
optim_profile_->setDimensions(
input.first.c_str(), nvinfer1::OptProfileSelector::kOPT,
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
}
infer_builder_config_->addOptimizationProfile(optim_profile_.get());
infer_engine_.reset(infer_builder_->buildEngineWithConfig(
*network(), *infer_builder_config_));
#endif
} else {
infer_engine_.reset(infer_builder_->buildCudaEngine(*network()));
}
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
}
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
nvinfer1::DataType dtype,
const nvinfer1::Dims &dims) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name);
PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
PADDLE_ENFORCE_EQ(network() != nullptr, true,
platform::errors::InvalidArgument(
"The TRT network should be initialized first."));
auto *input = network()->addInput(name.c_str(), dtype, dims);
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
PADDLE_ENFORCE(input->isNetworkInput());
TensorRTEngine::SetITensor(name, input);
return input;
......@@ -154,37 +184,21 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
const std::string &name) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
name);
auto *output = layer->getOutput(offset);
SetITensor(name, output);
PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str());
PADDLE_ENFORCE(!output->isNetworkInput());
infer_network_->markOutput(*output);
network()->markOutput(*output);
PADDLE_ENFORCE(output->isNetworkOutput());
// output buffers' size can only be decided later, set zero here to mark this
// and will reset later.
buffer_sizes_[name] = 0;
}
bool TensorRTEngine::HasDeclared(const std::string &name) {
return buffer_sizes_.count(name) > 0;
}
void TensorRTEngine::DeclareOutput(const std::string &name) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
name);
auto *output = TensorRTEngine::GetITensor(name);
PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str());
PADDLE_ENFORCE(!output->isNetworkInput());
infer_network_->markOutput(*output);
// output buffers' size can only be decided later, set zero here to mark this
// and will reset later.
buffer_sizes_[name] = 0;
network()->markOutput(*output);
}
void TensorRTEngine::SetITensor(const std::string &name,
......@@ -253,7 +267,7 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int num_inputs,
plugin::PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin);
return network()->addPluginExt(inputs, num_inputs, *plugin);
}
void TensorRTEngine::freshDeviceId() {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <NvInfer.h>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -36,6 +37,57 @@ namespace paddle {
namespace inference {
namespace tensorrt {
using FluidDT = framework::proto::VarType_Type;
using TRT_DT = nvinfer1::DataType;
namespace { // NOLINT
TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) {
case FluidDT::VarType_Type_FP32:
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
return TRT_DT::kINT32;
default:
return TRT_DT::kINT32;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"unknown fluid datatype in TRT op converter"));
return TRT_DT::kINT32;
}
// The T can be int32 or int64 type.
template <typename T>
nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
bool with_dynamic_shape = false) {
PADDLE_ENFORCE_GT(shape.size(), 1UL,
platform::errors::InvalidArgument(
"TensorRT's tensor input requires at least 2 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
PADDLE_ENFORCE_LE(shape.size(), 4UL,
platform::errors::InvalidArgument(
"TensorRT's tensor input requires at most 4 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
if (!with_dynamic_shape) {
if (shape.size() == 4UL) {
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
} else if (shape.size() == 3UL) {
return nvinfer1::Dims2(shape[1], shape[2]);
}
return nvinfer1::DimsCHW(shape[1], 1, 1);
} else {
if (shape.size() == 4UL) {
return nvinfer1::DimsNCHW(shape[0], shape[1], shape[2], shape[3]);
} else if (shape.size() == 3UL) {
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
}
return nvinfer1::Dims4(shape[0], shape[1], 1, 1);
}
}
} // NOLINT
class TRTInt8Calibrator;
/*
* TensorRT Engine.
......@@ -45,6 +97,7 @@ class TRTInt8Calibrator;
*/
class TensorRTEngine {
using DescType = ::paddle::framework::proto::BlockDesc;
using ShapeMapType = std::map<std::string, std::vector<int>>;
public:
// Weight is model parameter.
......@@ -68,33 +121,44 @@ class TensorRTEngine {
int max_batch, int max_workspace,
AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
TRTInt8Calibrator* calibrator = nullptr, int device_id = 0,
const ShapeMapType min_input_shape = {},
const ShapeMapType max_input_shape = {},
const ShapeMapType optim_input_shape = {},
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
precision_(precision),
calibrator_(calibrator),
device_id_(device_id),
logger_(logger) {}
min_input_shape_(min_input_shape),
max_input_shape_(max_input_shape),
optim_input_shape_(optim_input_shape),
logger_(logger) {
if (min_input_shape_.size() != 0 && max_input_shape_.size() != 0 &&
optim_input_shape_.size() != 0) {
PADDLE_ENFORCE_EQ(
min_input_shape_.size(), max_input_shape_.size(),
platform::errors::InvalidArgument(
"The min_input_shape_'s size(%d) should be equal to the "
"size(%d) of max_input_shape_",
min_input_shape_.size(), max_input_shape_.size()));
PADDLE_ENFORCE_EQ(
min_input_shape_.size(), optim_input_shape_.size(),
platform::errors::InvalidArgument(
"The min_input_shape_'s size(%d) should be equal to the "
"size(%d) of optim_input_shape_",
min_input_shape_.size(), optim_input_shape_.size()));
#if IS_TRT_VERSION_GE(6000)
with_dynamic_shape_ = true;
#else
LOG(WARNING) << "Using dynamic shape of TRT need ensure that the TRT "
"version should be at least 6.";
#endif
}
}
~TensorRTEngine() {}
// TODO(Superjomn) implement it later when graph segmentation is supported.
void Build(const DescType& paddle_model);
void Execute(int batch_size, std::vector<void*>* buffers,
cudaStream_t stream = nullptr);
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork() {
freshDeviceId();
infer_builder_.reset(createInferBuilder(&logger_));
infer_network_.reset(infer_builder_->createNetwork());
}
// After finishing adding ops, freeze this network and creates the execution
// environment.
void FreezeNetwork();
// Add an input and set its name, data type and dimension.
nvinfer1::ITensor* DeclareInput(const std::string& name,
nvinfer1::DataType dtype,
......@@ -105,15 +169,24 @@ class TensorRTEngine {
const std::string& name);
// Set the itensor_map_[name] as the network's output, and set its name.
void DeclareOutput(const std::string& name);
// Check if the ITensor has been declared
bool HasDeclared(const std::string& name);
void SetITensor(const std::string& name, nvinfer1::ITensor* tensor);
// Get an ITensor called name.
nvinfer1::ITensor* GetITensor(const std::string& name);
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
nvinfer1::IExecutionContext* context() {
std::unique_lock<std::mutex> lock(mutex_);
const std::thread::id tid = std::this_thread::get_id();
if (infer_context_.find(tid) == infer_context_.end()) {
PADDLE_ENFORCE_NOT_NULL(
infer_engine_,
platform::errors::InvalidArgument(
"You should build engine first and then set the context."));
infer_context_[tid].reset(infer_engine_->createExecutionContext());
}
return infer_context_[tid].get();
}
nvinfer1::IHostMemory* Serialize() {
PADDLE_ENFORCE(infer_engine_ != nullptr,
......@@ -170,6 +243,30 @@ class TensorRTEngine {
}
}
// NOTE: The func bellow was modified to adapt the dynamic shape.
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork();
// After finishing adding ops, freeze this network and creates the execution
// environment.
void FreezeNetwork();
void Execute(int batch_size, std::vector<void*>* buffers,
cudaStream_t stream = nullptr);
nvinfer1::INetworkDefinition* network() {
if (with_dynamic_shape_) {
return infer_networkv2_.get();
} else {
return infer_network_.get();
}
}
ShapeMapType min_input_shape() { return min_input_shape_; }
ShapeMapType max_input_shape() { return max_input_shape_; }
ShapeMapType optim_input_shape() { return optim_input_shape_; }
bool with_dynamic_shape() { return with_dynamic_shape_; }
private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
......@@ -189,10 +286,12 @@ class TensorRTEngine {
int batch_size_{-1};
int device_id_;
ShapeMapType min_input_shape_;
ShapeMapType max_input_shape_;
ShapeMapType optim_input_shape_;
nvinfer1::ILogger& logger_;
// max data size for the buffers.
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
itensor_map_;
......@@ -216,13 +315,17 @@ class TensorRTEngine {
infer_context_;
infer_ptr<nvinfer1::IHostMemory> ihost_memory_;
std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_;
// For dynamic shape
bool with_dynamic_shape_{false};
infer_ptr<nvinfer1::INetworkDefinition> infer_networkv2_;
#if IS_TRT_VERSION_GE(6000)
infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_;
std::unique_ptr<nvinfer1::IOptimizationProfile> optim_profile_;
#endif
std::mutex mutex_;
}; // class TensorRTEngine
#define IS_TRT_VERSION_GE(version) \
((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \
NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) >= version)
// Add a layer__ into engine__ with args ARGS.
// For example:
//
......@@ -252,9 +355,13 @@ class TRTEngineManager {
std::string name, int max_batch, int max_workspace,
AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
TRTInt8Calibrator* calibrator = nullptr, int device_id = 0,
const std::map<std::string, std::vector<int>> min_input_shape = {},
const std::map<std::string, std::vector<int>> max_input_shape = {},
const std::map<std::string, std::vector<int>> optim_input_shape = {},
nvinfer1::ILogger& logger = NaiveLogger::Global()) {
auto* p = new TensorRTEngine(max_batch, max_workspace, precision,
calibrator, device_id, logger);
calibrator, device_id, min_input_shape,
max_input_shape, optim_input_shape, logger);
engines_[name].reset(p);
return p;
}
......
......@@ -27,6 +27,14 @@ namespace paddle {
namespace inference {
namespace tensorrt {
#define IS_TRT_VERSION_GE(version) \
((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \
NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) >= version)
#define TRT_VERSION \
NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \
NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD
namespace dy = paddle::platform::dynload;
// TensorRT data type to size
......@@ -103,6 +111,14 @@ class NaiveProfiler : public nvinfer1::IProfiler {
}
};
inline size_t ProductDim(const nvinfer1::Dims& dims) {
size_t v = 1;
for (int i = 0; i < dims.nbDims; i++) {
v *= dims.d[i];
}
return v;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -374,6 +374,14 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_quant_int8_test SRCS trt_quant_int8_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR})
set(TEST_TRT_DYNAMIC_MODEL "${TRT_MODEL_INSTALL_DIR}/test_trt_dy_conv")
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL})
inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL} ${INFERENCE_URL}/tensorrt_test "test_trt_dy_conv.tar.gz")
endif()
inference_analysis_test(trt_dynamic_shape_test SRCS trt_dynamic_shape_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_DYNAMIC_MODEL})
endif()
set(LITE_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lite")
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(AnalysisPredictor, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/test_trt_dy_conv";
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.SwitchUseFeedFetchOps(false);
// Set the input's min, max, opt shape
std::map<std::string, std::vector<int>> min_input_shape = {
{"image", {1, 1, 3, 3}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"image", {1, 1, 10, 10}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"image", {1, 1, 3, 3}}};
config.EnableTensorRtEngine(
1 << 30, 1, 1, AnalysisConfig::Precision::kFloat32, false, true,
min_input_shape, max_input_shape, opt_input_shape);
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int channels = 1;
int height = 3;
int width = 3;
int input_num = channels * height * width * 1;
float *input = new float[input_num];
memset(input, 0, input_num * sizeof(float));
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({1, channels, height, width});
input_t->copy_from_cpu(input);
ASSERT_TRUE(predictor->ZeroCopyRun());
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
}
} // namespace inference
} // namespace paddle
......@@ -29,6 +29,7 @@
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
namespace paddle {
......@@ -40,6 +41,29 @@ using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
static void RuntimeStaticShapeCheck(std::vector<int64_t> runtime_input_shape,
std::vector<int64_t> model_input_shape) {
auto comma_fold = [](std::string a, int b) {
return std::move(a) + ", " + std::to_string(b);
};
std::string model_input_shape_str = std::accumulate(
std::next(model_input_shape.begin()), model_input_shape.end(),
std::to_string(model_input_shape[0]), comma_fold);
std::string runtime_input_shape_str = std::accumulate(
std::next(runtime_input_shape.begin()), runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]), comma_fold);
PADDLE_ENFORCE_EQ(
model_input_shape == runtime_input_shape, true,
platform::errors::InvalidArgument(
"Input shapes are inconsistent with the model. Expect [%s] in "
"model description, but got [%s] in runtime. TRT 5 "
"or lower version "
"does not support dynamic input shapes. Please check and "
"modify "
"your input shapes.",
model_input_shape_str, runtime_input_shape_str));
}
class TensorRTEngineOp : public framework::OperatorBase {
private:
std::vector<std::string> input_names_;
......@@ -206,39 +230,28 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
auto t_shape = framework::vectorize<int64_t>(t.dims());
// check if the input shapes are consistent with model.
if (HasAttr(x + "_shape")) {
std::vector<int64_t> i_shape = Attr<std::vector<int64_t>>(x + "_shape");
std::vector<int64_t> model_input_shape(i_shape.begin() + 1,
i_shape.end());
std::vector<int64_t> runtime_input_shape(t_shape.begin() + 1,
t_shape.end());
auto comma_fold = [](std::string a, int b) {
return std::move(a) + ", " + std::to_string(b);
};
std::string model_input_shape_str = std::accumulate(
std::next(model_input_shape.begin()), model_input_shape.end(),
std::to_string(model_input_shape[0]), comma_fold);
std::string runtime_input_shape_str = std::accumulate(
std::next(runtime_input_shape.begin()), runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]), comma_fold);
PADDLE_ENFORCE_EQ(
model_input_shape == runtime_input_shape, true,
platform::errors::InvalidArgument(
"Input shapes are inconsistent with the model. Expect [%s] in "
"model description, but got [%s] in runtime. TRT 5 "
"or lower version "
"does not support dynamic input shapes. Please check and "
"modify "
"your input shapes.",
model_input_shape_str, runtime_input_shape_str));
}
runtime_batch = t_shape[0];
const int bind_index = engine->engine()->getBindingIndex(x.c_str());
PADDLE_ENFORCE(bind_index < num_bindings,
"The bind index should be less than num_bindings");
if (!engine->with_dynamic_shape()) {
// check if the input shapes are consistent with model.
if (HasAttr(x + "_shape")) {
std::vector<int64_t> i_shape =
Attr<std::vector<int64_t>>(x + "_shape");
std::vector<int64_t> model_input_shape(i_shape.begin() + 1,
i_shape.end());
std::vector<int64_t> runtime_input_shape(t_shape.begin() + 1,
t_shape.end());
RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape);
}
} else {
#if IS_TRT_VERSION_GE(6000)
auto *trt_context = engine->context();
trt_context->setBindingDimensions(
bind_index, inference::tensorrt::Vec2TRT_Dims(t_shape, x, true));
#endif
}
buffers[bind_index] = static_cast<void *>(t.data<float>());
}
......@@ -248,13 +261,20 @@ class TensorRTEngineOp : public framework::OperatorBase {
for (const auto &y : Outputs("Ys")) {
const int bind_index =
engine->engine()->getBindingIndex(output_maps[output_index].c_str());
auto dims = engine->engine()->getBindingDimensions(bind_index);
// Use the output ITensor's dims to reshape the Fluid Tensor.
// The ITensor doesn't contain the batch size dim.
std::vector<int> ddim;
ddim.push_back(runtime_batch);
for (int i = 0; i < dims.nbDims; i++) {
ddim.push_back(dims.d[i]);
if (!engine->with_dynamic_shape()) {
auto dims = engine->engine()->getBindingDimensions(bind_index);
ddim.push_back(runtime_batch);
for (int i = 0; i < dims.nbDims; i++) {
ddim.push_back(dims.d[i]);
}
} else {
#if IS_TRT_VERSION_GE(6000)
auto *trt_context = engine->context();
auto dims = trt_context->getBindingDimensions(bind_index);
for (int i = 0; i < dims.nbDims; i++) ddim.push_back(dims.d[i]);
#endif
}
auto *fluid_v = scope.FindVar(y);
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
......@@ -289,7 +309,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
runtime_batch, max_batch_size_));
// Execute the engine.
engine->Execute(runtime_batch, &buffers, stream);
cudaStreamSynchronize(stream);
}
TensorRTEngine *GetEngine(const framework::Scope &scope,
......
......@@ -412,7 +412,13 @@ void BindAnalysisConfig(py::module *m) {
py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
py::arg("use_static") = false, py::arg("use_calib_mode") = true)
py::arg("use_static") = false, py::arg("use_calib_mode") = true,
py::arg("min_input_shape") =
std::map<std::string, std::vector<int>>({}),
py::arg("max_input_shape") =
std::map<std::string, std::vector<int>>({}),
py::arg("optim_input_shape") =
std::map<std::string, std::vector<int>>({}))
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true)
......
......@@ -6,8 +6,8 @@ REPO="${REPO:-paddlepaddle}"
cp -f ../../python/requirements.txt .
sed 's#FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04#FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04#g' ../../Dockerfile |
sed 's#TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz#TensorRT_5.1_ga_cuda9_cudnnv7.5.tar.gz#g' |
sed 's#/usr/local/TensorRT#/usr/local/TensorRT_5.1_ga_cuda9_cudnnv7.5#g' |
sed 's#TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz#TensorRT-6.0.1.5.Ubuntu-16.04.x86_64-gnu.cuda-9.0.cudnn7.6.tar.gz#g' |
sed 's#/usr/local/TensorRT#/usr/local/TensorRT-6.0.1.5#g' |
sed 's#libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0#libnccl2=2.4.7-1+cuda9.0 libnccl-dev=2.4.7-1+cuda9.0#g' |
sed 's#COPY ./paddle/scripts/docker/root/#COPY ./docker/root/#g' |
sed 's#COPY ./python/requirements.txt#COPY ./requirements.txt#' > Dockerfile.cuda9.0-cudnn7
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册