未验证 提交 3daf5185 编写于 作者: W Wilber 提交者: GitHub

add map_depthwise_conv_to_conv pass (#47955)

上级 67204c18
......@@ -76,6 +76,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(map_depthwise_conv_to_conv_pass inference)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("map_depthwise_conv_to_conv_pass", graph);
int found_count = 0;
std::unordered_map<std::string, std::string> replaced_map{
{"depthwise_conv2d", "conv2d"},
};
auto nodes = graph->Nodes();
for (auto& node : nodes) {
if (!node->IsOp()) continue;
auto* op_desc = node->Op();
std::string op_type = op_desc->Type();
if (!replaced_map.count(op_type)) continue;
op_desc->SetType(replaced_map[op_type]);
op_desc->Flush();
++found_count;
}
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(map_depthwise_conv_to_conv_pass,
paddle::framework::ir::MapDepthwiseConv2ConvPass);
REGISTER_PASS_CAPABILITY(map_depthwise_conv_to_conv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("depthwise_conv2d", 1)
.LE("conv2d", 1));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class MapDepthwiseConv2ConvPass : public FusePassBase {
public:
MapDepthwiseConv2ConvPass() = default;
virtual ~MapDepthwiseConv2ConvPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -165,6 +165,7 @@ const std::vector<std::string> kLiteSubgraphPasses({
// running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{
"simplify_with_basic_ops_pass",
"map_depthwise_conv_to_conv_pass",
"conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass",
......@@ -202,8 +203,9 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
// "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"map_depthwise_conv_to_conv_pass",
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册