提交 22511f30 编写于 作者: W Wangzheee

support preln_ernie

上级 3b0eaee9
......@@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder";
constexpr char kEmbEltwiseLayernormPass[] =
"embedding_eltwise_layernorm_fuse_pass_flag";
constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag";
constexpr char kPrelnEmbEltwiseLayernormPass[] =
"preln_embedding_eltwise_layernorm_fuse_pass_flag";
class Pass {
public:
......
......@@ -364,7 +364,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, inner_pattern_out[k]);
IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, inner_pattern_out[k]);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
......
......@@ -53,7 +53,7 @@ struct PrelnSkipLayerNorm : public PatternBase {
PATTERN_DECL_NODE(layer_norm_variance);
};
void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
void PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
// Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add", "X");
y->assert_is_op_input("elementwise_add", "Y");
......@@ -61,8 +61,8 @@ void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X");
->assert_is_op_input("elementwise_add", "Y");
->assert_is_op_input("layer_norm", "X")
->assert_is_op_input("elementwise_add", "Y");
// Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
......
......@@ -377,8 +377,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine->SetDLACore(Get<int>("trt_dla_core"));
trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
(graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) ||
(graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)));
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册