From 5c569aefa7cbb2a622462b31f22bddba2c41c9ae Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Wed, 27 Oct 2021 20:01:52 +0800 Subject: [PATCH] GeneratePass support attr condition and mapping (#36747) * GeneratePass support attr condition and mapping, test=develop * fix coverage, test=develop --- paddle/fluid/framework/ir/generate_pass.cc | 175 +++++++--- paddle/fluid/framework/pass_desc.proto | 62 +++- python/paddle/fluid/ir.py | 319 +++++++++++++----- .../unittests/ir/test_ir_generate_pass.py | 141 ++++++-- 4 files changed, 529 insertions(+), 168 deletions(-) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index b261cbeb08e..3f9ad5b2c52 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -19,21 +19,63 @@ namespace paddle { namespace framework { namespace ir { -void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { - const proto::BlockDesc& block = pass_desc.pattern().blocks(0); - for (const proto::VarDesc& var : block.vars()) { - PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput(); - var_pdnode->assert_is_var(); - var_pdnode->assert_more([&](Node* x) { - if (VarDesc(var).GetShape() == x->Var()->GetShape()) { - return true; +class operation_visitor : public boost::static_visitor { + public: + explicit operation_visitor(const proto::PassDesc::OperationType& type) + : type_(type) {} + + template + Attribute operator()(const T1& attr, const T2& operation) const { + PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand.")); + } + + template ::value || + std::is_floating_point::value>* = nullptr> + Attribute operator()(const T& attr, const T& operation) const { + switch (type_) { + case proto::PassDesc_OperationType_kSub: { + return attr - operation; + } + + default: + PADDLE_THROW( + platform::errors::Unimplemented("Unimplemented operation type.")); + } + } + + private: + proto::PassDesc::OperationType type_; +}; + +Attribute GetVarAttrValue(const VarDesc* desc, + const proto::PassDesc::Attr& attr) { + if ("shape" == attr.name()) { + std::vector shape = desc->GetShape(); + if (attr.has_operation()) { + if (attr.operation() == proto::PassDesc_OperationType_kSize) { + return static_cast(shape.size()); + } + } else if (attr.has_element_index()) { + int element_index = attr.element_index(); + if (attr.element_index() < 0) { + element_index += shape.size(); } - return false; - }); + if (element_index >= 0 && + static_cast(element_index) < shape.size()) { + return static_cast(shape[element_index]); + } + } else { + return shape; + } } + return boost::blank(); +} + +void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { // Traverse all operators to create subgraph. - for (int index = 0; index < block.ops_size(); ++index) { - const proto::OpDesc& op = block.ops(index); + for (int index = 0; index < pass_desc.pattern_size(); ++index) { + const proto::OpDesc& op = pass_desc.pattern(index); // Create a PDNode for current operator. Use the index as name to avoid // multiple operators with same type. Get a PDNode from pattern subgraph // through index in rewrite phase. @@ -116,6 +158,23 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { }); } } + for (const auto& condition : pass_desc.var_attr_conditions()) { + if (condition.has_condition_value()) { + PDNode* pdnode = pattern->RetrieveNode(condition.attr().var_name()); + pdnode->assert_more([&](Node* x) { + Attribute attr = GetVarAttrValue(x->Var(), condition.attr()); + switch (condition.type()) { + case proto::PassDesc_ConditionType_kEQ: { + return attr == GetAttrValue(condition.condition_value()); + } + + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unimplemented condition type.")); + } + }); + } + } } // There are some duplicate patterns. @@ -176,7 +235,33 @@ GraphPatternDetector::handle_t GetGenerateRewrite( if (IsDuplicatePattern(subgraph, graph)) { return; } - const proto::BlockDesc& block = pass_desc.replace().blocks(0); + for (const auto& condition : pass_desc.var_attr_conditions()) { + if (condition.has_condition_attr()) { + Node* node = + subgraph.at(pattern.RetrieveNode(condition.attr().var_name())); + Attribute node_attr = GetVarAttrValue(node->Var(), condition.attr()); + Attribute condition_attr; + if (condition.condition_attr().role() == + proto::PassDesc_RoleType_kVariable) { + Node* condition_node = + subgraph.at(pattern.RetrieveNode(condition.attr().var_name())); + condition_attr = GetVarAttrValue(condition_node->Var(), + condition.condition_attr()); + } else { + PADDLE_THROW( + platform::errors::Unimplemented("Unimplemented for operation.")); + } + bool check_failed = false; + if (condition.type() == proto::PassDesc_ConditionType_kEQ) { + check_failed = !(node_attr == condition_attr); + } + if (check_failed) { + VLOG(3) << "Check var [" << node->Name() << "] with attr [" + << condition.attr().name() << "] failed, skip this pattern."; + return; + } + } + } // `var_node_maps` record the mapping of variable to the pattern subgraph. std::map var_node_maps; for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { @@ -184,7 +269,8 @@ GraphPatternDetector::handle_t GetGenerateRewrite( var_node_maps.insert({var_map.replace_var(), node}); } // Traverse all operators to create subgraph. - for (const proto::OpDesc& op : block.ops()) { + for (int index = 0; index < pass_desc.replace_size(); ++index) { + const proto::OpDesc& op = pass_desc.replace(index); OpDesc op_desc; std::vector in_nodes, out_nodes; op_desc.SetType(op.type()); @@ -230,6 +316,30 @@ GraphPatternDetector::handle_t GetGenerateRewrite( for (const proto::OpDesc::Attr& attr : op.attrs()) { op_desc.SetAttr(attr.name(), GetAttrValue(attr)); } + for (const auto& attr_map : pass_desc.op_attr_maps()) { + if (attr_map.replace_attr().op_index() == index) { + Attribute attr; + if (attr_map.pattern_attr().role() == + proto::PassDesc_RoleType_kVariable) { + Node* condition_node = subgraph.at( + pattern.RetrieveNode(attr_map.pattern_attr().var_name())); + attr = + GetVarAttrValue(condition_node->Var(), attr_map.pattern_attr()); + } else { + Node* condition_node = subgraph.at(pattern.RetrieveNode( + std::to_string(attr_map.pattern_attr().op_index()))); + attr = + condition_node->Op()->GetAttr(attr_map.pattern_attr().name()); + } + if (attr_map.has_operation()) { + Attribute operation = GetAttrValue(attr_map.operation().value()); + attr = boost::apply_visitor( + operation_visitor(attr_map.operation().type()), attr, + operation); + } + op_desc.SetAttr(attr_map.replace_attr().name(), attr); + } + } // Create a Node for current operator. Node* op_node = graph->CreateOpNode(&op_desc); for (Node* node : in_nodes) { @@ -266,7 +376,7 @@ void GeneratePass::ApplyImpl(Graph* graph) const { for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { GraphPatternDetector detector; InitGeneratePattern(pass_desc, detector.mutable_pattern()); - if (pass_desc.replace().blocks(0).ops_size() == 0) { + if (pass_desc.replace_size() == 0) { detector(graph, GetGenerateDelete(detector.pattern(), pass_desc)); } else { detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); @@ -282,37 +392,6 @@ void GeneratePass::VerifyDesc() const { PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0, platform::errors::InvalidArgument( "Size of PassDesc should not be empty.")); - for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { - // Check inputs/outputs of subgraph should in `var_maps`. - std::set pattern_var_sets, replace_var_sets; - for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { - pattern_var_sets.emplace(var_map.pattern_var()); - replace_var_sets.emplace(var_map.replace_var()); - } - auto check_vars = [=](std::set* var_sets, - const proto::BlockDesc& block) { - for (const proto::OpDesc& op : block.ops()) { - for (const proto::OpDesc::Var& var : op.outputs()) { - for (const std::string& argument : var.arguments()) { - var_sets->emplace(argument); - } - } - } - for (const proto::OpDesc& op : block.ops()) { - for (const proto::OpDesc::Var& var : op.inputs()) { - for (const std::string& argument : var.arguments()) { - PADDLE_ENFORCE_NE( - var_sets->find(argument), var_sets->end(), - platform::errors::InvalidArgument( - "Subgraph of PassDesc has argument [%s] not in `var_maps`.", - argument)); - } - } - } - }; - check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0)); - check_vars(&replace_var_sets, pass_desc.replace().blocks(0)); - } } bool GeneratePass::VerifyGraph(const Graph& graph) { @@ -403,8 +482,8 @@ PassPairs::PassPairs(const SubgraphType& pattern, const SubgraphType& replace) { void PassPairs::AddPassDesc(const SubgraphType& pattern, const SubgraphType& replace) { proto::PassDesc* pass_desc = multi_pass_desc_.add_pass_descs(); - pass_desc->mutable_pattern()->CopyFrom(pattern.ProgramDesc()); - pass_desc->mutable_replace()->CopyFrom(replace.ProgramDesc()); + pass_desc->mutable_pattern()->CopyFrom(pattern.ProgramDesc().blocks(0).ops()); + pass_desc->mutable_replace()->CopyFrom(replace.ProgramDesc().blocks(0).ops()); PADDLE_ENFORCE_EQ(pattern.InputVars().size(), replace.InputVars().size(), platform::errors::InvalidArgument( "Size of lambda expression arguments is not equal " diff --git a/paddle/fluid/framework/pass_desc.proto b/paddle/fluid/framework/pass_desc.proto index c95e40a1d25..86a1effb289 100644 --- a/paddle/fluid/framework/pass_desc.proto +++ b/paddle/fluid/framework/pass_desc.proto @@ -16,20 +16,68 @@ package paddle.framework.proto; // Describes one subsitute subgraph. message PassDesc { + enum RoleType { + kVariable = 0; + kOperator = 1; + } + enum OperationType { + kAdd = 0; + kSub = 1; + kMul = 2; + kDiv = 3; + kSize = 4; + } + enum ConditionType { + kEQ = 0; + kNE = 1; + kGT = 2; + kGE = 3; + kLT = 4; + kLE = 5; + } + // Representation of attr in var or operator. + message Attr { + required RoleType role = 1; + optional string var_name = 2; + optional int32 op_index = 3; + required string name = 4; + optional string element_name = 5; + optional int32 element_index = 6; + optional OperationType operation = 7; + } + // The operation to be performed. + message Operation { + required OperationType type = 1; + optional Attr attr = 2; + optional OpDesc.Attr value = 3; + } message VarMap { required string pattern_var = 1; required string replace_var = 2; } message AttrMap { - required int32 pattern_op_idx = 1; - required int32 replace_op_idx = 2; - required string pattern_name = 3; - required string replace_name = 4; + required Attr pattern_attr = 1; + required Attr replace_attr = 2; + optional Operation operation = 3; + } + message AttrCondition { + required Attr attr = 1; + required ConditionType type = 2; + optional Attr condition_attr = 3; + optional OpDesc.Attr condition_value = 4; + optional Operation operation = 5; } - required ProgramDesc pattern = 1; - required ProgramDesc replace = 2; + // A pair of subgraphs for matching and rewriting. + repeated OpDesc pattern = 1; + repeated OpDesc replace = 2; + // Mapping vars between pattern and replace subgraphs. repeated VarMap var_maps = 3; - repeated AttrMap attr_maps = 4; + // Mapping attrs of vars and ops between pattern and replace subgraphs. + repeated AttrMap var_attr_maps = 4; + repeated AttrMap op_attr_maps = 5; + // Limit the attrs of vars and ops in pattern subgraph. + repeated AttrCondition var_attr_conditions = 6; + repeated AttrCondition op_attr_conditions = 7; } // A series of PassDesc. diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index 3c7c8879fd4..adeab721fc2 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -19,6 +19,7 @@ import paddle from . import core, unique_name from .framework import _apply_pass, OpProtoHolder +from .proto import framework_pb2 try: from .proto import pass_desc_pb2 except ModuleNotFoundError: @@ -142,28 +143,21 @@ class RegisterPassHelper(object): input_spec = self._input_specs.get(arg_name) if isinstance(input_spec, paddle.static.InputSpec): args.append( - paddle.static.data(arg_name, input_spec.shape, + PassDesc.VarHelper(arg_name, input_spec.shape, input_spec.dtype)) elif isinstance(input_spec, paddle.ParamAttr): args.append(paddle.ParamAttr(arg_name)) else: - args.append(paddle.static.data(arg_name, [-1])) + args.append(PassDesc.VarHelper(arg_name, [-1])) return args - def _prune_program_desc(self, program_desc): - block_desc = program_desc.blocks[0] - # block_desc.ClearField("vars") - for var in [ - var for var in block_desc.vars - if var.name not in self._input_specs - ]: - block_desc.vars.remove(var) - for op_desc in block_desc.ops: + def _prune_program_desc(self, ops): + for op_desc in ops: default_attrs = core.get_op_attrs_default_value( paddle.compat.to_bytes(op_desc.type)) remove_attrs = list() for attr in op_desc.attrs: - # attr must not in + # attr must not in if attr.name not in [ "op_namescope", "op_callstack", "op_device" ]: @@ -179,33 +173,69 @@ class RegisterPassHelper(object): for attr in remove_attrs: op_desc.attrs.remove(attr) - def _func_to_program_desc(self, func, program_desc, is_replace=False): + def _func_to_program_desc(self, func, ops): vars = list() program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.static.program_guard(program, startup_program): args = self._get_args_from_func(func) - for arg in args: - vars.append(arg.name) + vars.extend(args) outs = func(*args) if not isinstance(outs, (list, tuple)): outs = [outs] for out in outs: if isinstance(out, PassDesc.OpHelper): - for out in out.Outputs().values(): - vars.extend(out) - elif isinstance(out, paddle.fluid.framework.Variable): - vars.append(out.name) - program_desc.ParseFromString(program.desc.serialize_to_string()) - self._prune_program_desc(program_desc) - if is_replace: - attrs = list() - for op in program.current_block().ops: - if not isinstance(op, PassDesc.OpHelper): - continue - attrs.extend(op._attrs.values()) - return vars, attrs - return vars + op_outs = out.Outputs() + if len(op_outs) != 1: + raise ValueError( + "Operator '{}' has multiple outputs, please specify one output variable.". + format(out._type)) + for op_out in op_outs.values(): + vars.extend(op_out) + else: + vars.append(out) + block_desc = program.current_block().desc + for i in range(block_desc.op_size()): + ops.add().ParseFromString(block_desc.op(i).serialize_to_string()) + self._prune_program_desc(ops) + return vars, program.current_block().ops + + def _convert_vars_to_pass_desc(self, patterns, replaces, desc): + for (pattern, replace) in zip(patterns, replaces): + # Convert maps of inputs and outputs. + var_map = desc.var_maps.add() + var_map.pattern_var = pattern.name + var_map.replace_var = replace.name + conditions = desc.var_attr_conditions + # Convert shape condition. + if pattern.name in self._input_specs: + condition = conditions.add() + pattern.Attr("shape")._to_pass_desc_attr(condition.attr) + condition.condition_value.name = "" + condition.condition_value.type = framework_pb2.AttrType.LONGS + condition.condition_value.longs.extend(pattern.shape) + condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ + # Convert attr conditions. + if PassDesc.VarHelper == pattern.__class__: + for attr in pattern._attrs.values(): + if attr._condition is not None: + conditions.append(attr._condition) + conditions.extend( + [e._condition for e in attr._elements if e._condition]) + + def _convert_ops_to_pass_desc(self, patterns, replaces, desc): + for replace in replaces: + if isinstance(replace, PassDesc.OpHelper): + for attr in replace._attrs.values(): + # Convert attr maps. + mapped = attr._mapped + if inspect.isfunction(mapped): + mapped = mapped(patterns) + attr_map = desc.op_attr_maps.add() + mapped._to_pass_desc_attr(attr_map.pattern_attr) + attr._to_pass_desc_attr(attr_map.replace_attr) + if mapped._operation is not None: + attr_map.operation.CopyFrom(mapped._operation) def SerializeMultiPassDesc(self): switch_static_mode = paddle.in_dynamic_mode() @@ -213,30 +243,18 @@ class RegisterPassHelper(object): paddle.enable_static() multi_pass_desc = pass_desc_pb2.MultiPassDesc() multi_pass_desc.pass_type = self._pass_type + # Traverse all pass pairs and convert them to PassDesc data. + # Here need to add cache in the future. for (pattern, replace) in self._pass_pairs: pass_desc = multi_pass_desc.pass_descs.add() - pattern_vars = self._func_to_program_desc(pattern, - pass_desc.pattern) - replace_vars, attrs = self._func_to_program_desc( - replace, pass_desc.replace, is_replace=True) - for (pattern_var, replace_var) in zip(pattern_vars, replace_vars): - var_map = pass_desc.var_maps.add() - var_map.pattern_var = pattern_var - var_map.replace_var = replace_var - pattern_op_idxs = dict() - for (idx, op) in enumerate(pass_desc.pattern.blocks[0].ops): - op_idxs = pattern_op_idxs.get(op.type) - if op_idxs: - op_idxs.append(idx) - else: - pattern_op_idxs[op.type] = [idx] - for attr in attrs: - attr_map = pass_desc.attr_maps.add() - attr_map.pattern_op_idx = pattern_op_idxs[ - attr._pattern_op_type][attr._pattern_op_idx] - attr_map.replace_op_idx = attr._replace_op_idx - attr_map.pattern_name = attr._pattern_name - attr_map.replace_name = attr._replace_name + # Convert ProgramDescs of pattern and replace subgraphs. + pattern_vars, pattern_ops = self._func_to_program_desc( + pattern, pass_desc.pattern) + replace_vars, replace_ops = self._func_to_program_desc( + replace, pass_desc.replace) + self._convert_vars_to_pass_desc(pattern_vars, replace_vars, + pass_desc) + self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc) if switch_static_mode: paddle.disable_static() return multi_pass_desc.SerializeToString() @@ -244,18 +262,119 @@ class RegisterPassHelper(object): class PassDesc(object): class AttrHelper(object): - def __init__(self, name, replace_op_idx): - self._pattern_op_type = None - self._pattern_op_idx = -1 - self._replace_op_idx = replace_op_idx - self._pattern_name = name - self._replace_name = name - - def ReusePattern(self, op, index=0, name=None): - if name: - self._pattern_name = name - self._pattern_op_type = op - self._pattern_op_idx = index + def __init__(self, obj, name, element_index=None): + self._obj = obj + self._name = name + self._operation_type = None + self._element_index = element_index + self._elements = list() + self._operation = None + self._condition = None + self._mapped = None + + def __getitem__(self, index): + element = PassDesc.AttrHelper( + self._obj, self._name, element_index=index) + self._elements.append(element) + return element + + def _to_pass_desc_attr(self, pass_desc_attr): + if isinstance(self._obj, PassDesc.VarHelper): + pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable + pass_desc_attr.var_name = self._obj.name + else: + pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator + pass_desc_attr.op_index = self._obj._index + pass_desc_attr.name = self._name + if self._operation_type is not None: + pass_desc_attr.operation = self._operation_type + if self._element_index is not None: + pass_desc_attr.element_index = self._element_index + + def _to_op_desc_attr(self, value, op_desc_attr): + op_desc_attr.name = "" + if isinstance(value, int): + op_desc_attr.type = framework_pb2.AttrType.INT + op_desc_attr.i = value + else: + raise NotImplementedError("Unimplemented transform operation.") + + def _clone_with_operation(self, type, value=None): + attr = PassDesc.AttrHelper(self._obj, self._name, + self._element_index) + self._elements.append(attr) + if value is None: + attr._operation_type = type + return attr + operation = pass_desc_pb2.PassDesc.Operation() + operation.type = type + if isinstance(value, PassDesc.AttrHelper): + value._to_pass_desc_attr(operation.attr) + else: + self._to_op_desc_attr(value, operation.value) + attr._operation = operation + attr._operation_type = self._operation_type + return attr + + def __sub__(self, value): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kSub, value) + + def __add__(self, value): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kAdd, value) + + def Size(self): + return self._clone_with_operation( + pass_desc_pb2.PassDesc.OperationType.kSize) + + def _set_with_condition(self, type, value): + condition = pass_desc_pb2.PassDesc.AttrCondition() + self._to_pass_desc_attr(condition.attr) + condition.type = type + if isinstance(value, PassDesc.AttrHelper): + value._to_pass_desc_attr(condition.condition_attr) + else: + self._to_op_desc_attr(value, condition.condition_value) + self._condition = condition + + def EQ(self, value): + self._set_with_condition(pass_desc_pb2.PassDesc.ConditionType.kEQ, + value) + + def MappedPattern(self, var=None, op=None, index=0, name=None): + if all([var, op]): + raise ValueError("Only mapped one of which var or op.") + + def mapped_var(pattern_ops): + raise NotImplementedError( + "Mapping to variable is not implemented.") + + def mapped_op(pattern_ops): + ops = [o for o in pattern_ops if o._type == op] + if len(ops) <= index: + raise ValueError( + "Index '{}' of operator '{}' is incorrect.".format( + index, op)) + return PassDesc.AttrHelper(ops[index], name) + + self._mapped = mapped_op if var is None else mapped_var + + class VarHelper(paddle.static.Variable): + def __init__(self, *args, **kwargs): + block = paddle.static.default_main_program().current_block() + self._var = paddle.static.data(*args, **kwargs) + self._attrs = dict() + + def __getattr__(self, name): + return getattr(self._var, name) + + def Attr(self, name): + attr = self._attrs.get(name) + if attr is None: + attr = PassDesc.AttrHelper(self, name) + self._attrs[name] = attr + return attr class OpHelper(object): def __init__(self, type=None): @@ -267,8 +386,15 @@ class PassDesc(object): return op def __call__(self, *args, **kwargs): + if len(args) > 0: + raise ValueError( + "Each input argument needs to specify a parameter name.") for (in_name, in_args) in kwargs.items(): - in_arg_names = list() + op_input = self._inputs.get(in_name) + if op_input is None: + raise ValueError( + "Operator '{}' does not have input named '{}'.".format( + self._type, in_name)) if isinstance(in_args, (list, tuple)): if len(in_args) == 0: raise ValueError( @@ -278,52 +404,61 @@ class PassDesc(object): in_args = [in_args] for in_arg in in_args: if isinstance(in_arg, PassDesc.OpHelper): - in_arg_names.extend(in_arg.Output()) + op_outs = in_arg.Outputs() + if len(op_outs) != 1: + raise ValueError( + "The size of outputs of operator '{}' is not equal 1, please specify one output variable.". + format(in_arg._type)) + for op_out in op_outs.values(): + op_input.extend(op_out) else: - in_arg_names.append(in_arg.name) - self._op_desc.set_input(in_name, in_arg_names) + op_input.append(in_arg) + self._desc.set_input(in_name, [i.name for i in op_input]) + block = paddle.static.default_main_program().current_block() + for out_name, op_output in self._outputs.items(): + op_output_name = unique_name.generate(self._type) + op_output.append(block.create_var(name=op_output_name)) + self._desc.set_output(out_name, [op_output_name]) return self def Init(self): block = paddle.static.default_main_program().current_block() - self._attrs = dict() - self._op_idx = len(block.ops) - self._op_desc = block.desc.append_op() - self._op_desc.set_type(self._type) - self._op_proto = OpProtoHolder.instance().op_proto_map.get( - self._type) - if self._op_proto is None: + self._proto = OpProtoHolder.instance().op_proto_map.get(self._type) + if self._proto is None: raise AttributeError( "type object 'OpHelper' has no attribute '{}'".format( self._type)) + self._index = len(block.ops) + self._desc = block.desc.append_op() + self._desc.set_type(self._type) + self._attrs = dict() + self._inputs = {i.name: list() for i in self._proto.inputs} + self._outputs = {o.name: list() for o in self._proto.outputs} block.ops.append(self) def Attr(self, name): attr = self._attrs.get(name) - if attr: - return attr - attr = PassDesc.AttrHelper(name, self._op_idx) - self._attrs[name] = attr + if attr is None: + attr = PassDesc.AttrHelper(self, name) + self._attrs[name] = attr return attr def SetAttr(self, name, value): - self._op_desc._set_attr(name, value) + if isinstance(value, PassDesc.AttrHelper): + self.Attr(name)._mapped = value + else: + self._desc._set_attr(name, value) - def Output(self, name=None): - if name: - return self.Outputs()[name] - return list(self.Outputs().values())[0] + def Output(self, name): + output = self._outputs.get(name) + if output is None: + raise ValueError( + "Operator '{}' does not have output named '{}'.".format( + self._type, name)) + return output def Outputs(self): - outputs = self._op_desc.outputs() - if len(outputs) > 0: - return outputs - block = paddle.static.default_main_program().current_block() - for output_proto in self._op_proto.outputs: - name = unique_name.generate(self._type) - block.create_var(name=name) - self._op_desc.set_output(output_proto.name, [name]) - return self._op_desc.outputs() + return self._outputs OP = OpHelper() diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py index 61bd554ad26..2a7c2768e27 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py @@ -33,12 +33,12 @@ def generate_fc_fuse(): return ewadd def replace(x, w, b): - fc = ir.PassDesc.OP.fc - fc.Attr("in_num_col_dims").ReusePattern( - "mul", name="x_num_col_dims") + fc = ir.PassDesc.OP.fc(Input=x, W=w, Bias=b) + fc.Attr("in_num_col_dims").MappedPattern( + op="mul", name="x_num_col_dims") if with_relu: fc.SetAttr("activation_type", "relu") - return fc(Input=x, W=w, Bias=b) + return fc return pattern, replace @@ -96,8 +96,8 @@ def generate_combine_mul_v1(): @ir.RegisterPass def generate_combine_mul_v2(): def pattern(x, y1, y2): - mul1 = ir.PassDesc.OP.matmul_v2(x, y1) - mul2 = ir.PassDesc.OP.matmul_v2(x, y2) + mul1 = ir.PassDesc.OP.matmul_v2(X=x, Y=y1) + mul2 = ir.PassDesc.OP.matmul_v2(X=x, Y=y2) return mul1, mul2 def replace(x, y1, y2): @@ -126,11 +126,71 @@ def generate_simplify_inference_v2(): op1 = ir.PassDesc.OP.transpose2 op2 = ir.PassDesc.OP.transpose2 # op2.Attr("axis").EQ(op1.Attr("axis")) - return op2(X=op1(X=x)) + return op2(X=op1(X=x).Output("Out")).Output("Out") return pattern, lambda x: x +@ir.RegisterPass +def generate_layer_norm_fuse_pass(): + def pattern(x, gamma, beta): + gamma.Attr("shape").Size().EQ(1) + gamma.Attr("shape")[0].EQ(x.Attr("shape")[-1]) + beta.Attr("shape").EQ(gamma.Attr("shape")) + + mean1 = ir.PassDesc.OP.reduce_mean(X=x) + mean1.SetAttr("dim", [-1]) + mean1.SetAttr("reduce_all", False) + mean1.SetAttr("keep_dim", True) + ewsub = ir.PassDesc.OP.elementwise_sub(X=x, Y=mean1) + pow = ir.PassDesc.OP.pow(X=ewsub) + pow.SetAttr("factor", 2.0) + mean2 = ir.PassDesc.OP.reduce_mean(X=pow) + mean2.SetAttr("dim", [-1]) + mean2.SetAttr("reduce_all", False) + mean2.SetAttr("keep_dim", True) + scale = ir.PassDesc.OP.scale(X=mean2) + sqrt = ir.PassDesc.OP.sqrt(X=scale) + ewdiv = ir.PassDesc.OP.elementwise_sub(X=ewsub, Y=sqrt) + ewmul = ir.PassDesc.OP.elementwise_mul(X=ewdiv, Y=gamma) + return ir.PassDesc.OP.elementwise_add(X=ewmul, Y=beta) + + def replace(x, gamma, beta): + layer_norm = ir.PassDesc.OP.layer_norm(X=x, Scale=gamma, Bias=beta) + layer_norm.SetAttr("begin_norm_axis", x.Attr("shape").Size() - 1) + layer_norm.Attr("epsilon").MappedPattern(op="scale", name="bias") + layer_norm.SetAttr("is_test", True) + return layer_norm.Output("Y") + + return pattern, replace + + +@ir.RegisterPass +def unimplemented_operand_exception(): + def pattern(x, y): + return ir.PassDesc.OP.elementwise_add(X=x, Y=y) + + def replace(x, y): + out = ir.PassDesc.OP.elementwise_add(X=x, Y=y) + out.SetAttr("axis", x.Attr("shape") - 1) + return out + + return pattern, replace + + +@ir.RegisterPass +def unimplemented_operation_exception(): + def pattern(x, y): + return ir.PassDesc.OP.elementwise_add(X=x, Y=y) + + def replace(x, y): + out = ir.PassDesc.OP.elementwise_add(X=x, Y=y) + out.SetAttr("axis", x.Attr("shape").Size() + 1) + return out + + return pattern, replace + + def get_multi_pass_desc_from_str(s): multi_pass_desc = ir.pass_desc_pb2.MultiPassDesc() multi_pass_desc.ParseFromString(s) @@ -151,12 +211,24 @@ class TestGeneratePass(unittest.TestCase): def test_has_attr(self): self.assertFalse(hasattr(ir.PassDesc.OP, '__name__')) + def test_exception(self): + paddle.enable_static() + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data("x", [10, 10], "float32") + y = paddle.static.data("y", [10, 10], "float32") + paddle.add(x, y) + graph = core.Graph(program.desc) + with self.assertRaises(NotImplementedError): + core.get_pass("unimplemented_operand_exception").apply(graph) + with self.assertRaises(NotImplementedError): + core.get_pass("unimplemented_operation_exception").apply(graph) + def test_generate_fc_fuse(self): def _check_fc_fuse_pass(pass_desc, with_relu): - pattern_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.pattern.blocks[0].ops) - replace_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.replace.blocks[0].ops) + pattern_op_dicts = self.convert_ops_to_op_dicts(pass_desc.pattern) + replace_op_dicts = self.convert_ops_to_op_dicts(pass_desc.replace) self.assertEqual(len(pattern_op_dicts.get("mul", [])), 1) self.assertEqual( len(pattern_op_dicts.get("elementwise_add", [])), 1) @@ -166,10 +238,9 @@ class TestGeneratePass(unittest.TestCase): else: pattern_op_num = 2 # ewadd, mul self.assertEqual(len(pass_desc.var_maps), 4) - self.assertEqual( - len(pass_desc.pattern.blocks[0].ops), pattern_op_num) - self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1) - self.assertEqual(len(pass_desc.attr_maps), 1) + self.assertEqual(len(pass_desc.pattern), pattern_op_num) + self.assertEqual(len(pass_desc.replace), 1) + self.assertEqual(len(pass_desc.op_attr_maps), 1) helper = ir.RegisterPassHelper(generate_fc_fuse()) s = helper.SerializeMultiPassDesc() @@ -253,12 +324,10 @@ class TestGeneratePass(unittest.TestCase): self.assertEqual(len(multi_pass_desc.pass_descs), 1) pass_desc = multi_pass_desc.pass_descs[0] self.assertEqual(len(pass_desc.var_maps), 5) - self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) - self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4) - pattern_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.pattern.blocks[0].ops) - replace_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.replace.blocks[0].ops) + self.assertEqual(len(pass_desc.pattern), 2) + self.assertEqual(len(pass_desc.replace), 4) + pattern_op_dicts = self.convert_ops_to_op_dicts(pass_desc.pattern) + replace_op_dicts = self.convert_ops_to_op_dicts(pass_desc.replace) self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2) self.assertEqual(len(replace_op_dicts.get("concat", [])), 1) self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) @@ -292,3 +361,33 @@ class TestGeneratePass(unittest.TestCase): def test_generate_simplify_inference(self): self.check_generate_simplify_inference("generate_simplify_inference_v1") self.check_generate_simplify_inference("generate_simplify_inference_v2") + + def test_generate_layer_norm_fuse_pass(self): + paddle.enable_static() + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data("x", [3, 64, 120], "float32") + gamma = paddle.static.create_parameter( + shape=[120], dtype="float32", is_bias=True) + beta = paddle.static.create_parameter( + shape=[120], dtype="float32", is_bias=True) + + x_sub_mean = x - paddle.mean(x, axis=-1, keepdim=True) + std_dev = paddle.mean(x_sub_mean.pow(2), axis=-1, keepdim=True) + lnorm = x_sub_mean - (std_dev + 1e-5).sqrt() + out = lnorm * gamma + beta + graph = core.Graph(program.desc) + before_node_nums = len(graph.nodes()) + core.get_pass("generate_layer_norm_fuse_pass").apply(graph) + after_node_nums = len(graph.nodes()) + self.assertEqual(after_node_nums, before_node_nums - 14) + after_program = paddle.fluid.framework.IrGraph(graph).to_program() + executor = paddle.static.Executor(paddle.CPUPlace()) + executor.run(startup_program) + feed = {"x": np.random.random([3, 64, 120]).astype("float32")} + before_out = executor.run(program, feed=feed, fetch_list=[out.name]) + after_out = executor.run(after_program, + feed=feed, + fetch_list=[out.name]) + self.assertTrue(np.allclose(before_out, after_out)) -- GitLab