未验证 提交 a9f81534 编写于 作者: H heliqi 提交者: GitHub

add transpose_flatten_concat_fuse_pass test case (#37675)

* add transpose_flatten_concat pass

* modify skip func to ignore_pass_case func

* delete input_shape limit

* modify get node order
上级 59bd1e6f
......@@ -84,13 +84,16 @@ void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse(
LOG(WARNING) << "Pass in op compat failed.";
return;
}
const int kNumFields = 5;
const int kTransOffset = 1;
const int kTransOutOffset = 2;
const int kFlattenOffset = 3;
const int kFlattenOutOffset = 4;
std::vector<Node *> nodes;
std::vector<Node *> nodes;
std::vector<int> trans_axis0;
int flatten_axis0;
for (int i = 0; i < times; i++) {
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i))),
......@@ -112,6 +115,33 @@ void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse(
platform::errors::NotFound("Can not find %s in subgraph.",
input_nodes[i]->name()));
if (i == 0) {
trans_axis0 = BOOST_GET_CONST(
std::vector<int>,
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(0)))
->Op()
->GetAttr("axis"));
flatten_axis0 = BOOST_GET_CONST(
int, subgraph.at(pattern.GetPDNode("flatten" + std::to_string(0)))
->Op()
->GetAttr("axis"));
} else {
std::vector<int> trans_axis = BOOST_GET_CONST(
std::vector<int>,
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i)))
->Op()
->GetAttr("axis"));
// All axis of transpose should be the same
if (trans_axis0 != trans_axis) return;
int flatten_axis = BOOST_GET_CONST(
int, subgraph.at(pattern.GetPDNode("flatten" + std::to_string(0)))
->Op()
->GetAttr("axis"));
// All axis of flatten should be the same
if (flatten_axis0 != flatten_axis) return;
}
nodes.push_back(subgraph.at(input_nodes[i]));
nodes.push_back(
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i))));
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,72 +12,147 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TransposeFlattenConcatFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(name="data1", shape=[5, 5, 5], dtype="float32")
data2 = fluid.data(name="data2", shape=[5, 5, 5], dtype="float32")
trans1 = fluid.layers.transpose(data1, perm=[2, 1, 0])
trans2 = fluid.layers.transpose(data2, perm=[2, 1, 0])
flatt1 = fluid.layers.flatten(trans1)
flatt2 = fluid.layers.flatten(trans2)
concat_out = fluid.layers.concat([flatt1, flatt2])
# There is no parameters for above structure.
# Hence, append a batch_norm to avoid failure caused by load_combined.
out = fluid.layers.batch_norm(concat_out, is_test=True)
self.feeds = {
"data1": np.random.random([5, 5, 5]).astype("float32"),
"data2": np.random.random([5, 5, 5]).astype("float32")
}
self.fetch_list = [out]
class TestTransposeFlattenConcatFusePass(PassAutoScanTest):
"""
x_1_var x_2_var
| |
transpose2 transpose2
| |
flatten2 flatten2
\ /
flatten2_out_var flatten2_out_var
\ /
concat
"""
def test_check_output(self):
# There is no cpu pass for transpose_flatten_concat_fuse
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
def sample_predictor_configs(self, program_config):
# TRT
# after tensorrt_subgraph_pass ,The pass needs to be deleted on TRT
PassVersionChecker.IsCompatible('transpose_flatten_concat_fuse_pass')
# for gpu
config = self.create_inference_config(use_gpu=True)
yield config, ["fusion_transpose_flatten_concat", ], (1e-5, 1e-5)
def is_program_valid(self, prog_config):
concat_axis = prog_config.ops[-1].attrs["axis"]
ops_num = len(prog_config.ops) - 1
if ops_num % 2 != 0:
return False
input_num = ops_num // 2
flatten_shape = 0
x_trans_axis = prog_config.ops[0].attrs["axis"]
x_flatten_axis = prog_config.ops[1].attrs["axis"]
for i in range(input_num):
input_name = "transpose2_x" + str(i)
input_shape = prog_config.inputs[input_name].shape
trans_axis = prog_config.ops[i * 2].attrs["axis"]
if x_trans_axis != trans_axis:
return False
# calculate shape after transpose
input_shape = [input_shape[j] for j in trans_axis]
# calculate shape after flateen
flatten_axis = prog_config.ops[i * 2 + 1].attrs["axis"]
if x_flatten_axis != flatten_axis:
return False
flatten_shape1 = flatten_shape2 = 1
for j in range(len(input_shape)):
if j < flatten_axis:
flatten_shape1 *= input_shape[j]
else:
flatten_shape2 *= input_shape[j]
if concat_axis == 0:
if i == 0:
flatten_shape = flatten_shape2
elif flatten_shape != flatten_shape2:
return False
else:
if i == 0:
flatten_shape = flatten_shape1
elif flatten_shape != flatten_shape1:
return False
return True
class TransposeFlattenConcatFusePassWithAxisTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(name="data1", shape=[5, 5, 5], dtype="float32")
data2 = fluid.data(name="data2", shape=[5, 5, 5], dtype="float32")
trans1 = fluid.layers.transpose(data1, perm=[2, 1, 0])
trans2 = fluid.layers.transpose(data2, perm=[2, 1, 0])
flatt1 = fluid.layers.flatten(trans1, axis=2)
flatt2 = fluid.layers.flatten(trans2, axis=2)
concat_out = fluid.layers.concat([flatt1, flatt2], axis=1)
# There is no parameters for above structure.
# Hence, append a batch_norm to avoid failure caused by load_combined.
out = fluid.layers.batch_norm(concat_out, is_test=True)
def sample_program_config(self, draw):
times = draw(st.integers(min_value=1, max_value=6))
concat_axis = draw(st.integers(min_value=0, max_value=1))
ops = []
concat_input = []
inputs = {}
x_shape_rank = draw(st.integers(min_value=2, max_value=5))
# Generate axis of transpose
trans_axis = [j for j in range(x_shape_rank)]
for j in range(x_shape_rank - 1):
if draw(st.booleans()):
trans_axis[j], trans_axis[-1] = trans_axis[-1], trans_axis[j]
# Generate axis of flatten
flatten_axis = draw(
st.integers(
min_value=0, max_value=x_shape_rank - 1))
for i in range(times):
# Generate x_shape of transpose
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=10),
min_size=x_shape_rank,
max_size=x_shape_rank))
self.feeds = {
"data1": np.random.random([5, 5, 5]).astype("float32"),
"data2": np.random.random([5, 5, 5]).astype("float32")
}
self.fetch_list = [out]
str_i = str(i)
transpose_op = OpConfig(
"transpose2",
inputs={"X": ["transpose2_x" + str_i], },
axis=trans_axis,
outputs={
"Out": ["trans_out" + str_i],
"XShape": ["trans_shape" + str_i]
}, )
ops.append(transpose_op)
flatten_op = OpConfig(
"flatten2",
inputs={"X": ["trans_out" + str_i], },
axis=flatten_axis,
outputs={
"Out": ["flatten2_out" + str_i],
"XShape": ["xshape" + str_i]
}, )
concat_input.append("flatten2_out" + str_i)
ops.append(flatten_op)
inputs["transpose2_x" + str_i] = TensorConfig(shape=x_shape)
def test_check_output(self):
# There is no cpu pass for transpose_flatten_concat_fuse
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
concat_op = OpConfig(
"concat",
inputs={
"X": concat_input,
"AxisTensor": [],
},
outputs={"Out": ["concat_out"]},
axis=concat_axis, )
self.assertTrue(
PassVersionChecker.IsCompatible(
'transpose_flatten_concat_fuse_pass'))
ops.append(concat_op)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs=inputs,
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=300,
passes=["transpose_flatten_concat_fuse_pass"])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册