From 24418479413961fd8486b87dd7a09e983cf4b0ad Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Wed, 13 Oct 2021 17:12:56 +0800 Subject: [PATCH] Verify the correctness of graph rewrited by GeneratePass (#36116) Check detail PR description at https://github.com/PaddlePaddle/Paddle/pull/36116 --- paddle/fluid/framework/ir/generate_pass.cc | 117 ++++++++++- python/paddle/fluid/ir.py | 43 +++- .../unittests/ir/test_ir_generate_pass.py | 196 ++++++++++++------ 3 files changed, 275 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 085298314e..b261cbeb08 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -21,6 +21,16 @@ namespace ir { void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { const proto::BlockDesc& block = pass_desc.pattern().blocks(0); + for (const proto::VarDesc& var : block.vars()) { + PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput(); + var_pdnode->assert_is_var(); + var_pdnode->assert_more([&](Node* x) { + if (VarDesc(var).GetShape() == x->Var()->GetShape()) { + return true; + } + return false; + }); + } // Traverse all operators to create subgraph. for (int index = 0; index < block.ops_size(); ++index) { const proto::OpDesc& op = block.ops(index); @@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { pattern->NewNode(std::to_string(index))->assert_is_op(op.type()); // Create PDNodes for inputs of current operator. for (const proto::OpDesc::Var& var : op.inputs()) { - for (const std::string& argument : var.arguments()) { + for (int n = 0; n < var.arguments_size(); ++n) { + const std::string& argument = var.arguments(n); // The input may be the output of other operator. PDNode* var_pdnode = pattern->RetrieveNode(argument); if (nullptr == var_pdnode) { var_pdnode = pattern->NewNode(argument)->AsInput(); + var_pdnode->assert_is_var(); } else if (var_pdnode->IsOutput()) { var_pdnode->AsIntermediate(); } - var_pdnode->assert_is_op_input(op.type()); + var_pdnode->assert_more([&](Node* x) { + for (auto* out : x->outputs) { + if (out->IsOp() && out->Op()->Type() == op.type()) { + const auto& inputs = out->Op()->Inputs(); + const auto& iter = inputs.find(var.parameter()); + if (inputs.end() != iter) { + if (iter->second.end() != std::find(iter->second.begin(), + iter->second.end(), + x->Name())) { + return true; + } + } + } + } + return false; + }); pattern->AddEdge(var_pdnode, op_pdnode); } } @@ -50,6 +77,24 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { PDNode* var_pdnode = pattern->RetrieveNode(argument); if (nullptr == var_pdnode) { var_pdnode = pattern->NewNode(argument)->AsOutput(); + var_pdnode->assert_is_var(); + var_pdnode->assert_more([&](Node* x) { + for (Node* input : x->inputs) { + if (input && input->IsOp() && input->Op() && + input->Op()->Type() == op.type()) { + const auto& outputs = input->Op()->Outputs(); + const auto& iter = outputs.find(var.parameter()); + if (outputs.end() != iter) { + if (iter->second.end() != std::find(iter->second.begin(), + iter->second.end(), + x->Name())) { + return true; + } + } + } + } + return false; + }); } else if (var_pdnode->IsInput()) { var_pdnode->AsIntermediate(); } @@ -73,18 +118,64 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { } } -GraphPatternDetector::handle_t GetGenerateRewrite( +// There are some duplicate patterns. +bool IsDuplicatePattern(const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + for (auto iter : subgraph) { + if (nullptr == graph->RetrieveNode(iter.second->id())) { + VLOG(3) << "Node [" << iter.second->Name() + << "] of subgraph has been removed. So skip this optimize."; + return true; + } + } + return false; +} + +GraphPatternDetector::handle_t GetGenerateDelete( const PDPattern& pattern, const proto::PassDesc& pass_desc) { GraphPatternDetector::handle_t handler = [&]( - const GraphPatternDetector::subgraph_t subgraph, Graph* graph) { - // There are some duplicate patterns. - for (auto iter : subgraph) { - if (nullptr == graph->RetrieveNode(iter.second->id())) { - VLOG(3) << "Node [" << iter.second->Name() - << "] of subgraph has been removed. So skip this optimize."; - return; + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + if (IsDuplicatePattern(subgraph, graph)) { + return; + } + // `var_node_maps` record the mapping of variable to the pattern subgraph. + std::map var_node_maps; + for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { + Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var())); + const auto& iter = var_node_maps.find(var_map.replace_var()); + if (var_node_maps.end() == iter) { + // first node is input + var_node_maps.insert({var_map.replace_var(), node}); + } else { + // output node + for (Node* s_node : node->outputs) { + iter->second->outputs.push_back(s_node); + std::replace(s_node->inputs.begin(), s_node->inputs.end(), node, + iter->second); + s_node->Op()->RenameInput(node->Name(), iter->second->Name()); + } } } + // Remove nodes that are intermediate. + std::unordered_set remove_nodes; + for (const std::unique_ptr& pdnode : pattern.nodes()) { + remove_nodes.emplace(subgraph.at(pdnode.get())); + } + for (auto iter : var_node_maps) { + remove_nodes.erase(iter.second); + } + GraphSafeRemoveNodes(graph, remove_nodes); + }; + return handler; +} + +GraphPatternDetector::handle_t GetGenerateRewrite( + const PDPattern& pattern, const proto::PassDesc& pass_desc) { + GraphPatternDetector::handle_t handler = [&]( + const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { + if (IsDuplicatePattern(subgraph, graph)) { + return; + } const proto::BlockDesc& block = pass_desc.replace().blocks(0); // `var_node_maps` record the mapping of variable to the pattern subgraph. std::map var_node_maps; @@ -175,7 +266,11 @@ void GeneratePass::ApplyImpl(Graph* graph) const { for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { GraphPatternDetector detector; InitGeneratePattern(pass_desc, detector.mutable_pattern()); - detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); + if (pass_desc.replace().blocks(0).ops_size() == 0) { + detector(graph, GetGenerateDelete(detector.pattern(), pass_desc)); + } else { + detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); + } // The rewrited graph needs to be verified. Current Pass should be skipped // if validation failed. Rewrite based on the original graph cannot // implement rollback operation. diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index 7e2d3df1ce..3c7c8879fd 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -127,11 +127,13 @@ def apply_build_strategy(main_program, startup_program, build_strategy, class RegisterPassHelper(object): + _register_helpers = list() + def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()): self._pass_type = pass_type self._pass_pairs = pass_pairs - if isinstance(input_specs, dict): - self._input_specs = input_specs + self._input_specs = input_specs + RegisterPassHelper._register_helpers.append(self) def _get_args_from_func(self, func): args = list() @@ -148,6 +150,35 @@ class RegisterPassHelper(object): args.append(paddle.static.data(arg_name, [-1])) return args + def _prune_program_desc(self, program_desc): + block_desc = program_desc.blocks[0] + # block_desc.ClearField("vars") + for var in [ + var for var in block_desc.vars + if var.name not in self._input_specs + ]: + block_desc.vars.remove(var) + for op_desc in block_desc.ops: + default_attrs = core.get_op_attrs_default_value( + paddle.compat.to_bytes(op_desc.type)) + remove_attrs = list() + for attr in op_desc.attrs: + # attr must not in + if attr.name not in [ + "op_namescope", "op_callstack", "op_device" + ]: + attr_list_fields = attr.ListFields() + # attr format must be: name, type, value + if len(attr_list_fields) == 3: + attr_value = attr.ListFields()[-1][-1] + default_attr_value = default_attrs.get(attr.name) + # value must not default + if default_attr_value != attr_value: + continue + remove_attrs.append(attr) + for attr in remove_attrs: + op_desc.attrs.remove(attr) + def _func_to_program_desc(self, func, program_desc, is_replace=False): vars = list() program = paddle.static.Program() @@ -166,6 +197,7 @@ class RegisterPassHelper(object): elif isinstance(out, paddle.fluid.framework.Variable): vars.append(out.name) program_desc.ParseFromString(program.desc.serialize_to_string()) + self._prune_program_desc(program_desc) if is_replace: attrs = list() for op in program.current_block().ops: @@ -296,7 +328,7 @@ class PassDesc(object): OP = OpHelper() -def RegisterPass(function=None, input_specs=None): +def RegisterPass(function=None, input_specs=dict()): """ The function decorator of Register Pass. Decorator @RegisterPass handles the function and register it into a core.Pass instance. Use name of function @@ -305,11 +337,11 @@ def RegisterPass(function=None, input_specs=None): Args: function (callable): The function with return of callable pair(s) that represents the pattern subgraph and the replace subgraph. - input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype + input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype information of Tensor. Some operators limit the shape and dtype of datas when create subgraph with Paddle APIs. So user need specify InputSpec of data to ensure create a correctly subgraph. Of course, this argument is not limited to - matching subgraph. The default is None. + matching subgraph. The default is dict(). Returns: callables: Callable pair(s). @@ -351,6 +383,7 @@ def RegisterPass(function=None, input_specs=None): "Return value of Pass function must be (callable, callable)." ) helper = RegisterPassHelper(pass_pairs, pass_type, input_specs) + core.register_pass(pass_type, helper.SerializeMultiPassDesc) return python_func if inspect.isfunction(function): diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py index 851ae21c38..61bd554ad2 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py @@ -15,7 +15,7 @@ import unittest import paddle from paddle.static import InputSpec -from paddle.fluid import ir +from paddle.fluid import core, ir import numpy as np @@ -45,23 +45,37 @@ def generate_fc_fuse(): return list(map(create_pass_pair, [True, False])) -# add(X=add(x, y), Y=z)z => add_n(X=[x, y, z]) +# add(X=add(X=x, Y=y), Y=z) => sum(X=[x, y, z]) @ir.RegisterPass -def generate_add_n(): +def multi_add_to_sum_v1(): + pattern = lambda x, y, z: paddle.add(paddle.add(x, y), z) + replace = lambda x, y, z: paddle.add_n([x, y, z]) + return pattern, replace + + +@ir.RegisterPass +def multi_add_to_sum_v2(): def pattern(x, y, z): - return paddle.add(paddle.add(x, y), z) + ewadd1 = ir.PassDesc.OP.elementwise_add(X=x, Y=y) + ewadd2 = ir.PassDesc.OP.elementwise_add(X=ewadd1, Y=z) + return ewadd2 + + replace = lambda x, y, z: ir.PassDesc.OP.sum(X=[x, y, z]) + return pattern, replace - def replace(x, y, z): - return paddle.add_n([x, y, z]) +@ir.RegisterPass +def multi_add_to_sum_v3(): + pattern = lambda x, y, z: paddle.add(paddle.add(x, y), z) + replace = lambda x, y, z: ir.PassDesc.OP.sum(X=[x, y, z]) return pattern, replace # mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2))) @ir.RegisterPass(input_specs={ - 'x': InputSpec([1, 1]), - 'y1': InputSpec([1, 1]), - 'y2': InputSpec([1, 1]) + 'x': InputSpec([16, 32]), + 'y1': InputSpec([32, 12]), + 'y2': InputSpec([32, 48]) }) def generate_combine_mul_v1(): def pattern(x, y1, y2): @@ -72,8 +86,8 @@ def generate_combine_mul_v1(): def replace(x, y1, y2): concat_out = paddle.concat([y1, y2], axis=-1) mul_out = paddle.matmul(x, concat_out) - out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[1]) - out2 = paddle.slice(mul_out, axes=[1], starts=[1], ends=[2]) + out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[12]) + out2 = paddle.slice(mul_out, axes=[1], starts=[12], ends=[60]) return out1, out2 return pattern, replace @@ -97,11 +111,22 @@ def generate_combine_mul_v2(): # reshape(reshape(x)) => x -@ir.RegisterPass(input_specs={'x': InputSpec([-1, 16, 16, 16])}) -def generate_simplify_inference(): +@ir.RegisterPass(input_specs={'x': InputSpec([10, 16, 16])}) +def generate_simplify_inference_v1(): def pattern(x): - transpose = paddle.transpose(x, [0, 3, 1, 2]) - return paddle.transpose(transpose, [0, 3, 1, 2]) + transpose = paddle.transpose(x, [0, 2, 1]) + return paddle.transpose(transpose, [0, 2, 1]) + + return pattern, lambda x: x + + +@ir.RegisterPass +def generate_simplify_inference_v2(): + def pattern(x): + op1 = ir.PassDesc.OP.transpose2 + op2 = ir.PassDesc.OP.transpose2 + # op2.Attr("axis").EQ(op1.Attr("axis")) + return op2(X=op1(X=x)) return pattern, lambda x: x @@ -153,46 +178,73 @@ class TestGeneratePass(unittest.TestCase): _check_fc_fuse_pass(multi_pass_desc.pass_descs[0], True) _check_fc_fuse_pass(multi_pass_desc.pass_descs[1], False) - def test_generate_add_n(self): - helper = ir.RegisterPassHelper([generate_add_n()]) - s = helper.SerializeMultiPassDesc() - multi_pass_desc = get_multi_pass_desc_from_str(s) - self.assertEqual(len(multi_pass_desc.pass_descs), 1) - pass_desc = multi_pass_desc.pass_descs[0] - self.assertEqual(len(pass_desc.var_maps), 4) - self.assertEqual(len(pass_desc.attr_maps), 0) - self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) - self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1) - pattern_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.pattern.blocks[0].ops) - replace_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.replace.blocks[0].ops) - self.assertEqual(len(pattern_op_dicts.get("elementwise_add", [])), 2) - self.assertEqual(len(replace_op_dicts.get("sum", [])), 1) + def check_multi_add_to_sum(self, pass_type): + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data("x", [10, 10, 10], "float32") + y = paddle.static.data("y", [10, 10, 10], "float32") + z = paddle.static.data("z", [10, 10, 10], "float32") + add_1 = paddle.add(paddle.add(x, y), z) + matmul_1 = paddle.matmul(add_1, z) + add_tmp = paddle.add(x, y) + add_2 = paddle.add(add_tmp, z) + matmul_2 = paddle.matmul(add_2, add_tmp) + out = paddle.add(matmul_1, matmul_2) + graph = core.Graph(program.desc) + before_node_nums = len(graph.nodes()) + core.get_pass(pass_type).apply(graph) + after_node_nums = len(graph.nodes()) + self.assertEqual(after_node_nums, before_node_nums - 2) + after_program = paddle.fluid.framework.IrGraph(graph).to_program() + executor = paddle.static.Executor(paddle.CPUPlace()) + executor.run(startup_program) + feed = { + "x": np.random.random([10, 10, 10]).astype("float32"), + "y": np.random.random([10, 10, 10]).astype("float32"), + "z": np.random.random([10, 10, 10]).astype("float32") + } + before_out = executor.run(program, feed=feed, fetch_list=[out.name]) + after_out = executor.run(after_program, + feed=feed, + fetch_list=[out.name]) + self.assertTrue(np.allclose(before_out, after_out)) + + def test_multi_add_to_sum(self): + paddle.enable_static() + self.check_multi_add_to_sum("multi_add_to_sum_v1") + self.check_multi_add_to_sum("multi_add_to_sum_v2") + self.check_multi_add_to_sum("multi_add_to_sum_v3") def test_generate_combine_mul_v1(self): - input_specs = { - 'x': InputSpec([1, 1]), - 'y1': InputSpec([1, 1]), - 'y2': InputSpec([1, 1]) + paddle.enable_static() + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data("x", [16, 32]) + y = paddle.static.data("y", [32, 12]) + z = paddle.static.data("z", [32, 48]) + out1 = paddle.matmul(x, y) + out2 = paddle.matmul(x, z) + graph = core.Graph(program.desc) + before_node_nums = len(graph.nodes()) + core.get_pass("generate_combine_mul_v1").apply(graph) + after_node_nums = len(graph.nodes()) + self.assertEqual(after_node_nums, before_node_nums + 4) + after_program = paddle.fluid.framework.IrGraph(graph).to_program() + executor = paddle.static.Executor(paddle.CPUPlace()) + executor.run(startup_program) + feed = { + "x": np.random.random([16, 32]).astype("float32"), + "y": np.random.random([32, 12]).astype("float32"), + "z": np.random.random([32, 48]).astype("float32") } - helper = ir.RegisterPassHelper( - [generate_combine_mul_v1()], input_specs=input_specs) - s = helper.SerializeMultiPassDesc() - multi_pass_desc = get_multi_pass_desc_from_str(s) - self.assertEqual(len(multi_pass_desc.pass_descs), 1) - pass_desc = multi_pass_desc.pass_descs[0] - self.assertEqual(len(pass_desc.var_maps), 5) - self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) - self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4) - pattern_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.pattern.blocks[0].ops) - replace_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.replace.blocks[0].ops) - self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2) - self.assertEqual(len(replace_op_dicts.get("concat", [])), 1) - self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) - self.assertEqual(len(replace_op_dicts.get("slice", [])), 2) + before_out1, before_out2 = executor.run( + program, feed=feed, fetch_list=[out1.name, out2.name]) + after_out1, after_out2 = executor.run( + after_program, feed=feed, fetch_list=[out1.name, out2.name]) + self.assertTrue(np.allclose(before_out1, after_out1)) + self.assertTrue(np.allclose(before_out2, after_out2)) def test_generate_combine_mul_v2(self): helper = ir.RegisterPassHelper([generate_combine_mul_v2()]) @@ -212,17 +264,31 @@ class TestGeneratePass(unittest.TestCase): self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) self.assertEqual(len(replace_op_dicts.get("slice", [])), 2) + def check_generate_simplify_inference(self, pass_type): + paddle.enable_static() + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data("x", [10, 16, 16], "float32") + x1 = paddle.transpose(paddle.transpose(x, [0, 2, 1]), [0, 2, 1]) + tmp = paddle.transpose(x, [0, 2, 1]) + x2 = paddle.transpose(tmp, [0, 2, 1]) + out = paddle.add(x1, paddle.matmul(x2, tmp)) + graph = core.Graph(program.desc) + before_node_nums = len(graph.nodes()) + core.get_pass(pass_type).apply(graph) + after_node_nums = len(graph.nodes()) + self.assertEqual(after_node_nums, before_node_nums - 6) + after_program = paddle.fluid.framework.IrGraph(graph).to_program() + executor = paddle.static.Executor(paddle.CPUPlace()) + executor.run(startup_program) + feed = {"x": np.random.random([10, 16, 16]).astype("float32")} + before_out = executor.run(program, feed=feed, fetch_list=[out.name]) + after_out = executor.run(after_program, + feed=feed, + fetch_list=[out.name]) + self.assertTrue(np.allclose(before_out, after_out)) + def test_generate_simplify_inference(self): - input_specs = {'x': InputSpec([-1, 16, 16, 16])} - helper = ir.RegisterPassHelper( - [generate_simplify_inference()], input_specs=input_specs) - s = helper.SerializeMultiPassDesc() - multi_pass_desc = get_multi_pass_desc_from_str(s) - self.assertEqual(len(multi_pass_desc.pass_descs), 1) - pass_desc = multi_pass_desc.pass_descs[0] - self.assertEqual(len(pass_desc.var_maps), 2) - self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) - self.assertEqual(len(pass_desc.replace.blocks[0].ops), 0) - pattern_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.pattern.blocks[0].ops) - self.assertEqual(len(pattern_op_dicts.get("transpose2", [])), 2) + self.check_generate_simplify_inference("generate_simplify_inference_v1") + self.check_generate_simplify_inference("generate_simplify_inference_v2") -- GitLab