未验证 提交 2e97faf1 编写于 作者: W Wangzheee 提交者: GitHub

[ pass_enhance ]transpose_flatten_concat_fuse_pass (#33744)

上级 07eeb36e
...@@ -19,7 +19,50 @@ namespace paddle { ...@@ -19,7 +19,50 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { TransposeFlattenConcatFusePass::TransposeFlattenConcatFusePass() {
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("flatten2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(0)
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, 1})
.End();
}
void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse(
ir::Graph *graph, int times) const {
const std::string pattern_name = const std::string pattern_name =
"transpose_flatten" + std::to_string(times) + "_concat_fuse"; "transpose_flatten" + std::to_string(times) + "_concat_fuse";
...@@ -37,6 +80,10 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { ...@@ -37,6 +80,10 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) {
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
const int kNumFields = 5; const int kNumFields = 5;
const int kTransOffset = 1; const int kTransOffset = 1;
const int kTransOutOffset = 2; const int kTransOutOffset = 2;
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <memory> #include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -28,10 +27,14 @@ namespace ir { ...@@ -28,10 +27,14 @@ namespace ir {
// structure. // structure.
class TransposeFlattenConcatFusePass : public FusePassBase { class TransposeFlattenConcatFusePass : public FusePassBase {
public: public:
TransposeFlattenConcatFusePass();
virtual ~TransposeFlattenConcatFusePass() {} virtual ~TransposeFlattenConcatFusePass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private:
void RunTransposeFlattenConcatFuse(ir::Graph* graph, int times) const;
}; };
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册