“8a76baf02006c945fa4a2a01a58848cb38777697”上不存在“arch/powerpc/boot/Makefile”
未验证 提交 34c95eaf 编写于 作者: F feng_shuai 提交者: GitHub

batch_norm_act_fuse_pass_init (#33636)

* batch_norm_act_fuse_pass_init

* repair the unittest of batch_norm_act
上级 6da6ff6a
...@@ -29,6 +29,55 @@ void FuseBatchNormActOneDNNPass::ApplyImpl(Graph *graph) const { ...@@ -29,6 +29,55 @@ void FuseBatchNormActOneDNNPass::ApplyImpl(Graph *graph) const {
FuseBatchNormAct(graph, act_type); FuseBatchNormAct(graph, act_type);
} }
FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() {
AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsOptional()
.End()
.AddOutput("VarianceOut")
.IsOptional()
.End()
.AddOutput("SavedMean")
.IsOptional()
.End()
.AddOutput("SavedVariance")
.IsOptional()
.End()
.AddOutput("ReserveSpace")
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void FuseBatchNormActOneDNNPass::FuseBatchNormAct( void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
Graph *graph, const std::string &act_type) const { Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -45,6 +94,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct( ...@@ -45,6 +94,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "Fuse BatchNorm with ReLU activation op."; VLOG(4) << "Fuse BatchNorm with ReLU activation op.";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
// BN output // BN output
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern);
// ACT output // ACT output
...@@ -84,6 +138,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct( ...@@ -84,6 +138,11 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
bn_op->SetAttr("trainable_statistics", false); bn_op->SetAttr("trainable_statistics", false);
bn_op->SetOutput("Y", {act_out->Name()}); bn_op->SetOutput("Y", {act_out->Name()});
if (!IsCompat(*bn_op)) {
LOG(WARNING) << "Fc fuse pass in out fc op compat failed.";
return;
}
IR_OP_VAR_LINK(batch_norm, act_out); IR_OP_VAR_LINK(batch_norm, act_out);
GraphSafeRemoveNodes(g, {act, bn_out}); GraphSafeRemoveNodes(g, {act, bn_out});
found_bn_act_count++; found_bn_act_count++;
......
...@@ -31,6 +31,7 @@ namespace ir { ...@@ -31,6 +31,7 @@ namespace ir {
*/ */
class FuseBatchNormActOneDNNPass : public FusePassBase { class FuseBatchNormActOneDNNPass : public FusePassBase {
public: public:
FuseBatchNormActOneDNNPass();
virtual ~FuseBatchNormActOneDNNPass() {} virtual ~FuseBatchNormActOneDNNPass() {}
protected: protected:
......
...@@ -32,6 +32,7 @@ void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true, ...@@ -32,6 +32,7 @@ void SetBatchNormAttrs(OpDesc* bn_op, bool is_test = true,
bn_op->SetAttr("is_test", is_test); bn_op->SetAttr("is_test", is_test);
bn_op->SetAttr("trainable_statistics", trainable_stats); bn_op->SetAttr("trainable_statistics", trainable_stats);
bn_op->SetAttr("fuse_with_relu", false); bn_op->SetAttr("fuse_with_relu", false);
bn_op->SetAttr("epsilon", 0.001f);
} }
} }
......
...@@ -42,6 +42,10 @@ extra { ...@@ -42,6 +42,10 @@ extra {
inputs { inputs {
name: "MomentumTensor" name: "MomentumTensor"
} }
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs { attrs {
name: "is_test" name: "is_test"
type: BOOLEAN type: BOOLEAN
......
...@@ -8,6 +8,10 @@ def { ...@@ -8,6 +8,10 @@ def {
} }
} }
extra { extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs { attrs {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册