diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 96bb7c53271a969048134f95c417de5b8e76adeb..06ea7acb3315e110940e3b06a4edd9d8b9052b60 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -76,6 +76,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) pass_library(multi_batch_merge_pass base) +pass_library(map_depthwise_conv_to_conv_pass inference) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.cc b/paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..341fedcd4bacd56f3e478d31c12827f31e69a8c7 --- /dev/null +++ b/paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h" + +#include + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init("map_depthwise_conv_to_conv_pass", graph); + + int found_count = 0; + std::unordered_map replaced_map{ + {"depthwise_conv2d", "conv2d"}, + }; + + auto nodes = graph->Nodes(); + + for (auto& node : nodes) { + if (!node->IsOp()) continue; + auto* op_desc = node->Op(); + std::string op_type = op_desc->Type(); + if (!replaced_map.count(op_type)) continue; + op_desc->SetType(replaced_map[op_type]); + op_desc->Flush(); + ++found_count; + } + + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(map_depthwise_conv_to_conv_pass, + paddle::framework::ir::MapDepthwiseConv2ConvPass); +REGISTER_PASS_CAPABILITY(map_depthwise_conv_to_conv_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("depthwise_conv2d", 1) + .LE("conv2d", 1)); diff --git a/paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h b/paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..1691ab697346564ac55d2a1a7b19a55c853a9986 --- /dev/null +++ b/paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h @@ -0,0 +1,36 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class MapDepthwiseConv2ConvPass : public FusePassBase { + public: + MapDepthwiseConv2ConvPass() = default; + virtual ~MapDepthwiseConv2ConvPass() = default; + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc old mode 100755 new mode 100644 index 062264222b255389d13b56042d8ccbd4f5b134b7..c964ce7e4d0d22e6ac733fd3e4be5c789b672c50 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -165,6 +165,7 @@ const std::vector kLiteSubgraphPasses({ // running errors. After fusion operator supports low precision, delete this. const std::vector kGpuLowerPrecisionPasses{ "simplify_with_basic_ops_pass", + "map_depthwise_conv_to_conv_pass", "conv_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass", "conv_elementwise_add_act_fuse_pass", @@ -202,8 +203,9 @@ const std::vector kTrtLowerPrecisionPasses{ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ // "identity_scale_op_clean_pass", // - "is_test_pass", // - "simplify_with_basic_ops_pass", // + "is_test_pass", // + "simplify_with_basic_ops_pass", // + "map_depthwise_conv_to_conv_pass", "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //