提交 22511f30 编写于 作者: W Wangzheee

support preln_ernie

上级 3b0eaee9
...@@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder"; ...@@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder";
constexpr char kEmbEltwiseLayernormPass[] = constexpr char kEmbEltwiseLayernormPass[] =
"embedding_eltwise_layernorm_fuse_pass_flag"; "embedding_eltwise_layernorm_fuse_pass_flag";
constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag";
constexpr char kPrelnEmbEltwiseLayernormPass[] =
"preln_embedding_eltwise_layernorm_fuse_pass_flag";
class Pass { class Pass {
public: public:
......
...@@ -364,7 +364,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -364,7 +364,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]); IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, inner_pattern_out[k]); IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, inner_pattern_out[k]);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes; std::unordered_set<const Node*> marked_nodes;
......
...@@ -53,7 +53,7 @@ struct PrelnSkipLayerNorm : public PatternBase { ...@@ -53,7 +53,7 @@ struct PrelnSkipLayerNorm : public PatternBase {
PATTERN_DECL_NODE(layer_norm_variance); PATTERN_DECL_NODE(layer_norm_variance);
}; };
void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { void PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
// Create nodes for elementwise add op. // Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add", "X"); x->assert_is_op_input("elementwise_add", "X");
y->assert_is_op_input("elementwise_add", "Y"); y->assert_is_op_input("elementwise_add", "Y");
...@@ -61,8 +61,8 @@ void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { ...@@ -61,8 +61,8 @@ void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr()) auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X"); ->assert_is_op_input("layer_norm", "X")
->assert_is_op_input("elementwise_add", "Y"); ->assert_is_op_input("elementwise_add", "Y");
// Add links for elementwise_add op. // Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
......
...@@ -377,8 +377,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -377,8 +377,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine->SetDLACore(Get<int>("trt_dla_core")); trt_engine->SetDLACore(Get<int>("trt_dla_core"));
trt_engine->SetWithErnie( trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) && (graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)); graph->Has(framework::ir::kMultiheadMatmulPass)) ||
(graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)));
if (use_static_engine) { if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData( trt_engine_serialized_data = GetTrtEngineSerializedData(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册