未验证 提交 cc449652 编写于 作者: W wuhuanzhou 提交者: GitHub

[cherry-pick]Verify the correctness of graph rewrited by GeneratePass (#36453)

* [WIP]Verify the correctness of graph rewrited by GeneratePass, test=develop

* add delete subgraph and unittest, test=develop

* check simple pass, test=develop

* fix coverage, test=develop

* limit with input_spec via Paddle API, test=develop
上级 fc429fea
......@@ -21,6 +21,16 @@ namespace ir {
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
for (const proto::VarDesc& var : block.vars()) {
PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
if (VarDesc(var).GetShape() == x->Var()->GetShape()) {
return true;
}
return false;
});
}
// Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) {
const proto::OpDesc& op = block.ops(index);
......@@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
// Create PDNodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
for (int n = 0; n < var.arguments_size(); ++n) {
const std::string& argument = var.arguments(n);
// The input may be the output of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsInput();
var_pdnode->assert_is_var();
} else if (var_pdnode->IsOutput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_input(op.type());
var_pdnode->assert_more([&](Node* x) {
for (auto* out : x->outputs) {
if (out->IsOp() && out->Op()->Type() == op.type()) {
const auto& inputs = out->Op()->Inputs();
const auto& iter = inputs.find(var.parameter());
if (inputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
pattern->AddEdge(var_pdnode, op_pdnode);
}
}
......@@ -50,6 +77,24 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsOutput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
for (Node* input : x->inputs) {
if (input && input->IsOp() && input->Op() &&
input->Op()->Type() == op.type()) {
const auto& outputs = input->Op()->Outputs();
const auto& iter = outputs.find(var.parameter());
if (outputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
} else if (var_pdnode->IsInput()) {
var_pdnode->AsIntermediate();
}
......@@ -73,18 +118,64 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
}
}
GraphPatternDetector::handle_t GetGenerateRewrite(
// There are some duplicate patterns.
bool IsDuplicatePattern(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return true;
}
}
return false;
}
GraphPatternDetector::handle_t GetGenerateDelete(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) {
// There are some duplicate patterns.
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return;
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return;
}
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
const auto& iter = var_node_maps.find(var_map.replace_var());
if (var_node_maps.end() == iter) {
// first node is input
var_node_maps.insert({var_map.replace_var(), node});
} else {
// output node
for (Node* s_node : node->outputs) {
iter->second->outputs.push_back(s_node);
std::replace(s_node->inputs.begin(), s_node->inputs.end(), node,
iter->second);
s_node->Op()->RenameInput(node->Name(), iter->second->Name());
}
}
}
// Remove nodes that are intermediate.
std::unordered_set<const Node*> remove_nodes;
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) {
remove_nodes.emplace(subgraph.at(pdnode.get()));
}
for (auto iter : var_node_maps) {
remove_nodes.erase(iter.second);
}
GraphSafeRemoveNodes(graph, remove_nodes);
};
return handler;
}
GraphPatternDetector::handle_t GetGenerateRewrite(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return;
}
const proto::BlockDesc& block = pass_desc.replace().blocks(0);
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
......@@ -175,7 +266,11 @@ void GeneratePass::ApplyImpl(Graph* graph) const {
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
GraphPatternDetector detector;
InitGeneratePattern(pass_desc, detector.mutable_pattern());
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
if (pass_desc.replace().blocks(0).ops_size() == 0) {
detector(graph, GetGenerateDelete(detector.pattern(), pass_desc));
} else {
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
}
// The rewrited graph needs to be verified. Current Pass should be skipped
// if validation failed. Rewrite based on the original graph cannot
// implement rollback operation.
......
......@@ -127,11 +127,13 @@ def apply_build_strategy(main_program, startup_program, build_strategy,
class RegisterPassHelper(object):
_register_helpers = list()
def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type
self._pass_pairs = pass_pairs
if isinstance(input_specs, dict):
self._input_specs = input_specs
self._input_specs = input_specs
RegisterPassHelper._register_helpers.append(self)
def _get_args_from_func(self, func):
args = list()
......@@ -148,6 +150,35 @@ class RegisterPassHelper(object):
args.append(paddle.static.data(arg_name, [-1]))
return args
def _prune_program_desc(self, program_desc):
block_desc = program_desc.blocks[0]
# block_desc.ClearField("vars")
for var in [
var for var in block_desc.vars
if var.name not in self._input_specs
]:
block_desc.vars.remove(var)
for op_desc in block_desc.ops:
default_attrs = core.get_op_attrs_default_value(
paddle.compat.to_bytes(op_desc.type))
remove_attrs = list()
for attr in op_desc.attrs:
# attr must not in
if attr.name not in [
"op_namescope", "op_callstack", "op_device"
]:
attr_list_fields = attr.ListFields()
# attr format must be: name, type, value
if len(attr_list_fields) == 3:
attr_value = attr.ListFields()[-1][-1]
default_attr_value = default_attrs.get(attr.name)
# value must not default
if default_attr_value != attr_value:
continue
remove_attrs.append(attr)
for attr in remove_attrs:
op_desc.attrs.remove(attr)
def _func_to_program_desc(self, func, program_desc, is_replace=False):
vars = list()
program = paddle.static.Program()
......@@ -166,6 +197,7 @@ class RegisterPassHelper(object):
elif isinstance(out, paddle.fluid.framework.Variable):
vars.append(out.name)
program_desc.ParseFromString(program.desc.serialize_to_string())
self._prune_program_desc(program_desc)
if is_replace:
attrs = list()
for op in program.current_block().ops:
......@@ -296,7 +328,7 @@ class PassDesc(object):
OP = OpHelper()
def RegisterPass(function=None, input_specs=None):
def RegisterPass(function=None, input_specs=dict()):
"""
The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function
......@@ -305,11 +337,11 @@ def RegisterPass(function=None, input_specs=None):
Args:
function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype
input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is None.
matching subgraph. The default is dict().
Returns:
callables: Callable pair(s).
......@@ -351,6 +383,7 @@ def RegisterPass(function=None, input_specs=None):
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
core.register_pass(pass_type, helper.SerializeMultiPassDesc)
return python_func
if inspect.isfunction(function):
......
......@@ -15,7 +15,7 @@
import unittest
import paddle
from paddle.static import InputSpec
from paddle.fluid import ir
from paddle.fluid import core, ir
import numpy as np
......@@ -45,23 +45,37 @@ def generate_fc_fuse():
return list(map(create_pass_pair, [True, False]))
# add(X=add(x, y), Y=z)z => add_n(X=[x, y, z])
# add(X=add(X=x, Y=y), Y=z) => sum(X=[x, y, z])
@ir.RegisterPass
def generate_add_n():
def multi_add_to_sum_v1():
pattern = lambda x, y, z: paddle.add(paddle.add(x, y), z)
replace = lambda x, y, z: paddle.add_n([x, y, z])
return pattern, replace
@ir.RegisterPass
def multi_add_to_sum_v2():
def pattern(x, y, z):
return paddle.add(paddle.add(x, y), z)
ewadd1 = ir.PassDesc.OP.elementwise_add(X=x, Y=y)
ewadd2 = ir.PassDesc.OP.elementwise_add(X=ewadd1, Y=z)
return ewadd2
replace = lambda x, y, z: ir.PassDesc.OP.sum(X=[x, y, z])
return pattern, replace
def replace(x, y, z):
return paddle.add_n([x, y, z])
@ir.RegisterPass
def multi_add_to_sum_v3():
pattern = lambda x, y, z: paddle.add(paddle.add(x, y), z)
replace = lambda x, y, z: ir.PassDesc.OP.sum(X=[x, y, z])
return pattern, replace
# mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2)))
@ir.RegisterPass(input_specs={
'x': InputSpec([1, 1]),
'y1': InputSpec([1, 1]),
'y2': InputSpec([1, 1])
'x': InputSpec([16, 32]),
'y1': InputSpec([32, 12]),
'y2': InputSpec([32, 48])
})
def generate_combine_mul_v1():
def pattern(x, y1, y2):
......@@ -72,8 +86,8 @@ def generate_combine_mul_v1():
def replace(x, y1, y2):
concat_out = paddle.concat([y1, y2], axis=-1)
mul_out = paddle.matmul(x, concat_out)
out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[1])
out2 = paddle.slice(mul_out, axes=[1], starts=[1], ends=[2])
out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[12])
out2 = paddle.slice(mul_out, axes=[1], starts=[12], ends=[60])
return out1, out2
return pattern, replace
......@@ -97,11 +111,22 @@ def generate_combine_mul_v2():
# reshape(reshape(x)) => x
@ir.RegisterPass(input_specs={'x': InputSpec([-1, 16, 16, 16])})
def generate_simplify_inference():
@ir.RegisterPass(input_specs={'x': InputSpec([10, 16, 16])})
def generate_simplify_inference_v1():
def pattern(x):
transpose = paddle.transpose(x, [0, 3, 1, 2])
return paddle.transpose(transpose, [0, 3, 1, 2])
transpose = paddle.transpose(x, [0, 2, 1])
return paddle.transpose(transpose, [0, 2, 1])
return pattern, lambda x: x
@ir.RegisterPass
def generate_simplify_inference_v2():
def pattern(x):
op1 = ir.PassDesc.OP.transpose2
op2 = ir.PassDesc.OP.transpose2
# op2.Attr("axis").EQ(op1.Attr("axis"))
return op2(X=op1(X=x))
return pattern, lambda x: x
......@@ -153,46 +178,73 @@ class TestGeneratePass(unittest.TestCase):
_check_fc_fuse_pass(multi_pass_desc.pass_descs[0], True)
_check_fc_fuse_pass(multi_pass_desc.pass_descs[1], False)
def test_generate_add_n(self):
helper = ir.RegisterPassHelper([generate_add_n()])
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 4)
self.assertEqual(len(pass_desc.attr_maps), 0)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("elementwise_add", [])), 2)
self.assertEqual(len(replace_op_dicts.get("sum", [])), 1)
def check_multi_add_to_sum(self, pass_type):
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [10, 10, 10], "float32")
y = paddle.static.data("y", [10, 10, 10], "float32")
z = paddle.static.data("z", [10, 10, 10], "float32")
add_1 = paddle.add(paddle.add(x, y), z)
matmul_1 = paddle.matmul(add_1, z)
add_tmp = paddle.add(x, y)
add_2 = paddle.add(add_tmp, z)
matmul_2 = paddle.matmul(add_2, add_tmp)
out = paddle.add(matmul_1, matmul_2)
graph = core.Graph(program.desc)
before_node_nums = len(graph.nodes())
core.get_pass(pass_type).apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums - 2)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {
"x": np.random.random([10, 10, 10]).astype("float32"),
"y": np.random.random([10, 10, 10]).astype("float32"),
"z": np.random.random([10, 10, 10]).astype("float32")
}
before_out = executor.run(program, feed=feed, fetch_list=[out.name])
after_out = executor.run(after_program,
feed=feed,
fetch_list=[out.name])
self.assertTrue(np.allclose(before_out, after_out))
def test_multi_add_to_sum(self):
paddle.enable_static()
self.check_multi_add_to_sum("multi_add_to_sum_v1")
self.check_multi_add_to_sum("multi_add_to_sum_v2")
self.check_multi_add_to_sum("multi_add_to_sum_v3")
def test_generate_combine_mul_v1(self):
input_specs = {
'x': InputSpec([1, 1]),
'y1': InputSpec([1, 1]),
'y2': InputSpec([1, 1])
paddle.enable_static()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [16, 32])
y = paddle.static.data("y", [32, 12])
z = paddle.static.data("z", [32, 48])
out1 = paddle.matmul(x, y)
out2 = paddle.matmul(x, z)
graph = core.Graph(program.desc)
before_node_nums = len(graph.nodes())
core.get_pass("generate_combine_mul_v1").apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums + 4)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {
"x": np.random.random([16, 32]).astype("float32"),
"y": np.random.random([32, 12]).astype("float32"),
"z": np.random.random([32, 48]).astype("float32")
}
helper = ir.RegisterPassHelper(
[generate_combine_mul_v1()], input_specs=input_specs)
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 5)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2)
self.assertEqual(len(replace_op_dicts.get("concat", [])), 1)
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
self.assertEqual(len(replace_op_dicts.get("slice", [])), 2)
before_out1, before_out2 = executor.run(
program, feed=feed, fetch_list=[out1.name, out2.name])
after_out1, after_out2 = executor.run(
after_program, feed=feed, fetch_list=[out1.name, out2.name])
self.assertTrue(np.allclose(before_out1, after_out1))
self.assertTrue(np.allclose(before_out2, after_out2))
def test_generate_combine_mul_v2(self):
helper = ir.RegisterPassHelper([generate_combine_mul_v2()])
......@@ -212,17 +264,31 @@ class TestGeneratePass(unittest.TestCase):
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
self.assertEqual(len(replace_op_dicts.get("slice", [])), 2)
def check_generate_simplify_inference(self, pass_type):
paddle.enable_static()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [10, 16, 16], "float32")
x1 = paddle.transpose(paddle.transpose(x, [0, 2, 1]), [0, 2, 1])
tmp = paddle.transpose(x, [0, 2, 1])
x2 = paddle.transpose(tmp, [0, 2, 1])
out = paddle.add(x1, paddle.matmul(x2, tmp))
graph = core.Graph(program.desc)
before_node_nums = len(graph.nodes())
core.get_pass(pass_type).apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums - 6)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {"x": np.random.random([10, 16, 16]).astype("float32")}
before_out = executor.run(program, feed=feed, fetch_list=[out.name])
after_out = executor.run(after_program,
feed=feed,
fetch_list=[out.name])
self.assertTrue(np.allclose(before_out, after_out))
def test_generate_simplify_inference(self):
input_specs = {'x': InputSpec([-1, 16, 16, 16])}
helper = ir.RegisterPassHelper(
[generate_simplify_inference()], input_specs=input_specs)
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 2)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 0)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("transpose2", [])), 2)
self.check_generate_simplify_inference("generate_simplify_inference_v1")
self.check_generate_simplify_inference("generate_simplify_inference_v2")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册