未验证 提交 4ddd595f 编写于 作者: W Wilber 提交者: GitHub

add compat check for skip_layernorm (#33505)

上级 34b79d94
......@@ -129,6 +129,11 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
return;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "skip_layernorm pass in op compat failed.";
return;
}
VLOG(4) << "handle SkipLayerNorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
......
......@@ -33,6 +33,49 @@ class Graph;
class SkipLayerNormFusePass : public FusePassBase {
public:
SkipLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
virtual ~SkipLayerNormFusePass() {}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册