未验证 提交 9d8d5317 编写于 作者: F feng_shuai 提交者: GitHub

fc_elementwise_layer_fuse_pass (#33467)

* fc_elementwise_layer_fuse_pass

* fc_ele_layernorm_pass

* fc_elementwise_layernorm_pass

* fc_elementwise_layernorm_pass_amend
上级 5cca9e4c
...@@ -136,6 +136,70 @@ static bool IsEqual(const std::vector<T> &x, const std::vector<T> &y) { ...@@ -136,6 +136,70 @@ static bool IsEqual(const std::vector<T> &x, const std::vector<T> &y) {
return true; return true;
} }
FCElementwiseLayerNormFusePass::FCElementwiseLayerNormFusePass() {
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsOptional()
.End()
.AddOutput("Variance")
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(-1)
.End();
}
void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph, PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -159,6 +223,11 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -159,6 +223,11 @@ void FCElementwiseLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
return; return;
} }
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle FCElementwiseLayerNorm fuse"; VLOG(4) << "handle FCElementwiseLayerNorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_w, fc_w, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_w, fc_w, fused_pattern);
......
...@@ -24,6 +24,7 @@ class Graph; ...@@ -24,6 +24,7 @@ class Graph;
class FCElementwiseLayerNormFusePass : public FusePassBase { class FCElementwiseLayerNormFusePass : public FusePassBase {
public: public:
FCElementwiseLayerNormFusePass();
virtual ~FCElementwiseLayerNormFusePass() {} virtual ~FCElementwiseLayerNormFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册