未验证 提交 d1bbd900 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] add config.enable_low_precision_io api and remove rely on...

[Inference] add config.enable_low_precision_io api and remove rely on AnalysisConfig::Precison in trt (#52485)
上级 5ac8c040
......@@ -215,9 +215,16 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
"Cannot enable custom_device_mixed."));
#endif
}
skip_pass_ = backend_ == phi::Backend::UNDEFINED;
low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
if (Has("mixed_precision_mode")) {
low_precision_ =
static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
}
skip_pass_ = (backend_ == phi::Backend::UNDEFINED) ||
(low_precision_ == phi::DataType::UNDEFINED);
if (skip_pass_) return;
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
......@@ -226,8 +233,8 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
VLOG(4) << " - " << name;
}
if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types");
if (Has("enable_low_precision_io")) {
enable_low_precision_io_ = Get<bool>("enable_low_precision_io");
}
auto graph_size = graph->SubGraphsSize();
......@@ -290,8 +297,8 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
LOG(INFO) << "The number of ops run at low precision ["
<< op_run_low_precision_.size() << "/" << op_original_type_.size()
<< "]";
<< op_run_low_precision_.size() << "/"
<< op_original_type_.size() + 2 << "]";
}
void AutoMixedPrecisionPass::SetOpUniqueType() const {
......@@ -385,61 +392,68 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
bool support_low_precision = true;
if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") {
support_low_precision = !keep_io_types_;
support_low_precision = enable_low_precision_io_;
} else if (GetOpOriginalType(op_type) == "tensorrt_engine") {
auto enable_fp16 = op_node->Op()->GetAttrIfExists<bool>("enable_fp16");
auto enable_int8 = op_node->Op()->GetAttrIfExists<bool>("enable_int8");
auto low_precision_io =
op_node->Op()->GetAttrIfExists<bool>("enable_low_precision_io");
support_low_precision = enable_fp16 && !enable_int8 && low_precision_io;
} else {
support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
}
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision =
support_low_precision &&
IsFP32AndFP64(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision =
support_low_precision &&
IsFP32AndFP64(static_cast<VarType::Type>(out_dtype));
}
// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
if (low_precision_ == phi::DataType::FLOAT16) {
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
} else if (low_precision_ == phi::DataType::BFLOAT16) {
IsFP32AndFP64(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
IsFP32AndFP64(static_cast<VarType::Type>(out_dtype));
}
}
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
if (low_precision_ == phi::DataType::FLOAT16) {
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
} else if (low_precision_ == phi::DataType::BFLOAT16) {
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(
static_cast<phi::dtype::bfloat16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
}
}
support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
}
if (support_low_precision) {
......@@ -572,7 +586,12 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
bool AutoMixedPrecisionPass::InputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
if (GetOpOriginalType(op_desc->Type()) == "tensorrt_engine") {
auto vecs = op_desc->Input("Xs");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
......@@ -589,6 +608,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
......@@ -606,6 +634,16 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) ==
"fused_bias_dropout_residual_layer_norm") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
if (backend_ == phi::Backend::XPU) {
......@@ -805,7 +843,9 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "feed") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
if (op_node->Op()->HasAttr("sub_block") &&
GetOpOriginalType(op_type) != "tensorrt_engine")
continue;
VLOG(4) << "process op: " << op_type
<< " run low precision: " << op_run_low_precision_.count(op_type);
......
......@@ -68,9 +68,9 @@ class AutoMixedPrecisionPass : public FusePassBase {
private:
mutable bool skip_pass_{false};
mutable bool keep_io_types_{true};
mutable bool enable_low_precision_io_{false};
// float16 or bfloat16 now
mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::DataType low_precision_{phi::DataType::UNDEFINED};
mutable phi::Backend backend_{phi::Backend::UNDEFINED};
......
......@@ -19,7 +19,6 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
namespace paddle {
namespace framework {
......
......@@ -19,7 +19,6 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
namespace paddle {
namespace framework {
......
......@@ -34,7 +34,6 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/phi/common/data_type.h"
......@@ -225,9 +224,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
TensorRtDisabledOPs,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode,
TensorRtPrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, int);
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
TensorRtUseStaticEngine,
bool);
......@@ -263,9 +260,7 @@ struct Argument {
DlnneDisableNodesByOutputs,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool);
DECL_ARGUMENT_FIELD(dlnne_precision_mode,
DlnnePrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(dlnne_precision_mode, DlnnePrecisionMode, int);
using dlnne_input_shape_type = std::map<std::string, std::vector<int64_t>>;
DECL_ARGUMENT_FIELD(dlnne_input_shape_dict,
......@@ -277,9 +272,7 @@ struct Argument {
LitePassesFilter,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
DECL_ARGUMENT_FIELD(lite_precision_mode,
LitePrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(lite_precision_mode, LitePrecisionMode, int);
DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);
DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
......@@ -372,6 +365,7 @@ struct Argument {
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool);
// cinn compiler related
DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);
......
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
......@@ -60,8 +61,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("tensorrt_transformer_maskid",
new std::string(argument->tensorrt_transformer_maskid()));
pass->Set("disable_logs", new bool(argument->disable_logs()));
auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
auto trt_precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 =
trt_precision_mode == static_cast<int>(phi::DataType::INT8);
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("max_input_shape",
new std::map<std::string, std::vector<int>>(
......@@ -104,6 +106,8 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set("enable_low_precision_io",
new bool(argument->enable_low_precision_io()));
// "use_xpu" is used for passes in subgraphs.
pass->Set("use_xpu", new bool(argument->use_xpu()));
......@@ -161,8 +165,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("predictor_id", new int(argument->predictor_id()));
bool use_calib_mode = argument->tensorrt_use_calib_mode();
pass->Set("use_calib_mode", new bool(use_calib_mode));
pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode));
pass->Set("trt_precision_mode", new int(trt_precision_mode));
pass->Set("context_memory_sharing",
new bool(argument->trt_engine_memory_sharing()));
pass->Set("use_cuda_graph",
......@@ -242,8 +245,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::unordered_set<std::string>(
argument->dlnne_disable_nodes_by_outputs()));
pass->Set("use_calib_mode", new bool(argument->dlnne_use_calib_mode()));
pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode));
pass->Set("dlnne_precision_mode", new int(precision_mode));
pass->Set("input_shape_dict",
new std::map<std::string, std::vector<int64_t>>(
argument->dlnne_input_shape_dict()));
......@@ -254,8 +256,8 @@ void IRPassManager::CreatePasses(Argument *argument,
} else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
} else if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
bool lite_enable_int8 = argument->lite_precision_mode() ==
static_cast<int>(phi::DataType::INT8);
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
pass->Set("lite_ops_filter",
......
......@@ -572,9 +572,9 @@ void DlnneSubgraphPass::CreateDlnneOp(
// is unique.
auto engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(0));
auto precision_mode = Get<AnalysisConfig::Precision>("precision_mode");
auto precision_mode = Get<int>("dlnne_precision_mode");
bool enable_int8 = false;
if (precision_mode == AnalysisConfig::Precision::kInt8) {
if (precision_mode == static_cast<int>(phi::DataType::INT8)) {
enable_int8 = true;
}
auto use_calib_mode = Get<bool>("use_calib_mode");
......
......@@ -21,7 +21,6 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_util.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
namespace paddle {
namespace framework {
......
......@@ -20,10 +20,8 @@
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
......@@ -386,9 +384,10 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
graph_var_map[node->Name()] = node;
}
}
auto precision_mode = Get<AnalysisConfig::Precision>("precision_mode");
auto precision_mode = Get<int>("trt_precision_mode");
bool enable_fp16 = false;
if (precision_mode == AnalysisConfig::Precision::kHalf) enable_fp16 = true;
if (precision_mode == static_cast<int>(phi::DataType::FLOAT16))
enable_fp16 = true;
auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
......@@ -526,6 +525,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetAttr("use_inspector", Get<bool>("use_inspector"));
op_desc->SetAttr("model_precision", Get<int>("model_precision"));
op_desc->SetAttr("with_dynamic_shape", with_dynamic_shape);
op_desc->SetAttr("enable_low_precision_io",
Get<bool>("enable_low_precision_io"));
// we record all inputs' shapes in attr to check if they are consistent
// with the real inputs' shapes retrieved from scope when trt runs.
......@@ -643,7 +644,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
.Create(engine_key + std::to_string(predictor_id),
max_batch_size,
Get<int64_t>("workspace_size"),
precision_mode,
static_cast<phi::DataType>(precision_mode),
calibrator.get(),
Get<int>("gpu_device_id"),
with_dynamic_shape,
......@@ -668,6 +669,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
trt_engine->SetContextMemorySharing(Get<bool>("context_memory_sharing"));
trt_engine->SetLowPrecisionIO(Get<bool>("enable_low_precision_io"));
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
......
......@@ -21,7 +21,7 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_util.h"
namespace paddle {
namespace framework {
......
......@@ -100,7 +100,7 @@ void ConvertToMixedPrecisionPass::Run() {
}
pass.Set("mixed_black_list",
new std::unordered_set<std::string>{black_list_});
pass.Set("keep_io_types", new bool{keep_io_types_});
pass.Set("enable_low_precision_io", new bool{!keep_io_types_});
pass.Apply(main_graph_.get());
SaveMixedModel();
......
......@@ -105,17 +105,17 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
memory_pool_init_size_mb_ = memory_pool_init_size_mb;
FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
gpu_device_id_ = device_id;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
mixed_precision_mode_ = precision_mode;
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_mixed_ = true;
mixed_precision_mode_ = precision_mode;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Paddle-GPU inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableUseGpu or enable_use_gpu function."));
"The GPU inference currently only supports float32/float16/bfloat16 "
"precision. Please check the parameters you specified in EnableUseGpu "
"or enable_use_gpu function."));
}
#else
LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
......@@ -428,6 +428,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_low_precision_io_);
CP_MEMBER(enable_memory_optim_);
// TensorRT related.
......@@ -708,14 +709,13 @@ MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
return mkldnn_quantizer_config_.get();
}
void AnalysisConfig::EnableTensorRtEngine(
int64_t workspace_size,
int max_batch_size,
int min_subgraph_size,
AnalysisConfig::Precision precision_mode,
bool use_static,
bool use_calib_mode,
bool use_cuda_graph) {
void AnalysisConfig::EnableTensorRtEngine(int64_t workspace_size,
int max_batch_size,
int min_subgraph_size,
Precision precision_mode,
bool use_static,
bool use_calib_mode,
bool use_cuda_graph) {
#ifdef PADDLE_WITH_TENSORRT
if (!use_gpu()) {
LOG(ERROR) << "To use TensorRT engine, please call EnableUseGpu() first";
......@@ -766,6 +766,16 @@ void AnalysisConfig::EnableTensorRTMemoryOptim(bool engine_memory_sharing,
trt_engine_memory_sharing_identifier_ = sharing_identifier;
}
void AnalysisConfig::EnableLowPrecisionIO(bool x) {
PADDLE_ENFORCE_EQ(
enable_gpu_mixed_ || !x,
true,
platform::errors::InvalidArgument(
"To enable low precision io, please call EnableUseGPU() to specify "
"precision mode as low precision."));
enable_low_precision_io_ = x;
}
void AnalysisConfig::EnableDlnne(
int min_subgraph_size,
int max_batch_size,
......@@ -774,7 +784,7 @@ void AnalysisConfig::EnableDlnne(
std::unordered_set<std::string> disable_nodes_by_ouputs,
std::map<std::string, std::vector<int64_t>> dlnne_input_shape_dict,
bool use_calib_mode,
AnalysisConfig::Precision precision_mode) {
Precision precision_mode) {
use_dlnne_ = true;
dlnne_min_subgraph_size_ = min_subgraph_size;
dlnne_max_batchsize_ = max_batch_size;
......@@ -877,7 +887,7 @@ void AnalysisConfig::Update() {
if (use_tensorrt_) {
pass_builder()->ClearPasses();
for (const auto &pass : kTRTSubgraphPasses) {
if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
if (tensorrt_precision_mode_ == Precision::kInt8 &&
(pass == "conv_bn_fuse_pass")) {
continue;
}
......@@ -1191,7 +1201,7 @@ void AnalysisConfig::DisableGlogInfo() {
}
void AnalysisConfig::EnableLiteEngine(
AnalysisConfig::Precision precision_mode,
Precision precision_mode,
bool zero_copy,
const std::vector<std::string> &passes_filter,
const std::vector<std::string> &ops_filter) {
......@@ -1258,8 +1268,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
if (use_tensorrt_) {
#ifdef PADDLE_WITH_TENSORRT
auto Precision2String =
[](paddle::AnalysisConfig::Precision prec) -> std::string {
auto Precision2String = [](Precision prec) -> std::string {
if (prec == Precision::kFloat32)
return "fp32";
else if (prec == Precision::kHalf)
......
......@@ -1371,7 +1371,8 @@ void AnalysisPredictor::PrepareArgument() {
// For JITLayer
argument_->SetSkipLoadParams(config_.skip_load_params_);
argument_->SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_->SetTensorRtPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.tensorrt_precision_mode_)));
argument_->SetTensorRtUseOSS(config_.trt_use_varseqlen_);
argument_->SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_->SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_);
......@@ -1412,14 +1413,16 @@ void AnalysisPredictor::PrepareArgument() {
config_.dlnne_disable_nodes_by_outputs_);
argument_->SetDlnneInputShapeDict(config_.dlnne_input_shape_dict_);
argument_->SetDlnneUseCalibMode(config_.dlnne_use_calib_mode_);
argument_->SetDlnnePrecisionMode(config_.dlnne_precision_mode_);
argument_->SetDlnnePrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.dlnne_precision_mode_)));
}
argument_->SetUseXpu(config_.use_xpu_);
if (config_.lite_engine_enabled()) {
argument_->SetCpuMathLibraryNumThreads(
config_.cpu_math_library_num_threads());
argument_->SetLitePrecisionMode(config_.lite_precision_mode_);
argument_->SetLitePrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.lite_precision_mode_)));
argument_->SetLitePassesFilter(config_.lite_passes_filter_);
argument_->SetLiteOpsFilter(config_.lite_ops_filter_);
argument_->SetLiteZeroCopy(config_.lite_zero_copy_);
......@@ -1561,18 +1564,18 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
LOG(INFO) << "This model run in GPU mixed precision mode with no ir "
"optimization.";
} else {
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed.";
LOG(INFO)
<< "Ir optimization is turned off, no ir pass will be executed.";
}
} else {
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (config_.enable_gpu_mixed_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
LOG(INFO) << "This model run in GPU mixed precision mode.";
}
}
......@@ -1595,6 +1598,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_->SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
argument_->SetEnableLowPrecisionIO(config_.enable_low_precision_io_);
}
// NOTE All the members in AnalysisConfig should be copied to Argument.
......
......@@ -544,6 +544,13 @@ struct PD_INFER_DECL AnalysisConfig {
///
bool use_feed_fetch_ops_enabled() const { return use_feed_fetch_ops_; }
///
/// \brief Turn on the feed and fetch data with low precision.
///
/// \param x Whether to enable feed and fetch data with low precision.
///
void EnableLowPrecisionIO(bool x = true);
///
/// \brief Control whether to specify the inputs' names.
/// The ZeroCopyTensor type has a name member, assign it with the
......@@ -748,6 +755,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool tensorrt_dla_enabled() { return trt_use_dla_; }
void EnableTensorRtInspector();
bool tensorrt_inspector_enabled() { return trt_use_inspector_; }
void EnableDlnne(
......@@ -758,7 +766,8 @@ struct PD_INFER_DECL AnalysisConfig {
std::unordered_set<std::string> disable_nodes_by_outputs = {},
std::map<std::string, std::vector<int64_t>> input_dict = {},
bool use_calib_mode = false,
AnalysisConfig::Precision precision_mode = Precision::kFloat32);
Precision precision_mode = Precision::kFloat32);
bool dlnne_enabled() const { return use_dlnne_; }
///
......@@ -768,11 +777,10 @@ struct PD_INFER_DECL AnalysisConfig {
/// \param passes_filter Set the passes used in Lite sub-graph engine.
/// \param ops_filter Operators not supported by Lite.
///
void EnableLiteEngine(
AnalysisConfig::Precision precision_mode = Precision::kFloat32,
bool zero_copy = false,
const std::vector<std::string>& passes_filter = {},
const std::vector<std::string>& ops_filter = {});
void EnableLiteEngine(Precision precision_mode = Precision::kFloat32,
bool zero_copy = false,
const std::vector<std::string>& passes_filter = {},
const std::vector<std::string>& ops_filter = {});
///
/// \brief Turn on the usage of Lite sub-graph engine with opencl.
......@@ -1066,6 +1074,7 @@ struct PD_INFER_DECL AnalysisConfig {
// Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_;
bool enable_low_precision_io_{false};
// GPU related.
bool use_gpu_{false};
......
......@@ -150,7 +150,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"conv_elementwise_add2_act_fuse_pass", //
#endif
#endif
"transpose_flatten_concat_fuse_pass",
"transpose_flatten_concat_fuse_pass", //
"auto_mixed_precision_pass",
});
const std::vector<std::string> kDlnneSubgraphPasses({
......
......@@ -278,7 +278,7 @@ void PD_EnableDlnne(
std::unordered_set<std::string> disable_nodes_by_ouputs,
std::map<std::string, std::vector<int64_t>> dlnne_input_shape_dict,
bool use_calib_mode,
AnalysisConfig::Precision precision_mode) {
PD_ACPrecision precision_mode) {
PADDLE_ENFORCE_NOT_NULL(
config,
paddle::platform::errors::InvalidArgument(
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace inference {
......@@ -68,7 +69,7 @@ class CAllReduceOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
......
......@@ -13,6 +13,7 @@ the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace inference {
......@@ -26,7 +27,7 @@ class CrossMultiheadMatMulOpConverter : public OpConverter {
VLOG(3) << "convert a cross_multihead_mamul op to a corresponding tensorrt "
"network structure";
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
PADDLE_ENFORCE_EQ(
......
......@@ -13,6 +13,7 @@ the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace inference {
......@@ -27,7 +28,7 @@ class FlashMultiheadMatMulOpConverter : public OpConverter {
"network structure";
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
PADDLE_ENFORCE_EQ(
......
......@@ -52,7 +52,7 @@ class FusedTokenPruneOpConverter : public OpConverter {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
bool flag_varseqlen = engine_->use_varseqlen();
......
......@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace inference {
......@@ -41,8 +41,7 @@ class MatrixMultiplyOpConverter : public OpConverter {
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
bool enable_int8 =
(engine_->precision() == AnalysisConfig::Precision::kInt8);
bool enable_int8 = (engine_->precision() == phi::DataType::INT8);
float x_scale = 0;
float y_scale = 0;
float out_scale = 0;
......
......@@ -81,7 +81,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_->tensorrt_transformer_maskid() != "";
if (engine_->with_dynamic_shape()) {
if (engine_->tensorrt_transformer_maskid() != "" &&
engine_->precision() != AnalysisConfig::Precision::kFloat32 &&
engine_->precision() != phi::DataType::FLOAT32 &&
platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()) >=
75) {
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
......@@ -406,7 +406,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
assert(creator != nullptr);
int type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8 &&
(engine_->precision() == AnalysisConfig::Precision::kInt8)) {
(engine_->precision() == phi::DataType::INT8)) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
bool has_mask = true;
......@@ -488,7 +488,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
};
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
if (input_dims.d[1] <= 384 && !bias_qk_attr &&
engine_->precision() != AnalysisConfig::Precision::kFloat32 &&
engine_->precision() != phi::DataType::FLOAT32 &&
platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()) >=
75) {
/*
......@@ -860,7 +860,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
plugin::DynamicPluginTensorRT* plugin =
......
......@@ -178,7 +178,7 @@ class MultiheadMatMulRoformerOpConverter : public OpConverter {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
plugin::DynamicPluginTensorRT* plugin =
......
......@@ -277,7 +277,7 @@ class OpConverter {
}
}
// The scope here should be inited with the parameter vars.
// The scope here should be inited with the parameter vars.
void ConvertBlockToTRTEngine(
framework::BlockDesc* block_desc,
const framework::Scope& scope,
......@@ -286,9 +286,14 @@ class OpConverter {
const std::vector<std::string>& outputs,
TensorRTEngine* engine) {
engine->InitNetwork();
bool all_dynamic_shape_set = true;
for (auto& input : inputs) {
for (auto input : inputs) {
if (parameters.count(input)) continue;
// NOTE(liuyuanle): It is a trick. If you need a name [input], then you
// need to use [input.substr(0, idx)].
// Maybe we insert suffix of "_cast.tmp_" in auto_mixed_precision_pass.
auto idx = input.find("_cast.tmp_");
input = input.substr(0, idx);
auto* var = block_desc->FindVar(input);
PADDLE_ENFORCE_NOT_NULL(
var,
......@@ -299,6 +304,13 @@ class OpConverter {
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument("TensorRT engine only takes "
"LoDTensor as input"));
nvinfer1::DataType in_dtype = FluidDataType2TRT(var->GetDataType());
if (engine->WithFp16() && !engine->WithInt8() &&
in_dtype == nvinfer1::DataType::kFLOAT &&
engine->EnableLowPrecisionIO()) {
in_dtype = nvinfer1::DataType::kHALF;
}
auto var_shape = var->GetShape();
if (engine->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
......@@ -306,15 +318,7 @@ class OpConverter {
auto max_input_shape = engine->max_input_shape()[input];
auto optim_input_shape = engine->optim_input_shape()[input];
size_t ranks = min_input_shape.size();
// allow 0 dim for dynamic shape input
// if (ranks == 0) {
// all_dynamic_shape_set = false;
// LOG(INFO) << "trt input [" << input.c_str()
// << "] dynamic shape info not set, please check and
// retry.";
// // check other input
// continue;
// }
std::vector<int64_t> input_shape;
// input_shape.push_back(-1);
for (size_t i = 0; i < ranks; i++) {
......@@ -331,26 +335,14 @@ class OpConverter {
}
}
engine->DeclareInput(
input,
FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(input_shape, input, true));
input, in_dtype, Vec2TRT_Dims(input_shape, input, true));
#endif
} else {
engine->DeclareInput(
input,
FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var_shape, input));
VLOG(1) << "Set trt input [" << input << "] type is "
<< var->Proto()->type().lod_tensor().tensor().data_type();
engine->DeclareInput(input, in_dtype, Vec2TRT_Dims(var_shape, input));
}
VLOG(1) << "set trt engine input dtype " << static_cast<int>(in_dtype);
}
PADDLE_ENFORCE_EQ(all_dynamic_shape_set,
true,
platform::errors::InvalidArgument(
"some trt inputs dynamic shape info not set, "
"check the INFO log above for more details."));
framework::proto::BlockDesc* block_proto = block_desc->Proto();
ConvertBlock(*block_proto, parameters, scope, engine);
......@@ -365,12 +357,14 @@ class OpConverter {
FluidDT::VarType_Type_LOD_TENSOR,
platform::errors::InvalidArgument(
"The output tensor in TensorRT subgraph should be LoDTensor"));
engine->DeclareOutput(
output,
FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()));
VLOG(6) << "DeclareOutput(name: " << output << ", dtype: "
<< var->Proto()->type().lod_tensor().tensor().data_type() << ")";
nvinfer1::DataType out_dtype = FluidDataType2TRT(var->GetDataType());
if (engine->WithFp16() && !engine->WithInt8() &&
out_dtype == nvinfer1::DataType::kFLOAT &&
engine->EnableLowPrecisionIO()) {
out_dtype = nvinfer1::DataType::kHALF;
}
engine->DeclareOutput(output, out_dtype);
VLOG(1) << "set trt engine output dtype " << static_cast<int>(out_dtype);
}
engine->FreezeNetwork();
......
......@@ -26,7 +26,7 @@ class PadOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a transpose op to tensorrt tranpose layer";
VLOG(3) << "convert pad op to tensorrt IPaddingLayer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
......
......@@ -59,7 +59,7 @@ class PrelnResidualBiasOpConverter : public OpConverter {
int ele_bias_size = has_bias ? phi::product(ele_bias_dims) : 0;
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("ln_epsilon"));
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace inference {
......@@ -44,8 +45,7 @@ class SkipLayerNormOpConverter : public OpConverter {
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
bool enable_int8 =
(engine_->precision() == AnalysisConfig::Precision::kInt8);
bool enable_int8 = (engine_->precision() == phi::DataType::INT8);
float x_scale = 0;
float y_scale = 0;
......
......@@ -101,7 +101,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
engine_->tensorrt_transformer_maskid() != "";
if (engine_->with_dynamic_shape()) {
if (flag_varseqlen) {
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
if (engine_->precision() == phi::DataType::FLOAT32) {
PADDLE_THROW(platform::errors::Fatal(
"use use_varseqlen must be int8 or half, not float32."));
}
......@@ -258,7 +258,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
assert(creator != nullptr);
int type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8 &&
(engine_->precision() == AnalysisConfig::Precision::kInt8)) {
(engine_->precision() == phi::DataType::INT8)) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
bool has_mask = true;
......@@ -416,7 +416,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.push_back(fc_layer->getOutput(0));
plugin_inputs.push_back(input_bias_qk);
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
plugin::DynamicPluginTensorRT* plugin =
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/test_custom_op_plugin.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/common/data_type.h"
PD_BUILD_OP(custom_op)
.Inputs({"Input"})
......@@ -174,7 +175,7 @@ TEST(CustomPluginCreater, DynamicShapePlugin) {
engine_.reset(new TensorRTEngine(5,
1 << 15,
AnalysisConfig::Precision::kFloat32,
phi::DataType::FLOAT32,
nullptr,
0,
true,
......
......@@ -49,7 +49,7 @@ class UnaryOpConverter : public OpConverter {
org_type == nvinfer1::DataType::kINT32;
if (cast) {
layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input_tensor);
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
if (engine_->precision() == phi::DataType::FLOAT32) {
layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
} else {
layer->setOutputType(0, nvinfer1::DataType::kHALF);
......
......@@ -207,7 +207,7 @@ void TensorRTEngine::FreezeNetwork() {
infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
#endif
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
bool enable_fp16 = (precision_ == phi::DataType::FLOAT16);
if (enable_fp16) {
bool support_fp16 = infer_builder_->platformHasFastFp16();
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
......@@ -219,7 +219,7 @@ void TensorRTEngine::FreezeNetwork() {
}
}
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
bool enable_int8 = (precision_ == phi::DataType::INT8);
if (enable_int8) {
if (!use_dla_) {
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
......@@ -562,8 +562,8 @@ void TensorRTEngine::Deserialize(const std::string &engine_serialized_data) {
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
if (use_dla_) {
if (precision_ != AnalysisConfig::Precision::kInt8 &&
precision_ != AnalysisConfig::Precision::kHalf) {
if (precision_ != phi::DataType::INT8 &&
precision_ != phi::DataType::FLOAT16) {
LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
"set float32, so DLA is not used.";
} else if (runtime->getNbDLACores() == 0) {
......
......@@ -30,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
......@@ -276,22 +275,21 @@ class TensorRTEngine {
nvinfer1::Weights w_;
};
TensorRTEngine(
int max_batch,
int64_t max_workspace,
AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
TRTInt8Calibrator* calibrator = nullptr,
int device_id = 0,
bool with_dynamic_shape = false,
const ShapeMapType min_input_shape = {},
const ShapeMapType max_input_shape = {},
const ShapeMapType optim_input_shape = {},
const ShapeMapType min_shape_tensor = {},
const ShapeMapType max_shape_tensor = {},
const ShapeMapType optim_shape_tensor = {},
bool disable_trt_plugin_fp16 = false,
phi::DataType model_precision = phi::DataType::FLOAT32,
nvinfer1::ILogger& logger = NaiveLogger::Global())
TensorRTEngine(int max_batch,
int64_t max_workspace,
phi::DataType precision = phi::DataType::FLOAT32,
TRTInt8Calibrator* calibrator = nullptr,
int device_id = 0,
bool with_dynamic_shape = false,
const ShapeMapType& min_input_shape = {},
const ShapeMapType& max_input_shape = {},
const ShapeMapType& optim_input_shape = {},
const ShapeMapType& min_shape_tensor = {},
const ShapeMapType& max_shape_tensor = {},
const ShapeMapType& optim_shape_tensor = {},
bool disable_trt_plugin_fp16 = false,
phi::DataType model_precision = phi::DataType::FLOAT32,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
precision_(precision),
......@@ -395,7 +393,7 @@ class TensorRTEngine {
int GetRuntimeBatch();
bool WithFp16() {
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
bool enable_fp16 = (precision_ == phi::DataType::FLOAT16);
bool support_fp16 = infer_builder_->platformHasFastFp16();
// below is consistent with setFlag in engine.cc
bool fall_back_fp16 = WithInt8() && !use_dla_;
......@@ -403,7 +401,7 @@ class TensorRTEngine {
}
bool WithInt8() {
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
bool enable_int8 = (precision_ == phi::DataType::INT8);
bool support_int8 = infer_builder_->platformHasFastInt8();
return enable_int8 && support_int8;
}
......@@ -509,12 +507,12 @@ class TensorRTEngine {
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
ShapeMapType min_input_shape() { return min_input_shape_; }
ShapeMapType max_input_shape() { return max_input_shape_; }
ShapeMapType optim_input_shape() { return optim_input_shape_; }
ShapeMapType min_shape_tensor() { return min_shape_tensor_; }
ShapeMapType max_shape_tensor() { return max_shape_tensor_; }
ShapeMapType optim_shape_tensor() { return optim_shape_tensor_; }
ShapeMapType& min_input_shape() { return min_input_shape_; }
ShapeMapType& max_input_shape() { return max_input_shape_; }
ShapeMapType& optim_input_shape() { return optim_input_shape_; }
ShapeMapType& min_shape_tensor() { return min_shape_tensor_; }
ShapeMapType& max_shape_tensor() { return max_shape_tensor_; }
ShapeMapType& optim_shape_tensor() { return optim_shape_tensor_; }
bool AdjustDynamicShapeRange(const ShapeMapType& runtime_input_shape,
const ShapeMapType& runtime_shape_tensor,
......@@ -641,7 +639,7 @@ class TensorRTEngine {
}
bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; }
bool with_dynamic_shape() { return with_dynamic_shape_; }
AnalysisConfig::Precision precision() { return precision_; }
phi::DataType precision() { return precision_; }
#if IS_TRT_VERSION_GE(6000)
nvinfer1::IPluginV2Layer* AddDynamicPlugin(
......@@ -744,6 +742,12 @@ class TensorRTEngine {
context_memory_sharing_ = context_memory_sharing;
}
void SetLowPrecisionIO(bool low_precision_io) {
low_precision_io_ = low_precision_io;
}
bool EnableLowPrecisionIO() const { return low_precision_io_; }
void SetAllNodesLowerToTrt(bool all_nodes_offload_to_trt) {
// all nodes are in trt, so we can use cudaGraph to optimize runtime.
startup_with_cudagraph_ = all_nodes_offload_to_trt;
......@@ -764,7 +768,7 @@ class TensorRTEngine {
// the max memory size the engine uses
int64_t max_workspace_;
AnalysisConfig::Precision precision_;
phi::DataType precision_;
TRTInt8Calibrator* calibrator_;
// batch size of the current data, will be updated each Executation.
int batch_size_{-1};
......@@ -772,6 +776,8 @@ class TensorRTEngine {
// use for engine context memory sharing
bool context_memory_sharing_{false};
bool low_precision_io_{false};
int device_id_;
int max_profile_num_{1};
int cur_profile_num_{0};
......@@ -878,7 +884,7 @@ class TRTEngineManager {
std::string name,
int max_batch,
int64_t max_workspace,
AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
phi::DataType precision = phi::DataType::FLOAT32,
TRTInt8Calibrator* calibrator = nullptr,
int device_id = 0,
bool with_dynamic_shape = false,
......
......@@ -354,10 +354,10 @@ nvinfer1::DataType Pool3DPluginDynamic::getOutputDataType(
"The Pool3D Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT),
true,
platform::errors::InvalidArgument(
"The input type should be half or float"));
PADDLE_ENFORCE_EQ(
(input_types[0] == nvinfer1::DataType::kFLOAT),
true,
platform::errors::InvalidArgument("The input type should be float"));
return input_types[0];
}
......
......@@ -285,10 +285,10 @@ nvinfer1::DataType PoolPluginDynamic::getOutputDataType(
"The Pool Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT),
true,
platform::errors::InvalidArgument(
"The input type should be half or float"));
PADDLE_ENFORCE_EQ(
(input_types[0] == nvinfer1::DataType::kFLOAT),
true,
platform::errors::InvalidArgument("The input type should be float"));
return input_types[0];
}
......
......@@ -67,7 +67,7 @@ class TensorRTDynamicShapeValueEngineTest : public ::testing::Test {
{"shape", {18, 8, 4}}};
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kFloat32,
phi::DataType::FLOAT32,
nullptr,
0,
true,
......@@ -194,7 +194,7 @@ class TensorRTDynamicEngineTest : public ::testing::Test {
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kHalf,
phi::DataType::FLOAT16,
nullptr,
0,
true,
......@@ -372,7 +372,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kFloat32,
phi::DataType::FLOAT32,
nullptr,
0,
true,
......@@ -581,7 +581,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kHalf,
phi::DataType::FLOAT16,
nullptr,
0,
true,
......@@ -784,7 +784,7 @@ class TensorRTDynamicShapeGNTest : public ::testing::Test {
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kInt8,
phi::DataType::INT8,
nullptr,
0,
true,
......
......@@ -35,7 +35,7 @@ namespace tensorrt {
class TensorRTEngine;
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 {
class TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 {
public:
TRTInt8Calibrator(const std::unordered_map<std::string, size_t>& buffers,
int batch_size,
......
......@@ -14,7 +14,6 @@
#pragma once
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/place.h"
......@@ -171,7 +170,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
std::string shape_range_info_path_;
std::string model_opt_cache_dir_;
bool use_static_engine_;
AnalysisConfig::Precision precision_mode_;
phi::DataType precision_mode_;
std::map<std::string, std::vector<int>> min_input_shape_{};
std::map<std::string, std::vector<int>> max_input_shape_{};
std::map<std::string, std::vector<int>> opt_input_shape_{};
......@@ -265,12 +264,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_key_ + std::to_string(predictor_id_));
}
precision_mode_ = AnalysisConfig::Precision::kFloat32;
precision_mode_ = phi::DataType::FLOAT32;
if (enable_int8_) {
precision_mode_ = AnalysisConfig::Precision::kInt8;
precision_mode_ = phi::DataType::INT8;
}
if (enable_fp16_) {
precision_mode_ = AnalysisConfig::Precision::kHalf;
precision_mode_ = phi::DataType::FLOAT16;
}
}
......@@ -314,8 +313,15 @@ class TensorRTEngineOp : public framework::OperatorBase {
std::map<std::string, std::vector<int32_t>> runtime_input_shape;
std::map<std::string, std::vector<int32_t>> runtime_shape_tensor;
for (auto name : runtime_input_names_) {
auto &t =
inference::analysis::GetFromScope<phi::DenseTensor>(scope, name);
// NOTE(liuyuanle): It is a trick. If you need a [name], then you need
// to use [name.substr(0, idx)].
// Maybe we insert suffix of "_cast.tmp_" in auto_mixed_precision_pass.
std::string name_real = name;
auto idx = name.find("_cast.tmp_");
name = name.substr(0, idx);
auto &t = inference::analysis::GetFromScope<phi::DenseTensor>(
scope, name_real);
VLOG(4) << "trt engine runtime input name(" << name << "), dims("
<< t.dims() << ")";
auto t_shape = phi::vectorize<int32_t>(t.dims());
......@@ -378,7 +384,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine->min_input_shape();
std::map<std::string, std::vector<int>> max_input_shape =
trt_engine->max_input_shape();
for (auto &x : runtime_input_names_) {
for (auto x : runtime_input_names_) {
// NOTE(liuyuanle): It is a trick. If you need a [x], then you need
// to use [x.substr(0, idx)].
// Maybe we insert suffix of "_cast.tmp_" in
// auto_mixed_precision_pass.
auto idx = x.find("_cast.tmp_");
x = x.substr(0, idx);
PADDLE_ENFORCE_EQ(
min_input_shape.count(x),
true,
......@@ -544,14 +557,22 @@ class TensorRTEngineOp : public framework::OperatorBase {
binding_offset = engine->GetBindingsOffset();
}
// Bind input tensor to TRT.
for (const auto &x : runtime_input_names_) {
for (auto x : runtime_input_names_) {
// NOTE(liuyuanle): It is a trick. If you need a [x], then you need
// to use [x.substr(0, idx)].
// Maybe we insert suffix of "_cast.tmp_" in auto_mixed_precision_pass.
std::string x_real = x;
auto idx = x.find("_cast.tmp_");
x = x.substr(0, idx);
#if IS_TRT_VERSION_LT(8000)
// trt may remove input tensor if it's unused or used only at compile-time
if (engine->engine()->getBindingIndex(x.c_str()) < 0) continue;
#endif
// convert input and copy to TRT engine's buffer
auto &t = inference::analysis::GetFromScope<phi::DenseTensor>(scope, x);
auto &t =
inference::analysis::GetFromScope<phi::DenseTensor>(scope, x_real);
PADDLE_ENFORCE_GT(
t.numel(),
0,
......@@ -571,7 +592,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
t.ShareDataWith(out);
}
auto t_shape = phi::vectorize<int64_t>(t.dims());
// const int bind_index = engine->engine()->getBindingIndex(x.c_str());
// Get index of profile 0 first, then plus binding offset
const int bind_index =
engine->engine()->getBindingIndex(x.c_str()) + binding_offset;
......@@ -585,7 +605,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
"index=%d >= total inputs and outputs=%d",
bind_index,
num_bindings));
auto type = framework::TransToProtoVarType(t.dtype());
if (!engine->with_dynamic_shape()) {
// check if the input shapes are consistent with model.
if (HasAttr(x + "_shape")) {
......@@ -628,14 +647,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
if (engine->engine()->isShapeBinding(bind_index) &&
engine->engine()->bindingIsInput(bind_index)) {
std::vector<int> shape_v(t.numel());
if (type == framework::proto::VarType::INT32) {
if (t.dtype() == phi::DataType::INT32) {
paddle::memory::Copy(platform::CPUPlace(),
shape_v.data(),
t.place(),
t.data<int32_t>(),
t.numel() * sizeof(int),
nullptr);
} else if (type == framework::proto::VarType::INT64) {
} else if (t.dtype() == phi::DataType::INT64) {
auto int32_tensor = scope.FindVar(x + "_cast_to_INT32")
->GetMutable<phi::DenseTensor>();
*int32_tensor = phi::Cast<int64_t>(
......@@ -662,12 +681,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(indata_type,
intrt_type,
platform::errors::InvalidArgument(
"The TRT Engine OP's input type should equal "
"to the input data type"));
"The TRT Engine OP's input type [%d] should equal "
"to the input data type [%d].",
static_cast<int>(intrt_type),
static_cast<int>(indata_type)));
if (type == framework::proto::VarType::FP32) {
if (t.dtype() == phi::DataType::FLOAT32) {
buffers[bind_index] = static_cast<void *>(t.data<float>());
} else if (type == framework::proto::VarType::INT64) {
} else if (t.dtype() == phi::DataType::INT64) {
auto int32_tensor =
scope.FindVar(x + "_cast_to_INT32")->GetMutable<phi::DenseTensor>();
*int32_tensor = phi::Cast<int64_t>(
......@@ -676,12 +697,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
phi::DataType::INT32);
buffers[bind_index] =
static_cast<void *>(int32_tensor->data<int32_t>());
} else if (type == framework::proto::VarType::INT32) {
} else if (t.dtype() == phi::DataType::INT32) {
buffers[bind_index] = static_cast<void *>(t.data<int32_t>());
} else if (type == framework::proto::VarType::FP16) {
} else if (t.dtype() == phi::DataType::FLOAT16) {
buffers[bind_index] = static_cast<void *>(t.data<float16>());
#if IS_TRT_VERSION_GE(8400)
} else if (type == framework::proto::VarType::BOOL) {
} else if (t.dtype() == phi::DataType::BOOL) {
buffers[bind_index] = static_cast<void *>(t.data<bool>());
#endif
} else {
......
......@@ -833,6 +833,9 @@ void BindAnalysisConfig(py::module *m) {
&AnalysisConfig::SwitchSpecifyInputNames,
py::arg("x") = true)
.def("specify_input_name", &AnalysisConfig::specify_input_name)
.def("enable_low_precision_io",
&AnalysisConfig::EnableLowPrecisionIO,
py::arg("x") = true)
.def("enable_tensorrt_engine",
&AnalysisConfig::EnableTensorRtEngine,
py::arg("workspace_size") = 1 << 30,
......
......@@ -174,6 +174,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120)
set_tests_properties(test_trt_inspector PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_inference_predictor PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_inference_fp16_io PROPERTIES TIMEOUT 300)
if(WITH_NV_JETSON)
set_tests_properties(
......
......@@ -100,6 +100,7 @@ class TrtConvertFlattenTest_dim_2(TrtLayerAutoScanTest):
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
import paddle
from paddle.inference import Config, PrecisionType, create_predictor
from paddle.jit import to_static
from paddle.static import InputSpec
from paddle.vision.models import alexnet
class TestEnableLowPrecisionIO:
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
net = alexnet(True)
model = to_static(
net, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]
)
paddle.jit.save(
model, os.path.join(self.temp_dir.name, 'alexnet/inference')
)
def tearDown(self):
self.temp_dir.cleanup()
def get_fp32_output(self):
predictor = self.init_predictor(low_precision_io=False)
inputs = [
paddle.to_tensor(0.1 * np.ones([1, 3, 224, 224]).astype(np.float32))
]
outputs = predictor.run(inputs)
return outputs[0]
def get_fp16_output(self):
predictor = self.init_predictor(low_precision_io=True)
inputs = [
paddle.to_tensor(0.1 * np.ones([1, 3, 224, 224]).astype(np.float16))
]
outputs = predictor.run(inputs)
return outputs[0]
def test_output(self):
if paddle.is_compiled_with_cuda():
fp32_output = self.get_fp32_output()
fp16_output = self.get_fp16_output()
# np.testing.assert_allclose(
# fp32_output.numpy().flatten(),
# fp16_output.numpy().flatten(),
# )
class TestEnableLowPrecisionIOWithGPU(
TestEnableLowPrecisionIO, unittest.TestCase
):
def init_predictor(self, low_precision_io: bool):
config = Config(
os.path.join(self.temp_dir.name, 'alexnet/inference.pdmodel'),
os.path.join(self.temp_dir.name, 'alexnet/inference.pdiparams'),
)
config.enable_use_gpu(256, 0, PrecisionType.Half)
config.enable_memory_optim()
config.enable_low_precision_io(low_precision_io)
config.disable_glog_info()
predictor = create_predictor(config)
return predictor
class TestEnableLowPrecisionIOWithTRTAllGraph(
TestEnableLowPrecisionIO, unittest.TestCase
):
def init_predictor(self, low_precision_io: bool):
config = Config(
os.path.join(self.temp_dir.name, 'alexnet/inference.pdmodel'),
os.path.join(self.temp_dir.name, 'alexnet/inference.pdiparams'),
)
config.enable_use_gpu(256, 0, PrecisionType.Half)
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=3,
precision_mode=PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.enable_memory_optim()
config.enable_low_precision_io(low_precision_io)
config.disable_glog_info()
predictor = create_predictor(config)
return predictor
class TestEnableLowPrecisionIOWithTRTSubGraph(
TestEnableLowPrecisionIO, unittest.TestCase
):
def init_predictor(self, low_precision_io: bool):
config = Config(
os.path.join(self.temp_dir.name, 'alexnet/inference.pdmodel'),
os.path.join(self.temp_dir.name, 'alexnet/inference.pdiparams'),
)
config.enable_use_gpu(256, 0, PrecisionType.Half)
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=3,
precision_mode=PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.enable_memory_optim()
config.enable_low_precision_io(low_precision_io)
config.exp_disable_tensorrt_ops(["flatten_contiguous_range"])
config.disable_glog_info()
predictor = create_predictor(config)
return predictor
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册