From bab39eb24738450dea7be638aa7bc45353cff837 Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Thu, 16 Sep 2021 11:11:01 +0800 Subject: [PATCH] Python support register pass via PassDesc (#35602) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR主要功能:针对fusion等子图替换场景,支持Python侧开发并注册Pass。 背景 Pass是指输入一个深度学习计算图Graph,依照一定条件进行修改,输出修改后的Graph的过程; 当前PaddlePadle框架编写Pass代码存在以下问题: 用户需要手写Graph的条件匹配、在Graph上的修改代码; 对Graph操作需要深入底层框架代码,了解Graph的结构,并且知道相关Pass写法; 我们提出了针对fusion等子图替换类Pass的优化方案以支持用户在Python侧开发注册Pass,提升二次开发体验: 用户只需要输入匹配和替换的子图描述,由深度学习框架编写的代码来生成匹配和替换的逻辑,不需要用户对Graph进行匹配和替换操作; API级别的替换,用户可以通过Paddle的Python API构造子图,从而不需要知道Graph的结构,也能写Paddle的Graph Pass代码 --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/pass_desc.proto | 39 +++ python/paddle/fluid/ir.py | 246 +++++++++++++++++- .../unittests/ir/test_ir_generate_pass.py | 225 ++++++++++++++++ 4 files changed, 508 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/framework/pass_desc.proto create mode 100644 python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index f8a4d09924..46d580e325 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -223,6 +223,7 @@ if(WITH_PYTHON) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto) py_proto_compile(distributed_strategy_py_proto SRCS distributed_strategy.proto) + py_proto_compile(pass_desc_py_proto SRCS pass_desc.proto) #Generate an empty \ #__init__.py to make framework_py_proto as a valid python module. add_custom_target(fleet_proto_init ALL diff --git a/paddle/fluid/framework/pass_desc.proto b/paddle/fluid/framework/pass_desc.proto new file mode 100644 index 0000000000..c95e40a1d2 --- /dev/null +++ b/paddle/fluid/framework/pass_desc.proto @@ -0,0 +1,39 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +syntax = "proto2"; + +import "framework.proto"; +package paddle.framework.proto; + +// Describes one subsitute subgraph. +message PassDesc { + message VarMap { + required string pattern_var = 1; + required string replace_var = 2; + } + message AttrMap { + required int32 pattern_op_idx = 1; + required int32 replace_op_idx = 2; + required string pattern_name = 3; + required string replace_name = 4; + } + required ProgramDesc pattern = 1; + required ProgramDesc replace = 2; + repeated VarMap var_maps = 3; + repeated AttrMap attr_maps = 4; +} + +// A series of PassDesc. +message MultiPassDesc { + optional string pass_type = 1; + repeated PassDesc pass_descs = 2; +} diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index 69775dbdaf..17b7ea1122 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import copy -from . import core -from .framework import _apply_pass +import inspect +from os import path +import paddle +from . import core, unique_name +from .framework import _apply_pass, OpProtoHolder + +try: + from .proto import pass_desc_pb2 +except ModuleNotFoundError: + import sys + sys.path.append(path.join(path.dirname(__file__), 'proto')) + from .proto import pass_desc_pb2 def get_data_vars(program): @@ -115,3 +124,234 @@ def apply_build_strategy(main_program, startup_program, build_strategy, build_strategy.enable_inplace = False build_strategy._clear_finalized() return build_strategy + + +class RegisterPassHelper(object): + def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()): + self._pass_type = pass_type + self._pass_pairs = pass_pairs + if isinstance(input_specs, dict): + self._input_specs = input_specs + + def _get_args_from_func(self, func): + args = list() + arg_specs = inspect.getfullargspec(func) + for arg_name in arg_specs.args: + input_spec = self._input_specs.get(arg_name) + if isinstance(input_spec, paddle.static.InputSpec): + args.append( + paddle.static.data(arg_name, input_spec.shape, + input_spec.dtype)) + elif isinstance(input_spec, paddle.ParamAttr): + args.append(paddle.ParamAttr(arg_name)) + else: + args.append(paddle.static.data(arg_name, [-1])) + return args + + def _func_to_program_desc(self, func, program_desc, is_replace=False): + vars = list() + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + args = self._get_args_from_func(func) + for arg in args: + vars.append(arg.name) + outs = func(*args) + if not isinstance(outs, (list, tuple)): + outs = [outs] + for out in outs: + if isinstance(out, PassDesc.OpHelper): + for out in out.Outputs().values(): + vars.extend(out) + elif isinstance(out, paddle.fluid.framework.Variable): + vars.append(out.name) + program_desc.ParseFromString(program.desc.serialize_to_string()) + if is_replace: + attrs = list() + for op in program.current_block().ops: + if not isinstance(op, PassDesc.OpHelper): + continue + attrs.extend(op._attrs.values()) + return vars, attrs + return vars + + def SerializeMultiPassDesc(self): + switch_static_mode = paddle.in_dynamic_mode() + if switch_static_mode: + paddle.enable_static() + multi_pass_desc = pass_desc_pb2.MultiPassDesc() + multi_pass_desc.pass_type = self._pass_type + for (pattern, replace) in self._pass_pairs: + pass_desc = multi_pass_desc.pass_descs.add() + pattern_vars = self._func_to_program_desc(pattern, + pass_desc.pattern) + replace_vars, attrs = self._func_to_program_desc( + replace, pass_desc.replace, is_replace=True) + for (pattern_var, replace_var) in zip(pattern_vars, replace_vars): + var_map = pass_desc.var_maps.add() + var_map.pattern_var = pattern_var + var_map.replace_var = replace_var + pattern_op_idxs = dict() + for (idx, op) in enumerate(pass_desc.pattern.blocks[0].ops): + op_idxs = pattern_op_idxs.get(op.type) + if op_idxs: + op_idxs.append(idx) + else: + pattern_op_idxs[op.type] = [idx] + for attr in attrs: + attr_map = pass_desc.attr_maps.add() + attr_map.pattern_op_idx = pattern_op_idxs[ + attr._pattern_op_type][attr._pattern_op_idx] + attr_map.replace_op_idx = attr._replace_op_idx + attr_map.pattern_name = attr._pattern_name + attr_map.replace_name = attr._replace_name + if switch_static_mode: + paddle.disable_static() + return multi_pass_desc.SerializeToString() + + +class PassDesc(object): + class AttrHelper(object): + def __init__(self, name, replace_op_idx): + self._pattern_op_type = None + self._pattern_op_idx = -1 + self._replace_op_idx = replace_op_idx + self._pattern_name = name + self._replace_name = name + + def ReusePattern(self, op, index=0, name=None): + if name: + self._pattern_name = name + self._pattern_op_type = op + self._pattern_op_idx = index + + class OpHelper(object): + def __init__(self, type=None): + self._type = type + + def __getattr__(self, name): + if self._type is not None: + raise AttributeError( + "type object 'OpHelper' has no attribute '{}'".format(name)) + op = PassDesc.OpHelper(name) + op.Init() + return op + + def __call__(self, *args, **kwargs): + for (in_name, in_args) in kwargs.items(): + in_arg_names = list() + if isinstance(in_args, (list, tuple)): + if len(in_args) == 0: + raise ValueError( + "Input '{}' of operator '{}' cannot be empty.". + format(in_name, self._type)) + else: + in_args = [in_args] + for in_arg in in_args: + if isinstance(in_arg, PassDesc.OpHelper): + in_arg_names.extend(in_arg.Output()) + else: + in_arg_names.append(in_arg.name) + self._op_desc.set_input(in_name, in_arg_names) + return self + + def Init(self): + block = paddle.static.default_main_program().current_block() + self._attrs = dict() + self._op_idx = len(block.ops) + self._op_desc = block.desc.append_op() + self._op_desc.set_type(self._type) + self._op_proto = OpProtoHolder.instance().get_op_proto(self._type) + block.ops.append(self) + + def Attr(self, name): + attr = self._attrs.get(name) + if attr: + return attr + attr = PassDesc.AttrHelper(name, self._op_idx) + self._attrs[name] = attr + return attr + + def SetAttr(self, name, value): + self._op_desc._set_attr(name, value) + + def Output(self, name=None): + if name: + return self.Outputs()[name] + return list(self.Outputs().values())[0] + + def Outputs(self): + outputs = self._op_desc.outputs() + if len(outputs) > 0: + return outputs + block = paddle.static.default_main_program().current_block() + for output_proto in self._op_proto.outputs: + name = unique_name.generate(self._type) + block.create_var(name=name) + self._op_desc.set_output(output_proto.name, [name]) + return self._op_desc.outputs() + + OP = OpHelper() + + +def RegisterPass(function=None, input_specs=None): + """ + The function decorator of Register Pass. Decorator @RegisterPass handles + the function and register it into a core.Pass instance. Use name of function + as Pass type. + + Args: + function (callable): The function with return of callable pair(s) that + represents the pattern subgraph and the replace subgraph. + input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype + information of Tensor. Some operators limit the shape and dtype of datas when + create subgraph with Paddle APIs. So user need specify InputSpec of data to + ensure create a correctly subgraph. Of course, this argument is not limited to + matching subgraph. The default is None. + + Returns: + callables: Callable pair(s). + + Examples: + .. code-block:: python + + import paddle + from paddle.fluid.ir import RegisterPass + + @RegisterPass + def multi_add_to_addn(): + def pattern(x, y, z): + return paddle.add(paddle.add(x, y), z) + def replace(x, y, z): + return paddle.add_n([x, y, z]) + return pattern, replace + """ + + def _is_pass_pair(check_pair): + if isinstance(check_pair, (list, tuple)): + if len(check_pair) == 2: + if all(map(inspect.isfunction, check_pair)): + return True + return False + + def decorated(python_func): + pass_type = python_func.__name__ + signature = inspect.signature(python_func) + if len(signature.parameters) > 0: + raise NotImplementedError( + "Pass function with parameter is not supported now.") + elif len(signature.parameters) == 0: + pass_pairs = python_func() + if _is_pass_pair(pass_pairs): + pass_pairs = [pass_pairs] + elif not all(map(_is_pass_pair, pass_pairs)): + raise ValueError( + "Return value of Pass function must be (callable, callable)." + ) + helper = RegisterPassHelper(pass_pairs, pass_type, input_specs) + return python_func + + if inspect.isfunction(function): + return decorated(function) + + return decorated diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py new file mode 100644 index 0000000000..c8b9d5e573 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py @@ -0,0 +1,225 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +from paddle.static import InputSpec +from paddle.fluid import ir +import numpy as np + + +# 0: ewadd(X=mul(X=x, Y=w), Y=b) => fc(Input=x, W=w, Bias=b) +# 1: relu(X=ewadd(X=mul(X=x, Y=w), Y=b)) => fc(Input=x, W=w, Bias=b) +@ir.RegisterPass +def generate_fc_fuse(): + def create_pass_pair(with_relu): + def pattern(x, w, b): + mul = ir.PassDesc.OP.mul(X=x, Y=w) + ewadd = ir.PassDesc.OP.elementwise_add(X=mul, Y=b) + if with_relu: + return ir.PassDesc.OP.relu(X=ewadd) + else: + return ewadd + + def replace(x, w, b): + fc = ir.PassDesc.OP.fc + fc.Attr("in_num_col_dims").ReusePattern( + "mul", name="x_num_col_dims") + if with_relu: + fc.SetAttr("activation_type", "relu") + return fc(Input=x, W=w, Bias=b) + + return pattern, replace + + return list(map(create_pass_pair, [True, False])) + + +# add(X=add(x, y), Y=z)z => add_n(X=[x, y, z]) +@ir.RegisterPass +def generate_add_n(): + def pattern(x, y, z): + return paddle.add(paddle.add(x, y), z) + + def replace(x, y, z): + return paddle.add_n([x, y, z]) + + return pattern, replace + + +# mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2))) +@ir.RegisterPass(input_specs={ + 'x': InputSpec([1, 1]), + 'y1': InputSpec([1, 1]), + 'y2': InputSpec([1, 1]) +}) +def generate_combine_mul_v1(): + def pattern(x, y1, y2): + mul1 = paddle.matmul(x, y1) + mul2 = paddle.matmul(x, y2) + return mul1, mul2 + + def replace(x, y1, y2): + concat_out = paddle.concat([y1, y2], axis=-1) + mul_out = paddle.matmul(x, concat_out) + out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[1]) + out2 = paddle.slice(mul_out, axes=[1], starts=[1], ends=[2]) + return out1, out2 + + return pattern, replace + + +@ir.RegisterPass +def generate_combine_mul_v2(): + def pattern(x, y1, y2): + mul1 = ir.PassDesc.OP.matmul_v2(x, y1) + mul2 = ir.PassDesc.OP.matmul_v2(x, y2) + return mul1, mul2 + + def replace(x, y1, y2): + concat = ir.PassDesc.OP.concat(X=[y1, y2]) + matmul = ir.PassDesc.OP.matmul_v2(X=x, Y=concat) + out1 = ir.PassDesc.OP.slice(Input=matmul) + out2 = ir.PassDesc.OP.slice(Input=matmul) + return out1, out2 + + return pattern, replace + + +# reshape(reshape(x)) => x +@ir.RegisterPass(input_specs={'x': InputSpec([-1, 16, 16, 16])}) +def generate_simplify_inference(): + def pattern(x): + transpose = paddle.transpose(x, [0, 3, 1, 2]) + return paddle.transpose(transpose, [0, 3, 1, 2]) + + return pattern, lambda x: x + + +def get_multi_pass_desc_from_str(s): + multi_pass_desc = ir.pass_desc_pb2.MultiPassDesc() + multi_pass_desc.ParseFromString(s) + return multi_pass_desc + + +class TestGeneratePass(unittest.TestCase): + def convert_ops_to_op_dicts(self, ops): + op_dicts = dict() + for op in ops: + op_list = op_dicts.get(op.type) + if isinstance(op_list, list): + op_list.append(op) + else: + op_dicts[op.type] = [op] + return op_dicts + + def test_generate_fc_fuse(self): + def _check_fc_fuse_pass(pass_desc, with_relu): + pattern_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.pattern.blocks[0].ops) + replace_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.replace.blocks[0].ops) + self.assertEqual(len(pattern_op_dicts.get("mul", [])), 1) + self.assertEqual( + len(pattern_op_dicts.get("elementwise_add", [])), 1) + if with_relu: + self.assertEqual(len(pattern_op_dicts.get("relu", [])), 1) + pattern_op_num = 3 # relu, ewadd, mul + else: + pattern_op_num = 2 # ewadd, mul + self.assertEqual(len(pass_desc.var_maps), 4) + self.assertEqual( + len(pass_desc.pattern.blocks[0].ops), pattern_op_num) + self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1) + self.assertEqual(len(pass_desc.attr_maps), 1) + + helper = ir.RegisterPassHelper(generate_fc_fuse()) + s = helper.SerializeMultiPassDesc() + multi_pass_desc = get_multi_pass_desc_from_str(s) + self.assertEqual(len(multi_pass_desc.pass_descs), 2) + _check_fc_fuse_pass(multi_pass_desc.pass_descs[0], True) + _check_fc_fuse_pass(multi_pass_desc.pass_descs[1], False) + + def test_generate_add_n(self): + helper = ir.RegisterPassHelper([generate_add_n()]) + s = helper.SerializeMultiPassDesc() + multi_pass_desc = get_multi_pass_desc_from_str(s) + self.assertEqual(len(multi_pass_desc.pass_descs), 1) + pass_desc = multi_pass_desc.pass_descs[0] + self.assertEqual(len(pass_desc.var_maps), 4) + self.assertEqual(len(pass_desc.attr_maps), 0) + self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) + self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1) + pattern_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.pattern.blocks[0].ops) + replace_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.replace.blocks[0].ops) + self.assertEqual(len(pattern_op_dicts.get("elementwise_add", [])), 2) + self.assertEqual(len(replace_op_dicts.get("sum", [])), 1) + + def test_generate_combine_mul_v1(self): + input_specs = { + 'x': InputSpec([1, 1]), + 'y1': InputSpec([1, 1]), + 'y2': InputSpec([1, 1]) + } + helper = ir.RegisterPassHelper( + [generate_combine_mul_v1()], input_specs=input_specs) + s = helper.SerializeMultiPassDesc() + multi_pass_desc = get_multi_pass_desc_from_str(s) + self.assertEqual(len(multi_pass_desc.pass_descs), 1) + pass_desc = multi_pass_desc.pass_descs[0] + self.assertEqual(len(pass_desc.var_maps), 5) + self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) + self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4) + pattern_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.pattern.blocks[0].ops) + replace_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.replace.blocks[0].ops) + self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2) + self.assertEqual(len(replace_op_dicts.get("concat", [])), 1) + self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) + self.assertEqual(len(replace_op_dicts.get("slice", [])), 2) + + def test_generate_combine_mul_v2(self): + helper = ir.RegisterPassHelper([generate_combine_mul_v2()]) + s = helper.SerializeMultiPassDesc() + multi_pass_desc = get_multi_pass_desc_from_str(s) + self.assertEqual(len(multi_pass_desc.pass_descs), 1) + pass_desc = multi_pass_desc.pass_descs[0] + self.assertEqual(len(pass_desc.var_maps), 5) + self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) + self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4) + pattern_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.pattern.blocks[0].ops) + replace_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.replace.blocks[0].ops) + self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2) + self.assertEqual(len(replace_op_dicts.get("concat", [])), 1) + self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) + self.assertEqual(len(replace_op_dicts.get("slice", [])), 2) + + def test_generate_simplify_inference(self): + input_specs = {'x': InputSpec([-1, 16, 16, 16])} + helper = ir.RegisterPassHelper( + [generate_simplify_inference()], input_specs=input_specs) + s = helper.SerializeMultiPassDesc() + multi_pass_desc = get_multi_pass_desc_from_str(s) + self.assertEqual(len(multi_pass_desc.pass_descs), 1) + pass_desc = multi_pass_desc.pass_descs[0] + self.assertEqual(len(pass_desc.var_maps), 2) + self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) + self.assertEqual(len(pass_desc.replace.blocks[0].ops), 0) + pattern_op_dicts = self.convert_ops_to_op_dicts( + pass_desc.pattern.blocks[0].ops) + self.assertEqual(len(pattern_op_dicts.get("transpose2", [])), 2) -- GitLab