未验证 提交 574f3402 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] fix pass and convert_op for preln_ernie (#39733)

* fix pass and convert_op for preln_ernie and add preln_ernie'flag in pass
上级 5595fdbb
......@@ -428,6 +428,19 @@ PrelnEmbeddingEltwiseLayerNormFusePass::
void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) {
VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, "
"use_oss, with_interleaved, with_dynamic_shape. Stop this pass, "
"please reconfig.";
return;
}
int fusion_count =
PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
......
......@@ -39,7 +39,6 @@ struct PrelnSkipLayerNorm : public PatternBase {
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(fused_skipe_layernorm);
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
......@@ -62,8 +61,13 @@ void PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X")
->assert_is_op_input("elementwise_add", "Y");
->assert_more([](Node *x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
// Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
......@@ -104,6 +108,18 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_skip_layernorm_fuse", graph);
bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) {
VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_oss, "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. ";
return;
}
int found_subgraph_count = 0;
GraphPatternDetector gpd;
......
......@@ -39,7 +39,6 @@ struct SkipLayerNorm : public PatternBase {
PDNode *operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(fused_skipe_layernorm);
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
......@@ -59,9 +58,10 @@ PDNode *SkipLayerNorm::operator()(PDNode *x, PDNode *y) {
y->assert_is_op_input("elementwise_add", "Y");
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add");
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_only_output_of_op("elementwise_add");
// Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
......
......@@ -54,6 +54,27 @@ void IRPassManager::CreatePasses(Argument *argument,
int pass_num = 0;
for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_oss", new bool(argument->tensorrt_use_oss()));
pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved()));
pass->Set("disable_logs", new bool(argument->disable_logs()));
auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("max_input_shape", new std::map<std::string, std::vector<int>>(
argument->max_input_shape()));
pass->Set("min_input_shape", new std::map<std::string, std::vector<int>>(
argument->min_input_shape()));
pass->Set("optim_input_shape", new std::map<std::string, std::vector<int>>(
argument->optim_input_shape()));
// tuned trt dynamic_shape
pass->Set("trt_tuned_dynamic_shape",
new bool(argument->tensorrt_tuned_dynamic_shape()));
bool with_dynamic_shape = (argument->max_input_shape().size() > 0 &&
argument->min_input_shape().size() > 0 &&
argument->optim_input_shape().size() > 0) ||
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
......@@ -99,17 +120,9 @@ void IRPassManager::CreatePasses(Argument *argument,
new int(argument->tensorrt_min_subgraph_size()));
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
pass->Set("predictor_id", new int(argument->predictor_id()));
bool use_calib_mode = argument->tensorrt_use_calib_mode();
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("use_calib_mode", new bool(use_calib_mode));
pass->Set("use_oss", new bool(argument->tensorrt_use_oss()));
pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved()));
pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode));
......@@ -161,22 +174,8 @@ void IRPassManager::CreatePasses(Argument *argument,
// tuned trt dynamic_shape
pass->Set("trt_shape_range_info_path",
new std::string(argument->tensorrt_shape_range_info_path()));
pass->Set("trt_tuned_dynamic_shape",
new bool(argument->tensorrt_tuned_dynamic_shape()));
pass->Set("trt_allow_build_at_runtime",
new bool(argument->tensorrt_allow_build_at_runtime()));
pass->Set("max_input_shape", new std::map<std::string, std::vector<int>>(
argument->max_input_shape()));
pass->Set("min_input_shape", new std::map<std::string, std::vector<int>>(
argument->min_input_shape()));
pass->Set("optim_input_shape",
new std::map<std::string, std::vector<int>>(
argument->optim_input_shape()));
bool with_dynamic_shape = (argument->max_input_shape().size() > 0 &&
argument->min_input_shape().size() > 0 &&
argument->optim_input_shape().size() > 0) ||
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
pass->Set("trt_disabled_ops", new std::vector<std::string>(
argument->tensorrt_disabled_ops()));
pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla()));
......@@ -192,14 +191,15 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *(&argument->main_program()));
}
if (pass_name == "lite_subgraph_pass") {
bool enable_int8 =
bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
pass->Set("lite_ops_filter",
new std::vector<std::string>(argument->lite_ops_filter()));
pass->Set("predictor_id", new int(argument->predictor_id()));
pass->Set("enable_int8", new bool(enable_int8));
pass->Erase("enable_int8");
pass->Set("enable_int8", new bool(lite_enable_int8));
pass->Set("use_gpu", new bool(argument->use_gpu()));
pass->Set("zero_copy", new bool(argument->lite_zero_copy()));
pass->Set("use_xpu", new bool(argument->use_xpu()));
......@@ -236,7 +236,6 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::vector<std::string>(
argument->nnadapter_model_cache_token()));
}
disable_logs_ = argument->disable_logs();
if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu()));
bool fc_mkldnn_pass = 0;
......@@ -248,9 +247,6 @@ void IRPassManager::CreatePasses(Argument *argument,
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
pass->Set("use_fc_padding", new bool(use_fc_padding));
}
pass->Set("disable_logs", new bool(disable_logs_));
pre_pass = pass_name;
passes_.emplace_back(std::move(pass));
......
......@@ -592,6 +592,14 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetModelParamsPath(config_.params_file());
}
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_oss_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
argument_.SetTensorRtTunedDynamicShape(
config_.tuned_tensorrt_dynamic_shape());
if (config_.use_gpu() && config_.tensorrt_engine_enabled()) {
LOG(INFO) << "TensorRT subgraph engine is enabled";
argument_.SetUseTensorRT(true);
......@@ -601,18 +609,10 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_.SetTensorRtUseDLA(config_.trt_use_dla_);
argument_.SetTensorRtDLACore(config_.trt_dla_core_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_oss_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
argument_.SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
argument_.SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path());
argument_.SetTensorRtTunedDynamicShape(
config_.tuned_tensorrt_dynamic_shape());
argument_.SetTensorRtAllowBuildAtRuntime(
config_.trt_allow_build_at_runtime());
argument_.SetTensorRtUseInspector(config_.trt_use_inspector_);
......
......@@ -51,21 +51,11 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
std::vector<std::string> id_names;
std::vector<std::string> emb_names;
id_names =
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
int input_num = id_names.size();
// Declare inputs
std::vector<nvinfer1::ITensor*> input_ids;
for (int i = 0; i < input_num; i++) {
input_ids.push_back(engine_->GetITensor(id_names[i]));
}
int input_num = emb_names.size();
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
......@@ -126,7 +116,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
{"bert_embeddings_position_embeddings", input_embs[1],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[1])},
{"output_int8", &output_int8, nvinfer1::PluginFieldType::kINT32, 1},
{"output_fp16", &output_int8, nvinfer1::PluginFieldType::kINT32, 1},
};
nvinfer1::PluginFieldCollection* plugin_ptr =
......@@ -156,7 +146,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
shuffle_layer->setReshapeDimensions(shape_dim);
shuffle_layer->setName(
("PrelnEmbeltwise_Shuffle_reshape (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
op_desc.Output("Out_0")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(
......@@ -170,7 +160,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("CustomPrelnEmbLayerNormPluginDynamic_V3(Output: " +
op_desc.Output("Out")[0] + ")")
op_desc.Output("Out_0")[0] + ")")
.c_str());
free(plugin_ptr);
float out_0_scale =
......
......@@ -92,8 +92,10 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
"fail to add CustomPrelnSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "preln_skip_layernorm", {output_name},
std::vector<std::string> output_names;
output_names.push_back(op_desc.Output("Out_0")[0]);
output_names.push_back(op_desc.Output("Out_1")[0]);
RreplenishLayerAndOutput(layer, "preln_skip_layernorm", {output_names},
test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册