未验证 提交 54af52b0 编写于 作者: F feng_shuai 提交者: GitHub

Shuffle channel detect pass (#33814)

上级 cc5d4b1a
......@@ -30,6 +30,44 @@ namespace ir {
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
ShuffleChannelDetectPass::ShuffleChannelDetectPass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsOptional()
.IsTensor()
.End()
.AddInput("ShapeTensor")
.IsOptional()
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
}
void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "shufflechannel_pattern";
FusePassBase::Init(pattern_name, graph);
......@@ -46,7 +84,10 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "The Pass in op compat failed.";
return;
}
PADDLE_ENFORCE_GT(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input X."));
......
......@@ -26,6 +26,7 @@ class Graph;
class ShuffleChannelDetectPass : public FusePassBase {
public:
ShuffleChannelDetectPass();
virtual ~ShuffleChannelDetectPass() {}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册