From 90805e2df7b6fcd0bf78e8fa10fcbe98ef74c936 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Mon, 16 Nov 2020 11:28:52 +0800 Subject: [PATCH] Register op_version for new attribute use_addto (#28463) * register op_version for addto * upgrade pass capability * change eq to le * change eq to le * fix merge --- .../ir/conv_affine_channel_fuse_pass.cc | 4 +- .../fluid/framework/ir/conv_bn_fuse_pass.cc | 4 +- .../ir/conv_elementwise_add2_act_fuse_pass.cc | 4 +- .../ir/conv_elementwise_add_act_fuse_pass.cc | 3 +- .../ir/conv_elementwise_add_fuse_pass.cc | 3 +- .../conv_activation_mkldnn_fuse_pass.cc | 10 +-- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 4 +- .../conv_concat_relu_mkldnn_fuse_pass.cc | 4 +- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 71 ++++++++++--------- .../ir/mkldnn/depthwise_conv_mkldnn_pass.cc | 4 +- .../ir/quant_conv2d_dequant_fuse_pass.cc | 5 +- .../ir_passes/tensorrt_subgraph_pass.cc | 8 ++- paddle/fluid/operators/conv_op.cc | 35 +++++++++ 13 files changed, 106 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index 9c984a23e37..c0ebf6de9de 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -238,11 +238,11 @@ REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass, REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("affine_channel", 0)); REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("affine_channel", 0)); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index a915015bf55..72ac7c3b0e8 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -383,11 +383,11 @@ REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("batch_norm", 0)); REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("batch_norm", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index ad6af69ae02..545beb34e78 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -119,7 +121,7 @@ REGISTER_PASS(conv_elementwise_add2_act_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("relu", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 93e6e13ff70..d01a2f26223 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h" + #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -107,7 +108,7 @@ REGISTER_PASS(conv_elementwise_add_act_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("relu", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index e4396f227f7..e34a2d96581 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h" + #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -93,5 +94,5 @@ REGISTER_PASS(conv_elementwise_add_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index c33398553ec..d0bdeb9ad8c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -107,7 +109,7 @@ REGISTER_PASS(conv_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("relu", 0)); REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, @@ -115,7 +117,7 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .LE("leaky_relu", 1)); REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, @@ -123,7 +125,7 @@ REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("relu6", 0)); REGISTER_PASS(conv_swish_mkldnn_fuse_pass, @@ -131,5 +133,5 @@ REGISTER_PASS(conv_swish_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("swish", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 716c49dcb12..b0849d74b61 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" + #include #include + #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -150,7 +152,7 @@ REGISTER_PASS(conv_bias_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index 76e10212550..c4d7a120372 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -128,6 +130,6 @@ REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("concat", 0) .EQ("relu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index 2fb131aceaa..a837b42b3ea 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -13,11 +13,13 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" + #include #include #include #include #include + #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_y, - elementwise_add_out); - }; + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_y, + elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( conv_output); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_x, - elementwise_add_out); - }; + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_x, + elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( conv_x_output->AsIntermediate(); conv_y_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); - return std::make_tuple(elementwise_add_op, elementwise_add_out); - }; + return std::make_tuple(elementwise_add_op, elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc index b2c0afdc754..39f47406a77 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc @@ -63,5 +63,5 @@ REGISTER_PASS(depthwise_conv_mkldnn_pass, paddle::framework::ir::DepthwiseConvMKLDNNPass); REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass) .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination().EQ( - "depthwise_conv2d", 0)); + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "depthwise_conv2d", 1)); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 895c396e1e6..96c5546d212 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h" + #include #include #include #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" -#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -331,7 +332,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass); REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("fc", 0) .LE("conv2d_transpose", 1) .EQ("fake_quantize_abs_max", 0) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 08f3d609fa3..bf0d87da91f 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" + #include #include #include @@ -20,7 +22,6 @@ #include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/inference/analysis/helper.h" -#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/op_teller.h" @@ -309,6 +310,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( min_input_shape, max_input_shape, opt_input_shape, disable_trt_plugin_fp16); trt_engine->SetUseOSS(Get("use_oss")); + trt_engine->SetWithErnie( graph->Has(framework::ir::kEmbEltwiseLayernormPass) && graph->Has(framework::ir::kMultiheadMatmulPass)); @@ -367,13 +369,13 @@ REGISTER_PASS(tensorrt_subgraph_pass, REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("pool2d", 0) .EQ("relu", 0) .EQ("softmax", 0) .EQ("sigmoid", 0) .EQ("hard_swish", 0) - .EQ("depthwise_conv2d", 0) + .LE("depthwise_conv2d", 1) .EQ("batch_norm", 0) .EQ("concat", 0) .EQ("tanh", 0) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index ef8a2b38f20..76ff1084fa6 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_version_registry.h" + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -817,3 +819,36 @@ REGISTER_OP_CPU_KERNEL( conv3d_grad_grad, ops::GemmConvDoubleGradKernel, ops::GemmConvDoubleGradKernel); + +REGISTER_OP_VERSION(conv2d) + .AddCheckpoint( + R"ROC( + Upgrade conv2d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false)); + +REGISTER_OP_VERSION(depthwise_conv2d) + .AddCheckpoint( + R"ROC( + Upgrade depthwise_conv2d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false)); + +REGISTER_OP_VERSION(conv3d) + .AddCheckpoint( + R"ROC( + Upgrade conv3d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false)); -- GitLab