未验证 提交 24418479 编写于 作者: W wuhuanzhou 提交者: GitHub

Verify the correctness of graph rewrited by GeneratePass (#36116)

Check detail PR description at https://github.com/PaddlePaddle/Paddle/pull/36116
上级 9a9953d9
...@@ -21,6 +21,16 @@ namespace ir { ...@@ -21,6 +21,16 @@ namespace ir {
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0); const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
for (const proto::VarDesc& var : block.vars()) {
PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
if (VarDesc(var).GetShape() == x->Var()->GetShape()) {
return true;
}
return false;
});
}
// Traverse all operators to create subgraph. // Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) { for (int index = 0; index < block.ops_size(); ++index) {
const proto::OpDesc& op = block.ops(index); const proto::OpDesc& op = block.ops(index);
...@@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { ...@@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
pattern->NewNode(std::to_string(index))->assert_is_op(op.type()); pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
// Create PDNodes for inputs of current operator. // Create PDNodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) { for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) { for (int n = 0; n < var.arguments_size(); ++n) {
const std::string& argument = var.arguments(n);
// The input may be the output of other operator. // The input may be the output of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument); PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) { if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsInput(); var_pdnode = pattern->NewNode(argument)->AsInput();
var_pdnode->assert_is_var();
} else if (var_pdnode->IsOutput()) { } else if (var_pdnode->IsOutput()) {
var_pdnode->AsIntermediate(); var_pdnode->AsIntermediate();
} }
var_pdnode->assert_is_op_input(op.type()); var_pdnode->assert_more([&](Node* x) {
for (auto* out : x->outputs) {
if (out->IsOp() && out->Op()->Type() == op.type()) {
const auto& inputs = out->Op()->Inputs();
const auto& iter = inputs.find(var.parameter());
if (inputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
pattern->AddEdge(var_pdnode, op_pdnode); pattern->AddEdge(var_pdnode, op_pdnode);
} }
} }
...@@ -50,6 +77,24 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { ...@@ -50,6 +77,24 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
PDNode* var_pdnode = pattern->RetrieveNode(argument); PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) { if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsOutput(); var_pdnode = pattern->NewNode(argument)->AsOutput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
for (Node* input : x->inputs) {
if (input && input->IsOp() && input->Op() &&
input->Op()->Type() == op.type()) {
const auto& outputs = input->Op()->Outputs();
const auto& iter = outputs.find(var.parameter());
if (outputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
} else if (var_pdnode->IsInput()) { } else if (var_pdnode->IsInput()) {
var_pdnode->AsIntermediate(); var_pdnode->AsIntermediate();
} }
...@@ -73,17 +118,63 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { ...@@ -73,17 +118,63 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
} }
} }
GraphPatternDetector::handle_t GetGenerateRewrite( // There are some duplicate patterns.
const PDPattern& pattern, const proto::PassDesc& pass_desc) { bool IsDuplicatePattern(const GraphPatternDetector::subgraph_t& subgraph,
GraphPatternDetector::handle_t handler = [&]( Graph* graph) {
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) {
// There are some duplicate patterns.
for (auto iter : subgraph) { for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) { if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name() VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize."; << "] of subgraph has been removed. So skip this optimize.";
return true;
}
}
return false;
}
GraphPatternDetector::handle_t GetGenerateDelete(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return; return;
} }
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
const auto& iter = var_node_maps.find(var_map.replace_var());
if (var_node_maps.end() == iter) {
// first node is input
var_node_maps.insert({var_map.replace_var(), node});
} else {
// output node
for (Node* s_node : node->outputs) {
iter->second->outputs.push_back(s_node);
std::replace(s_node->inputs.begin(), s_node->inputs.end(), node,
iter->second);
s_node->Op()->RenameInput(node->Name(), iter->second->Name());
}
}
}
// Remove nodes that are intermediate.
std::unordered_set<const Node*> remove_nodes;
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) {
remove_nodes.emplace(subgraph.at(pdnode.get()));
}
for (auto iter : var_node_maps) {
remove_nodes.erase(iter.second);
}
GraphSafeRemoveNodes(graph, remove_nodes);
};
return handler;
}
GraphPatternDetector::handle_t GetGenerateRewrite(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return;
} }
const proto::BlockDesc& block = pass_desc.replace().blocks(0); const proto::BlockDesc& block = pass_desc.replace().blocks(0);
// `var_node_maps` record the mapping of variable to the pattern subgraph. // `var_node_maps` record the mapping of variable to the pattern subgraph.
...@@ -175,7 +266,11 @@ void GeneratePass::ApplyImpl(Graph* graph) const { ...@@ -175,7 +266,11 @@ void GeneratePass::ApplyImpl(Graph* graph) const {
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
GraphPatternDetector detector; GraphPatternDetector detector;
InitGeneratePattern(pass_desc, detector.mutable_pattern()); InitGeneratePattern(pass_desc, detector.mutable_pattern());
if (pass_desc.replace().blocks(0).ops_size() == 0) {
detector(graph, GetGenerateDelete(detector.pattern(), pass_desc));
} else {
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
}
// The rewrited graph needs to be verified. Current Pass should be skipped // The rewrited graph needs to be verified. Current Pass should be skipped
// if validation failed. Rewrite based on the original graph cannot // if validation failed. Rewrite based on the original graph cannot
// implement rollback operation. // implement rollback operation.
......
...@@ -127,11 +127,13 @@ def apply_build_strategy(main_program, startup_program, build_strategy, ...@@ -127,11 +127,13 @@ def apply_build_strategy(main_program, startup_program, build_strategy,
class RegisterPassHelper(object): class RegisterPassHelper(object):
_register_helpers = list()
def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()): def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type self._pass_type = pass_type
self._pass_pairs = pass_pairs self._pass_pairs = pass_pairs
if isinstance(input_specs, dict):
self._input_specs = input_specs self._input_specs = input_specs
RegisterPassHelper._register_helpers.append(self)
def _get_args_from_func(self, func): def _get_args_from_func(self, func):
args = list() args = list()
...@@ -148,6 +150,35 @@ class RegisterPassHelper(object): ...@@ -148,6 +150,35 @@ class RegisterPassHelper(object):
args.append(paddle.static.data(arg_name, [-1])) args.append(paddle.static.data(arg_name, [-1]))
return args return args
def _prune_program_desc(self, program_desc):
block_desc = program_desc.blocks[0]
# block_desc.ClearField("vars")
for var in [
var for var in block_desc.vars
if var.name not in self._input_specs
]:
block_desc.vars.remove(var)
for op_desc in block_desc.ops:
default_attrs = core.get_op_attrs_default_value(
paddle.compat.to_bytes(op_desc.type))
remove_attrs = list()
for attr in op_desc.attrs:
# attr must not in
if attr.name not in [
"op_namescope", "op_callstack", "op_device"
]:
attr_list_fields = attr.ListFields()
# attr format must be: name, type, value
if len(attr_list_fields) == 3:
attr_value = attr.ListFields()[-1][-1]
default_attr_value = default_attrs.get(attr.name)
# value must not default
if default_attr_value != attr_value:
continue
remove_attrs.append(attr)
for attr in remove_attrs:
op_desc.attrs.remove(attr)
def _func_to_program_desc(self, func, program_desc, is_replace=False): def _func_to_program_desc(self, func, program_desc, is_replace=False):
vars = list() vars = list()
program = paddle.static.Program() program = paddle.static.Program()
...@@ -166,6 +197,7 @@ class RegisterPassHelper(object): ...@@ -166,6 +197,7 @@ class RegisterPassHelper(object):
elif isinstance(out, paddle.fluid.framework.Variable): elif isinstance(out, paddle.fluid.framework.Variable):
vars.append(out.name) vars.append(out.name)
program_desc.ParseFromString(program.desc.serialize_to_string()) program_desc.ParseFromString(program.desc.serialize_to_string())
self._prune_program_desc(program_desc)
if is_replace: if is_replace:
attrs = list() attrs = list()
for op in program.current_block().ops: for op in program.current_block().ops:
...@@ -296,7 +328,7 @@ class PassDesc(object): ...@@ -296,7 +328,7 @@ class PassDesc(object):
OP = OpHelper() OP = OpHelper()
def RegisterPass(function=None, input_specs=None): def RegisterPass(function=None, input_specs=dict()):
""" """
The function decorator of Register Pass. Decorator @RegisterPass handles The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function the function and register it into a core.Pass instance. Use name of function
...@@ -305,11 +337,11 @@ def RegisterPass(function=None, input_specs=None): ...@@ -305,11 +337,11 @@ def RegisterPass(function=None, input_specs=None):
Args: Args:
function (callable): The function with return of callable pair(s) that function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph. represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is None. matching subgraph. The default is dict().
Returns: Returns:
callables: Callable pair(s). callables: Callable pair(s).
...@@ -351,6 +383,7 @@ def RegisterPass(function=None, input_specs=None): ...@@ -351,6 +383,7 @@ def RegisterPass(function=None, input_specs=None):
"Return value of Pass function must be (callable, callable)." "Return value of Pass function must be (callable, callable)."
) )
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs) helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
core.register_pass(pass_type, helper.SerializeMultiPassDesc)
return python_func return python_func
if inspect.isfunction(function): if inspect.isfunction(function):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import paddle import paddle
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid import ir from paddle.fluid import core, ir
import numpy as np import numpy as np
...@@ -45,23 +45,37 @@ def generate_fc_fuse(): ...@@ -45,23 +45,37 @@ def generate_fc_fuse():
return list(map(create_pass_pair, [True, False])) return list(map(create_pass_pair, [True, False]))
# add(X=add(x, y), Y=z)z => add_n(X=[x, y, z]) # add(X=add(X=x, Y=y), Y=z) => sum(X=[x, y, z])
@ir.RegisterPass @ir.RegisterPass
def generate_add_n(): def multi_add_to_sum_v1():
pattern = lambda x, y, z: paddle.add(paddle.add(x, y), z)
replace = lambda x, y, z: paddle.add_n([x, y, z])
return pattern, replace
@ir.RegisterPass
def multi_add_to_sum_v2():
def pattern(x, y, z): def pattern(x, y, z):
return paddle.add(paddle.add(x, y), z) ewadd1 = ir.PassDesc.OP.elementwise_add(X=x, Y=y)
ewadd2 = ir.PassDesc.OP.elementwise_add(X=ewadd1, Y=z)
return ewadd2
replace = lambda x, y, z: ir.PassDesc.OP.sum(X=[x, y, z])
return pattern, replace
def replace(x, y, z):
return paddle.add_n([x, y, z])
@ir.RegisterPass
def multi_add_to_sum_v3():
pattern = lambda x, y, z: paddle.add(paddle.add(x, y), z)
replace = lambda x, y, z: ir.PassDesc.OP.sum(X=[x, y, z])
return pattern, replace return pattern, replace
# mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2))) # mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2)))
@ir.RegisterPass(input_specs={ @ir.RegisterPass(input_specs={
'x': InputSpec([1, 1]), 'x': InputSpec([16, 32]),
'y1': InputSpec([1, 1]), 'y1': InputSpec([32, 12]),
'y2': InputSpec([1, 1]) 'y2': InputSpec([32, 48])
}) })
def generate_combine_mul_v1(): def generate_combine_mul_v1():
def pattern(x, y1, y2): def pattern(x, y1, y2):
...@@ -72,8 +86,8 @@ def generate_combine_mul_v1(): ...@@ -72,8 +86,8 @@ def generate_combine_mul_v1():
def replace(x, y1, y2): def replace(x, y1, y2):
concat_out = paddle.concat([y1, y2], axis=-1) concat_out = paddle.concat([y1, y2], axis=-1)
mul_out = paddle.matmul(x, concat_out) mul_out = paddle.matmul(x, concat_out)
out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[1]) out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[12])
out2 = paddle.slice(mul_out, axes=[1], starts=[1], ends=[2]) out2 = paddle.slice(mul_out, axes=[1], starts=[12], ends=[60])
return out1, out2 return out1, out2
return pattern, replace return pattern, replace
...@@ -97,11 +111,22 @@ def generate_combine_mul_v2(): ...@@ -97,11 +111,22 @@ def generate_combine_mul_v2():
# reshape(reshape(x)) => x # reshape(reshape(x)) => x
@ir.RegisterPass(input_specs={'x': InputSpec([-1, 16, 16, 16])}) @ir.RegisterPass(input_specs={'x': InputSpec([10, 16, 16])})
def generate_simplify_inference(): def generate_simplify_inference_v1():
def pattern(x):
transpose = paddle.transpose(x, [0, 2, 1])
return paddle.transpose(transpose, [0, 2, 1])
return pattern, lambda x: x
@ir.RegisterPass
def generate_simplify_inference_v2():
def pattern(x): def pattern(x):
transpose = paddle.transpose(x, [0, 3, 1, 2]) op1 = ir.PassDesc.OP.transpose2
return paddle.transpose(transpose, [0, 3, 1, 2]) op2 = ir.PassDesc.OP.transpose2
# op2.Attr("axis").EQ(op1.Attr("axis"))
return op2(X=op1(X=x))
return pattern, lambda x: x return pattern, lambda x: x
...@@ -153,46 +178,73 @@ class TestGeneratePass(unittest.TestCase): ...@@ -153,46 +178,73 @@ class TestGeneratePass(unittest.TestCase):
_check_fc_fuse_pass(multi_pass_desc.pass_descs[0], True) _check_fc_fuse_pass(multi_pass_desc.pass_descs[0], True)
_check_fc_fuse_pass(multi_pass_desc.pass_descs[1], False) _check_fc_fuse_pass(multi_pass_desc.pass_descs[1], False)
def test_generate_add_n(self): def check_multi_add_to_sum(self, pass_type):
helper = ir.RegisterPassHelper([generate_add_n()]) program = paddle.static.Program()
s = helper.SerializeMultiPassDesc() startup_program = paddle.static.Program()
multi_pass_desc = get_multi_pass_desc_from_str(s) with paddle.static.program_guard(program, startup_program):
self.assertEqual(len(multi_pass_desc.pass_descs), 1) x = paddle.static.data("x", [10, 10, 10], "float32")
pass_desc = multi_pass_desc.pass_descs[0] y = paddle.static.data("y", [10, 10, 10], "float32")
self.assertEqual(len(pass_desc.var_maps), 4) z = paddle.static.data("z", [10, 10, 10], "float32")
self.assertEqual(len(pass_desc.attr_maps), 0) add_1 = paddle.add(paddle.add(x, y), z)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) matmul_1 = paddle.matmul(add_1, z)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1) add_tmp = paddle.add(x, y)
pattern_op_dicts = self.convert_ops_to_op_dicts( add_2 = paddle.add(add_tmp, z)
pass_desc.pattern.blocks[0].ops) matmul_2 = paddle.matmul(add_2, add_tmp)
replace_op_dicts = self.convert_ops_to_op_dicts( out = paddle.add(matmul_1, matmul_2)
pass_desc.replace.blocks[0].ops) graph = core.Graph(program.desc)
self.assertEqual(len(pattern_op_dicts.get("elementwise_add", [])), 2) before_node_nums = len(graph.nodes())
self.assertEqual(len(replace_op_dicts.get("sum", [])), 1) core.get_pass(pass_type).apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums - 2)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {
"x": np.random.random([10, 10, 10]).astype("float32"),
"y": np.random.random([10, 10, 10]).astype("float32"),
"z": np.random.random([10, 10, 10]).astype("float32")
}
before_out = executor.run(program, feed=feed, fetch_list=[out.name])
after_out = executor.run(after_program,
feed=feed,
fetch_list=[out.name])
self.assertTrue(np.allclose(before_out, after_out))
def test_multi_add_to_sum(self):
paddle.enable_static()
self.check_multi_add_to_sum("multi_add_to_sum_v1")
self.check_multi_add_to_sum("multi_add_to_sum_v2")
self.check_multi_add_to_sum("multi_add_to_sum_v3")
def test_generate_combine_mul_v1(self): def test_generate_combine_mul_v1(self):
input_specs = { paddle.enable_static()
'x': InputSpec([1, 1]), program = paddle.static.Program()
'y1': InputSpec([1, 1]), startup_program = paddle.static.Program()
'y2': InputSpec([1, 1]) with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [16, 32])
y = paddle.static.data("y", [32, 12])
z = paddle.static.data("z", [32, 48])
out1 = paddle.matmul(x, y)
out2 = paddle.matmul(x, z)
graph = core.Graph(program.desc)
before_node_nums = len(graph.nodes())
core.get_pass("generate_combine_mul_v1").apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums + 4)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {
"x": np.random.random([16, 32]).astype("float32"),
"y": np.random.random([32, 12]).astype("float32"),
"z": np.random.random([32, 48]).astype("float32")
} }
helper = ir.RegisterPassHelper( before_out1, before_out2 = executor.run(
[generate_combine_mul_v1()], input_specs=input_specs) program, feed=feed, fetch_list=[out1.name, out2.name])
s = helper.SerializeMultiPassDesc() after_out1, after_out2 = executor.run(
multi_pass_desc = get_multi_pass_desc_from_str(s) after_program, feed=feed, fetch_list=[out1.name, out2.name])
self.assertEqual(len(multi_pass_desc.pass_descs), 1) self.assertTrue(np.allclose(before_out1, after_out1))
pass_desc = multi_pass_desc.pass_descs[0] self.assertTrue(np.allclose(before_out2, after_out2))
self.assertEqual(len(pass_desc.var_maps), 5)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2)
self.assertEqual(len(replace_op_dicts.get("concat", [])), 1)
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
self.assertEqual(len(replace_op_dicts.get("slice", [])), 2)
def test_generate_combine_mul_v2(self): def test_generate_combine_mul_v2(self):
helper = ir.RegisterPassHelper([generate_combine_mul_v2()]) helper = ir.RegisterPassHelper([generate_combine_mul_v2()])
...@@ -212,17 +264,31 @@ class TestGeneratePass(unittest.TestCase): ...@@ -212,17 +264,31 @@ class TestGeneratePass(unittest.TestCase):
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
self.assertEqual(len(replace_op_dicts.get("slice", [])), 2) self.assertEqual(len(replace_op_dicts.get("slice", [])), 2)
def check_generate_simplify_inference(self, pass_type):
paddle.enable_static()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [10, 16, 16], "float32")
x1 = paddle.transpose(paddle.transpose(x, [0, 2, 1]), [0, 2, 1])
tmp = paddle.transpose(x, [0, 2, 1])
x2 = paddle.transpose(tmp, [0, 2, 1])
out = paddle.add(x1, paddle.matmul(x2, tmp))
graph = core.Graph(program.desc)
before_node_nums = len(graph.nodes())
core.get_pass(pass_type).apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums - 6)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {"x": np.random.random([10, 16, 16]).astype("float32")}
before_out = executor.run(program, feed=feed, fetch_list=[out.name])
after_out = executor.run(after_program,
feed=feed,
fetch_list=[out.name])
self.assertTrue(np.allclose(before_out, after_out))
def test_generate_simplify_inference(self): def test_generate_simplify_inference(self):
input_specs = {'x': InputSpec([-1, 16, 16, 16])} self.check_generate_simplify_inference("generate_simplify_inference_v1")
helper = ir.RegisterPassHelper( self.check_generate_simplify_inference("generate_simplify_inference_v2")
[generate_simplify_inference()], input_specs=input_specs)
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 2)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 0)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("transpose2", [])), 2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册