diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index 9c984a23e377d749947a61793838956079a3678b..c0ebf6de9de23bf7074bcdb5e6f669a059b4d720 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -238,11 +238,11 @@ REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass, REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("affine_channel", 0)); REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("affine_channel", 0)); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index a915015bf55bd8a93fcc8311abc871e11cb9402d..72ac7c3b0e8ab8ff192352ebbcd80cde80f35825 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -383,11 +383,11 @@ REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("batch_norm", 0)); REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("batch_norm", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index ad6af69ae02e4f6262ee8760dbda90e0b5833feb..545beb34e78df521b6469f952063f83c5ee52e33 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -119,7 +121,7 @@ REGISTER_PASS(conv_elementwise_add2_act_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("relu", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 93e6e13ff7092c80958d5defb11e2e456298c7b7..d01a2f2622347c37d889ed19ad78e5afbd60c007 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h" + #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -107,7 +108,7 @@ REGISTER_PASS(conv_elementwise_add_act_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("relu", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index e4396f227f7f5280cfd3057aebfed8d02480d154..e34a2d96581531001678de3dd4e326f70d8e035c 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h" + #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -93,5 +94,5 @@ REGISTER_PASS(conv_elementwise_add_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index c33398553ecd2cbe291e9cc605aa23ce318e9efe..d0bdeb9ad8c46004746215d4df43f6359cb69a26 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -107,7 +109,7 @@ REGISTER_PASS(conv_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("relu", 0)); REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, @@ -115,7 +117,7 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .LE("leaky_relu", 1)); REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, @@ -123,7 +125,7 @@ REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("relu6", 0)); REGISTER_PASS(conv_swish_mkldnn_fuse_pass, @@ -131,5 +133,5 @@ REGISTER_PASS(conv_swish_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("swish", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 716c49dcb12d9b432dfddd54d1dc3fa33570f26f..b0849d74b6153ff00689a86f9c2f1c58cbca62f3 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" + #include #include + #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -150,7 +152,7 @@ REGISTER_PASS(conv_bias_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index 76e102125501144cbfd06ced2c88b4f1e02e261b..c4d7a12037293e87b84b7395a9981d95fc2ee1e8 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -128,6 +130,6 @@ REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("concat", 0) .EQ("relu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index 2fb131aceaad28a365e8202dca35cfe53f8f54da..a837b42b3ead48d8f852c09ed97dda1c7b0f08d2 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -13,11 +13,13 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" + #include #include #include #include #include + #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_y, - elementwise_add_out); - }; + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_y, + elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( conv_output); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_x, - elementwise_add_out); - }; + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_x, + elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( conv_x_output->AsIntermediate(); conv_y_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); - return std::make_tuple(elementwise_add_op, elementwise_add_out); - }; + return std::make_tuple(elementwise_add_op, elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc index b2c0afdc754fb7aa3b3ffaf09e5b1961c080bcd6..39f47406a77ca9e11f588029678d1ca6c1e48372 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc @@ -63,5 +63,5 @@ REGISTER_PASS(depthwise_conv_mkldnn_pass, paddle::framework::ir::DepthwiseConvMKLDNNPass); REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass) .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination().EQ( - "depthwise_conv2d", 0)); + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "depthwise_conv2d", 1)); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 895c396e1e614fb06c37d519b45c942429bbf9a2..96c5546d21208b2708774a78bb7fe693b9a440a5 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h" + #include #include #include #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" -#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -331,7 +332,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass); REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("fc", 0) .LE("conv2d_transpose", 1) .EQ("fake_quantize_abs_max", 0) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 08f3d609fa3e6ad32c7751fe9178bc8a83463f43..bf0d87da91f534ef8470636448a698074485be55 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" + #include #include #include @@ -20,7 +22,6 @@ #include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/inference/analysis/helper.h" -#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/op_teller.h" @@ -309,6 +310,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( min_input_shape, max_input_shape, opt_input_shape, disable_trt_plugin_fp16); trt_engine->SetUseOSS(Get("use_oss")); + trt_engine->SetWithErnie( graph->Has(framework::ir::kEmbEltwiseLayernormPass) && graph->Has(framework::ir::kMultiheadMatmulPass)); @@ -367,13 +369,13 @@ REGISTER_PASS(tensorrt_subgraph_pass, REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("pool2d", 0) .EQ("relu", 0) .EQ("softmax", 0) .EQ("sigmoid", 0) .EQ("hard_swish", 0) - .EQ("depthwise_conv2d", 0) + .LE("depthwise_conv2d", 1) .EQ("batch_norm", 0) .EQ("concat", 0) .EQ("tanh", 0) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index ef8a2b38f20b99f0b1e41ddc1976f88dd8d1f5ab..76ff1084fa61b4cc7fec3a59f39b956ec6582998 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_version_registry.h" + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -817,3 +819,36 @@ REGISTER_OP_CPU_KERNEL( conv3d_grad_grad, ops::GemmConvDoubleGradKernel, ops::GemmConvDoubleGradKernel); + +REGISTER_OP_VERSION(conv2d) + .AddCheckpoint( + R"ROC( + Upgrade conv2d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false)); + +REGISTER_OP_VERSION(depthwise_conv2d) + .AddCheckpoint( + R"ROC( + Upgrade depthwise_conv2d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false)); + +REGISTER_OP_VERSION(conv3d) + .AddCheckpoint( + R"ROC( + Upgrade conv3d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false));