未验证 提交 6b10c0e5 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] save_optimized_model_pass support tensorrt (#55893)

* fix cudnn 8.7+ bug on cudnnConvolutionBiasActivationForward

* save_optimized_model_pass support tensorrt

* update

* update

* fix compile

* update

* fix ut timeout
上级 68b0cf92
......@@ -249,7 +249,7 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
subgraphes_[i] = graph->GetSubGraph(i);
all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
<< "op nodes";
<< " op nodes";
for (auto* var_node : subgraphes_[i]->Nodes()) {
if (!var_node->IsVar()) continue;
......
......@@ -64,10 +64,6 @@ void NaiveExecutor::Run() {
VLOG(4) << std::this_thread::get_id() << " run "
<< op->DebugStringEx(scope_) << " on scope " << scope_;
op->SetIsCalledByExecutor(false);
#ifdef PADDLE_WITH_NVTX
platform::CudaNvtxRangePush(op->Type() + "|" + op->OutputVars(true).front(),
platform::NvtxRangeColor::Green);
#endif
for (auto &func : input_hookfuncs_) {
func(op.get(), scope_);
......@@ -77,7 +73,14 @@ void NaiveExecutor::Run() {
op->SetOutputHooks(output_hookfuncs_);
}
#ifdef PADDLE_WITH_NVTX
platform::CudaNvtxRangePush(op->Type() + "|" + op->OutputVars(true).front(),
platform::NvtxRangeColor::Green);
#endif
op->Run(*scope_, place_);
#ifdef PADDLE_WITH_NVTX
platform::CudaNvtxRangePop();
#endif
// Update the shared_holder so that only records the max one.
if (reuse_cache_.count(op.get())) {
......@@ -105,9 +108,6 @@ void NaiveExecutor::Run() {
}
}
#ifdef PADDLE_WITH_NVTX
platform::CudaNvtxRangePop();
#endif
for (auto &func : output_hookfuncs_) {
func(op.get(), scope_);
}
......
......@@ -14,8 +14,10 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include <fcntl.h>
#include <cstddef>
#include <memory>
#include <string>
#include <unordered_set>
......@@ -32,6 +34,7 @@
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
......@@ -124,11 +127,6 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);
static std::once_flag trt_plugin_registered;
std::call_once(trt_plugin_registered, []() {
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
});
auto model_precision =
static_cast<phi::DataType>(Get<int>("model_precision"));
if (model_precision == phi::DataType::BFLOAT16) {
......@@ -291,7 +289,6 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
// Add new block for TensorRTEngineOP
const framework::BlockDesc &main_block =
program_desc->Block(framework::kRootBlockIndex);
// const framework::BlockDesc& main_block = program_desc->Block(0);
framework::BlockDesc *new_block = program_desc->AppendBlock(main_block);
// A fake block desc.
......@@ -319,9 +316,9 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
// is unique.
std::set<std::string> input_names;
std::set<std::string> input_names_with_id;
std::vector<std::string> params;
// if we delete fluid copy of params shared by more than 1 ops, there will be
// problem, so we filter them out.
std::vector<std::string> parameters;
// if we delete fluid copy of parameters shared by more than 1 ops, there will
// be problem, so we filter them out.
std::vector<std::string> params_not_shared;
auto *scope = param_scope();
......@@ -330,7 +327,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
input_names.insert(x->Name());
input_names_with_id.insert(x->Name() + std::to_string(x->id()));
if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) {
params.push_back(x->Name());
parameters.push_back(x->Name());
}
if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0 &&
x->outputs.size() <= 1) {
......@@ -340,33 +337,15 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
// So we reserved a name for later use when casting INT64 -> INT32 or
// FP64->FP32. We must check whether scope has had the same name var!
if (x->Var()->GetDataType() == framework::proto::VarType::INT64) {
std::string tmp_name = x->Name() + "_cast_to_INT32";
LOG(WARNING)
<< "tensorrt_subgraph's input named " << x->Name()
<< " having int64 dtype in pdmodel description, we will cast them to "
"int32 dtype to feed them into paddle-trt.";
/*
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.",
tmp_name));
*/
scope->Var(tmp_name);
} else if (x->Var()->GetDataType() == framework::proto::VarType::FP64) {
std::string tmp_name = x->Name() + "_cast_to_FP32";
LOG(WARNING) << "tensorrt_subgraph's input named " << x->Name()
<< " having float64 dtype in pdmodel description, we will "
"cast them to "
"float32 dtype to feed them into paddle-trt.";
/*
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.",
tmp_name));
*/
scope->Var(tmp_name);
}
}
......@@ -412,10 +391,10 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
graph_var_map[node->Name()] = node;
}
}
auto precision_mode = Get<int>("trt_precision_mode");
auto precision_mode =
static_cast<phi::DataType>(Get<int>("trt_precision_mode"));
bool enable_fp16 = false;
if (precision_mode == static_cast<int>(phi::DataType::FLOAT16))
enable_fp16 = true;
if (precision_mode == phi::DataType::FLOAT16) enable_fp16 = true;
auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
......@@ -423,14 +402,14 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
Get<std::map<std::string, std::vector<int>>>("min_input_shape");
auto max_input_shape =
Get<std::map<std::string, std::vector<int>>>("max_input_shape");
auto opt_input_shape =
auto optim_input_shape =
Get<std::map<std::string, std::vector<int>>>("optim_input_shape");
auto min_shape_tensor =
Get<std::map<std::string, std::vector<int>>>("min_shape_tensor");
auto max_shape_tensor =
Get<std::map<std::string, std::vector<int>>>("max_shape_tensor");
auto opt_shape_tensor =
auto optim_shape_tensor =
Get<std::map<std::string, std::vector<int>>>("optim_shape_tensor");
auto allow_build_at_runtime = Get<bool>("trt_allow_build_at_runtime");
......@@ -444,10 +423,10 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
inference::DeserializeShapeRangeInfo(shape_range_info_path,
&min_input_shape,
&max_input_shape,
&opt_input_shape,
&optim_input_shape,
&min_shape_tensor,
&max_shape_tensor,
&opt_shape_tensor);
&optim_shape_tensor);
} else {
shape_range_info_path =
Get<std::string>("model_opt_cache_dir") + "shape_range_info.pbtxt";
......@@ -457,10 +436,10 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
inference::DeserializeShapeRangeInfo(shape_range_info_path,
&min_input_shape,
&max_input_shape,
&opt_input_shape,
&optim_input_shape,
&min_shape_tensor,
&max_shape_tensor,
&opt_shape_tensor);
&optim_shape_tensor);
} else {
int fd = open(shape_range_info_path.c_str(), O_WRONLY | O_CREAT, 0644);
close(fd);
......@@ -509,32 +488,20 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
if (static_cast<framework::proto::VarType_Type>(
map_origin_outputs_dtype[name]) ==
framework::proto::VarType::INT64) {
std::string tmp_name = name + "_cast_to_INT64";
LOG(WARNING) << "tensorrt_subgraph's output named " << name
<< " having int64 dtype in pdmodel description, but in fact "
"it is int32 "
"dtype after executing this tensorrt_subgraph, so we "
"need cast them into int64.";
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.", tmp_name));
scope->Var(tmp_name);
} else if (static_cast<framework::proto::VarType_Type>(
map_origin_outputs_dtype[name]) ==
framework::proto::VarType::FP64) {
std::string tmp_name = name + "_cast_to_FP64";
LOG(WARNING)
<< "tensorrt_subgraph's output named " << name
<< " having float64 dtype in pdmodel description, but in fact "
"it is float32 "
"dtype after executing this tensorrt_subgraph, so we "
"need cast them into float64.";
PADDLE_ENFORCE_EQ(scope->FindVar(tmp_name),
nullptr,
platform::errors::InvalidArgument(
"The var name %s has exists in scope.", tmp_name));
scope->Var(tmp_name);
}
}
PADDLE_ENFORCE_EQ(output_mapping.empty(),
......@@ -546,30 +513,73 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
true,
platform::errors::PreconditionNotMet("the block has no var-desc"));
// Set attrs
// Get pass attrs.
auto use_varseqlen = Get<bool>("use_varseqlen");
auto with_interleaved = Get<bool>("with_interleaved");
auto tensorrt_transformer_posid =
Get<std::string>("tensorrt_transformer_posid");
auto tensorrt_transformer_maskid =
Get<std::string>("tensorrt_transformer_maskid");
auto use_dla = Get<bool>("trt_use_dla");
auto dla_core = Get<int>("trt_dla_core");
auto use_inspector = Get<bool>("use_inspector");
auto disable_trt_plugin_fp16 = Get<bool>("disable_trt_plugin_fp16");
auto context_memory_sharing = Get<bool>("context_memory_sharing");
auto enable_low_precision_io = Get<bool>("enable_low_precision_io");
auto workspace_size = Get<int64_t>("workspace_size");
auto gpu_device_id = Get<int>("gpu_device_id");
// Set op's attrs.
op_desc->SetType("tensorrt_engine");
op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
op_desc->SetOutput(
"Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
op_desc->SetBlockAttr("sub_block", new_block);
op_desc->SetAttr("subgraph", block_desc.Proto()->SerializeAsString());
op_desc->SetAttr("origin_outputs_dtype", origin_outputs_dtype);
op_desc->SetAttr("max_batch_size", max_batch_size);
op_desc->SetAttr("workspace_size", Get<int64_t>("workspace_size"));
op_desc->SetAttr("gpu_id", Get<int>("gpu_device_id"));
op_desc->SetAttr("workspace_size", workspace_size);
op_desc->SetAttr("gpu_device_id", gpu_device_id);
op_desc->SetAttr("output_name_mapping", output_mapping);
op_desc->SetAttr("origin_output_rank", renamed_output_rank);
op_desc->SetAttr("parameters", params);
op_desc->SetAttr("parameters", parameters);
op_desc->SetAttr("allow_build_at_runtime", allow_build_at_runtime);
op_desc->SetAttr("shape_range_info_path", shape_range_info_path);
op_desc->SetAttr("use_inspector", Get<bool>("use_inspector"));
op_desc->SetAttr("model_precision", Get<int>("model_precision"));
op_desc->SetAttr("use_inspector", use_inspector);
op_desc->SetAttr("with_dynamic_shape", with_dynamic_shape);
op_desc->SetAttr("enable_low_precision_io",
Get<bool>("enable_low_precision_io"));
op_desc->SetAttr("enable_low_precision_io", enable_low_precision_io);
if (!trt_tuned_dynamic_shape) {
std::vector<std::string> dynamic_shape_names;
std::vector<int> dynamic_shape_lens;
std::vector<int> min_input_shape_vector;
std::vector<int> max_input_shape_vector;
std::vector<int> opt_input_shape_vector;
for (const auto &it : min_input_shape) {
dynamic_shape_names.push_back(it.first);
dynamic_shape_lens.push_back(it.second.size());
for (const auto &value : it.second) {
min_input_shape_vector.push_back(value);
}
}
for (const auto &it : max_input_shape) {
for (const auto &value : it.second) {
max_input_shape_vector.push_back(value);
}
}
for (const auto &it : optim_input_shape) {
for (const auto &value : it.second) {
opt_input_shape_vector.push_back(value);
}
}
op_desc->SetAttr("dynamic_shape_names", dynamic_shape_names);
op_desc->SetAttr("dynamic_shape_lens", dynamic_shape_lens);
op_desc->SetAttr("min_input_shape_vector", min_input_shape_vector);
op_desc->SetAttr("max_input_shape_vector", max_input_shape_vector);
op_desc->SetAttr("opt_input_shape_vector", opt_input_shape_vector);
}
// we record all inputs' shapes in attr to check if they are consistent
// with the real inputs' shapes retrieved from scope when trt runs.
......@@ -624,14 +634,20 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetAttr("engine_key", engine_key);
op_desc->SetAttr("calibration_engine_key", calibration_engine_key);
op_desc->SetAttr("predictor_id", predictor_id);
std::string trt_engine_serialized_data = "";
op_desc->SetAttr("use_varseqlen", use_varseqlen);
op_desc->SetAttr("with_interleaved", with_interleaved);
op_desc->SetAttr("use_dla", use_dla);
op_desc->SetAttr("dla_core", dla_core);
op_desc->SetAttr("disable_trt_plugin_fp16", disable_trt_plugin_fp16);
op_desc->SetAttr("context_memory_sharing", context_memory_sharing);
std::string trt_engine_serialized_data;
op_desc->SetAttr("engine_serialized_data", trt_engine_serialized_data);
op_desc->Flush();
std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator;
if (enable_int8 && !calibration_data.empty()) {
calibrator.reset(new tensorrt::TRTInt8Calibrator(calibration_data));
calibrator =
std::make_unique<tensorrt::TRTInt8Calibrator>(calibration_data);
LOG(INFO) << "RUN Paddle TRT int8 calibration mode...";
}
// When in int8 mode and calibration_mode, the program just produce the
......@@ -656,7 +672,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
"static shape mode instead.";
min_input_shape = {};
max_input_shape = {};
opt_input_shape = {};
optim_input_shape = {};
}
const float trt_compile_version = tensorrt::TrtMajorVersion(TRT_VERSION);
......@@ -677,42 +693,33 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
framework::ir::Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
// run fp16.
// When running fp16, the output accuracy of the model will be affected,
// closing the plugin fp16 may bring some improvement on accuracy.
bool disable_trt_plugin_fp16 = Get<bool>("disable_trt_plugin_fp16");
tensorrt::TensorRTEngine::ConstructionParams params;
params.max_batch_size = max_batch_size;
params.max_workspace_size = workspace_size;
params.calibrator = calibrator.get();
params.device_id = gpu_device_id;
params.with_dynamic_shape = with_dynamic_shape;
params.min_input_shape = min_input_shape;
params.max_input_shape = max_input_shape;
params.optim_input_shape = optim_input_shape;
params.min_shape_tensor = min_shape_tensor;
params.max_shape_tensor = max_shape_tensor;
params.optim_shape_tensor = optim_shape_tensor;
params.disable_trt_plugin_fp16 = disable_trt_plugin_fp16;
params.precision = precision_mode;
params.use_varseqlen = use_varseqlen;
params.use_dla = use_dla;
params.dla_core = dla_core;
params.with_interleaved = with_interleaved;
params.tensorrt_transformer_posid = tensorrt_transformer_posid;
params.tensorrt_transformer_maskid = tensorrt_transformer_maskid;
params.context_memory_sharing = context_memory_sharing;
params.use_inspector = use_inspector;
params.enable_low_precision_io = enable_low_precision_io;
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key + std::to_string(predictor_id),
max_batch_size,
Get<int64_t>("workspace_size"),
static_cast<phi::DataType>(precision_mode),
calibrator.get(),
Get<int>("gpu_device_id"),
with_dynamic_shape,
min_input_shape,
max_input_shape,
opt_input_shape,
min_shape_tensor,
max_shape_tensor,
opt_shape_tensor,
disable_trt_plugin_fp16,
static_cast<phi::DataType>(Get<int>("model_precision")));
trt_engine->SetUseOSS(Get<bool>("use_varseqlen"));
trt_engine->SetWithInterleaved(Get<bool>("with_interleaved"));
trt_engine->SetTransformerPosid(
Get<std::string>("tensorrt_transformer_posid"));
trt_engine->SetTransformerMaskid(
Get<std::string>("tensorrt_transformer_maskid"));
trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
trt_engine->SetDLACore(Get<int>("trt_dla_core"));
trt_engine->SetUseInspector(Get<bool>("use_inspector"));
trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
trt_engine->SetContextMemorySharing(Get<bool>("context_memory_sharing"));
trt_engine->SetLowPrecisionIO(Get<bool>("enable_low_precision_io"));
.Create(engine_key + std::to_string(predictor_id), params);
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
......@@ -749,13 +756,14 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
"kernel etc). This process may cost a lot of time.";
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
std::unordered_set<std::string> param_set(params.begin(), params.end());
std::unordered_set<std::string> parameters_set(parameters.begin(),
parameters.end());
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(
&block_desc_temp,
*scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set,
parameters_set,
output_mapping,
trt_engine);
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/executor.h"
......@@ -63,8 +64,7 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
"set."));
}
auto graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(argument->main_program()));
auto graph = std::make_unique<framework::ir::Graph>(argument->main_program());
argument->SetMainGraph(graph.release());
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(scope_ptr,
......
......@@ -24,16 +24,6 @@ namespace inference {
namespace analysis {
void SaveOptimizedModelPass::SaveOptimizedModel(Argument* argument) {
if (!argument->save_optimized_model()) {
LOG(WARNING) << "save_optim_cache_model is turned off, skip "
"save_optimized_model_pass";
return;
}
if (!argument->enable_ir_optim()) {
LOG(WARNING) << "ir_optim is turned off, skip save_optimized_model_pass";
return;
}
std::string model_opt_cache_dir = argument->optim_cache_dir();
if (!model_opt_cache_dir.empty()) {
if (!PathExists(model_opt_cache_dir)) {
......@@ -55,9 +45,11 @@ void SaveOptimizedModelPass::SaveOptimizedModel(Argument* argument) {
auto* graph = argument->main_graph_ptr();
framework::ProgramDesc optimized_program_desc;
// NOTE(liuyuanle): If the following line of code is not added, an error
// [SegmentFault] may occur!
optimized_program_desc.CopyFrom(*argument->main_program().Proto());
framework::ir::GraphToProgram(*graph, &optimized_program_desc);
auto IsPersistable = [](const framework::VarDesc* var) {
......@@ -133,11 +125,10 @@ void SaveOptimizedModelPass::SaveOptimizedModel(Argument* argument) {
}
void SaveOptimizedModelPass::RunImpl(Argument* argument) {
// TODO(inference): Support trt.
if (argument->use_xpu() ||
(argument->use_gpu() && !argument->use_tensorrt())) {
SaveOptimizedModel(argument);
if (!argument->save_optimized_model() || !argument->enable_ir_optim()) {
return;
}
SaveOptimizedModel(argument);
}
std::string SaveOptimizedModelPass::repr() const {
......
......@@ -641,7 +641,7 @@ bool AnalysisPredictor::PrepareProgram(
}
bool AnalysisPredictor::CreateExecutor() {
executor_.reset(new paddle::framework::NaiveExecutor(place_));
executor_ = std::make_unique<paddle::framework::NaiveExecutor>(place_);
return true;
}
......@@ -1341,7 +1341,7 @@ bool AnalysisPredictor::GetFetch(std::vector<paddle::Tensor> *outputs,
void AnalysisPredictor::PrepareArgument() {
VLOG(3) << "AnalysisPredictor::PrepareArgument";
// Init std::unique_ptr argument_.
argument_.reset(new Argument);
argument_ = std::make_unique<Argument>();
argument_->SetUseGPU(config_.use_gpu());
argument_->SetUseCutlass(config_.use_cutlass_);
argument_->SetUseFcPadding(config_.use_fc_padding());
......@@ -1570,7 +1570,8 @@ void AnalysisPredictor::PrepareArgument() {
if (!config_.ir_optim()) {
argument_->SetEnableIrOptim(false);
if (config_.enable_gpu_mixed_) {
if (config_.enable_gpu_mixed_ &&
model_precision_ == phi::DataType::FLOAT32) {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
......@@ -1886,6 +1887,10 @@ AnalysisPredictor::GetInputTypes() {
input_type[name] = paddle_infer::DataType::UINT8;
} else if (dtype == paddle::framework::proto::VarType::INT8) {
input_type[name] = paddle_infer::DataType::INT8;
} else if (dtype == paddle::framework::proto::VarType::FP64) {
input_type[name] = paddle_infer::DataType::FLOAT64;
} else if (dtype == paddle::framework::proto::VarType::BOOL) {
input_type[name] = paddle_infer::DataType::BOOL;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported data type `%s` when get input dtype ", dtype));
......@@ -2609,7 +2614,7 @@ AnalysisPredictor::~AnalysisPredictor() {
#ifdef PADDLE_WITH_TENSORRT
if (config_.trt_engine_memory_sharing()) {
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.releaseContextMemory(predictor_id_);
.ReleaseContextMemory(predictor_id_);
}
#endif
}
......
......@@ -167,7 +167,7 @@ class OpConverter {
op_desc.Type()));
it->SetEngine(engine);
engine->SetScope(scope);
engine->SetScope(&scope);
it->SetBlockDesc(block);
(*it)(op, scope, test_mode);
......@@ -301,7 +301,7 @@ class OpConverter {
nvinfer1::DataType in_dtype = FluidDataType2TRT(var->GetDataType());
if (engine->precision() == phi::DataType::FLOAT16 &&
in_dtype == nvinfer1::DataType::kFLOAT &&
engine->EnableLowPrecisionIO()) {
engine->LowPrecisionIOEnabled()) {
in_dtype = nvinfer1::DataType::kHALF;
}
......@@ -360,7 +360,7 @@ class OpConverter {
nvinfer1::DataType out_dtype = FluidDataType2TRT(var->GetDataType());
if (engine->precision() == phi::DataType::FLOAT16 &&
out_dtype == nvinfer1::DataType::kFLOAT &&
engine->EnableLowPrecisionIO()) {
engine->LowPrecisionIOEnabled()) {
out_dtype = nvinfer1::DataType::kHALF;
}
engine->DeclareOutput(output, out_dtype);
......@@ -470,7 +470,7 @@ class OpConverter {
auto shape = newShape->getDimensions();
shuffle->setReshapeDimensions(shape);
}
if (name != "") {
if (!name.empty()) {
shuffle->setName(name.c_str());
}
return shuffle->getOutput(0);
......@@ -481,7 +481,7 @@ class OpConverter {
const std::string& name = "") {
auto* shuffle = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
shuffle->setReshapeDimensions(shape);
if (name != "") {
if (!name.empty()) {
shuffle->setName(name.c_str());
}
return shuffle->getOutput(0);
......@@ -774,11 +774,6 @@ class OpConverter {
bool test_mode_;
private:
// registered op converter map, whose key is the fluid op type, and value is
// the pointer position of corresponding OpConverter class.
std::unordered_map<std::string, OpConverter*> converters_;
// fluid inference scope
framework::Scope* scope_{nullptr};
std::mutex mut_;
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h> // NOLINT
#include <memory>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
......@@ -95,7 +96,11 @@ TEST(CustomPluginCreater, StaticShapePlugin) {
// init trt engine
std::unique_ptr<TensorRTEngine> engine_;
engine_.reset(new TensorRTEngine(5, 1 << 15));
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 5;
params.max_workspace_size = 1 << 15;
engine_ = std::make_unique<TensorRTEngine>(params);
engine_->InitNetwork();
engine_->DeclareInput(
......@@ -173,15 +178,10 @@ TEST(CustomPluginCreater, DynamicShapePlugin) {
std::map<std::string, std::vector<int>> optim_input_shape = {
{"x", {1, 2, 5, 5}}};
engine_.reset(new TensorRTEngine(5,
1 << 15,
phi::DataType::FLOAT32,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape));
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 5;
params.max_workspace_size = 1 << 15;
engine_ = std::make_unique<TensorRTEngine>(params);
engine_->InitNetwork();
LOG(INFO) << "with_dynamic_shape " << engine_->with_dynamic_shape();
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include <gtest/gtest.h> // NOLINT
#include <memory>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
......@@ -28,7 +30,10 @@ TEST(OpConverter, ConvertBlock) {
// init trt engine
std::unique_ptr<TensorRTEngine> engine_;
engine_.reset(new TensorRTEngine(5, 1 << 15));
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 5;
params.max_workspace_size = 1 << 15;
engine_ = std::make_unique<TensorRTEngine>(params);
engine_->InitNetwork();
engine_->DeclareInput(
......
......@@ -88,7 +88,10 @@ class TRTConvertValidation {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_),
0,
platform::errors::External("cudaStreamCreate error."));
engine_.reset(new TensorRTEngine(max_batch_size, workspace_size));
TensorRTEngine::ConstructionParams params;
params.max_batch_size = max_batch_size;
params.max_workspace_size = workspace_size;
engine_ = std::make_unique<TensorRTEngine>(params);
engine_->InitNetwork();
}
......@@ -155,7 +158,7 @@ class TRTConvertValidation {
engine_->FreezeNetwork();
// Declare outputs.
op_desc_.reset(new framework::OpDesc(desc, nullptr));
op_desc_ = std::make_unique<framework::OpDesc>(desc, nullptr);
}
// We use the set 'neglected_output' here, because some Ops like batch norm,
......
......@@ -21,17 +21,15 @@ limitations under the License. */
#include "NvInferRuntimeCommon.h"
#include "cuda_runtime_api.h" // NOLINT
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
int TensorRTEngine::runtime_batch_ = 1;
thread_local int TensorRTEngine::predictor_id_per_thread = -1;
void TensorRTEngine::Weight::SetDataType(phi::DataType type) {
......@@ -64,10 +62,10 @@ void TensorRTEngine::Weight::SetDataType(phi::DataType type) {
}
void TensorRTEngine::InitNetwork() {
freshDeviceId();
FreshDeviceId();
infer_builder_.reset(createInferBuilder(&logger_));
if (with_dynamic_shape_) {
if (with_dynamic_shape()) {
infer_network_.reset(infer_builder_->createNetworkV2(
1U << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
......@@ -92,7 +90,7 @@ nvinfer1::IExecutionContext *TensorRTEngine::context() {
// IExecutionContext...
// It's ok. We will set it later.
nvinfer1::IExecutionContext *infer_context{nullptr};
if (context_memory_sharing_) {
if (params_.context_memory_sharing) {
infer_context =
infer_engine_->createExecutionContextWithoutDeviceMemory();
} else {
......@@ -102,7 +100,7 @@ nvinfer1::IExecutionContext *TensorRTEngine::context() {
infer_context,
platform::errors::InvalidArgument(
"TensorRT engine can not build execution context."));
if (with_dynamic_shape_) {
if (with_dynamic_shape()) {
// need new profile if it's not the first
if (cur_profile_num_ > 0) {
infer_context->setOptimizationProfile(cur_profile_num_);
......@@ -118,15 +116,15 @@ nvinfer1::IExecutionContext *TensorRTEngine::context() {
void TensorRTEngine::Execute(int batch_size,
std::vector<void *> *buffers,
cudaStream_t stream) {
freshDeviceId();
FreshDeviceId();
auto infer_context = context();
if (context_memory_sharing_) {
if (params_.context_memory_sharing) {
void *context_memory{nullptr};
context_memory =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.getContextMemory(
.GetContextMemory(
predictor_id_per_thread,
phi::GPUPlace(device_id_),
phi::GPUPlace(device_id()),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
infer_context->setDeviceMemory(context_memory);
}
......@@ -182,12 +180,11 @@ bool TensorRTEngine::Enqueue(nvinfer1::IExecutionContext *context,
} else {
ret = context->enqueueV2(buffers->data(), stream, nullptr);
}
SetRuntimeBatch(batch_size);
return ret;
}
void TensorRTEngine::FreezeNetwork() {
freshDeviceId();
FreshDeviceId();
VLOG(3) << "TRT to freeze network";
PADDLE_ENFORCE_NOT_NULL(infer_builder_,
platform::errors::InvalidArgument(
......@@ -197,17 +194,17 @@ void TensorRTEngine::FreezeNetwork() {
platform::errors::InvalidArgument(
"Call InitNetwork first to initialize network."));
// build engine.
if (!with_dynamic_shape_) {
infer_builder_->setMaxBatchSize(max_batch_);
if (!with_dynamic_shape()) {
infer_builder_->setMaxBatchSize(params_.max_batch_size);
}
#if IS_TRT_VERSION_GE(8300)
infer_builder_config_->setMemoryPoolLimit(
nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_);
nvinfer1::MemoryPoolType::kWORKSPACE, params_.max_workspace_size);
#else
infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
infer_builder_config_->setMaxWorkspaceSize(params_.max_workspace_size);
#endif
bool enable_fp16 = (precision_ == phi::DataType::FLOAT16);
bool enable_fp16 = (precision() == phi::DataType::FLOAT16);
if (enable_fp16) {
bool support_fp16 = infer_builder_->platformHasFastFp16();
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
......@@ -219,15 +216,15 @@ void TensorRTEngine::FreezeNetwork() {
}
}
bool enable_int8 = (precision_ == phi::DataType::INT8);
bool enable_int8 = (precision() == phi::DataType::INT8);
if (enable_int8) {
if (!use_dla_) {
if (!use_dla()) {
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
}
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kINT8);
if (calibrator_) {
infer_builder_config_->setInt8Calibrator(calibrator_);
if (params_.calibrator) {
infer_builder_config_->setInt8Calibrator(params_.calibrator);
} else {
infer_builder_config_->setInt8Calibrator(nullptr);
......@@ -259,7 +256,7 @@ void TensorRTEngine::FreezeNetwork() {
}
}
if (use_dla_) {
if (use_dla()) {
if (!enable_int8 && !enable_fp16) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
"set float32, so DLA is not used.";
......@@ -268,42 +265,43 @@ void TensorRTEngine::FreezeNetwork() {
<< "TensorRT DLA is set by config, but your device does not have "
"DLA, so DLA is not used.";
} else {
if (dla_core_ < 0 || dla_core_ >= infer_builder_->getNbDLACores()) {
dla_core_ = 0;
if (params_.dla_core < 0 ||
params_.dla_core >= infer_builder_->getNbDLACores()) {
params_.dla_core = 0;
LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < "
<< infer_builder_->getNbDLACores() << ", but got "
<< dla_core_ << ", so use use 0 as default.";
<< params_.dla_core << ", so use use 0 as default.";
}
infer_builder_config_->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
infer_builder_config_->setDLACore(dla_core_);
infer_builder_config_->setDLACore(params_.dla_core);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
LOG(INFO) << "TensorRT DLA enabled in FreezeNetwork(), DLACore "
<< dla_core_;
<< params_.dla_core;
}
}
if (with_dynamic_shape_) {
if (with_dynamic_shape()) {
LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode.";
for (int i = 0; i < max_profile_num_; i++) {
for (auto &input : min_input_shape_) {
for (auto &input : min_input_shape()) {
#if IS_TRT_VERSION_LT(7100)
// trt6/trt7011 will check all_of input > 0
if (!(std::all_of(input.second.begin(),
input.second.end(),
[](int x) { return x > 0; }) &&
std::all_of(max_input_shape_[input.first].begin(),
max_input_shape_[input.first].end(),
std::all_of(max_input_shape()[input.first].begin(),
max_input_shape()[input.first].end(),
[](int x) { return x > 0; }) &&
std::all_of(optim_input_shape_[input.first].begin(),
optim_input_shape_[input.first].end(),
std::all_of(optim_input_shape()[input.first].begin(),
optim_input_shape()[input.first].end(),
[](int x) { return x > 0; }))) {
continue;
}
#endif
VLOG(4) << "TRT dynamic_shape set " << input.first
<< " min: " << Vec2Str(input.second)
<< ", max: " << Vec2Str(max_input_shape_[input.first])
<< ", opt: " << Vec2Str(optim_input_shape_[input.first]);
<< ", max: " << Vec2Str(max_input_shape()[input.first])
<< ", opt: " << Vec2Str(optim_input_shape()[input.first]);
optim_profiles_[i]->setDimensions(
input.first.c_str(),
......@@ -312,38 +310,39 @@ void TensorRTEngine::FreezeNetwork() {
optim_profiles_[i]->setDimensions(
input.first.c_str(),
nvinfer1::OptProfileSelector::kMAX,
Vec2TRT_Dims(max_input_shape_[input.first], input.first, true));
Vec2TRT_Dims(max_input_shape()[input.first], input.first, true));
optim_profiles_[i]->setDimensions(
input.first.c_str(),
nvinfer1::OptProfileSelector::kOPT,
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
Vec2TRT_Dims(optim_input_shape()[input.first], input.first, true));
}
for (int input_id = 0; input_id < network()->getNbInputs(); input_id++) {
auto input_name = network()->getInput(input_id)->getName();
if (!itensor_map_.count(input_name)) continue;
if (!GetITensor(input_name)->isShapeTensor()) continue;
PADDLE_ENFORCE_EQ(min_shape_tensor_.count(input_name) &&
max_shape_tensor_.count(input_name) &&
optim_shape_tensor_.count(input_name),
PADDLE_ENFORCE_EQ(min_shape_tensor().count(input_name) > 0 &&
max_shape_tensor().count(input_name) > 0 &&
optim_shape_tensor().count(input_name) > 0,
true,
platform::errors::InvalidArgument(
"Fail to find min/max/optim shape value for TRT "
"network's shape tensor input named %s.",
input_name));
auto min_vec = min_shape_tensor_.at(input_name);
auto min_vec = min_shape_tensor().at(input_name);
optim_profiles_[i]->setShapeValues(input_name,
nvinfer1::OptProfileSelector::kMIN,
min_vec.data(),
min_vec.size());
optim_profiles_[i]->setShapeValues(input_name,
nvinfer1::OptProfileSelector::kMAX,
max_shape_tensor_[input_name].data(),
min_vec.size());
optim_profiles_[i]->setShapeValues(
input_name,
nvinfer1::OptProfileSelector::kMAX,
max_shape_tensor()[input_name].data(),
min_vec.size());
optim_profiles_[i]->setShapeValues(
input_name,
nvinfer1::OptProfileSelector::kOPT,
optim_shape_tensor_[input_name].data(),
optim_shape_tensor()[input_name].data(),
min_vec.size());
}
......@@ -358,7 +357,7 @@ void TensorRTEngine::FreezeNetwork() {
}
}
#if IS_TRT_VERSION_GE(8200)
if (use_inspector_) {
if (params_.use_inspector) {
infer_builder_config_->setProfilingVerbosity(
nvinfer1::ProfilingVerbosity::kDETAILED);
}
......@@ -388,12 +387,12 @@ void TensorRTEngine::FreezeNetwork() {
cur_profile_num_ = 0;
}
// for engine context memory sharing
if (context_memory_sharing_) {
if (params_.context_memory_sharing) {
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.updateContextMemorySize(infer_engine_->getDeviceMemorySize(),
.UpdateContextMemorySize(infer_engine_->getDeviceMemorySize(),
predictor_id_per_thread);
}
if (use_inspector_) {
if (params_.use_inspector) {
GetEngineInfo();
}
}
......@@ -539,7 +538,7 @@ nvinfer1::ITensor *TensorRTEngine::ConvertWeight2ITensor(
}
// In fact , this is not always right, because we can't determine if the 0th
// dimension is batch. Just for run chenqu's model
if (!this->with_dynamic_shape()) {
if (!with_dynamic_shape()) {
trt_in_shape.nbDims--;
for (int i = 0; i < trt_in_shape.nbDims; i++) {
trt_in_shape.d[i] = trt_in_shape.d[i + 1];
......@@ -563,12 +562,12 @@ std::unordered_map<std::string, nvinfer1::ITensor *>
}
void TensorRTEngine::Deserialize(const std::string &engine_serialized_data) {
freshDeviceId();
FreshDeviceId();
infer_runtime_.reset(createInferRuntime(&logger_));
if (use_dla_) {
if (precision_ != phi::DataType::INT8 &&
precision_ != phi::DataType::FLOAT16) {
if (use_dla()) {
if (precision() != phi::DataType::INT8 &&
precision() != phi::DataType::FLOAT16) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
"set float32, so DLA is not used.";
} else if (infer_runtime_->getNbDLACores() == 0) {
......@@ -576,15 +575,16 @@ void TensorRTEngine::Deserialize(const std::string &engine_serialized_data) {
<< "TensorRT DLA is set by config, but your device does not have "
"DLA, so DLA is not used.";
} else {
if (dla_core_ < 0 || dla_core_ >= infer_runtime_->getNbDLACores()) {
dla_core_ = 0;
if (params_.dla_core < 0 ||
params_.dla_core >= infer_runtime_->getNbDLACores()) {
params_.dla_core = 0;
LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < "
<< infer_runtime_->getNbDLACores() << ", but got "
<< dla_core_ << ", so use use 0 as default.";
<< params_.dla_core << ", so use use 0 as default.";
}
infer_runtime_->setDLACore(dla_core_);
infer_runtime_->setDLACore(params_.dla_core);
LOG(INFO) << "TensorRT DLA enabled in Deserialize(), DLACore "
<< dla_core_;
<< params_.dla_core;
}
}
......@@ -602,20 +602,16 @@ void TensorRTEngine::Deserialize(const std::string &engine_serialized_data) {
binding_num_ = infer_engine_->getNbBindings();
// for engine context memory sharing
if (context_memory_sharing_) {
if (params_.context_memory_sharing) {
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.updateContextMemorySize(infer_engine_->getDeviceMemorySize(),
.UpdateContextMemorySize(infer_engine_->getDeviceMemorySize(),
predictor_id_per_thread);
}
if (use_inspector_) {
if (params_.use_inspector) {
GetEngineInfo();
}
}
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
runtime_batch_ = batch_size;
}
// Note: Only for support plugin.
TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
const std::string &name, const phi::DenseTensor &weight_tensor) {
......@@ -830,8 +826,6 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
return weight;
}
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs,
int num_inputs,
......@@ -856,16 +850,16 @@ nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt(
return network()->addPluginV2(inputs, num_inputs, *plugin);
}
void TensorRTEngine::freshDeviceId() {
void TensorRTEngine::FreshDeviceId() {
int count;
cudaGetDeviceCount(&count);
PADDLE_ENFORCE_LT(device_id_,
PADDLE_ENFORCE_LT(device_id(),
count,
platform::errors::OutOfRange(
"Device id %d exceeds the current device count: %d.",
device_id_,
device_id(),
count));
platform::SetDeviceId(device_id_);
platform::SetDeviceId(device_id());
}
void TensorRTEngine::GetEngineInfo() {
......
......@@ -148,7 +148,7 @@ class NaiveProfiler : public nvinfer1::IProfiler {
typedef std::pair<std::string, float> Record;
std::vector<Record> mProfile;
virtual void reportLayerTime(const char* layerName, float ms) TRT_NOEXCEPT {
void reportLayerTime(const char* layerName, float ms) TRT_NOEXCEPT override {
auto record =
std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r) {
return r.first == layerName;
......@@ -235,6 +235,130 @@ static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) {
return nv_type;
}
using FluidDT = paddle::framework::proto::VarType_Type;
using TRT_DT = nvinfer1::DataType;
static TRT_DT FluidDataType2TRT(FluidDT type) {
switch (type) {
case FluidDT::VarType_Type_FP32:
case FluidDT::VarType_Type_FP64:
return TRT_DT::kFLOAT;
case FluidDT::VarType_Type_INT32:
case FluidDT::VarType_Type_INT64:
return TRT_DT::kINT32;
case FluidDT::VarType_Type_FP16:
return TRT_DT::kHALF;
#if IS_TRT_VERSION_GE(8400)
case FluidDT::VarType_Type_BOOL:
return TRT_DT::kBOOL;
#endif
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"unsupported datatype in TRT op converter, type: %s. "
"Boolean type is supported as TRT input/output "
"using TensorRT v8.4+.",
VarType_Type_Name(type)));
}
return TRT_DT::kINT32;
}
// The T can be int32 or int64 type.
template <typename T>
static nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape,
std::string input,
bool with_dynamic_shape = false) {
PADDLE_ENFORCE_GE(shape.size(),
0UL,
paddle::platform::errors::InvalidArgument(
"TensorRT's tensor input requires at least 0 "
"dimensions, but input %s has %d dims.",
input,
shape.size()));
auto ShapeStr = [](const std::vector<T>& shape) {
std::ostringstream os;
os << "[";
for (size_t i = 0; i < shape.size(); ++i) {
if (i == shape.size() - 1) {
os << shape[i];
} else {
os << shape[i] << ",";
}
}
os << "]";
return os.str();
};
if (!with_dynamic_shape) {
if (shape.size() == 4UL) {
if (shape[2] == -1 || shape[3] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input,
ShapeStr(shape)));
}
return nvinfer1::Dims3(shape[1], shape[2], shape[3]);
} else if (shape.size() == 5UL) {
if (shape[2] == -1 || shape[3] == -1 || shape[4] == -1) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input,
ShapeStr(shape)));
}
return nvinfer1::Dims4(shape[1], shape[2], shape[3], shape[4]);
} else if (shape.size() == 3UL) {
if (shape[1] == -1 || shape[2] == -1) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input,
ShapeStr(shape)));
}
return nvinfer1::Dims2(shape[1], shape[2]);
} else if (shape.size() == 2UL) {
if (shape[1] == -1) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input,
ShapeStr(shape)));
}
nvinfer1::Dims dims;
dims.nbDims = 1;
dims.d[0] = shape[1];
return dims;
}
// static shape doesn't support 1D op so far.
PADDLE_ENFORCE_NE(shape.size(),
1UL,
paddle::platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s."
"it's not supported by trt so far",
input,
ShapeStr(shape)));
nvinfer1::Dims dims;
dims.nbDims = shape.size() - 1;
for (size_t i = 1; i < shape.size(); i++) {
dims.d[i - 1] = shape[i];
}
return dims;
} else {
if (shape.size() == 4UL) {
return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]);
} else if (shape.size() == 3UL) {
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
}
nvinfer1::Dims dims;
dims.nbDims = shape.size();
for (size_t i = 0; i < shape.size(); i++) {
dims.d[i] = shape[i];
}
return dims;
}
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h"
#endif
......@@ -35,7 +36,7 @@ namespace tensorrt {
class TensorRTDynamicShapeValueEngineTest : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_ = std::make_unique<phi::GPUContext>(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
......@@ -65,29 +66,21 @@ class TensorRTDynamicShapeValueEngineTest : public ::testing::Test {
{"shape", {18, 8, 4}}};
std::map<std::string, std::vector<int>> optim_input_value = {
{"shape", {18, 8, 4}}};
engine_ = new TensorRTEngine(16,
1 << 10,
phi::DataType::FLOAT32,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
min_input_value,
max_input_value,
optim_input_value,
false,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 16;
params.max_workspace_size = 1 << 10;
params.with_dynamic_shape = true;
params.min_input_shape = min_input_shape;
params.max_input_shape = max_input_shape;
params.optim_input_shape = optim_input_shape;
params.min_shape_tensor = min_input_value;
params.max_shape_tensor = max_input_value;
params.optim_shape_tensor = optim_input_value;
engine_ = std::make_unique<TensorRTEngine>(params, NaiveLogger::Global());
engine_->InitNetwork();
}
void PrepareInputOutput(const std::vector<float> &input,
......@@ -106,8 +99,8 @@ class TensorRTDynamicShapeValueEngineTest : public ::testing::Test {
phi::DenseTensor input_;
phi::DenseTensor shape_;
phi::DenseTensor output_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
std::unique_ptr<TensorRTEngine> engine_;
std::unique_ptr<phi::GPUContext> ctx_;
};
TEST_F(TensorRTDynamicShapeValueEngineTest, test_trt_dynamic_shape_value) {
......@@ -167,7 +160,7 @@ TEST_F(TensorRTDynamicShapeValueEngineTest, test_trt_dynamic_shape_value) {
class TensorRTDynamicEngineTest : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_ = std::make_unique<phi::GPUContext>(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
......@@ -192,29 +185,18 @@ class TensorRTDynamicEngineTest : public ::testing::Test {
std::map<std::string, std::vector<int>> optim_input_shape = {
{"input", {16, 32, 1, 1}}};
engine_ = new TensorRTEngine(16,
1 << 10,
phi::DataType::FLOAT16,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
false,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 16;
params.max_workspace_size = 1 << 10;
params.with_dynamic_shape = true;
params.precision = phi::DataType::FLOAT16;
params.min_input_shape = min_input_shape;
params.max_input_shape = max_input_shape;
params.optim_input_shape = optim_input_shape;
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
engine_ = std::make_unique<TensorRTEngine>(params, NaiveLogger::Global());
engine_->InitNetwork();
}
void PrepareInputOutput(const std::vector<float16> &input,
......@@ -230,8 +212,8 @@ class TensorRTDynamicEngineTest : public ::testing::Test {
protected:
phi::DenseTensor input_;
phi::DenseTensor output_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
std::unique_ptr<TensorRTEngine> engine_;
std::unique_ptr<phi::GPUContext> ctx_;
};
TEST_F(TensorRTDynamicEngineTest, test_spmm) {
......@@ -336,7 +318,7 @@ TEST_F(TensorRTDynamicEngineTest, test_spmm) {
class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_ = std::make_unique<phi::GPUContext>(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
......@@ -370,29 +352,18 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
engine_ = new TensorRTEngine(16,
1 << 10,
phi::DataType::FLOAT32,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
false,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 16;
params.max_workspace_size = 1 << 10;
params.precision = phi::DataType::FLOAT32;
params.with_dynamic_shape = true;
params.min_input_shape = min_input_shape;
params.max_input_shape = max_input_shape;
params.optim_input_shape = optim_input_shape;
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
engine_ = std::make_unique<TensorRTEngine>(params, NaiveLogger::Global());
engine_->InitNetwork();
}
void PrepareInputOutput(const std::vector<std::vector<float>> inputs,
......@@ -419,13 +390,12 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
protected:
std::vector<phi::DenseTensor> inputs_;
std::vector<phi::DenseTensor> outputs_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
std::unique_ptr<TensorRTEngine> engine_;
std::unique_ptr<phi::GPUContext> ctx_;
};
TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kFLOAT, nvinfer1::Dims2{-1, 4});
auto *x = engine_->DeclareInput(
......@@ -545,7 +515,7 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_ = std::make_unique<phi::GPUContext>(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
......@@ -579,29 +549,17 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
engine_ = new TensorRTEngine(16,
1 << 10,
phi::DataType::FLOAT16,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
false,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 16;
params.max_workspace_size = 1 << 10;
params.precision = phi::DataType::FLOAT16;
params.with_dynamic_shape = true;
params.min_input_shape = min_input_shape;
params.max_input_shape = max_input_shape;
params.optim_input_shape = optim_input_shape;
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
engine_ = std::make_unique<TensorRTEngine>(params, NaiveLogger::Global());
engine_->InitNetwork();
}
void PrepareInputOutput(const std::vector<std::vector<float16>> inputs,
......@@ -628,13 +586,12 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
protected:
std::vector<phi::DenseTensor> inputs_;
std::vector<phi::DenseTensor> outputs_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
std::unique_ptr<TensorRTEngine> engine_;
std::unique_ptr<phi::GPUContext> ctx_;
};
TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims2{-1, 4});
auto *x = engine_->DeclareInput(
......@@ -754,7 +711,7 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
class TensorRTDynamicShapeGNTest : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new phi::GPUContext(platform::CUDAPlace(0));
ctx_ = std::make_unique<phi::GPUContext>(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
......@@ -782,29 +739,18 @@ class TensorRTDynamicShapeGNTest : public ::testing::Test {
std::map<std::string, std::vector<int>> max_input_value = {};
std::map<std::string, std::vector<int>> optim_input_value = {};
engine_ = new TensorRTEngine(16,
1 << 10,
phi::DataType::INT8,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
min_input_value,
max_input_value,
optim_input_value,
false,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 16;
params.max_workspace_size = 1 << 10;
params.precision = phi::DataType::INT8;
params.with_dynamic_shape = true;
params.min_input_shape = min_input_shape;
params.max_input_shape = max_input_shape;
params.optim_input_shape = optim_input_shape;
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
engine_ = std::make_unique<TensorRTEngine>(params, NaiveLogger::Global());
engine_->InitNetwork();
}
void PrepareInputOutput(const std::vector<float> &input,
......@@ -923,8 +869,8 @@ class TensorRTDynamicShapeGNTest : public ::testing::Test {
protected:
phi::DenseTensor x_;
phi::DenseTensor y_;
TensorRTEngine *engine_;
phi::GPUContext *ctx_;
std::unique_ptr<TensorRTEngine> engine_;
std::unique_ptr<phi::GPUContext> ctx_;
// case from SD
int n_ = 2;
int c_ = 320;
......@@ -942,8 +888,6 @@ class TensorRTDynamicShapeGNTest : public ::testing::Test {
/*
TEST_F(TensorRTDynamicShapeGNTest, test_trt_dynamic_shape_groupnorm) {
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
float *bias = new float[c_];
float *scale = new float[c_];
for (int i = 0; i < c_; i++) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
......@@ -48,17 +49,13 @@ class TensorRTEngineTest : public ::testing::Test {
.get());
ctx_->PartialInitWithAllocator();
engine_ = new TensorRTEngine(10, 1 << 10);
TensorRTEngine::ConstructionParams params;
params.max_batch_size = 10;
params.max_workspace_size = 1 << 10;
engine_ = std::make_unique<TensorRTEngine>(params);
engine_->InitNetwork();
}
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
}
void PrepareInputOutput(const std::vector<float> &input,
std::vector<int> output_shape) {
paddle::framework::TensorFromVector(input, *ctx_, &input_);
......@@ -72,7 +69,7 @@ class TensorRTEngineTest : public ::testing::Test {
protected:
phi::DenseTensor input_;
phi::DenseTensor output_;
TensorRTEngine *engine_;
std::unique_ptr<TensorRTEngine> engine_;
phi::GPUContext *ctx_;
};
......@@ -111,15 +108,6 @@ TEST_F(TensorRTEngineTest, add_layer) {
buffers[0] = reinterpret_cast<void *>(x_v_gpu_data);
buffers[1] = reinterpret_cast<void *>(y_gpu_data);
LOG(INFO) << "Set attr";
engine_->Set("test_attr", new std::string("test_attr"));
if (engine_->Has("test_attr")) {
auto attr_val = engine_->Get<std::string>("test_attr");
engine_->Erase("test_attr");
}
std::string *attr_key = new std::string("attr_key");
engine_->SetNotOwned("attr1", attr_key);
LOG(INFO) << "to execute";
engine_->Execute(1, &buffers, ctx_->stream());
......@@ -128,8 +116,6 @@ TEST_F(TensorRTEngineTest, add_layer) {
LOG(INFO) << "to checkout output";
ASSERT_EQ(y_cpu[0], x_v[0] * 2 + 3);
delete attr_key;
}
TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
......
......@@ -43,7 +43,7 @@ class TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 {
const platform::Place place);
explicit TRTInt8Calibrator(const std::string& calibration_data);
~TRTInt8Calibrator();
~TRTInt8Calibrator() override;
int getBatchSize() const TRT_NOEXCEPT override;
......@@ -91,7 +91,7 @@ class TRTCalibratorEngine {
*/
class TRTCalibratorEngineManager {
public:
bool Has() const { return res_.size() > 0; }
bool Has() const { return !res_.empty(); }
bool Has(const std::string& name) const {
if (res_.count(name) == 0) return false;
return res_.at(name).get() != nullptr;
......
......@@ -14,30 +14,34 @@
#pragma once
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#ifdef PADDLE_WITH_CUDA
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/fluid/framework/data_device_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/utils/string/string_helper.h"
namespace paddle {
......@@ -171,10 +175,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
std::string model_opt_cache_dir_;
bool use_static_engine_;
phi::DataType precision_mode_;
std::map<std::string, std::vector<int>> min_input_shape_{};
std::map<std::string, std::vector<int>> max_input_shape_{};
std::map<std::string, std::vector<int>> opt_input_shape_{};
phi::DataType model_precision_{phi::DataType::FLOAT32};
public:
TensorRTEngineOp(const std::string &type,
......@@ -185,7 +185,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
input_names_ = Inputs("Xs");
max_batch_size_ = Attr<int>("max_batch_size");
workspace_size_ = Attr<int64_t>("workspace_size");
device_id_ = Attr<int>("gpu_id");
device_id_ = Attr<int>("gpu_device_id");
enable_int8_ = Attr<bool>("enable_int8");
enable_fp16_ = Attr<bool>("enable_fp16");
use_calib_mode_ = Attr<bool>("use_calib_mode");
......@@ -200,43 +200,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (use_static_engine_) {
model_opt_cache_dir_ = Attr<std::string>("model_opt_cache_dir");
}
model_precision_ = static_cast<phi::DataType>(Attr<int>("model_precision"));
if (HasAttr("dynamic_shape_names") && HasAttr("min_input_shape") &&
HasAttr("max_input_shape") && HasAttr("opt_input_shape")) {
std::vector<std::string> dynamic_shape_names;
std::vector<std::vector<int>> min_input_shapes;
std::vector<std::vector<int>> max_input_shapes;
std::vector<std::vector<int>> opt_input_shapes;
std::vector<int> dynamic_shape_lens;
dynamic_shape_names =
Attr<std::vector<std::string>>("dynamic_shape_names");
std::vector<int> min_shapes = Attr<std::vector<int>>("min_input_shape");
std::vector<int> max_shapes = Attr<std::vector<int>>("max_input_shape");
std::vector<int> opt_shapes = Attr<std::vector<int>>("opt_input_shape");
dynamic_shape_lens = Attr<std::vector<int>>("dynamic_shape_lens");
int idx = 0;
for (size_t i = 0; i < dynamic_shape_lens.size(); ++i) {
std::vector<int> tmp1, tmp2, tmp3;
for (int j = 0; j < dynamic_shape_lens[i]; ++j) {
tmp1.push_back(min_shapes[idx]);
tmp2.push_back(max_shapes[idx]);
tmp3.push_back(opt_shapes[idx++]);
}
min_input_shapes.emplace_back(tmp1);
max_input_shapes.emplace_back(tmp2);
opt_input_shapes.emplace_back(tmp3);
}
for (size_t i = 0; i < dynamic_shape_names.size(); ++i) {
min_input_shape_.insert(
std::make_pair(dynamic_shape_names[i], min_input_shapes[i]));
max_input_shape_.insert(
std::make_pair(dynamic_shape_names[i], max_input_shapes[i]));
opt_input_shape_.insert(
std::make_pair(dynamic_shape_names[i], opt_input_shapes[i]));
}
}
auto params = Attr<std::vector<std::string>>("parameters");
for (const auto &param : params) {
......@@ -249,11 +212,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
// calibration_mode is true represents we need to
// generate the calibration table data.
calibration_mode_ =
(enable_int8_ && calibration_data_.size() == 0 && use_calib_mode_);
(enable_int8_ && calibration_data_.empty() && use_calib_mode_);
VLOG(4) << "calibration_mode: " << calibration_mode_;
if (enable_int8_ && calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
if (enable_int8_ && !calibration_data_.empty()) {
calibrator_ = std::make_unique<TRTInt8Calibrator>(calibration_data_);
}
bool has_engine =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
......@@ -486,36 +449,29 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto t_shape = phi::vectorize(t.dims());
runtime_batch = t_shape[0];
}
calib_res->calib_.reset(new TRTInt8Calibrator(
calib_buffers, runtime_batch, calibration_engine_key_, dev_place));
calib_res->calib_ = std::make_unique<TRTInt8Calibrator>(
calib_buffers, runtime_batch, calibration_engine_key_, dev_place);
calib_res->thr_.reset(new std::thread([&]() {
std::map<std::string, std::vector<int>> min_input_shape;
std::map<std::string, std::vector<int>> max_input_shape;
std::map<std::string, std::vector<int>> opt_input_shape;
std::map<std::string, std::vector<int>> min_shape_tensor;
std::map<std::string, std::vector<int>> max_shape_tensor;
std::map<std::string, std::vector<int>> opt_shape_tensor;
if (shape_range_info_path_.size())
TensorRTEngine::ConstructionParams params;
params.max_batch_size = max_batch_size_;
params.max_workspace_size = workspace_size_;
params.precision = precision_mode_;
params.calibrator = calib_res->calib_.get();
params.device_id = dev_place.device;
params.with_dynamic_shape = with_dynamic_shape_;
if (!shape_range_info_path_.empty()) {
inference::DeserializeShapeRangeInfo(shape_range_info_path_,
&min_input_shape,
&max_input_shape,
&opt_input_shape,
&min_shape_tensor,
&max_shape_tensor,
&opt_shape_tensor);
calib_res->engine_.reset(new TensorRTEngine(max_batch_size_,
workspace_size_,
precision_mode_,
calib_res->calib_.get(),
dev_place.device,
with_dynamic_shape_,
min_input_shape,
max_input_shape,
opt_input_shape,
min_shape_tensor,
max_shape_tensor,
opt_shape_tensor));
&params.min_input_shape,
&params.max_input_shape,
&params.optim_input_shape,
&params.min_shape_tensor,
&params.max_shape_tensor,
&params.optim_shape_tensor);
}
params.context_memory_sharing = Attr<bool>("context_memory_sharing");
params.enable_low_precision_io = Attr<bool>("enable_low_precision_io");
calib_res->engine_ = std::make_unique<TensorRTEngine>(params);
VLOG(3) << "start the calib trt engine thread";
PrepareTRTEngine(scope, calib_res->engine_.get());
}));
......@@ -597,7 +553,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
// This must be a zero dimension tensor.
// At present, we convert it to a 1D tensor to feed them into Trt.
if (t_shape.size() == 0) {
if (t_shape.empty()) {
PADDLE_ENFORCE_EQ(
t.numel(),
1UL,
......@@ -669,8 +625,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
t.numel() * sizeof(int),
nullptr);
} else if (t.dtype() == phi::DataType::INT64) {
auto int32_tensor = scope.FindVar(x + "_cast_to_INT32")
->GetMutable<phi::DenseTensor>();
std::string x_t = x + "_cast_to_INT32";
if (scope.FindVar(x_t) == nullptr) {
const_cast<framework::Scope *>(&scope)->Var(x_t);
}
auto int32_tensor =
scope.FindVar(x_t)->GetMutable<phi::DenseTensor>();
*int32_tensor = phi::Cast<int64_t>(
reinterpret_cast<const phi::GPUContext &>(dev_ctx),
t,
......@@ -703,16 +663,22 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (t.dtype() == phi::DataType::FLOAT32) {
buffers[bind_index] = static_cast<void *>(t.data<float>());
} else if (t.dtype() == phi::DataType::FLOAT64) {
auto fp32_tensor =
scope.FindVar(x + "_cast_to_FP32")->GetMutable<phi::DenseTensor>();
std::string x_t = x + "_cast_to_FP32";
if (scope.FindVar(x_t) == nullptr) {
const_cast<framework::Scope *>(&scope)->Var(x_t);
}
auto fp32_tensor = scope.FindVar(x_t)->GetMutable<phi::DenseTensor>();
*fp32_tensor = phi::Cast<double>(
reinterpret_cast<const phi::GPUContext &>(dev_ctx),
t,
phi::DataType::FLOAT32);
buffers[bind_index] = static_cast<void *>(fp32_tensor->data<float>());
} else if (t.dtype() == phi::DataType::INT64) {
auto int32_tensor =
scope.FindVar(x + "_cast_to_INT32")->GetMutable<phi::DenseTensor>();
std::string x_t = x + "_cast_to_INT32";
if (scope.FindVar(x_t) == nullptr) {
const_cast<framework::Scope *>(&scope)->Var(x_t);
}
auto int32_tensor = scope.FindVar(x_t)->GetMutable<phi::DenseTensor>();
*int32_tensor = phi::Cast<int64_t>(
reinterpret_cast<const phi::GPUContext &>(dev_ctx),
t,
......@@ -827,8 +793,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto y = Outputs("Ys")[i];
auto *fluid_v = scope.FindVar(y);
auto *fluid_t = fluid_v->GetMutable<phi::DenseTensor>();
auto int32_tensor =
scope.FindVar(y + "_cast_to_INT64")->GetMutable<phi::DenseTensor>();
std::string y_t = y + "_cast_to_INT64";
if (scope.FindVar(y_t) == nullptr) {
const_cast<framework::Scope *>(&scope)->Var(y_t);
}
auto int32_tensor = scope.FindVar(y_t)->GetMutable<phi::DenseTensor>();
int32_tensor->Resize(fluid_t->dims());
dev_ctx.Alloc<int32_t>(int32_tensor);
framework::TensorCopy(*fluid_t, dev_place, dev_ctx, int32_tensor);
......@@ -840,8 +809,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
auto y = Outputs("Ys")[i];
auto *fluid_v = scope.FindVar(y);
auto *fluid_t = fluid_v->GetMutable<phi::DenseTensor>();
auto fp32_tensor =
scope.FindVar(y + "_cast_to_FP64")->GetMutable<phi::DenseTensor>();
std::string y_t = y + "_cast_to_FP64";
if (scope.FindVar(y_t) == nullptr) {
const_cast<framework::Scope *>(&scope)->Var(y_t);
}
auto fp32_tensor = scope.FindVar(y_t)->GetMutable<phi::DenseTensor>();
fp32_tensor->Resize(fluid_t->dims());
dev_ctx.Alloc<float>(fp32_tensor);
framework::TensorCopy(*fluid_t, dev_place, dev_ctx, fp32_tensor);
......@@ -856,20 +828,92 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const {
if (!trt_engine_) {
TensorRTEngine::ConstructionParams params;
params.max_batch_size = max_batch_size_;
params.max_workspace_size = workspace_size_;
params.precision = precision_mode_;
params.calibrator = calibrator_.get();
params.device_id = dev_place.device;
params.with_dynamic_shape = with_dynamic_shape_;
params.context_memory_sharing = Attr<bool>("context_memory_sharing");
params.use_dla = Attr<bool>("use_dla");
params.dla_core = Attr<int>("dla_core");
params.disable_trt_plugin_fp16 = Attr<bool>("disable_trt_plugin_fp16");
params.enable_low_precision_io = Attr<bool>("enable_low_precision_io");
params.use_inspector = Attr<bool>("use_inspector");
if (!shape_range_info_path_.empty()) {
inference::DeserializeShapeRangeInfo(shape_range_info_path_,
&params.min_input_shape,
&params.max_input_shape,
&params.optim_input_shape,
&params.min_shape_tensor,
&params.max_shape_tensor,
&params.optim_shape_tensor);
} else {
if (HasAttr("dynamic_shape_names") &&
HasAttr("min_input_shape_vector") &&
HasAttr("max_input_shape_vector") &&
HasAttr("opt_input_shape_vector")) {
std::vector<std::string> dynamic_shape_names;
std::vector<std::vector<int>> min_input_shapes;
std::vector<std::vector<int>> max_input_shapes;
std::vector<std::vector<int>> opt_input_shapes;
std::vector<int> dynamic_shape_lens;
dynamic_shape_names =
Attr<std::vector<std::string>>("dynamic_shape_names");
std::vector<int> min_shapes =
Attr<std::vector<int>>("min_input_shape_vector");
std::vector<int> max_shapes =
Attr<std::vector<int>>("max_input_shape_vector");
std::vector<int> opt_shapes =
Attr<std::vector<int>>("opt_input_shape_vector");
dynamic_shape_lens = Attr<std::vector<int>>("dynamic_shape_lens");
int idx = 0;
for (size_t i = 0; i < dynamic_shape_lens.size(); ++i) {
std::vector<int> tmp1, tmp2, tmp3;
for (int j = 0; j < dynamic_shape_lens[i]; ++j) {
tmp1.push_back(min_shapes[idx]);
tmp2.push_back(max_shapes[idx]);
tmp3.push_back(opt_shapes[idx++]);
}
min_input_shapes.emplace_back(tmp1);
max_input_shapes.emplace_back(tmp2);
opt_input_shapes.emplace_back(tmp3);
}
for (size_t i = 0; i < dynamic_shape_names.size(); ++i) {
params.min_input_shape.insert(
std::make_pair(dynamic_shape_names[i], min_input_shapes[i]));
params.max_input_shape.insert(
std::make_pair(dynamic_shape_names[i], max_input_shapes[i]));
params.optim_input_shape.insert(
std::make_pair(dynamic_shape_names[i], opt_input_shapes[i]));
}
}
}
trt_engine_ =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key_ + std::to_string(predictor_id_),
max_batch_size_,
workspace_size_,
precision_mode_,
calibrator_.get(),
device_id_,
with_dynamic_shape_,
min_input_shape_,
max_input_shape_,
opt_input_shape_);
PrepareTRTEngine(scope, trt_engine_);
.Create(engine_key_ + std::to_string(predictor_id_), params);
if (use_static_engine_) {
LOG(INFO) << "Load TRT Optimized Info from "
<< inference::analysis::GetTrtEngineSerializedPath(
model_opt_cache_dir_, engine_key_);
std::string trt_engine_serialized_data =
inference::analysis::GetTrtEngineSerializedData(
model_opt_cache_dir_, engine_key_);
trt_engine_->Deserialize(trt_engine_serialized_data);
} else {
// This brach mainly used to ut.
PrepareTRTEngine(scope, trt_engine_);
}
}
PADDLE_ENFORCE_NOT_NULL(
trt_engine_,
platform::errors::Fatal(
"The pointer to tensorrt engine should not be null."));
return trt_engine_;
}
};
......
......@@ -113,7 +113,7 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
ASSERT_EQ(block_->ops_size(), 2);
LOG(INFO) << "create tensorrt desc";
LOG(INFO) << "create tensorrt op desc";
framework::OpDesc engine_op_desc(nullptr);
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
......@@ -138,19 +138,18 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
engine_op_desc.SetAttr("gpu_device_id", device_id);
engine_op_desc.SetAttr("shape_range_info_path", std::string(""));
engine_op_desc.SetAttr("model_opt_cache_dir", std::string(""));
engine_op_desc.SetAttr("allow_build_at_runtime", allow_build_at_runtime);
engine_op_desc.SetAttr("use_static_engine", true);
engine_op_desc.SetAttr("dynamic_shape_names", std::vector<std::string>{"x"});
engine_op_desc.SetAttr("dynamic_shape_lens", std::vector<int>{4});
engine_op_desc.SetAttr("with_dynamic_shape", true);
engine_op_desc.SetAttr("min_input_shape", std::vector<int>{1, 1, 1, 1});
engine_op_desc.SetAttr("max_input_shape", std::vector<int>{16, 16, 16, 16});
engine_op_desc.SetAttr("opt_input_shape", std::vector<int>{2, 4, 4, 4});
engine_op_desc.SetAttr("model_precision",
static_cast<int>(phi::DataType::FLOAT32));
engine_op_desc.SetAttr("use_static_engine", false);
engine_op_desc.SetAttr("with_dynamic_shape", false);
engine_op_desc.SetAttr("context_memory_sharing", true);
engine_op_desc.SetAttr("disable_trt_plugin_fp16", false);
engine_op_desc.SetAttr("enable_low_precision_io", false);
engine_op_desc.SetAttr("use_inspector", false);
engine_op_desc.SetAttr("use_dla", false);
engine_op_desc.SetAttr("dla_core", 0);
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
......@@ -263,7 +262,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
{output_dim, output_dim},
{batch_size, output_dim});
LOG(INFO) << "create tensorrt desc";
LOG(INFO) << "create tensorrt op desc";
framework::OpDesc engine_op_desc(nullptr);
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
......@@ -288,11 +287,18 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
engine_op_desc.SetAttr("gpu_device_id", device_id);
engine_op_desc.SetAttr("shape_range_info_path", std::string(""));
engine_op_desc.SetAttr("model_opt_cache_dir", std::string(""));
engine_op_desc.SetAttr("allow_build_at_runtime", false);
engine_op_desc.SetAttr("use_static_engine", false);
engine_op_desc.SetAttr("with_dynamic_shape", false);
engine_op_desc.SetAttr("context_memory_sharing", true);
engine_op_desc.SetAttr("disable_trt_plugin_fp16", false);
engine_op_desc.SetAttr("enable_low_precision_io", false);
engine_op_desc.SetAttr("use_inspector", false);
engine_op_desc.SetAttr("use_dla", false);
engine_op_desc.SetAttr("dla_core", 0);
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
......
......@@ -174,6 +174,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_trt_inspector PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_inference_predictor PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_inference_fp16_io PROPERTIES TIMEOUT 300)
set_tests_properties(test_save_optimized_model_pass PROPERTIES TIMEOUT 300)
if(WITH_NV_JETSON)
set_tests_properties(
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
import paddle
from paddle.inference import Config, PrecisionType, create_predictor
from paddle.jit import to_static
from paddle.static import InputSpec
from paddle.vision.models import alexnet
class TestSaveOptimizedModelPass:
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
net = alexnet(True)
model = to_static(
net, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]
)
paddle.jit.save(
model, os.path.join(self.temp_dir.name, 'alexnet/inference')
)
def tearDown(self):
self.temp_dir.cleanup()
def get_baseline(self):
predictor = self.init_predictor(save_optimized_model=True)
inputs = [
paddle.to_tensor(0.1 * np.ones([1, 3, 224, 224]).astype(np.float32))
]
outputs = predictor.run(inputs)
return outputs[0]
def get_test_output(self):
predictor = self.init_predictor(save_optimized_model=False)
inputs = [
paddle.to_tensor(0.1 * np.ones([1, 3, 224, 224]).astype(np.float32))
]
outputs = predictor.run(inputs)
return outputs[0]
def test_output(self):
if paddle.is_compiled_with_cuda():
baseline = self.get_baseline()
test_output = self.get_test_output()
np.testing.assert_allclose(
baseline.numpy().flatten(),
test_output.numpy().flatten(),
)
class TestSaveOptimizedModelPassWithGPU(
TestSaveOptimizedModelPass, unittest.TestCase
):
def init_predictor(self, save_optimized_model: bool):
if save_optimized_model is True:
config = Config(
os.path.join(self.temp_dir.name, 'alexnet/inference.pdmodel'),
os.path.join(self.temp_dir.name, 'alexnet/inference.pdiparams'),
)
config.enable_use_gpu(256, 0, PrecisionType.Half)
config.enable_memory_optim()
config.switch_ir_optim(True)
config.set_optim_cache_dir(
os.path.join(self.temp_dir.name, 'alexnet')
)
config.enable_save_optim_model(True)
else:
config = Config(
os.path.join(self.temp_dir.name, 'alexnet/_optimized.pdmodel'),
os.path.join(
self.temp_dir.name, 'alexnet/_optimized.pdiparams'
),
)
config.enable_use_gpu(256, 0, PrecisionType.Half)
config.enable_memory_optim()
config.switch_ir_optim(False)
predictor = create_predictor(config)
return predictor
class TestSaveOptimizedModelPassWithTRT(
TestSaveOptimizedModelPass, unittest.TestCase
):
def init_predictor(self, save_optimized_model: bool):
if save_optimized_model is True:
config = Config(
os.path.join(self.temp_dir.name, 'alexnet/inference.pdmodel'),
os.path.join(self.temp_dir.name, 'alexnet/inference.pdiparams'),
)
config.enable_use_gpu(256, 0)
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=3,
precision_mode=PrecisionType.Half,
use_static=True,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{"x": [1, 3, 224, 224], "flatten_1.tmp_0": [1, 9216]},
{"x": [1, 3, 224, 224], "flatten_1.tmp_0": [1, 9216]},
{"x": [1, 3, 224, 224], "flatten_1.tmp_0": [1, 9216]},
)
config.exp_disable_tensorrt_ops(["flatten_contiguous_range"])
config.enable_memory_optim()
config.switch_ir_optim(True)
config.set_optim_cache_dir(
os.path.join(self.temp_dir.name, 'alexnet')
)
config.enable_save_optim_model(True)
else:
config = Config(
os.path.join(self.temp_dir.name, 'alexnet/_optimized.pdmodel'),
os.path.join(
self.temp_dir.name, 'alexnet/_optimized.pdiparams'
),
)
config.enable_use_gpu(256, 0)
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=3,
precision_mode=PrecisionType.Half,
use_static=True,
use_calib_mode=False,
)
config.enable_memory_optim()
config.switch_ir_optim(False)
predictor = create_predictor(config)
return predictor
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册