未验证 提交 90805e2d 编写于 作者: L Leo Chen 提交者: GitHub

Register op_version for new attribute use_addto (#28463)

* register op_version for addto

* upgrade pass capability

* change eq to le

* change eq to le

* fix merge
上级 1c3eef4c
...@@ -238,11 +238,11 @@ REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass, ...@@ -238,11 +238,11 @@ REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass) REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("affine_channel", 0)); .EQ("affine_channel", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass) REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("elementwise_add", 0) .EQ("elementwise_add", 0)
.EQ("affine_channel", 0)); .EQ("affine_channel", 0));
...@@ -383,11 +383,11 @@ REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, ...@@ -383,11 +383,11 @@ REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("batch_norm", 0)); .EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("elementwise_add", 0) .EQ("elementwise_add", 0)
.EQ("batch_norm", 0)); .EQ("batch_norm", 0));
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
...@@ -119,7 +121,7 @@ REGISTER_PASS(conv_elementwise_add2_act_fuse_pass, ...@@ -119,7 +121,7 @@ REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass) REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("elementwise_add", 0) .EQ("elementwise_add", 0)
.EQ("relu", 0) .EQ("relu", 0)
.EQ("identity", 0)); .EQ("identity", 0));
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
...@@ -107,7 +108,7 @@ REGISTER_PASS(conv_elementwise_add_act_fuse_pass, ...@@ -107,7 +108,7 @@ REGISTER_PASS(conv_elementwise_add_act_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("elementwise_add", 0) .EQ("elementwise_add", 0)
.EQ("relu", 0) .EQ("relu", 0)
.EQ("identity", 0)); .EQ("identity", 0));
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
...@@ -93,5 +94,5 @@ REGISTER_PASS(conv_elementwise_add_fuse_pass, ...@@ -93,5 +94,5 @@ REGISTER_PASS(conv_elementwise_add_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass) REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("elementwise_add", 0)); .EQ("elementwise_add", 0));
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -107,7 +109,7 @@ REGISTER_PASS(conv_relu_mkldnn_fuse_pass, ...@@ -107,7 +109,7 @@ REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("relu", 0)); .EQ("relu", 0));
REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
...@@ -115,7 +117,7 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, ...@@ -115,7 +117,7 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.LE("leaky_relu", 1)); .LE("leaky_relu", 1));
REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
...@@ -123,7 +125,7 @@ REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, ...@@ -123,7 +125,7 @@ REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("relu6", 0)); .EQ("relu6", 0));
REGISTER_PASS(conv_swish_mkldnn_fuse_pass, REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
...@@ -131,5 +133,5 @@ REGISTER_PASS(conv_swish_mkldnn_fuse_pass, ...@@ -131,5 +133,5 @@ REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("swish", 0)); .EQ("swish", 0));
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"
#include <functional> #include <functional>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -150,7 +152,7 @@ REGISTER_PASS(conv_bias_mkldnn_fuse_pass, ...@@ -150,7 +152,7 @@ REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("elementwise_add", 0)); .EQ("elementwise_add", 0));
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -128,6 +130,6 @@ REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass, ...@@ -128,6 +130,6 @@ REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("concat", 0) .EQ("concat", 0)
.EQ("relu", 0)); .EQ("relu", 0));
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional> #include <functional>
#include <list> #include <list>
#include <map> #include <map>
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
...@@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( auto get_node_from_elementwise_add =
const GraphPatternDetector::subgraph_t& subgraph) [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { -> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern); elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_y, return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out); elementwise_add_out);
}; };
return ExecuteHandleOnGraph<IdentityFuseHandle>( return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats, &gpd, graph_with_stats,
...@@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output); conv_output);
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( auto get_node_from_elementwise_add =
const GraphPatternDetector::subgraph_t& subgraph) [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { -> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern); elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_x, return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out); elementwise_add_out);
}; };
return ExecuteHandleOnGraph<IdentityFuseHandle>( return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats, &gpd, graph_with_stats,
...@@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output->AsIntermediate(); conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate(); conv_y_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( auto get_node_from_elementwise_add =
const GraphPatternDetector::subgraph_t& subgraph) [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*> { -> std::tuple<Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern); elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_out); return std::make_tuple(elementwise_add_op, elementwise_add_out);
}; };
return ExecuteHandleOnGraph<ProjectionFuseHandle>( return ExecuteHandleOnGraph<ProjectionFuseHandle>(
&gpd, graph_with_stats, &gpd, graph_with_stats,
...@@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, ...@@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("elementwise_add", 0)); .EQ("elementwise_add", 0));
...@@ -63,5 +63,5 @@ REGISTER_PASS(depthwise_conv_mkldnn_pass, ...@@ -63,5 +63,5 @@ REGISTER_PASS(depthwise_conv_mkldnn_pass,
paddle::framework::ir::DepthwiseConvMKLDNNPass); paddle::framework::ir::DepthwiseConvMKLDNNPass);
REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass) REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ( paddle::framework::compatible::OpVersionComparatorCombination().LE(
"depthwise_conv2d", 0)); "depthwise_conv2d", 1));
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
...@@ -331,7 +332,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass); ...@@ -331,7 +332,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass);
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("fc", 0) .EQ("fc", 0)
.LE("conv2d_transpose", 1) .LE("conv2d_transpose", 1)
.EQ("fake_quantize_abs_max", 0) .EQ("fake_quantize_abs_max", 0)
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include <set> #include <set>
...@@ -20,7 +22,6 @@ ...@@ -20,7 +22,6 @@
#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/tensorrt/op_teller.h"
...@@ -309,6 +310,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -309,6 +310,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
min_input_shape, max_input_shape, opt_input_shape, min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16); disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss")); trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetWithErnie( trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) && graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)); graph->Has(framework::ir::kMultiheadMatmulPass));
...@@ -367,13 +369,13 @@ REGISTER_PASS(tensorrt_subgraph_pass, ...@@ -367,13 +369,13 @@ REGISTER_PASS(tensorrt_subgraph_pass,
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0) .LE("conv2d", 1)
.EQ("pool2d", 0) .EQ("pool2d", 0)
.EQ("relu", 0) .EQ("relu", 0)
.EQ("softmax", 0) .EQ("softmax", 0)
.EQ("sigmoid", 0) .EQ("sigmoid", 0)
.EQ("hard_swish", 0) .EQ("hard_swish", 0)
.EQ("depthwise_conv2d", 0) .LE("depthwise_conv2d", 1)
.EQ("batch_norm", 0) .EQ("batch_norm", 0)
.EQ("concat", 0) .EQ("concat", 0)
.EQ("tanh", 0) .EQ("tanh", 0)
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -817,3 +819,36 @@ REGISTER_OP_CPU_KERNEL( ...@@ -817,3 +819,36 @@ REGISTER_OP_CPU_KERNEL(
conv3d_grad_grad, conv3d_grad_grad,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, float>, ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, double>); ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(conv2d)
.AddCheckpoint(
R"ROC(
Upgrade conv2d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
REGISTER_OP_VERSION(depthwise_conv2d)
.AddCheckpoint(
R"ROC(
Upgrade depthwise_conv2d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
REGISTER_OP_VERSION(conv3d)
.AddCheckpoint(
R"ROC(
Upgrade conv3d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册