未验证 提交 ee49994f 编写于 作者: G gem5 提交者: GitHub

Unify the pass of the map class (#49568)

上级 5e835d36
...@@ -76,7 +76,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference) ...@@ -76,7 +76,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base) pass_library(multi_batch_merge_pass base)
pass_library(map_depthwise_conv_to_conv_pass inference) pass_library(map_op_to_another_pass inference)
pass_library(conv_bn_fuse_pass inference) pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(seqpool_concat_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference)
......
...@@ -28,7 +28,6 @@ class Graph; ...@@ -28,7 +28,6 @@ class Graph;
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("inplace_op_var", graph); FusePassBase::Init("inplace_op_var", graph);
int found_subgraph_count = 0; int found_subgraph_count = 0;
MapToReshape(graph);
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
auto is_valid_reshape = [](Node* node) { auto is_valid_reshape = [](Node* node) {
...@@ -98,26 +97,6 @@ void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { ...@@ -98,26 +97,6 @@ void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
void InplaceOpVarPass::MapToReshape(ir::Graph* graph) const {
// flatten_contiguous_range op map to reshape.
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "flatten_contiguous_range") {
auto* op_node = node->Op();
auto start_axis = PADDLE_GET_CONST(int, op_node->GetAttr("start_axis"));
auto stop_axis = PADDLE_GET_CONST(int, op_node->GetAttr("stop_axis"));
auto input_name = op_node->Input("X")[0];
auto* block = op_node->Block();
auto input_shape = block->FindVar(input_name)->GetShape();
if (start_axis == 1 && stop_axis == 3 && input_shape.size() == 4 &&
input_shape[2] == 1 && input_shape[3] == 1) {
op_node->SetType("reshape2");
op_node->SetAttr("shape", std::vector<int>{0, -1});
op_node->Flush();
}
}
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -28,7 +28,6 @@ class InplaceOpVarPass : public FusePassBase { ...@@ -28,7 +28,6 @@ class InplaceOpVarPass : public FusePassBase {
private: private:
virtual ~InplaceOpVarPass() = default; virtual ~InplaceOpVarPass() = default;
void MapToReshape(ir::Graph* graph) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/map_depthwise_conv_to_conv_pass.h" #include "paddle/fluid/framework/ir/map_op_to_another_pass.h"
#include <string> #include <string>
...@@ -23,14 +23,15 @@ namespace paddle { ...@@ -23,14 +23,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const { void MapOp2AnotherPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("map_depthwise_conv_to_conv_pass", graph); FusePassBase::Init("map_op_to_another_pass", graph);
int found_count = 0; int found_count = 0;
std::unordered_map<std::string, std::string> replaced_map{ std::unordered_map<std::string, std::string> replaced_map{
{"depthwise_conv2d", "conv2d"}, {"depthwise_conv2d", "conv2d"},
{"flatten_contiguous_range", "reshape2"},
}; };
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
...@@ -40,8 +41,21 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const { ...@@ -40,8 +41,21 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
auto* op_desc = node->Op(); auto* op_desc = node->Op();
std::string op_type = op_desc->Type(); std::string op_type = op_desc->Type();
if (!replaced_map.count(op_type)) continue; if (!replaced_map.count(op_type)) continue;
op_desc->SetType(replaced_map[op_type]); if (op_type == "flatten_contiguous_range") {
op_desc->SetAttr("use_cudnn", true); auto start_axis = PADDLE_GET_CONST(int, op_desc->GetAttr("start_axis"));
auto stop_axis = PADDLE_GET_CONST(int, op_desc->GetAttr("stop_axis"));
auto input_name = op_desc->Input("X")[0];
auto* block = op_desc->Block();
auto input_shape = block->FindVar(input_name)->GetShape();
if (start_axis == 1 && stop_axis == 3 && input_shape.size() == 4 &&
input_shape[2] == 1 && input_shape[3] == 1) {
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("shape", std::vector<int>{0, -1});
}
} else if (op_type == "depthwise_conv2d") {
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
}
op_desc->Flush(); op_desc->Flush();
++found_count; ++found_count;
} }
...@@ -53,10 +67,11 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const { ...@@ -53,10 +67,11 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(map_depthwise_conv_to_conv_pass, REGISTER_PASS(map_op_to_another_pass, paddle::framework::ir::MapOp2AnotherPass);
paddle::framework::ir::MapDepthwiseConv2ConvPass); REGISTER_PASS_CAPABILITY(map_op_to_another_pass)
REGISTER_PASS_CAPABILITY(map_depthwise_conv_to_conv_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("depthwise_conv2d", 1) .LE("depthwise_conv2d", 1)
.LE("conv2d", 1)); .LE("conv2d", 1)
.EQ("reshape2", 0)
.EQ("flatten_contiguous_range", 0));
...@@ -22,10 +22,10 @@ namespace paddle { ...@@ -22,10 +22,10 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class MapDepthwiseConv2ConvPass : public FusePassBase { class MapOp2AnotherPass : public FusePassBase {
public: public:
MapDepthwiseConv2ConvPass() = default; MapOp2AnotherPass() = default;
virtual ~MapDepthwiseConv2ConvPass() = default; virtual ~MapOp2AnotherPass() = default;
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
......
...@@ -162,12 +162,12 @@ const std::vector<std::string> kLiteSubgraphPasses({ ...@@ -162,12 +162,12 @@ const std::vector<std::string> kLiteSubgraphPasses({
// support fp16/bf16 precision, temporarily use low precision pass to prevent // support fp16/bf16 precision, temporarily use low precision pass to prevent
// running errors. After fusion operator supports low precision, delete this. // running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{ const std::vector<std::string> kGpuLowerPrecisionPasses{
"map_op_to_another_pass",
"identity_scale_op_clean_pass", "identity_scale_op_clean_pass",
"simplify_with_basic_ops_pass", "simplify_with_basic_ops_pass",
"silu_fuse_pass", "silu_fuse_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass",
"map_depthwise_conv_to_conv_pass",
"conv_bn_fuse_pass", "conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass", "conv_elementwise_add_act_fuse_pass",
...@@ -212,12 +212,12 @@ const std::vector<std::string> kCINNCompilerPasses{ ...@@ -212,12 +212,12 @@ const std::vector<std::string> kCINNCompilerPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"identity_scale_op_clean_pass", // "map_op_to_another_pass", //
"identity_scale_op_clean_pass", //
"is_test_pass", // "is_test_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", // "delete_quant_dequant_linear_op_pass", //
"delete_weight_dequant_linear_op_pass", // "delete_weight_dequant_linear_op_pass", //
"map_depthwise_conv_to_conv_pass", //
"constant_folding_pass", // "constant_folding_pass", //
"silu_fuse_pass", // "silu_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册