未验证 提交 ee49994f 编写于 作者: G gem5 提交者: GitHub

Unify the pass of the map class (#49568)

上级 5e835d36
......@@ -76,7 +76,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(map_depthwise_conv_to_conv_pass inference)
pass_library(map_op_to_another_pass inference)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference)
......
......@@ -28,7 +28,6 @@ class Graph;
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("inplace_op_var", graph);
int found_subgraph_count = 0;
MapToReshape(graph);
auto nodes = graph->Nodes();
auto is_valid_reshape = [](Node* node) {
......@@ -98,26 +97,6 @@ void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_subgraph_count);
}
void InplaceOpVarPass::MapToReshape(ir::Graph* graph) const {
// flatten_contiguous_range op map to reshape.
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "flatten_contiguous_range") {
auto* op_node = node->Op();
auto start_axis = PADDLE_GET_CONST(int, op_node->GetAttr("start_axis"));
auto stop_axis = PADDLE_GET_CONST(int, op_node->GetAttr("stop_axis"));
auto input_name = op_node->Input("X")[0];
auto* block = op_node->Block();
auto input_shape = block->FindVar(input_name)->GetShape();
if (start_axis == 1 && stop_axis == 3 && input_shape.size() == 4 &&
input_shape[2] == 1 && input_shape[3] == 1) {
op_node->SetType("reshape2");
op_node->SetAttr("shape", std::vector<int>{0, -1});
op_node->Flush();
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -28,7 +28,6 @@ class InplaceOpVarPass : public FusePassBase {
private:
virtual ~InplaceOpVarPass() = default;
void MapToReshape(ir::Graph* graph) const;
};
} // namespace ir
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h"
#include "paddle/fluid/framework/ir/map_op_to_another_pass.h"
#include <string>
......@@ -23,14 +23,15 @@ namespace paddle {
namespace framework {
namespace ir {
void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
void MapOp2AnotherPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("map_depthwise_conv_to_conv_pass", graph);
FusePassBase::Init("map_op_to_another_pass", graph);
int found_count = 0;
std::unordered_map<std::string, std::string> replaced_map{
{"depthwise_conv2d", "conv2d"},
{"flatten_contiguous_range", "reshape2"},
};
auto nodes = graph->Nodes();
......@@ -40,8 +41,21 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
auto* op_desc = node->Op();
std::string op_type = op_desc->Type();
if (!replaced_map.count(op_type)) continue;
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
if (op_type == "flatten_contiguous_range") {
auto start_axis = PADDLE_GET_CONST(int, op_desc->GetAttr("start_axis"));
auto stop_axis = PADDLE_GET_CONST(int, op_desc->GetAttr("stop_axis"));
auto input_name = op_desc->Input("X")[0];
auto* block = op_desc->Block();
auto input_shape = block->FindVar(input_name)->GetShape();
if (start_axis == 1 && stop_axis == 3 && input_shape.size() == 4 &&
input_shape[2] == 1 && input_shape[3] == 1) {
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("shape", std::vector<int>{0, -1});
}
} else if (op_type == "depthwise_conv2d") {
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
}
op_desc->Flush();
++found_count;
}
......@@ -53,10 +67,11 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework
} // namespace paddle
REGISTER_PASS(map_depthwise_conv_to_conv_pass,
paddle::framework::ir::MapDepthwiseConv2ConvPass);
REGISTER_PASS_CAPABILITY(map_depthwise_conv_to_conv_pass)
REGISTER_PASS(map_op_to_another_pass, paddle::framework::ir::MapOp2AnotherPass);
REGISTER_PASS_CAPABILITY(map_op_to_another_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("depthwise_conv2d", 1)
.LE("conv2d", 1));
.LE("conv2d", 1)
.EQ("reshape2", 0)
.EQ("flatten_contiguous_range", 0));
......@@ -22,10 +22,10 @@ namespace paddle {
namespace framework {
namespace ir {
class MapDepthwiseConv2ConvPass : public FusePassBase {
class MapOp2AnotherPass : public FusePassBase {
public:
MapDepthwiseConv2ConvPass() = default;
virtual ~MapDepthwiseConv2ConvPass() = default;
MapOp2AnotherPass() = default;
virtual ~MapOp2AnotherPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
......
......@@ -162,12 +162,12 @@ const std::vector<std::string> kLiteSubgraphPasses({
// support fp16/bf16 precision, temporarily use low precision pass to prevent
// running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{
"map_op_to_another_pass",
"identity_scale_op_clean_pass",
"simplify_with_basic_ops_pass",
"silu_fuse_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
"map_depthwise_conv_to_conv_pass",
"conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass",
......@@ -212,12 +212,12 @@ const std::vector<std::string> kCINNCompilerPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"identity_scale_op_clean_pass", //
"map_op_to_another_pass", //
"identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", //
"delete_weight_dequant_linear_op_pass", //
"map_depthwise_conv_to_conv_pass", //
"constant_folding_pass", //
"silu_fuse_pass", //
"conv_bn_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册