未验证 提交 fba46ea3 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] Fix AI-Rank BERT emb_eltwise_layernorm input order (#32482)

* fix airank bert emb order

* move input num check to converter

* add input num check

* add unused var check white list
上级 2328921f
...@@ -290,10 +290,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -290,10 +290,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
ids.push_back(inner_pattern_ins[js[iter]].first->Name()); ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name()); embs.push_back(inner_pattern_ins[js[iter]].second->Name());
} }
OpDesc new_op_desc; OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm"); new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs); new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
new_op_desc.SetInput("SentId", {ids[2]});
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
new_op_desc.SetInput("SentEmbedding", {embs[2]});
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
......
...@@ -37,7 +37,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -37,7 +37,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
int place_id = section_config.place_id(); int place_id = section_config.place_id();
#if (defined PADDLE_WITH_NCCL) #if (defined PADDLE_WITH_NCCL)
place_ = platform::CUDAPlace(place_id); place_ = platform::CUDAPlace(place_id);
#elif (defined WITH_ASCEND_CL) #elif (defined WITH_ASCEND_CL) // NOLINT
place_ = platform::NPUPlace(place_id); place_ = platform::NPUPlace(place_id);
#endif #endif
worker_ = DeviceWorkerFactory::CreateDeviceWorker( worker_ = DeviceWorkerFactory::CreateDeviceWorker(
......
...@@ -53,27 +53,28 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() { ...@@ -53,27 +53,28 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
// Use pointer here for safe static deinitialization // Use pointer here for safe static deinitialization
static auto *allow_set = new std::unordered_set<std::string>({ static auto *allow_set = new std::unordered_set<std::string>({
// called once // called once
"batch_norm", // 0 "batch_norm", // 0
"batch_norm_grad", // 0 "batch_norm_grad", // 0
"sync_batch_norm", // 0 "sync_batch_norm", // 0
"sync_batch_norm_grad", // 0 "sync_batch_norm_grad", // 0
"inplace_abn", // 0 "inplace_abn", // 0
"inplace_abn_grad", // 0 "inplace_abn_grad", // 0
"dgc_momentum", // 0 "dgc_momentum", // 0
"fake_quantize_range_abs_max", // 0 "fake_quantize_range_abs_max", // 0
"rmsprop", // 0 "rmsprop", // 0
"sequence_conv_grad", // 0 "sequence_conv_grad", // 0
"roi_perspective_transform_grad", // 0 "roi_perspective_transform_grad", // 0
"fill_zeros_like", // 1 "fill_zeros_like", // 1
"fill_any_like", // 1 "fill_any_like", // 1
"nce_grad", // 1 "nce_grad", // 1
"precision_recall", // 1 "precision_recall", // 1
"fusion_seqpool_cvm_concat", // 2 "fusion_seqpool_cvm_concat", // 2
"fused_batch_norm_act", // 2 "fused_batch_norm_act", // 2
"fused_batch_norm_act_grad", // 2 "fused_batch_norm_act_grad", // 2
"data_norm", // 0 "data_norm", // 0
"data_norm_grad", // 0 "data_norm_grad", // 0
"update_loss_scaling", // 0 "update_loss_scaling", // 0
"fused_embedding_eltwise_layernorm", // 0
}); });
return *allow_set; return *allow_set;
} }
......
...@@ -34,8 +34,17 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -34,8 +34,17 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer"; VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto id_names = op_desc.Input("Ids"); auto word_id_name = op_desc.Input("WordId").front();
auto emb_names = op_desc.Input("Embs"); auto pos_id_name = op_desc.Input("PosId").front();
auto sent_id_name = op_desc.Input("SentId").front();
auto word_emb_name = op_desc.Input("WordEmbedding").front();
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
std::vector<std::string> id_names = {word_id_name, pos_id_name,
sent_id_name};
std::vector<std::string> emb_names = {word_emb_name, pos_emb_name,
sent_emb_name};
int input_num = id_names.size(); int input_num = id_names.size();
// Declare inputs // Declare inputs
...@@ -91,6 +100,12 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -91,6 +100,12 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
if (enable_int8) { if (enable_int8) {
output_fp16 = 1; output_fp16 = 1;
} }
PADDLE_ENFORCE_EQ(
input_num, 3,
platform::errors::InvalidArgument(
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d.",
input_num));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
output_fp16, 1, output_fp16, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -125,15 +140,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -125,15 +140,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_ptr->fields = fields.data(); plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(
engine_->network()->getInput(0)->getName())); // word_embedding, engine_->GetITensor(word_id_name)); // word_embedding,
// eval_placeholder_0 // eval_placeholder_0
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(
engine_->network()->getInput(1)->getName())); // sent_embedding, engine_->GetITensor(sent_id_name)); // sent_embedding,
// eval_placeholder_1 // eval_placeholder_1
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(
engine_->network()->getInput(2)->getName())); // cu_seqlens, engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2 // eval_placeholder_2
auto max_seqlen_tensor = auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName()); engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto* shuffle_layer = auto* shuffle_layer =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册