From 3a7e470b3e1166a663e52081c797c2c12897723b Mon Sep 17 00:00:00 2001 From: zqw_1997 <118182234+zhengqiwen1997@users.noreply.github.com> Date: Tue, 31 Jan 2023 14:19:51 +0800 Subject: [PATCH] remove fluid.ir.RegisterPassHelper PassDesc and RegisterPass (#49578) * remove fluid.ir.RegisterPassHelper PassDesc and RegisterPass * proto import problems * change import way of pass_desc_pb2 * change sys.path * change the way of import framwork_pb2 * add fluid_path directory from path.dirname * fluid_path changed --- python/paddle/fluid/ir.py | 466 +---------------- .../unittests/ir/test_ir_generate_pass.py | 3 +- .../incubate/passes/fuse_resnet_unit_pass.py | 2 +- python/paddle/incubate/passes/ir.py | 483 ++++++++++++++++++ 4 files changed, 488 insertions(+), 466 deletions(-) create mode 100644 python/paddle/incubate/passes/ir.py diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index fb077ed8b5..c444b4ceda 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -16,18 +16,8 @@ import copy import inspect from os import path import paddle -from . import core, unique_name -from .framework import _apply_pass, OpProtoHolder - -from .proto import framework_pb2 - -try: - from .proto import pass_desc_pb2 -except ModuleNotFoundError: - import sys - - sys.path.append(path.join(path.dirname(__file__), 'proto')) - from .proto import pass_desc_pb2 +from . import core +from .framework import _apply_pass def get_data_vars(program): @@ -138,455 +128,3 @@ def apply_build_strategy( build_strategy.enable_inplace = False build_strategy._clear_finalized() return build_strategy - - -class RegisterPassHelper: - _register_helpers = list() - - def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()): - self._pass_type = pass_type - self._pass_pairs = pass_pairs - self._input_specs = input_specs - RegisterPassHelper._register_helpers.append(self) - - def _get_args_from_func(self, func): - args = list() - arg_specs = inspect.getfullargspec(func) - for arg_name in arg_specs.args: - input_spec = self._input_specs.get(arg_name) - if isinstance(input_spec, paddle.static.InputSpec): - args.append( - PassDesc.VarHelper( - arg_name, input_spec.shape, input_spec.dtype - ) - ) - elif isinstance(input_spec, paddle.ParamAttr): - args.append(paddle.ParamAttr(arg_name)) - else: - args.append(PassDesc.VarHelper(arg_name, [-1])) - return args - - def _prune_program_desc(self, ops): - for op_desc in ops: - default_attrs = core.get_op_attrs_default_value( - op_desc.type.encode() - ) - remove_attrs = list() - for attr in op_desc.attrs: - # attr must not in - if attr.name not in [ - "op_namescope", - "op_callstack", - "op_device", - ]: - attr_list_fields = attr.ListFields() - # attr format must be: name, type, value - if len(attr_list_fields) == 3: - attr_value = attr.ListFields()[-1][-1] - default_attr_value = default_attrs.get(attr.name) - # value must not default - if default_attr_value != attr_value: - continue - remove_attrs.append(attr) - for attr in remove_attrs: - op_desc.attrs.remove(attr) - - def _func_to_program_desc(self, func, ops): - vars = list() - program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(program, startup_program): - args = self._get_args_from_func(func) - vars.extend(args) - outs = func(*args) - if not isinstance(outs, (list, tuple)): - outs = [outs] - for out in outs: - if isinstance(out, PassDesc.OpHelper): - op_outs = out.Outputs() - if len(op_outs) != 1: - raise ValueError( - "Operator '{}' has multiple outputs, please specify one output variable.".format( - out._type - ) - ) - for op_out in op_outs.values(): - vars.extend(op_out) - else: - vars.append(out) - block_desc = program.current_block().desc - for i in range(block_desc.op_size()): - ops.add().ParseFromString(block_desc.op(i).serialize_to_string()) - self._prune_program_desc(ops) - return vars, program.current_block().ops - - def _convert_vars_to_pass_desc(self, patterns, replaces, desc): - def _add_element_conditions(conditions, elements): - for element in elements: - if element._condition: - conditions.append(element._condition) - _add_element_conditions(conditions, element._elements) - - for (pattern, replace) in zip(patterns, replaces): - # Convert maps of inputs and outputs. - var_map = desc.var_maps.add() - var_map.pattern_var = pattern.name - var_map.replace_var = replace.name - conditions = desc.var_attr_conditions - # Convert shape condition. - if pattern.name in self._input_specs: - condition = conditions.add() - pattern.Attr("shape")._to_pass_desc_attr(condition.attr) - condition.condition_value.name = "" - condition.condition_value.type = framework_pb2.AttrType.LONGS - condition.condition_value.longs.extend(pattern.shape) - condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ - # Convert attr conditions. - if PassDesc.VarHelper == pattern.__class__: - for attr in pattern._attrs.values(): - _add_element_conditions(conditions, [attr]) - - def _convert_ops_to_pass_desc(self, patterns, replaces, desc): - for replace in replaces: - if isinstance(replace, PassDesc.OpHelper): - for attr in replace._attrs.values(): - # Convert attr maps. - mapped = attr._mapped - if inspect.isfunction(mapped): - mapped = mapped(patterns) - attr_map = desc.op_attr_maps.add() - mapped._to_pass_desc_attr(attr_map.pattern_attr) - attr._to_pass_desc_attr(attr_map.replace_attr) - if mapped._operation is not None: - attr_map.operation.CopyFrom(mapped._operation) - - def SerializeMultiPassDesc(self): - switch_static_mode = paddle.in_dynamic_mode() - if switch_static_mode: - paddle.enable_static() - multi_pass_desc = pass_desc_pb2.MultiPassDesc() - multi_pass_desc.pass_type = self._pass_type - # Traverse all pass pairs and convert them to PassDesc data. - # Here need to add cache in the future. - for (pattern, replace) in self._pass_pairs: - pass_desc = multi_pass_desc.pass_descs.add() - # Convert ProgramDescs of pattern and replace subgraphs. - pattern_vars, pattern_ops = self._func_to_program_desc( - pattern, pass_desc.pattern - ) - replace_vars, replace_ops = self._func_to_program_desc( - replace, pass_desc.replace - ) - self._convert_vars_to_pass_desc( - pattern_vars, replace_vars, pass_desc - ) - self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc) - if switch_static_mode: - paddle.disable_static() - return multi_pass_desc.SerializeToString() - - -class PassDesc: - class AttrHelper: - def __init__(self, obj, name, element_index=None): - self._obj = obj - self._name = name - self._operation_type = None - self._element_index = element_index - self._elements = list() - self._operation = None - self._condition = None - self._mapped = None - - def __getitem__(self, index): - element = PassDesc.AttrHelper( - self._obj, self._name, element_index=index - ) - self._elements.append(element) - return element - - def _to_pass_desc_attr(self, pass_desc_attr): - if isinstance(self._obj, PassDesc.VarHelper): - pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable - pass_desc_attr.var_name = self._obj.name - else: - pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator - pass_desc_attr.op_index = self._obj._index - pass_desc_attr.name = self._name - if self._operation_type is not None: - pass_desc_attr.operation = self._operation_type - if self._element_index is not None: - pass_desc_attr.element_index = self._element_index - - def _to_op_desc_attr(self, value, op_desc_attr): - op_desc_attr.name = "" - if isinstance(value, int): - op_desc_attr.type = framework_pb2.AttrType.INT - op_desc_attr.i = value - else: - raise NotImplementedError("Unimplemented transform operation.") - - def _clone_with_operation(self, type, value=None): - attr = PassDesc.AttrHelper( - self._obj, self._name, self._element_index - ) - self._elements.append(attr) - if value is None: - attr._operation_type = type - return attr - operation = pass_desc_pb2.PassDesc.Operation() - operation.type = type - if isinstance(value, PassDesc.AttrHelper): - value._to_pass_desc_attr(operation.attr) - else: - self._to_op_desc_attr(value, operation.value) - attr._operation = operation - attr._operation_type = self._operation_type - return attr - - def __sub__(self, value): - return self._clone_with_operation( - pass_desc_pb2.PassDesc.OperationType.kSub, value - ) - - def __add__(self, value): - return self._clone_with_operation( - pass_desc_pb2.PassDesc.OperationType.kAdd, value - ) - - def Mod(self, value): - return self._clone_with_operation( - pass_desc_pb2.PassDesc.OperationType.kMod, value - ) - - def Size(self): - return self._clone_with_operation( - pass_desc_pb2.PassDesc.OperationType.kSize - ) - - def _set_with_condition(self, type, value): - condition = pass_desc_pb2.PassDesc.AttrCondition() - self._to_pass_desc_attr(condition.attr) - condition.type = type - if isinstance(value, PassDesc.AttrHelper): - value._to_pass_desc_attr(condition.condition_attr) - else: - self._to_op_desc_attr(value, condition.condition_value) - if self._operation: - condition.operation.CopyFrom(self._operation) - self._condition = condition - - def EQ(self, value): - self._set_with_condition( - pass_desc_pb2.PassDesc.ConditionType.kEQ, value - ) - - def MappedPattern( - self, var=None, op=None, index=0, name=None, element_index=None - ): - if all([var, op]): - raise ValueError("Only mapped one of which var or op.") - - def mapped_var(pattern_ops): - raise NotImplementedError( - "Mapping to variable is not implemented." - ) - - def mapped_op(pattern_ops): - ops = [o for o in pattern_ops if o._type == op] - if len(ops) <= index: - raise ValueError( - "Index '{}' of operator '{}' is incorrect.".format( - index, op - ) - ) - return PassDesc.AttrHelper( - ops[index], name, element_index=element_index - ) - - self._mapped = mapped_op if var is None else mapped_var - - class VarHelper(paddle.static.Variable): - def __init__(self, *args, **kwargs): - block = paddle.static.default_main_program().current_block() - self._var = paddle.static.data(*args, **kwargs) - self._attrs = dict() - - def __getattr__(self, name): - return getattr(self._var, name) - - def Attr(self, name): - attr = self._attrs.get(name) - if attr is None: - attr = PassDesc.AttrHelper(self, name) - self._attrs[name] = attr - return attr - - class OpHelper: - def __init__(self, type=None): - self._type = type - - def __getattr__(self, name): - op = PassDesc.OpHelper(name) - op.Init() - return op - - def __call__(self, *args, **kwargs): - if len(args) > 0: - raise ValueError( - "Each input argument needs to specify a parameter name." - ) - for (in_name, in_args) in kwargs.items(): - op_input = self._inputs.get(in_name) - if op_input is None: - raise ValueError( - "Operator '{}' does not have input named '{}'.".format( - self._type, in_name - ) - ) - if isinstance(in_args, (list, tuple)): - if len(in_args) == 0: - raise ValueError( - "Input '{}' of operator '{}' cannot be empty.".format( - in_name, self._type - ) - ) - else: - in_args = [in_args] - for in_arg in in_args: - if isinstance(in_arg, PassDesc.OpHelper): - op_outs = in_arg.Outputs() - if len(op_outs) != 1: - raise ValueError( - "The size of outputs of operator '{}' is not equal 1, please specify one output variable.".format( - in_arg._type - ) - ) - for op_out in op_outs.values(): - op_input.extend(op_out) - else: - op_input.append(in_arg) - self._desc.set_input(in_name, [i.name for i in op_input]) - block = paddle.static.default_main_program().current_block() - for out_name, op_output in self._outputs.items(): - op_output_name = unique_name.generate(self._type) - op_output.append(block.create_var(name=op_output_name)) - self._desc.set_output(out_name, [op_output_name]) - return self - - def Init(self): - block = paddle.static.default_main_program().current_block() - self._proto = OpProtoHolder.instance().op_proto_map.get(self._type) - if self._proto is None: - raise AttributeError( - "type object 'OpHelper' has no attribute '{}'".format( - self._type - ) - ) - self._index = len(block.ops) - self._desc = block.desc.append_op() - self._desc.set_type(self._type) - self._attrs = dict() - self._inputs = {i.name: list() for i in self._proto.inputs} - self._outputs = {o.name: list() for o in self._proto.outputs} - block.ops.append(self) - - def Attr(self, name): - attr = self._attrs.get(name) - if attr is None: - attr = PassDesc.AttrHelper(self, name) - self._attrs[name] = attr - return attr - - def SetAttr(self, name, value): - if isinstance(value, PassDesc.AttrHelper): - self.Attr(name)._mapped = value - else: - self._desc._set_attr(name, value) - - def Output(self, name): - output = self._outputs.get(name) - if output is None: - raise ValueError( - "Operator '{}' does not have output named '{}'.".format( - self._type, name - ) - ) - return output - - def Outputs(self): - return self._outputs - - def SetOutputs(self, **kwargs): - for param, arg in kwargs.items(): - if arg is None: - self._desc.remove_output(param) - else: - self._desc.set_output(param, [arg.name]) - - OP = OpHelper() - - -def RegisterPass(function=None, input_specs=dict()): - """ - The function decorator of Register Pass. Decorator @RegisterPass handles - the function and register it into a core.Pass instance. Use name of function - as Pass type. - - Args: - function (callable): The function with return of callable pair(s) that - represents the pattern subgraph and the replace subgraph. - input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype - information of Tensor. Some operators limit the shape and dtype of datas when - create subgraph with Paddle APIs. So user need specify InputSpec of data to - ensure create a correctly subgraph. Of course, this argument is not limited to - matching subgraph. The default is dict(). - - Returns: - callables: Callable pair(s). - - Examples: - .. code-block:: python - - import paddle - from paddle.fluid.ir import RegisterPass - - @RegisterPass - def multi_add_to_addn(): - def pattern(x, y, z): - return paddle.add(paddle.add(x, y), z) - def replace(x, y, z): - return paddle.add_n([x, y, z]) - return pattern, replace - """ - - def _is_pass_pair(check_pair): - if isinstance(check_pair, (list, tuple)): - if len(check_pair) == 2: - if all(map(inspect.isfunction, check_pair)): - return True - return False - - def decorated(python_func): - pass_type = python_func.__name__ - signature = inspect.signature(python_func) - if len(signature.parameters) > 0: - raise NotImplementedError( - "Pass function with parameter is not supported now." - ) - elif len(signature.parameters) == 0: - pass_pairs = python_func() - if _is_pass_pair(pass_pairs): - pass_pairs = [pass_pairs] - elif not all(map(_is_pass_pair, pass_pairs)): - raise ValueError( - "Return value of Pass function must be (callable, callable)." - ) - helper = RegisterPassHelper(pass_pairs, pass_type, input_specs) - core.register_pass(pass_type, helper.SerializeMultiPassDesc) - return python_func - - if inspect.isfunction(function): - return decorated(function) - - return decorated diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py index 2025f94ffd..2f3a2f2d77 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py @@ -17,7 +17,8 @@ import unittest import numpy as np import paddle -from paddle.fluid import core, ir +from paddle.fluid import core +from paddle.incubate.passes import ir from paddle.static import InputSpec diff --git a/python/paddle/incubate/passes/fuse_resnet_unit_pass.py b/python/paddle/incubate/passes/fuse_resnet_unit_pass.py index 6441427f46..7acf28eecb 100644 --- a/python/paddle/incubate/passes/fuse_resnet_unit_pass.py +++ b/python/paddle/incubate/passes/fuse_resnet_unit_pass.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.fluid.ir as ir +import paddle.incubate.passes.ir as ir def set_resnet_unit_attrs(resnet_unit, has_shortcut): diff --git a/python/paddle/incubate/passes/ir.py b/python/paddle/incubate/passes/ir.py new file mode 100644 index 0000000000..cf6568a545 --- /dev/null +++ b/python/paddle/incubate/passes/ir.py @@ -0,0 +1,483 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from os import path + +import paddle +from paddle.fluid.proto import framework_pb2 + +from ...fluid import core, unique_name +from ...fluid.framework import OpProtoHolder + +try: + from paddle.fluid.proto import pass_desc_pb2 +except ModuleNotFoundError: + import sys + + fluid_path = path.dirname(__file__) + '/../../fluid' + sys.path.append(path.join(fluid_path, 'proto')) + from paddle.fluid.proto import pass_desc_pb2 + + +class RegisterPassHelper: + _register_helpers = list() + + def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()): + self._pass_type = pass_type + self._pass_pairs = pass_pairs + self._input_specs = input_specs + RegisterPassHelper._register_helpers.append(self) + + def _get_args_from_func(self, func): + args = list() + arg_specs = inspect.getfullargspec(func) + for arg_name in arg_specs.args: + input_spec = self._input_specs.get(arg_name) + if isinstance(input_spec, paddle.static.InputSpec): + args.append( + PassDesc.VarHelper( + arg_name, input_spec.shape, input_spec.dtype + ) + ) + elif isinstance(input_spec, paddle.ParamAttr): + args.append(paddle.ParamAttr(arg_name)) + else: + args.append(PassDesc.VarHelper(arg_name, [-1])) + return args + + def _prune_program_desc(self, ops): + for op_desc in ops: + default_attrs = core.get_op_attrs_default_value( + op_desc.type.encode() + ) + remove_attrs = list() + for attr in op_desc.attrs: + # attr must not in + if attr.name not in [ + "op_namescope", + "op_callstack", + "op_device", + ]: + attr_list_fields = attr.ListFields() + # attr format must be: name, type, value + if len(attr_list_fields) == 3: + attr_value = attr.ListFields()[-1][-1] + default_attr_value = default_attrs.get(attr.name) + # value must not default + if default_attr_value != attr_value: + continue + remove_attrs.append(attr) + for attr in remove_attrs: + op_desc.attrs.remove(attr) + + def _func_to_program_desc(self, func, ops): + vars = list() + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + args = self._get_args_from_func(func) + vars.extend(args) + outs = func(*args) + if not isinstance(outs, (list, tuple)): + outs = [outs] + for out in outs: + if isinstance(out, PassDesc.OpHelper): + op_outs = out.Outputs() + if len(op_outs) != 1: + raise ValueError( + "Operator '{}' has multiple outputs, please specify one output variable.".format( + out._type + ) + ) + for op_out in op_outs.values(): + vars.extend(op_out) + else: + vars.append(out) + block_desc = program.current_block().desc + for i in range(block_desc.op_size()): + ops.add().ParseFromString(block_desc.op(i).serialize_to_string()) + self._prune_program_desc(ops) + return vars, program.current_block().ops + + def _convert_vars_to_pass_desc(self, patterns, replaces, desc): + def _add_element_conditions(conditions, elements): + for element in elements: + if element._condition: + conditions.append(element._condition) + _add_element_conditions(conditions, element._elements) + + for (pattern, replace) in zip(patterns, replaces): + # Convert maps of inputs and outputs. + var_map = desc.var_maps.add() + var_map.pattern_var = pattern.name + var_map.replace_var = replace.name + conditions = desc.var_attr_conditions + # Convert shape condition. + if pattern.name in self._input_specs: + condition = conditions.add() + pattern.Attr("shape")._to_pass_desc_attr(condition.attr) + condition.condition_value.name = "" + condition.condition_value.type = framework_pb2.AttrType.LONGS + condition.condition_value.longs.extend(pattern.shape) + condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ + # Convert attr conditions. + if PassDesc.VarHelper == pattern.__class__: + for attr in pattern._attrs.values(): + _add_element_conditions(conditions, [attr]) + + def _convert_ops_to_pass_desc(self, patterns, replaces, desc): + for replace in replaces: + if isinstance(replace, PassDesc.OpHelper): + for attr in replace._attrs.values(): + # Convert attr maps. + mapped = attr._mapped + if inspect.isfunction(mapped): + mapped = mapped(patterns) + attr_map = desc.op_attr_maps.add() + mapped._to_pass_desc_attr(attr_map.pattern_attr) + attr._to_pass_desc_attr(attr_map.replace_attr) + if mapped._operation is not None: + attr_map.operation.CopyFrom(mapped._operation) + + def SerializeMultiPassDesc(self): + switch_static_mode = paddle.in_dynamic_mode() + if switch_static_mode: + paddle.enable_static() + multi_pass_desc = pass_desc_pb2.MultiPassDesc() + multi_pass_desc.pass_type = self._pass_type + # Traverse all pass pairs and convert them to PassDesc data. + # Here need to add cache in the future. + for (pattern, replace) in self._pass_pairs: + pass_desc = multi_pass_desc.pass_descs.add() + # Convert ProgramDescs of pattern and replace subgraphs. + pattern_vars, pattern_ops = self._func_to_program_desc( + pattern, pass_desc.pattern + ) + replace_vars, replace_ops = self._func_to_program_desc( + replace, pass_desc.replace + ) + self._convert_vars_to_pass_desc( + pattern_vars, replace_vars, pass_desc + ) + self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc) + if switch_static_mode: + paddle.disable_static() + return multi_pass_desc.SerializeToString() + + +class PassDesc: + class AttrHelper: + def __init__(self, obj, name, element_index=None): + self._obj = obj + self._name = name + self._operation_type = None + self._element_index = element_index + self._elements = list() + self._operation = None + self._condition = None + self._mapped = None + + def __getitem__(self, index): + element = PassDesc.AttrHelper( + self._obj, self._name, element_index=index + ) + self._elements.append(element) + return element + + def _to_pass_desc_attr(self, pass_desc_attr): + if isinstance(self._obj, PassDesc.VarHelper): + pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable + pass_desc_attr.var_name = self._obj.name + else: + pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator + pass_desc_attr.op_index = self._obj._index + pass_desc_attr.name = self._name + if self._operation_type is not None: + pass_desc_attr.operation = self._operation_type + if self._element_index is not None: + pass_desc_attr.element_index = self._element_index + + def _to_op_desc_attr(self, value, op_desc_attr): + op_desc_attr.name = "" + if isinstance(value, int): + op_desc_attr.type = framework_pb2.AttrType.INT + op_desc_attr.i = value + else: + raise NotImplementedError("Unimplemented transform operation.") + + def _clone_with_operation(self, type, value=None): + attr = PassDesc.AttrHelper( + self._obj, self._name, self._element_index + ) + self._elements.append(attr) + if value is None: + attr._operation_type = type + return attr + operation = pass_desc_pb2.PassDesc.Operation() + operation.type = type + if isinstance(value, PassDesc.AttrHelper): + value._to_pass_desc_attr(operation.attr) + else: + self._to_op_desc_attr(value, operation.value) + attr._operation = operation + attr._operation_type = self._operation_type + return attr + + def __sub__(self, value): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kSub, value + ) + + def __add__(self, value): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kAdd, value + ) + + def Mod(self, value): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kMod, value + ) + + def Size(self): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kSize + ) + + def _set_with_condition(self, type, value): + condition = pass_desc_pb2.PassDesc.AttrCondition() + self._to_pass_desc_attr(condition.attr) + condition.type = type + if isinstance(value, PassDesc.AttrHelper): + value._to_pass_desc_attr(condition.condition_attr) + else: + self._to_op_desc_attr(value, condition.condition_value) + if self._operation: + condition.operation.CopyFrom(self._operation) + self._condition = condition + + def EQ(self, value): + self._set_with_condition( + pass_desc_pb2.PassDesc.ConditionType.kEQ, value + ) + + def MappedPattern( + self, var=None, op=None, index=0, name=None, element_index=None + ): + if all([var, op]): + raise ValueError("Only mapped one of which var or op.") + + def mapped_var(pattern_ops): + raise NotImplementedError( + "Mapping to variable is not implemented." + ) + + def mapped_op(pattern_ops): + ops = [o for o in pattern_ops if o._type == op] + if len(ops) <= index: + raise ValueError( + "Index '{}' of operator '{}' is incorrect.".format( + index, op + ) + ) + return PassDesc.AttrHelper( + ops[index], name, element_index=element_index + ) + + self._mapped = mapped_op if var is None else mapped_var + + class VarHelper(paddle.static.Variable): + def __init__(self, *args, **kwargs): + block = paddle.static.default_main_program().current_block() + self._var = paddle.static.data(*args, **kwargs) + self._attrs = dict() + + def __getattr__(self, name): + return getattr(self._var, name) + + def Attr(self, name): + attr = self._attrs.get(name) + if attr is None: + attr = PassDesc.AttrHelper(self, name) + self._attrs[name] = attr + return attr + + class OpHelper: + def __init__(self, type=None): + self._type = type + + def __getattr__(self, name): + op = PassDesc.OpHelper(name) + op.Init() + return op + + def __call__(self, *args, **kwargs): + if len(args) > 0: + raise ValueError( + "Each input argument needs to specify a parameter name." + ) + for (in_name, in_args) in kwargs.items(): + op_input = self._inputs.get(in_name) + if op_input is None: + raise ValueError( + "Operator '{}' does not have input named '{}'.".format( + self._type, in_name + ) + ) + if isinstance(in_args, (list, tuple)): + if len(in_args) == 0: + raise ValueError( + "Input '{}' of operator '{}' cannot be empty.".format( + in_name, self._type + ) + ) + else: + in_args = [in_args] + for in_arg in in_args: + if isinstance(in_arg, PassDesc.OpHelper): + op_outs = in_arg.Outputs() + if len(op_outs) != 1: + raise ValueError( + "The size of outputs of operator '{}' is not equal 1, please specify one output variable.".format( + in_arg._type + ) + ) + for op_out in op_outs.values(): + op_input.extend(op_out) + else: + op_input.append(in_arg) + self._desc.set_input(in_name, [i.name for i in op_input]) + block = paddle.static.default_main_program().current_block() + for out_name, op_output in self._outputs.items(): + op_output_name = unique_name.generate(self._type) + op_output.append(block.create_var(name=op_output_name)) + self._desc.set_output(out_name, [op_output_name]) + return self + + def Init(self): + block = paddle.static.default_main_program().current_block() + self._proto = OpProtoHolder.instance().op_proto_map.get(self._type) + if self._proto is None: + raise AttributeError( + "type object 'OpHelper' has no attribute '{}'".format( + self._type + ) + ) + self._index = len(block.ops) + self._desc = block.desc.append_op() + self._desc.set_type(self._type) + self._attrs = dict() + self._inputs = {i.name: list() for i in self._proto.inputs} + self._outputs = {o.name: list() for o in self._proto.outputs} + block.ops.append(self) + + def Attr(self, name): + attr = self._attrs.get(name) + if attr is None: + attr = PassDesc.AttrHelper(self, name) + self._attrs[name] = attr + return attr + + def SetAttr(self, name, value): + if isinstance(value, PassDesc.AttrHelper): + self.Attr(name)._mapped = value + else: + self._desc._set_attr(name, value) + + def Output(self, name): + output = self._outputs.get(name) + if output is None: + raise ValueError( + "Operator '{}' does not have output named '{}'.".format( + self._type, name + ) + ) + return output + + def Outputs(self): + return self._outputs + + def SetOutputs(self, **kwargs): + for param, arg in kwargs.items(): + if arg is None: + self._desc.remove_output(param) + else: + self._desc.set_output(param, [arg.name]) + + OP = OpHelper() + + +def RegisterPass(function=None, input_specs=dict()): + """ + The function decorator of Register Pass. Decorator @RegisterPass handles + the function and register it into a core.Pass instance. Use name of function + as Pass type. + + Args: + function (callable): The function with return of callable pair(s) that + represents the pattern subgraph and the replace subgraph. + input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype + information of Tensor. Some operators limit the shape and dtype of datas when + create subgraph with Paddle APIs. So user need specify InputSpec of data to + ensure create a correctly subgraph. Of course, this argument is not limited to + matching subgraph. The default is dict(). + + Returns: + callables: Callable pair(s). + + Examples: + .. code-block:: python + + import paddle + from paddle.fluid.ir import RegisterPass + + @RegisterPass + def multi_add_to_addn(): + def pattern(x, y, z): + return paddle.add(paddle.add(x, y), z) + def replace(x, y, z): + return paddle.add_n([x, y, z]) + return pattern, replace + """ + + def _is_pass_pair(check_pair): + if isinstance(check_pair, (list, tuple)): + if len(check_pair) == 2: + if all(map(inspect.isfunction, check_pair)): + return True + return False + + def decorated(python_func): + pass_type = python_func.__name__ + signature = inspect.signature(python_func) + if len(signature.parameters) > 0: + raise NotImplementedError( + "Pass function with parameter is not supported now." + ) + elif len(signature.parameters) == 0: + pass_pairs = python_func() + if _is_pass_pair(pass_pairs): + pass_pairs = [pass_pairs] + elif not all(map(_is_pass_pair, pass_pairs)): + raise ValueError( + "Return value of Pass function must be (callable, callable)." + ) + helper = RegisterPassHelper(pass_pairs, pass_type, input_specs) + core.register_pass(pass_type, helper.SerializeMultiPassDesc) + return python_func + + if inspect.isfunction(function): + return decorated(function) + + return decorated -- GitLab