未验证 提交 3a7e470b 编写于 作者: Z zqw_1997 提交者: GitHub

remove fluid.ir.RegisterPassHelper PassDesc and RegisterPass (#49578)

* remove fluid.ir.RegisterPassHelper PassDesc and RegisterPass

* proto import problems

* change import way of pass_desc_pb2

* change sys.path

* change the way of import framwork_pb2

* add fluid_path directory from path.dirname

* fluid_path changed
上级 0f173d5a
...@@ -16,18 +16,8 @@ import copy ...@@ -16,18 +16,8 @@ import copy
import inspect import inspect
from os import path from os import path
import paddle import paddle
from . import core, unique_name from . import core
from .framework import _apply_pass, OpProtoHolder from .framework import _apply_pass
from .proto import framework_pb2
try:
from .proto import pass_desc_pb2
except ModuleNotFoundError:
import sys
sys.path.append(path.join(path.dirname(__file__), 'proto'))
from .proto import pass_desc_pb2
def get_data_vars(program): def get_data_vars(program):
...@@ -138,455 +128,3 @@ def apply_build_strategy( ...@@ -138,455 +128,3 @@ def apply_build_strategy(
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy._clear_finalized() build_strategy._clear_finalized()
return build_strategy return build_strategy
class RegisterPassHelper:
_register_helpers = list()
def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type
self._pass_pairs = pass_pairs
self._input_specs = input_specs
RegisterPassHelper._register_helpers.append(self)
def _get_args_from_func(self, func):
args = list()
arg_specs = inspect.getfullargspec(func)
for arg_name in arg_specs.args:
input_spec = self._input_specs.get(arg_name)
if isinstance(input_spec, paddle.static.InputSpec):
args.append(
PassDesc.VarHelper(
arg_name, input_spec.shape, input_spec.dtype
)
)
elif isinstance(input_spec, paddle.ParamAttr):
args.append(paddle.ParamAttr(arg_name))
else:
args.append(PassDesc.VarHelper(arg_name, [-1]))
return args
def _prune_program_desc(self, ops):
for op_desc in ops:
default_attrs = core.get_op_attrs_default_value(
op_desc.type.encode()
)
remove_attrs = list()
for attr in op_desc.attrs:
# attr must not in
if attr.name not in [
"op_namescope",
"op_callstack",
"op_device",
]:
attr_list_fields = attr.ListFields()
# attr format must be: name, type, value
if len(attr_list_fields) == 3:
attr_value = attr.ListFields()[-1][-1]
default_attr_value = default_attrs.get(attr.name)
# value must not default
if default_attr_value != attr_value:
continue
remove_attrs.append(attr)
for attr in remove_attrs:
op_desc.attrs.remove(attr)
def _func_to_program_desc(self, func, ops):
vars = list()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
args = self._get_args_from_func(func)
vars.extend(args)
outs = func(*args)
if not isinstance(outs, (list, tuple)):
outs = [outs]
for out in outs:
if isinstance(out, PassDesc.OpHelper):
op_outs = out.Outputs()
if len(op_outs) != 1:
raise ValueError(
"Operator '{}' has multiple outputs, please specify one output variable.".format(
out._type
)
)
for op_out in op_outs.values():
vars.extend(op_out)
else:
vars.append(out)
block_desc = program.current_block().desc
for i in range(block_desc.op_size()):
ops.add().ParseFromString(block_desc.op(i).serialize_to_string())
self._prune_program_desc(ops)
return vars, program.current_block().ops
def _convert_vars_to_pass_desc(self, patterns, replaces, desc):
def _add_element_conditions(conditions, elements):
for element in elements:
if element._condition:
conditions.append(element._condition)
_add_element_conditions(conditions, element._elements)
for (pattern, replace) in zip(patterns, replaces):
# Convert maps of inputs and outputs.
var_map = desc.var_maps.add()
var_map.pattern_var = pattern.name
var_map.replace_var = replace.name
conditions = desc.var_attr_conditions
# Convert shape condition.
if pattern.name in self._input_specs:
condition = conditions.add()
pattern.Attr("shape")._to_pass_desc_attr(condition.attr)
condition.condition_value.name = ""
condition.condition_value.type = framework_pb2.AttrType.LONGS
condition.condition_value.longs.extend(pattern.shape)
condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ
# Convert attr conditions.
if PassDesc.VarHelper == pattern.__class__:
for attr in pattern._attrs.values():
_add_element_conditions(conditions, [attr])
def _convert_ops_to_pass_desc(self, patterns, replaces, desc):
for replace in replaces:
if isinstance(replace, PassDesc.OpHelper):
for attr in replace._attrs.values():
# Convert attr maps.
mapped = attr._mapped
if inspect.isfunction(mapped):
mapped = mapped(patterns)
attr_map = desc.op_attr_maps.add()
mapped._to_pass_desc_attr(attr_map.pattern_attr)
attr._to_pass_desc_attr(attr_map.replace_attr)
if mapped._operation is not None:
attr_map.operation.CopyFrom(mapped._operation)
def SerializeMultiPassDesc(self):
switch_static_mode = paddle.in_dynamic_mode()
if switch_static_mode:
paddle.enable_static()
multi_pass_desc = pass_desc_pb2.MultiPassDesc()
multi_pass_desc.pass_type = self._pass_type
# Traverse all pass pairs and convert them to PassDesc data.
# Here need to add cache in the future.
for (pattern, replace) in self._pass_pairs:
pass_desc = multi_pass_desc.pass_descs.add()
# Convert ProgramDescs of pattern and replace subgraphs.
pattern_vars, pattern_ops = self._func_to_program_desc(
pattern, pass_desc.pattern
)
replace_vars, replace_ops = self._func_to_program_desc(
replace, pass_desc.replace
)
self._convert_vars_to_pass_desc(
pattern_vars, replace_vars, pass_desc
)
self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc)
if switch_static_mode:
paddle.disable_static()
return multi_pass_desc.SerializeToString()
class PassDesc:
class AttrHelper:
def __init__(self, obj, name, element_index=None):
self._obj = obj
self._name = name
self._operation_type = None
self._element_index = element_index
self._elements = list()
self._operation = None
self._condition = None
self._mapped = None
def __getitem__(self, index):
element = PassDesc.AttrHelper(
self._obj, self._name, element_index=index
)
self._elements.append(element)
return element
def _to_pass_desc_attr(self, pass_desc_attr):
if isinstance(self._obj, PassDesc.VarHelper):
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable
pass_desc_attr.var_name = self._obj.name
else:
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator
pass_desc_attr.op_index = self._obj._index
pass_desc_attr.name = self._name
if self._operation_type is not None:
pass_desc_attr.operation = self._operation_type
if self._element_index is not None:
pass_desc_attr.element_index = self._element_index
def _to_op_desc_attr(self, value, op_desc_attr):
op_desc_attr.name = ""
if isinstance(value, int):
op_desc_attr.type = framework_pb2.AttrType.INT
op_desc_attr.i = value
else:
raise NotImplementedError("Unimplemented transform operation.")
def _clone_with_operation(self, type, value=None):
attr = PassDesc.AttrHelper(
self._obj, self._name, self._element_index
)
self._elements.append(attr)
if value is None:
attr._operation_type = type
return attr
operation = pass_desc_pb2.PassDesc.Operation()
operation.type = type
if isinstance(value, PassDesc.AttrHelper):
value._to_pass_desc_attr(operation.attr)
else:
self._to_op_desc_attr(value, operation.value)
attr._operation = operation
attr._operation_type = self._operation_type
return attr
def __sub__(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kSub, value
)
def __add__(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kAdd, value
)
def Mod(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kMod, value
)
def Size(self):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kSize
)
def _set_with_condition(self, type, value):
condition = pass_desc_pb2.PassDesc.AttrCondition()
self._to_pass_desc_attr(condition.attr)
condition.type = type
if isinstance(value, PassDesc.AttrHelper):
value._to_pass_desc_attr(condition.condition_attr)
else:
self._to_op_desc_attr(value, condition.condition_value)
if self._operation:
condition.operation.CopyFrom(self._operation)
self._condition = condition
def EQ(self, value):
self._set_with_condition(
pass_desc_pb2.PassDesc.ConditionType.kEQ, value
)
def MappedPattern(
self, var=None, op=None, index=0, name=None, element_index=None
):
if all([var, op]):
raise ValueError("Only mapped one of which var or op.")
def mapped_var(pattern_ops):
raise NotImplementedError(
"Mapping to variable is not implemented."
)
def mapped_op(pattern_ops):
ops = [o for o in pattern_ops if o._type == op]
if len(ops) <= index:
raise ValueError(
"Index '{}' of operator '{}' is incorrect.".format(
index, op
)
)
return PassDesc.AttrHelper(
ops[index], name, element_index=element_index
)
self._mapped = mapped_op if var is None else mapped_var
class VarHelper(paddle.static.Variable):
def __init__(self, *args, **kwargs):
block = paddle.static.default_main_program().current_block()
self._var = paddle.static.data(*args, **kwargs)
self._attrs = dict()
def __getattr__(self, name):
return getattr(self._var, name)
def Attr(self, name):
attr = self._attrs.get(name)
if attr is None:
attr = PassDesc.AttrHelper(self, name)
self._attrs[name] = attr
return attr
class OpHelper:
def __init__(self, type=None):
self._type = type
def __getattr__(self, name):
op = PassDesc.OpHelper(name)
op.Init()
return op
def __call__(self, *args, **kwargs):
if len(args) > 0:
raise ValueError(
"Each input argument needs to specify a parameter name."
)
for (in_name, in_args) in kwargs.items():
op_input = self._inputs.get(in_name)
if op_input is None:
raise ValueError(
"Operator '{}' does not have input named '{}'.".format(
self._type, in_name
)
)
if isinstance(in_args, (list, tuple)):
if len(in_args) == 0:
raise ValueError(
"Input '{}' of operator '{}' cannot be empty.".format(
in_name, self._type
)
)
else:
in_args = [in_args]
for in_arg in in_args:
if isinstance(in_arg, PassDesc.OpHelper):
op_outs = in_arg.Outputs()
if len(op_outs) != 1:
raise ValueError(
"The size of outputs of operator '{}' is not equal 1, please specify one output variable.".format(
in_arg._type
)
)
for op_out in op_outs.values():
op_input.extend(op_out)
else:
op_input.append(in_arg)
self._desc.set_input(in_name, [i.name for i in op_input])
block = paddle.static.default_main_program().current_block()
for out_name, op_output in self._outputs.items():
op_output_name = unique_name.generate(self._type)
op_output.append(block.create_var(name=op_output_name))
self._desc.set_output(out_name, [op_output_name])
return self
def Init(self):
block = paddle.static.default_main_program().current_block()
self._proto = OpProtoHolder.instance().op_proto_map.get(self._type)
if self._proto is None:
raise AttributeError(
"type object 'OpHelper' has no attribute '{}'".format(
self._type
)
)
self._index = len(block.ops)
self._desc = block.desc.append_op()
self._desc.set_type(self._type)
self._attrs = dict()
self._inputs = {i.name: list() for i in self._proto.inputs}
self._outputs = {o.name: list() for o in self._proto.outputs}
block.ops.append(self)
def Attr(self, name):
attr = self._attrs.get(name)
if attr is None:
attr = PassDesc.AttrHelper(self, name)
self._attrs[name] = attr
return attr
def SetAttr(self, name, value):
if isinstance(value, PassDesc.AttrHelper):
self.Attr(name)._mapped = value
else:
self._desc._set_attr(name, value)
def Output(self, name):
output = self._outputs.get(name)
if output is None:
raise ValueError(
"Operator '{}' does not have output named '{}'.".format(
self._type, name
)
)
return output
def Outputs(self):
return self._outputs
def SetOutputs(self, **kwargs):
for param, arg in kwargs.items():
if arg is None:
self._desc.remove_output(param)
else:
self._desc.set_output(param, [arg.name])
OP = OpHelper()
def RegisterPass(function=None, input_specs=dict()):
"""
The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function
as Pass type.
Args:
function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is dict().
Returns:
callables: Callable pair(s).
Examples:
.. code-block:: python
import paddle
from paddle.fluid.ir import RegisterPass
@RegisterPass
def multi_add_to_addn():
def pattern(x, y, z):
return paddle.add(paddle.add(x, y), z)
def replace(x, y, z):
return paddle.add_n([x, y, z])
return pattern, replace
"""
def _is_pass_pair(check_pair):
if isinstance(check_pair, (list, tuple)):
if len(check_pair) == 2:
if all(map(inspect.isfunction, check_pair)):
return True
return False
def decorated(python_func):
pass_type = python_func.__name__
signature = inspect.signature(python_func)
if len(signature.parameters) > 0:
raise NotImplementedError(
"Pass function with parameter is not supported now."
)
elif len(signature.parameters) == 0:
pass_pairs = python_func()
if _is_pass_pair(pass_pairs):
pass_pairs = [pass_pairs]
elif not all(map(_is_pass_pair, pass_pairs)):
raise ValueError(
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
core.register_pass(pass_type, helper.SerializeMultiPassDesc)
return python_func
if inspect.isfunction(function):
return decorated(function)
return decorated
...@@ -17,7 +17,8 @@ import unittest ...@@ -17,7 +17,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid import core, ir from paddle.fluid import core
from paddle.incubate.passes import ir
from paddle.static import InputSpec from paddle.static import InputSpec
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid.ir as ir import paddle.incubate.passes.ir as ir
def set_resnet_unit_attrs(resnet_unit, has_shortcut): def set_resnet_unit_attrs(resnet_unit, has_shortcut):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from os import path
import paddle
from paddle.fluid.proto import framework_pb2
from ...fluid import core, unique_name
from ...fluid.framework import OpProtoHolder
try:
from paddle.fluid.proto import pass_desc_pb2
except ModuleNotFoundError:
import sys
fluid_path = path.dirname(__file__) + '/../../fluid'
sys.path.append(path.join(fluid_path, 'proto'))
from paddle.fluid.proto import pass_desc_pb2
class RegisterPassHelper:
_register_helpers = list()
def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type
self._pass_pairs = pass_pairs
self._input_specs = input_specs
RegisterPassHelper._register_helpers.append(self)
def _get_args_from_func(self, func):
args = list()
arg_specs = inspect.getfullargspec(func)
for arg_name in arg_specs.args:
input_spec = self._input_specs.get(arg_name)
if isinstance(input_spec, paddle.static.InputSpec):
args.append(
PassDesc.VarHelper(
arg_name, input_spec.shape, input_spec.dtype
)
)
elif isinstance(input_spec, paddle.ParamAttr):
args.append(paddle.ParamAttr(arg_name))
else:
args.append(PassDesc.VarHelper(arg_name, [-1]))
return args
def _prune_program_desc(self, ops):
for op_desc in ops:
default_attrs = core.get_op_attrs_default_value(
op_desc.type.encode()
)
remove_attrs = list()
for attr in op_desc.attrs:
# attr must not in
if attr.name not in [
"op_namescope",
"op_callstack",
"op_device",
]:
attr_list_fields = attr.ListFields()
# attr format must be: name, type, value
if len(attr_list_fields) == 3:
attr_value = attr.ListFields()[-1][-1]
default_attr_value = default_attrs.get(attr.name)
# value must not default
if default_attr_value != attr_value:
continue
remove_attrs.append(attr)
for attr in remove_attrs:
op_desc.attrs.remove(attr)
def _func_to_program_desc(self, func, ops):
vars = list()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
args = self._get_args_from_func(func)
vars.extend(args)
outs = func(*args)
if not isinstance(outs, (list, tuple)):
outs = [outs]
for out in outs:
if isinstance(out, PassDesc.OpHelper):
op_outs = out.Outputs()
if len(op_outs) != 1:
raise ValueError(
"Operator '{}' has multiple outputs, please specify one output variable.".format(
out._type
)
)
for op_out in op_outs.values():
vars.extend(op_out)
else:
vars.append(out)
block_desc = program.current_block().desc
for i in range(block_desc.op_size()):
ops.add().ParseFromString(block_desc.op(i).serialize_to_string())
self._prune_program_desc(ops)
return vars, program.current_block().ops
def _convert_vars_to_pass_desc(self, patterns, replaces, desc):
def _add_element_conditions(conditions, elements):
for element in elements:
if element._condition:
conditions.append(element._condition)
_add_element_conditions(conditions, element._elements)
for (pattern, replace) in zip(patterns, replaces):
# Convert maps of inputs and outputs.
var_map = desc.var_maps.add()
var_map.pattern_var = pattern.name
var_map.replace_var = replace.name
conditions = desc.var_attr_conditions
# Convert shape condition.
if pattern.name in self._input_specs:
condition = conditions.add()
pattern.Attr("shape")._to_pass_desc_attr(condition.attr)
condition.condition_value.name = ""
condition.condition_value.type = framework_pb2.AttrType.LONGS
condition.condition_value.longs.extend(pattern.shape)
condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ
# Convert attr conditions.
if PassDesc.VarHelper == pattern.__class__:
for attr in pattern._attrs.values():
_add_element_conditions(conditions, [attr])
def _convert_ops_to_pass_desc(self, patterns, replaces, desc):
for replace in replaces:
if isinstance(replace, PassDesc.OpHelper):
for attr in replace._attrs.values():
# Convert attr maps.
mapped = attr._mapped
if inspect.isfunction(mapped):
mapped = mapped(patterns)
attr_map = desc.op_attr_maps.add()
mapped._to_pass_desc_attr(attr_map.pattern_attr)
attr._to_pass_desc_attr(attr_map.replace_attr)
if mapped._operation is not None:
attr_map.operation.CopyFrom(mapped._operation)
def SerializeMultiPassDesc(self):
switch_static_mode = paddle.in_dynamic_mode()
if switch_static_mode:
paddle.enable_static()
multi_pass_desc = pass_desc_pb2.MultiPassDesc()
multi_pass_desc.pass_type = self._pass_type
# Traverse all pass pairs and convert them to PassDesc data.
# Here need to add cache in the future.
for (pattern, replace) in self._pass_pairs:
pass_desc = multi_pass_desc.pass_descs.add()
# Convert ProgramDescs of pattern and replace subgraphs.
pattern_vars, pattern_ops = self._func_to_program_desc(
pattern, pass_desc.pattern
)
replace_vars, replace_ops = self._func_to_program_desc(
replace, pass_desc.replace
)
self._convert_vars_to_pass_desc(
pattern_vars, replace_vars, pass_desc
)
self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc)
if switch_static_mode:
paddle.disable_static()
return multi_pass_desc.SerializeToString()
class PassDesc:
class AttrHelper:
def __init__(self, obj, name, element_index=None):
self._obj = obj
self._name = name
self._operation_type = None
self._element_index = element_index
self._elements = list()
self._operation = None
self._condition = None
self._mapped = None
def __getitem__(self, index):
element = PassDesc.AttrHelper(
self._obj, self._name, element_index=index
)
self._elements.append(element)
return element
def _to_pass_desc_attr(self, pass_desc_attr):
if isinstance(self._obj, PassDesc.VarHelper):
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable
pass_desc_attr.var_name = self._obj.name
else:
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator
pass_desc_attr.op_index = self._obj._index
pass_desc_attr.name = self._name
if self._operation_type is not None:
pass_desc_attr.operation = self._operation_type
if self._element_index is not None:
pass_desc_attr.element_index = self._element_index
def _to_op_desc_attr(self, value, op_desc_attr):
op_desc_attr.name = ""
if isinstance(value, int):
op_desc_attr.type = framework_pb2.AttrType.INT
op_desc_attr.i = value
else:
raise NotImplementedError("Unimplemented transform operation.")
def _clone_with_operation(self, type, value=None):
attr = PassDesc.AttrHelper(
self._obj, self._name, self._element_index
)
self._elements.append(attr)
if value is None:
attr._operation_type = type
return attr
operation = pass_desc_pb2.PassDesc.Operation()
operation.type = type
if isinstance(value, PassDesc.AttrHelper):
value._to_pass_desc_attr(operation.attr)
else:
self._to_op_desc_attr(value, operation.value)
attr._operation = operation
attr._operation_type = self._operation_type
return attr
def __sub__(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kSub, value
)
def __add__(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kAdd, value
)
def Mod(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kMod, value
)
def Size(self):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kSize
)
def _set_with_condition(self, type, value):
condition = pass_desc_pb2.PassDesc.AttrCondition()
self._to_pass_desc_attr(condition.attr)
condition.type = type
if isinstance(value, PassDesc.AttrHelper):
value._to_pass_desc_attr(condition.condition_attr)
else:
self._to_op_desc_attr(value, condition.condition_value)
if self._operation:
condition.operation.CopyFrom(self._operation)
self._condition = condition
def EQ(self, value):
self._set_with_condition(
pass_desc_pb2.PassDesc.ConditionType.kEQ, value
)
def MappedPattern(
self, var=None, op=None, index=0, name=None, element_index=None
):
if all([var, op]):
raise ValueError("Only mapped one of which var or op.")
def mapped_var(pattern_ops):
raise NotImplementedError(
"Mapping to variable is not implemented."
)
def mapped_op(pattern_ops):
ops = [o for o in pattern_ops if o._type == op]
if len(ops) <= index:
raise ValueError(
"Index '{}' of operator '{}' is incorrect.".format(
index, op
)
)
return PassDesc.AttrHelper(
ops[index], name, element_index=element_index
)
self._mapped = mapped_op if var is None else mapped_var
class VarHelper(paddle.static.Variable):
def __init__(self, *args, **kwargs):
block = paddle.static.default_main_program().current_block()
self._var = paddle.static.data(*args, **kwargs)
self._attrs = dict()
def __getattr__(self, name):
return getattr(self._var, name)
def Attr(self, name):
attr = self._attrs.get(name)
if attr is None:
attr = PassDesc.AttrHelper(self, name)
self._attrs[name] = attr
return attr
class OpHelper:
def __init__(self, type=None):
self._type = type
def __getattr__(self, name):
op = PassDesc.OpHelper(name)
op.Init()
return op
def __call__(self, *args, **kwargs):
if len(args) > 0:
raise ValueError(
"Each input argument needs to specify a parameter name."
)
for (in_name, in_args) in kwargs.items():
op_input = self._inputs.get(in_name)
if op_input is None:
raise ValueError(
"Operator '{}' does not have input named '{}'.".format(
self._type, in_name
)
)
if isinstance(in_args, (list, tuple)):
if len(in_args) == 0:
raise ValueError(
"Input '{}' of operator '{}' cannot be empty.".format(
in_name, self._type
)
)
else:
in_args = [in_args]
for in_arg in in_args:
if isinstance(in_arg, PassDesc.OpHelper):
op_outs = in_arg.Outputs()
if len(op_outs) != 1:
raise ValueError(
"The size of outputs of operator '{}' is not equal 1, please specify one output variable.".format(
in_arg._type
)
)
for op_out in op_outs.values():
op_input.extend(op_out)
else:
op_input.append(in_arg)
self._desc.set_input(in_name, [i.name for i in op_input])
block = paddle.static.default_main_program().current_block()
for out_name, op_output in self._outputs.items():
op_output_name = unique_name.generate(self._type)
op_output.append(block.create_var(name=op_output_name))
self._desc.set_output(out_name, [op_output_name])
return self
def Init(self):
block = paddle.static.default_main_program().current_block()
self._proto = OpProtoHolder.instance().op_proto_map.get(self._type)
if self._proto is None:
raise AttributeError(
"type object 'OpHelper' has no attribute '{}'".format(
self._type
)
)
self._index = len(block.ops)
self._desc = block.desc.append_op()
self._desc.set_type(self._type)
self._attrs = dict()
self._inputs = {i.name: list() for i in self._proto.inputs}
self._outputs = {o.name: list() for o in self._proto.outputs}
block.ops.append(self)
def Attr(self, name):
attr = self._attrs.get(name)
if attr is None:
attr = PassDesc.AttrHelper(self, name)
self._attrs[name] = attr
return attr
def SetAttr(self, name, value):
if isinstance(value, PassDesc.AttrHelper):
self.Attr(name)._mapped = value
else:
self._desc._set_attr(name, value)
def Output(self, name):
output = self._outputs.get(name)
if output is None:
raise ValueError(
"Operator '{}' does not have output named '{}'.".format(
self._type, name
)
)
return output
def Outputs(self):
return self._outputs
def SetOutputs(self, **kwargs):
for param, arg in kwargs.items():
if arg is None:
self._desc.remove_output(param)
else:
self._desc.set_output(param, [arg.name])
OP = OpHelper()
def RegisterPass(function=None, input_specs=dict()):
"""
The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function
as Pass type.
Args:
function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is dict().
Returns:
callables: Callable pair(s).
Examples:
.. code-block:: python
import paddle
from paddle.fluid.ir import RegisterPass
@RegisterPass
def multi_add_to_addn():
def pattern(x, y, z):
return paddle.add(paddle.add(x, y), z)
def replace(x, y, z):
return paddle.add_n([x, y, z])
return pattern, replace
"""
def _is_pass_pair(check_pair):
if isinstance(check_pair, (list, tuple)):
if len(check_pair) == 2:
if all(map(inspect.isfunction, check_pair)):
return True
return False
def decorated(python_func):
pass_type = python_func.__name__
signature = inspect.signature(python_func)
if len(signature.parameters) > 0:
raise NotImplementedError(
"Pass function with parameter is not supported now."
)
elif len(signature.parameters) == 0:
pass_pairs = python_func()
if _is_pass_pair(pass_pairs):
pass_pairs = [pass_pairs]
elif not all(map(_is_pass_pair, pass_pairs)):
raise ValueError(
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
core.register_pass(pass_type, helper.SerializeMultiPassDesc)
return python_func
if inspect.isfunction(function):
return decorated(function)
return decorated
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册