From 930ca3f4950187131aa010f3ae88006a0fbf91ad Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 18 Jun 2021 21:26:17 +0800 Subject: [PATCH] pass enhance (#33661) --- .../fluid/framework/ir/conv_bn_fuse_pass.cc | 259 +++++++++++++++++- paddle/fluid/framework/ir/conv_bn_fuse_pass.h | 6 + .../fluid/framework/ir/pass_tester_helper.h | 29 +- .../fluid/operators/compat/batch_norm.pbtxt | 4 + paddle/fluid/operators/compat/conv2d.pbtxt | 8 + paddle/fluid/operators/compat/relu.pbtxt | 8 + 6 files changed, 307 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 9cc44c941ec..03a78ec3a21 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -140,6 +140,91 @@ void recompute_bias_and_weights(const Scope* scope, } } +ConvBNFusePass::ConvBNFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .End() + .AddInput("ResidualData") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("batch_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddInput("Mean") + .IsTensor() + .End() + .AddInput("Variance") + .IsTensor() + .End() + .AddOutput("MeanOut") + .IsTensor() + .End() + .AddOutput("VarianceOut") + .IsTensor() + .End() + .AddOutput("SavedMean") + .IsTensor() + .End() + .AddOutput("SavedVariance") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumLE(0.001f) + .IsNumGE(0.0f) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -161,8 +246,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { int found_conv_bn_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle " + conv_type() + "BN fuse"; - // conv, batch_norm, // conv_weight, conv_out, // bn_scale, bn_bias, bn_mean, bn_variance, @@ -236,6 +324,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { } conv->Op()->SetOutput("Output", std::vector<std::string>({bn_out->Name()})); + if (!IsCompat(*conv->Op())) { + LOG(WARNING) << "conv_bn fuse pass in out conv op compat failed."; + return; + } GraphSafeRemoveNodes( graph, {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, @@ -251,6 +343,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()})); desc.SetType("elementwise_add"); desc.SetAttr("axis", 1); + if (!IsCompat(desc)) { + LOG(WARNING) + << "conv_bn fuse pass in out elementwise_add op compat failed."; + return; + } auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance, @@ -269,6 +366,91 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_conv_bn_count); } +ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .End() + .AddInput("ResidualData") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("batch_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddInput("Mean") + .IsTensor() + .End() + .AddInput("Variance") + .IsTensor() + .End() + .AddOutput("MeanOut") + .IsTensor() + .End() + .AddOutput("VarianceOut") + .IsTensor() + .End() + .AddOutput("SavedMean") + .IsTensor() + .End() + .AddOutput("SavedVariance") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumLE(0.001f) + .IsNumGE(0.0f) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -290,8 +472,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { int found_conv_bn_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle " + conv_type() + "BN fuse"; - // conv, batch_norm, // conv_weight, conv_out, // bn_scale, bn_bias, bn_mean, bn_variance, @@ -361,7 +546,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { // Update the elementwise_add node eltwise->Op()->SetAttr("axis", 1); eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()})); - + if (!IsCompat(*eltwise->Op())) { + LOG(WARNING) + << "conv_eltwise_bn fuse pass in out eltwise op compat failed."; + return; + } GraphSafeRemoveNodes( graph, {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, @@ -377,6 +566,70 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_conv_bn_count); } +ConvTransposeBNFusePass::ConvTransposeBNFusePass() { + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); +} + +ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() { + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h index 342cd8dad5f..c78dfc2a487 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h @@ -31,6 +31,7 @@ class Graph; class ConvBNFusePass : public FusePassBase { public: + ConvBNFusePass(); virtual ~ConvBNFusePass() {} virtual std::string conv_type() const { return "conv2d"; } @@ -41,6 +42,7 @@ class ConvBNFusePass : public FusePassBase { class ConvEltwiseAddBNFusePass : public FusePassBase { public: + ConvEltwiseAddBNFusePass(); virtual ~ConvEltwiseAddBNFusePass() {} virtual std::string conv_type() const { return "conv2d"; } @@ -51,11 +53,15 @@ class ConvEltwiseAddBNFusePass : public FusePassBase { class ConvTransposeBNFusePass : public ConvBNFusePass { public: + ConvTransposeBNFusePass(); + virtual ~ConvTransposeBNFusePass() {} std::string conv_type() const { return "conv2d_transpose"; } }; class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { public: + ConvTransposeEltwiseAddBNFusePass(); + virtual ~ConvTransposeEltwiseAddBNFusePass() {} std::string conv_type() const { return "conv2d_transpose"; } }; diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 4b6068d4776..f5639e7bc9a 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -39,28 +39,49 @@ struct Layers { } VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias, - bool use_cudnn = false) { + int groups = 1, std::vector<int> strides = {1, 1}, + std::vector<int> paddings = {0, 0}, + std::string padding_algorithm = "EXPLICIT", + std::vector<int> dilations = {1, 1}, + std::string data_format = "NCHW", bool use_cudnn = false) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("conv2d"); op->SetInput("Input", {input->Name()}); op->SetInput("Filter", {filter->Name()}); op->SetInput("Bias", {bias->Name()}); - op->SetOutput("Out", {out->Name()}); + op->SetOutput("Output", {out->Name()}); op->SetAttr("use_cudnn", use_cudnn); + op->SetAttr("groups", groups); + op->SetAttr("strides", strides); + op->SetAttr("paddings", paddings); + op->SetAttr("padding_algorithm", padding_algorithm); + op->SetAttr("dilations", dilations); + op->SetAttr("data_format", data_format); op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast<int>(OpRole::kForward)); return out; } - VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias) { + VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias, + int groups = 1, std::vector<int> strides = {1, 1}, + std::vector<int> paddings = {0, 0}, + std::string padding_algorithm = "EXPLICIT", + std::vector<int> dilations = {1, 1}, + std::string data_format = "NCHW") { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("conv2d_transpose"); op->SetInput("Input", {input->Name()}); op->SetInput("Filter", {filter->Name()}); op->SetInput("Bias", {bias->Name()}); - op->SetOutput("Out", {out->Name()}); + op->SetOutput("Output", {out->Name()}); + op->SetAttr("groups", groups); + op->SetAttr("strides", strides); + op->SetAttr("paddings", paddings); + op->SetAttr("padding_algorithm", padding_algorithm); + op->SetAttr("dilations", dilations); + op->SetAttr("data_format", data_format); op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast<int>(OpRole::kForward)); return out; diff --git a/paddle/fluid/operators/compat/batch_norm.pbtxt b/paddle/fluid/operators/compat/batch_norm.pbtxt index ac2ccc6296c..ed6162fb91c 100644 --- a/paddle/fluid/operators/compat/batch_norm.pbtxt +++ b/paddle/fluid/operators/compat/batch_norm.pbtxt @@ -42,6 +42,10 @@ extra { inputs { name: "MomentumTensor" } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" type: BOOLEAN diff --git a/paddle/fluid/operators/compat/conv2d.pbtxt b/paddle/fluid/operators/compat/conv2d.pbtxt index ae4381bbc43..d8a08b6b410 100644 --- a/paddle/fluid/operators/compat/conv2d.pbtxt +++ b/paddle/fluid/operators/compat/conv2d.pbtxt @@ -41,6 +41,14 @@ def { } } extra { + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } + attrs { + name: "skip_quant" + type: BOOLEAN + } attrs { name: "is_test" type: BOOLEAN diff --git a/paddle/fluid/operators/compat/relu.pbtxt b/paddle/fluid/operators/compat/relu.pbtxt index bd0e9988010..271ed91718c 100644 --- a/paddle/fluid/operators/compat/relu.pbtxt +++ b/paddle/fluid/operators/compat/relu.pbtxt @@ -12,6 +12,14 @@ extra { name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" type: BOOLEAN } + attrs { + name: "out_threshold" + type: FLOAT + } + attrs { + name: "Out0_threshold" + type: FLOAT + } attrs { name: "use_mkldnn" type: BOOLEAN -- GitLab