未验证 提交 34c95eaf 编写于 作者: F feng_shuai 提交者: GitHub

batch_norm_act_fuse_pass_init (#33636)

* batch_norm_act_fuse_pass_init

* repair the unittest of batch_norm_act
上级 6da6ff6a
......@@ -29,6 +29,55 @@ void FuseBatchNormActOneDNNPass::ApplyImpl(Graph *graph) const {
FuseBatchNormAct(graph, act_type);
}
FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() {
AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsOptional()
.End()
.AddOutput("VarianceOut")
.IsOptional()
.End()
.AddOutput("SavedMean")
.IsOptional()
.End()
.AddOutput("SavedVariance")
.IsOptional()
.End()
.AddOutput("ReserveSpace")
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
......@@ -45,6 +94,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse BatchNorm with ReLU activation op.";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
// BN output
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern);
// ACT output
......@@ -84,6 +138,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
bn_op->SetAttr("trainable_statistics", false);
bn_op->SetOutput("Y", {act_out->Name()});
if (!IsCompat(*bn_op)) {
LOG(WARNING) << "Fc fuse pass in out fc op compat failed.";
return;
}
IR_OP_VAR_LINK(batch_norm, act_out);
GraphSafeRemoveNodes(g, {act, bn_out});
found_bn_act_count++;
......
......@@ -31,6 +31,7 @@ namespace ir {
*/
class FuseBatchNormActOneDNNPass : public FusePassBase {
public:
FuseBatchNormActOneDNNPass();
virtual ~FuseBatchNormActOneDNNPass() {}
protected:
......
......@@ -32,6 +32,7 @@ void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true,
bn_op->SetAttr("is_test", is_test);
bn_op->SetAttr("trainable_statistics", trainable_stats);
bn_op->SetAttr("fuse_with_relu", false);
bn_op->SetAttr("epsilon", 0.001f);
}
}
......
......@@ -42,6 +42,10 @@ extra {
inputs {
name: "MomentumTensor"
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
......
......@@ -8,6 +8,10 @@ def {
}
}
extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册