未验证 提交 5c569aef 编写于 作者: W wuhuanzhou 提交者: GitHub

GeneratePass support attr condition and mapping (#36747)

* GeneratePass support attr condition and mapping, test=develop

* fix coverage, test=develop
上级 d65f41db
...@@ -19,21 +19,63 @@ namespace paddle { ...@@ -19,21 +19,63 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { class operation_visitor : public boost::static_visitor<Attribute> {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0); public:
for (const proto::VarDesc& var : block.vars()) { explicit operation_visitor(const proto::PassDesc::OperationType& type)
PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput(); : type_(type) {}
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) { template <typename T1, typename T2>
if (VarDesc(var).GetShape() == x->Var()->GetShape()) { Attribute operator()(const T1& attr, const T2& operation) const {
return true; PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand."));
} }
return false;
}); template <typename T,
std::enable_if_t<std::is_integral<T>::value ||
std::is_floating_point<T>::value>* = nullptr>
Attribute operator()(const T& attr, const T& operation) const {
switch (type_) {
case proto::PassDesc_OperationType_kSub: {
return attr - operation;
}
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unimplemented operation type."));
}
}
private:
proto::PassDesc::OperationType type_;
};
Attribute GetVarAttrValue(const VarDesc* desc,
const proto::PassDesc::Attr& attr) {
if ("shape" == attr.name()) {
std::vector<int64_t> shape = desc->GetShape();
if (attr.has_operation()) {
if (attr.operation() == proto::PassDesc_OperationType_kSize) {
return static_cast<int>(shape.size());
}
} else if (attr.has_element_index()) {
int element_index = attr.element_index();
if (attr.element_index() < 0) {
element_index += shape.size();
}
if (element_index >= 0 &&
static_cast<size_t>(element_index) < shape.size()) {
return static_cast<int>(shape[element_index]);
} }
} else {
return shape;
}
}
return boost::blank();
}
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
// Traverse all operators to create subgraph. // Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) { for (int index = 0; index < pass_desc.pattern_size(); ++index) {
const proto::OpDesc& op = block.ops(index); const proto::OpDesc& op = pass_desc.pattern(index);
// Create a PDNode for current operator. Use the index as name to avoid // Create a PDNode for current operator. Use the index as name to avoid
// multiple operators with same type. Get a PDNode from pattern subgraph // multiple operators with same type. Get a PDNode from pattern subgraph
// through index in rewrite phase. // through index in rewrite phase.
...@@ -116,6 +158,23 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { ...@@ -116,6 +158,23 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
}); });
} }
} }
for (const auto& condition : pass_desc.var_attr_conditions()) {
if (condition.has_condition_value()) {
PDNode* pdnode = pattern->RetrieveNode(condition.attr().var_name());
pdnode->assert_more([&](Node* x) {
Attribute attr = GetVarAttrValue(x->Var(), condition.attr());
switch (condition.type()) {
case proto::PassDesc_ConditionType_kEQ: {
return attr == GetAttrValue(condition.condition_value());
}
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unimplemented condition type."));
}
});
}
}
} }
// There are some duplicate patterns. // There are some duplicate patterns.
...@@ -176,7 +235,33 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -176,7 +235,33 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
if (IsDuplicatePattern(subgraph, graph)) { if (IsDuplicatePattern(subgraph, graph)) {
return; return;
} }
const proto::BlockDesc& block = pass_desc.replace().blocks(0); for (const auto& condition : pass_desc.var_attr_conditions()) {
if (condition.has_condition_attr()) {
Node* node =
subgraph.at(pattern.RetrieveNode(condition.attr().var_name()));
Attribute node_attr = GetVarAttrValue(node->Var(), condition.attr());
Attribute condition_attr;
if (condition.condition_attr().role() ==
proto::PassDesc_RoleType_kVariable) {
Node* condition_node =
subgraph.at(pattern.RetrieveNode(condition.attr().var_name()));
condition_attr = GetVarAttrValue(condition_node->Var(),
condition.condition_attr());
} else {
PADDLE_THROW(
platform::errors::Unimplemented("Unimplemented for operation."));
}
bool check_failed = false;
if (condition.type() == proto::PassDesc_ConditionType_kEQ) {
check_failed = !(node_attr == condition_attr);
}
if (check_failed) {
VLOG(3) << "Check var [" << node->Name() << "] with attr ["
<< condition.attr().name() << "] failed, skip this pattern.";
return;
}
}
}
// `var_node_maps` record the mapping of variable to the pattern subgraph. // `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps; std::map<std::string, Node*> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
...@@ -184,7 +269,8 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -184,7 +269,8 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
var_node_maps.insert({var_map.replace_var(), node}); var_node_maps.insert({var_map.replace_var(), node});
} }
// Traverse all operators to create subgraph. // Traverse all operators to create subgraph.
for (const proto::OpDesc& op : block.ops()) { for (int index = 0; index < pass_desc.replace_size(); ++index) {
const proto::OpDesc& op = pass_desc.replace(index);
OpDesc op_desc; OpDesc op_desc;
std::vector<Node *> in_nodes, out_nodes; std::vector<Node *> in_nodes, out_nodes;
op_desc.SetType(op.type()); op_desc.SetType(op.type());
...@@ -230,6 +316,30 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -230,6 +316,30 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
for (const proto::OpDesc::Attr& attr : op.attrs()) { for (const proto::OpDesc::Attr& attr : op.attrs()) {
op_desc.SetAttr(attr.name(), GetAttrValue(attr)); op_desc.SetAttr(attr.name(), GetAttrValue(attr));
} }
for (const auto& attr_map : pass_desc.op_attr_maps()) {
if (attr_map.replace_attr().op_index() == index) {
Attribute attr;
if (attr_map.pattern_attr().role() ==
proto::PassDesc_RoleType_kVariable) {
Node* condition_node = subgraph.at(
pattern.RetrieveNode(attr_map.pattern_attr().var_name()));
attr =
GetVarAttrValue(condition_node->Var(), attr_map.pattern_attr());
} else {
Node* condition_node = subgraph.at(pattern.RetrieveNode(
std::to_string(attr_map.pattern_attr().op_index())));
attr =
condition_node->Op()->GetAttr(attr_map.pattern_attr().name());
}
if (attr_map.has_operation()) {
Attribute operation = GetAttrValue(attr_map.operation().value());
attr = boost::apply_visitor(
operation_visitor(attr_map.operation().type()), attr,
operation);
}
op_desc.SetAttr(attr_map.replace_attr().name(), attr);
}
}
// Create a Node for current operator. // Create a Node for current operator.
Node* op_node = graph->CreateOpNode(&op_desc); Node* op_node = graph->CreateOpNode(&op_desc);
for (Node* node : in_nodes) { for (Node* node : in_nodes) {
...@@ -266,7 +376,7 @@ void GeneratePass::ApplyImpl(Graph* graph) const { ...@@ -266,7 +376,7 @@ void GeneratePass::ApplyImpl(Graph* graph) const {
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
GraphPatternDetector detector; GraphPatternDetector detector;
InitGeneratePattern(pass_desc, detector.mutable_pattern()); InitGeneratePattern(pass_desc, detector.mutable_pattern());
if (pass_desc.replace().blocks(0).ops_size() == 0) { if (pass_desc.replace_size() == 0) {
detector(graph, GetGenerateDelete(detector.pattern(), pass_desc)); detector(graph, GetGenerateDelete(detector.pattern(), pass_desc));
} else { } else {
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
...@@ -282,37 +392,6 @@ void GeneratePass::VerifyDesc() const { ...@@ -282,37 +392,6 @@ void GeneratePass::VerifyDesc() const {
PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0, PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Size of PassDesc should not be empty.")); "Size of PassDesc should not be empty."));
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
// Check inputs/outputs of subgraph should in `var_maps`.
std::set<std::string> pattern_var_sets, replace_var_sets;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
pattern_var_sets.emplace(var_map.pattern_var());
replace_var_sets.emplace(var_map.replace_var());
}
auto check_vars = [=](std::set<std::string>* var_sets,
const proto::BlockDesc& block) {
for (const proto::OpDesc& op : block.ops()) {
for (const proto::OpDesc::Var& var : op.outputs()) {
for (const std::string& argument : var.arguments()) {
var_sets->emplace(argument);
}
}
}
for (const proto::OpDesc& op : block.ops()) {
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
PADDLE_ENFORCE_NE(
var_sets->find(argument), var_sets->end(),
platform::errors::InvalidArgument(
"Subgraph of PassDesc has argument [%s] not in `var_maps`.",
argument));
}
}
}
};
check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0));
check_vars(&replace_var_sets, pass_desc.replace().blocks(0));
}
} }
bool GeneratePass::VerifyGraph(const Graph& graph) { bool GeneratePass::VerifyGraph(const Graph& graph) {
...@@ -403,8 +482,8 @@ PassPairs::PassPairs(const SubgraphType& pattern, const SubgraphType& replace) { ...@@ -403,8 +482,8 @@ PassPairs::PassPairs(const SubgraphType& pattern, const SubgraphType& replace) {
void PassPairs::AddPassDesc(const SubgraphType& pattern, void PassPairs::AddPassDesc(const SubgraphType& pattern,
const SubgraphType& replace) { const SubgraphType& replace) {
proto::PassDesc* pass_desc = multi_pass_desc_.add_pass_descs(); proto::PassDesc* pass_desc = multi_pass_desc_.add_pass_descs();
pass_desc->mutable_pattern()->CopyFrom(pattern.ProgramDesc()); pass_desc->mutable_pattern()->CopyFrom(pattern.ProgramDesc().blocks(0).ops());
pass_desc->mutable_replace()->CopyFrom(replace.ProgramDesc()); pass_desc->mutable_replace()->CopyFrom(replace.ProgramDesc().blocks(0).ops());
PADDLE_ENFORCE_EQ(pattern.InputVars().size(), replace.InputVars().size(), PADDLE_ENFORCE_EQ(pattern.InputVars().size(), replace.InputVars().size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Size of lambda expression arguments is not equal " "Size of lambda expression arguments is not equal "
......
...@@ -16,20 +16,68 @@ package paddle.framework.proto; ...@@ -16,20 +16,68 @@ package paddle.framework.proto;
// Describes one subsitute subgraph. // Describes one subsitute subgraph.
message PassDesc { message PassDesc {
enum RoleType {
kVariable = 0;
kOperator = 1;
}
enum OperationType {
kAdd = 0;
kSub = 1;
kMul = 2;
kDiv = 3;
kSize = 4;
}
enum ConditionType {
kEQ = 0;
kNE = 1;
kGT = 2;
kGE = 3;
kLT = 4;
kLE = 5;
}
// Representation of attr in var or operator.
message Attr {
required RoleType role = 1;
optional string var_name = 2;
optional int32 op_index = 3;
required string name = 4;
optional string element_name = 5;
optional int32 element_index = 6;
optional OperationType operation = 7;
}
// The operation to be performed.
message Operation {
required OperationType type = 1;
optional Attr attr = 2;
optional OpDesc.Attr value = 3;
}
message VarMap { message VarMap {
required string pattern_var = 1; required string pattern_var = 1;
required string replace_var = 2; required string replace_var = 2;
} }
message AttrMap { message AttrMap {
required int32 pattern_op_idx = 1; required Attr pattern_attr = 1;
required int32 replace_op_idx = 2; required Attr replace_attr = 2;
required string pattern_name = 3; optional Operation operation = 3;
required string replace_name = 4; }
message AttrCondition {
required Attr attr = 1;
required ConditionType type = 2;
optional Attr condition_attr = 3;
optional OpDesc.Attr condition_value = 4;
optional Operation operation = 5;
} }
required ProgramDesc pattern = 1; // A pair of subgraphs for matching and rewriting.
required ProgramDesc replace = 2; repeated OpDesc pattern = 1;
repeated OpDesc replace = 2;
// Mapping vars between pattern and replace subgraphs.
repeated VarMap var_maps = 3; repeated VarMap var_maps = 3;
repeated AttrMap attr_maps = 4; // Mapping attrs of vars and ops between pattern and replace subgraphs.
repeated AttrMap var_attr_maps = 4;
repeated AttrMap op_attr_maps = 5;
// Limit the attrs of vars and ops in pattern subgraph.
repeated AttrCondition var_attr_conditions = 6;
repeated AttrCondition op_attr_conditions = 7;
} }
// A series of PassDesc. // A series of PassDesc.
......
...@@ -19,6 +19,7 @@ import paddle ...@@ -19,6 +19,7 @@ import paddle
from . import core, unique_name from . import core, unique_name
from .framework import _apply_pass, OpProtoHolder from .framework import _apply_pass, OpProtoHolder
from .proto import framework_pb2
try: try:
from .proto import pass_desc_pb2 from .proto import pass_desc_pb2
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -142,23 +143,16 @@ class RegisterPassHelper(object): ...@@ -142,23 +143,16 @@ class RegisterPassHelper(object):
input_spec = self._input_specs.get(arg_name) input_spec = self._input_specs.get(arg_name)
if isinstance(input_spec, paddle.static.InputSpec): if isinstance(input_spec, paddle.static.InputSpec):
args.append( args.append(
paddle.static.data(arg_name, input_spec.shape, PassDesc.VarHelper(arg_name, input_spec.shape,
input_spec.dtype)) input_spec.dtype))
elif isinstance(input_spec, paddle.ParamAttr): elif isinstance(input_spec, paddle.ParamAttr):
args.append(paddle.ParamAttr(arg_name)) args.append(paddle.ParamAttr(arg_name))
else: else:
args.append(paddle.static.data(arg_name, [-1])) args.append(PassDesc.VarHelper(arg_name, [-1]))
return args return args
def _prune_program_desc(self, program_desc): def _prune_program_desc(self, ops):
block_desc = program_desc.blocks[0] for op_desc in ops:
# block_desc.ClearField("vars")
for var in [
var for var in block_desc.vars
if var.name not in self._input_specs
]:
block_desc.vars.remove(var)
for op_desc in block_desc.ops:
default_attrs = core.get_op_attrs_default_value( default_attrs = core.get_op_attrs_default_value(
paddle.compat.to_bytes(op_desc.type)) paddle.compat.to_bytes(op_desc.type))
remove_attrs = list() remove_attrs = list()
...@@ -179,33 +173,69 @@ class RegisterPassHelper(object): ...@@ -179,33 +173,69 @@ class RegisterPassHelper(object):
for attr in remove_attrs: for attr in remove_attrs:
op_desc.attrs.remove(attr) op_desc.attrs.remove(attr)
def _func_to_program_desc(self, func, program_desc, is_replace=False): def _func_to_program_desc(self, func, ops):
vars = list() vars = list()
program = paddle.static.Program() program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program): with paddle.static.program_guard(program, startup_program):
args = self._get_args_from_func(func) args = self._get_args_from_func(func)
for arg in args: vars.extend(args)
vars.append(arg.name)
outs = func(*args) outs = func(*args)
if not isinstance(outs, (list, tuple)): if not isinstance(outs, (list, tuple)):
outs = [outs] outs = [outs]
for out in outs: for out in outs:
if isinstance(out, PassDesc.OpHelper): if isinstance(out, PassDesc.OpHelper):
for out in out.Outputs().values(): op_outs = out.Outputs()
vars.extend(out) if len(op_outs) != 1:
elif isinstance(out, paddle.fluid.framework.Variable): raise ValueError(
vars.append(out.name) "Operator '{}' has multiple outputs, please specify one output variable.".
program_desc.ParseFromString(program.desc.serialize_to_string()) format(out._type))
self._prune_program_desc(program_desc) for op_out in op_outs.values():
if is_replace: vars.extend(op_out)
attrs = list() else:
for op in program.current_block().ops: vars.append(out)
if not isinstance(op, PassDesc.OpHelper): block_desc = program.current_block().desc
continue for i in range(block_desc.op_size()):
attrs.extend(op._attrs.values()) ops.add().ParseFromString(block_desc.op(i).serialize_to_string())
return vars, attrs self._prune_program_desc(ops)
return vars return vars, program.current_block().ops
def _convert_vars_to_pass_desc(self, patterns, replaces, desc):
for (pattern, replace) in zip(patterns, replaces):
# Convert maps of inputs and outputs.
var_map = desc.var_maps.add()
var_map.pattern_var = pattern.name
var_map.replace_var = replace.name
conditions = desc.var_attr_conditions
# Convert shape condition.
if pattern.name in self._input_specs:
condition = conditions.add()
pattern.Attr("shape")._to_pass_desc_attr(condition.attr)
condition.condition_value.name = ""
condition.condition_value.type = framework_pb2.AttrType.LONGS
condition.condition_value.longs.extend(pattern.shape)
condition.type = pass_desc_pb2.PassDesc.ConditionType.kEQ
# Convert attr conditions.
if PassDesc.VarHelper == pattern.__class__:
for attr in pattern._attrs.values():
if attr._condition is not None:
conditions.append(attr._condition)
conditions.extend(
[e._condition for e in attr._elements if e._condition])
def _convert_ops_to_pass_desc(self, patterns, replaces, desc):
for replace in replaces:
if isinstance(replace, PassDesc.OpHelper):
for attr in replace._attrs.values():
# Convert attr maps.
mapped = attr._mapped
if inspect.isfunction(mapped):
mapped = mapped(patterns)
attr_map = desc.op_attr_maps.add()
mapped._to_pass_desc_attr(attr_map.pattern_attr)
attr._to_pass_desc_attr(attr_map.replace_attr)
if mapped._operation is not None:
attr_map.operation.CopyFrom(mapped._operation)
def SerializeMultiPassDesc(self): def SerializeMultiPassDesc(self):
switch_static_mode = paddle.in_dynamic_mode() switch_static_mode = paddle.in_dynamic_mode()
...@@ -213,30 +243,18 @@ class RegisterPassHelper(object): ...@@ -213,30 +243,18 @@ class RegisterPassHelper(object):
paddle.enable_static() paddle.enable_static()
multi_pass_desc = pass_desc_pb2.MultiPassDesc() multi_pass_desc = pass_desc_pb2.MultiPassDesc()
multi_pass_desc.pass_type = self._pass_type multi_pass_desc.pass_type = self._pass_type
# Traverse all pass pairs and convert them to PassDesc data.
# Here need to add cache in the future.
for (pattern, replace) in self._pass_pairs: for (pattern, replace) in self._pass_pairs:
pass_desc = multi_pass_desc.pass_descs.add() pass_desc = multi_pass_desc.pass_descs.add()
pattern_vars = self._func_to_program_desc(pattern, # Convert ProgramDescs of pattern and replace subgraphs.
pass_desc.pattern) pattern_vars, pattern_ops = self._func_to_program_desc(
replace_vars, attrs = self._func_to_program_desc( pattern, pass_desc.pattern)
replace, pass_desc.replace, is_replace=True) replace_vars, replace_ops = self._func_to_program_desc(
for (pattern_var, replace_var) in zip(pattern_vars, replace_vars): replace, pass_desc.replace)
var_map = pass_desc.var_maps.add() self._convert_vars_to_pass_desc(pattern_vars, replace_vars,
var_map.pattern_var = pattern_var pass_desc)
var_map.replace_var = replace_var self._convert_ops_to_pass_desc(pattern_ops, replace_ops, pass_desc)
pattern_op_idxs = dict()
for (idx, op) in enumerate(pass_desc.pattern.blocks[0].ops):
op_idxs = pattern_op_idxs.get(op.type)
if op_idxs:
op_idxs.append(idx)
else:
pattern_op_idxs[op.type] = [idx]
for attr in attrs:
attr_map = pass_desc.attr_maps.add()
attr_map.pattern_op_idx = pattern_op_idxs[
attr._pattern_op_type][attr._pattern_op_idx]
attr_map.replace_op_idx = attr._replace_op_idx
attr_map.pattern_name = attr._pattern_name
attr_map.replace_name = attr._replace_name
if switch_static_mode: if switch_static_mode:
paddle.disable_static() paddle.disable_static()
return multi_pass_desc.SerializeToString() return multi_pass_desc.SerializeToString()
...@@ -244,18 +262,119 @@ class RegisterPassHelper(object): ...@@ -244,18 +262,119 @@ class RegisterPassHelper(object):
class PassDesc(object): class PassDesc(object):
class AttrHelper(object): class AttrHelper(object):
def __init__(self, name, replace_op_idx): def __init__(self, obj, name, element_index=None):
self._pattern_op_type = None self._obj = obj
self._pattern_op_idx = -1 self._name = name
self._replace_op_idx = replace_op_idx self._operation_type = None
self._pattern_name = name self._element_index = element_index
self._replace_name = name self._elements = list()
self._operation = None
def ReusePattern(self, op, index=0, name=None): self._condition = None
if name: self._mapped = None
self._pattern_name = name
self._pattern_op_type = op def __getitem__(self, index):
self._pattern_op_idx = index element = PassDesc.AttrHelper(
self._obj, self._name, element_index=index)
self._elements.append(element)
return element
def _to_pass_desc_attr(self, pass_desc_attr):
if isinstance(self._obj, PassDesc.VarHelper):
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kVariable
pass_desc_attr.var_name = self._obj.name
else:
pass_desc_attr.role = pass_desc_pb2.PassDesc.RoleType.kOperator
pass_desc_attr.op_index = self._obj._index
pass_desc_attr.name = self._name
if self._operation_type is not None:
pass_desc_attr.operation = self._operation_type
if self._element_index is not None:
pass_desc_attr.element_index = self._element_index
def _to_op_desc_attr(self, value, op_desc_attr):
op_desc_attr.name = ""
if isinstance(value, int):
op_desc_attr.type = framework_pb2.AttrType.INT
op_desc_attr.i = value
else:
raise NotImplementedError("Unimplemented transform operation.")
def _clone_with_operation(self, type, value=None):
attr = PassDesc.AttrHelper(self._obj, self._name,
self._element_index)
self._elements.append(attr)
if value is None:
attr._operation_type = type
return attr
operation = pass_desc_pb2.PassDesc.Operation()
operation.type = type
if isinstance(value, PassDesc.AttrHelper):
value._to_pass_desc_attr(operation.attr)
else:
self._to_op_desc_attr(value, operation.value)
attr._operation = operation
attr._operation_type = self._operation_type
return attr
def __sub__(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kSub, value)
def __add__(self, value):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kAdd, value)
def Size(self):
return self._clone_with_operation(
pass_desc_pb2.PassDesc.OperationType.kSize)
def _set_with_condition(self, type, value):
condition = pass_desc_pb2.PassDesc.AttrCondition()
self._to_pass_desc_attr(condition.attr)
condition.type = type
if isinstance(value, PassDesc.AttrHelper):
value._to_pass_desc_attr(condition.condition_attr)
else:
self._to_op_desc_attr(value, condition.condition_value)
self._condition = condition
def EQ(self, value):
self._set_with_condition(pass_desc_pb2.PassDesc.ConditionType.kEQ,
value)
def MappedPattern(self, var=None, op=None, index=0, name=None):
if all([var, op]):
raise ValueError("Only mapped one of which var or op.")
def mapped_var(pattern_ops):
raise NotImplementedError(
"Mapping to variable is not implemented.")
def mapped_op(pattern_ops):
ops = [o for o in pattern_ops if o._type == op]
if len(ops) <= index:
raise ValueError(
"Index '{}' of operator '{}' is incorrect.".format(
index, op))
return PassDesc.AttrHelper(ops[index], name)
self._mapped = mapped_op if var is None else mapped_var
class VarHelper(paddle.static.Variable):
def __init__(self, *args, **kwargs):
block = paddle.static.default_main_program().current_block()
self._var = paddle.static.data(*args, **kwargs)
self._attrs = dict()
def __getattr__(self, name):
return getattr(self._var, name)
def Attr(self, name):
attr = self._attrs.get(name)
if attr is None:
attr = PassDesc.AttrHelper(self, name)
self._attrs[name] = attr
return attr
class OpHelper(object): class OpHelper(object):
def __init__(self, type=None): def __init__(self, type=None):
...@@ -267,8 +386,15 @@ class PassDesc(object): ...@@ -267,8 +386,15 @@ class PassDesc(object):
return op return op
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if len(args) > 0:
raise ValueError(
"Each input argument needs to specify a parameter name.")
for (in_name, in_args) in kwargs.items(): for (in_name, in_args) in kwargs.items():
in_arg_names = list() op_input = self._inputs.get(in_name)
if op_input is None:
raise ValueError(
"Operator '{}' does not have input named '{}'.".format(
self._type, in_name))
if isinstance(in_args, (list, tuple)): if isinstance(in_args, (list, tuple)):
if len(in_args) == 0: if len(in_args) == 0:
raise ValueError( raise ValueError(
...@@ -278,52 +404,61 @@ class PassDesc(object): ...@@ -278,52 +404,61 @@ class PassDesc(object):
in_args = [in_args] in_args = [in_args]
for in_arg in in_args: for in_arg in in_args:
if isinstance(in_arg, PassDesc.OpHelper): if isinstance(in_arg, PassDesc.OpHelper):
in_arg_names.extend(in_arg.Output()) op_outs = in_arg.Outputs()
if len(op_outs) != 1:
raise ValueError(
"The size of outputs of operator '{}' is not equal 1, please specify one output variable.".
format(in_arg._type))
for op_out in op_outs.values():
op_input.extend(op_out)
else: else:
in_arg_names.append(in_arg.name) op_input.append(in_arg)
self._op_desc.set_input(in_name, in_arg_names) self._desc.set_input(in_name, [i.name for i in op_input])
block = paddle.static.default_main_program().current_block()
for out_name, op_output in self._outputs.items():
op_output_name = unique_name.generate(self._type)
op_output.append(block.create_var(name=op_output_name))
self._desc.set_output(out_name, [op_output_name])
return self return self
def Init(self): def Init(self):
block = paddle.static.default_main_program().current_block() block = paddle.static.default_main_program().current_block()
self._attrs = dict() self._proto = OpProtoHolder.instance().op_proto_map.get(self._type)
self._op_idx = len(block.ops) if self._proto is None:
self._op_desc = block.desc.append_op()
self._op_desc.set_type(self._type)
self._op_proto = OpProtoHolder.instance().op_proto_map.get(
self._type)
if self._op_proto is None:
raise AttributeError( raise AttributeError(
"type object 'OpHelper' has no attribute '{}'".format( "type object 'OpHelper' has no attribute '{}'".format(
self._type)) self._type))
self._index = len(block.ops)
self._desc = block.desc.append_op()
self._desc.set_type(self._type)
self._attrs = dict()
self._inputs = {i.name: list() for i in self._proto.inputs}
self._outputs = {o.name: list() for o in self._proto.outputs}
block.ops.append(self) block.ops.append(self)
def Attr(self, name): def Attr(self, name):
attr = self._attrs.get(name) attr = self._attrs.get(name)
if attr: if attr is None:
return attr attr = PassDesc.AttrHelper(self, name)
attr = PassDesc.AttrHelper(name, self._op_idx)
self._attrs[name] = attr self._attrs[name] = attr
return attr return attr
def SetAttr(self, name, value): def SetAttr(self, name, value):
self._op_desc._set_attr(name, value) if isinstance(value, PassDesc.AttrHelper):
self.Attr(name)._mapped = value
else:
self._desc._set_attr(name, value)
def Output(self, name=None): def Output(self, name):
if name: output = self._outputs.get(name)
return self.Outputs()[name] if output is None:
return list(self.Outputs().values())[0] raise ValueError(
"Operator '{}' does not have output named '{}'.".format(
self._type, name))
return output
def Outputs(self): def Outputs(self):
outputs = self._op_desc.outputs() return self._outputs
if len(outputs) > 0:
return outputs
block = paddle.static.default_main_program().current_block()
for output_proto in self._op_proto.outputs:
name = unique_name.generate(self._type)
block.create_var(name=name)
self._op_desc.set_output(output_proto.name, [name])
return self._op_desc.outputs()
OP = OpHelper() OP = OpHelper()
......
...@@ -33,12 +33,12 @@ def generate_fc_fuse(): ...@@ -33,12 +33,12 @@ def generate_fc_fuse():
return ewadd return ewadd
def replace(x, w, b): def replace(x, w, b):
fc = ir.PassDesc.OP.fc fc = ir.PassDesc.OP.fc(Input=x, W=w, Bias=b)
fc.Attr("in_num_col_dims").ReusePattern( fc.Attr("in_num_col_dims").MappedPattern(
"mul", name="x_num_col_dims") op="mul", name="x_num_col_dims")
if with_relu: if with_relu:
fc.SetAttr("activation_type", "relu") fc.SetAttr("activation_type", "relu")
return fc(Input=x, W=w, Bias=b) return fc
return pattern, replace return pattern, replace
...@@ -96,8 +96,8 @@ def generate_combine_mul_v1(): ...@@ -96,8 +96,8 @@ def generate_combine_mul_v1():
@ir.RegisterPass @ir.RegisterPass
def generate_combine_mul_v2(): def generate_combine_mul_v2():
def pattern(x, y1, y2): def pattern(x, y1, y2):
mul1 = ir.PassDesc.OP.matmul_v2(x, y1) mul1 = ir.PassDesc.OP.matmul_v2(X=x, Y=y1)
mul2 = ir.PassDesc.OP.matmul_v2(x, y2) mul2 = ir.PassDesc.OP.matmul_v2(X=x, Y=y2)
return mul1, mul2 return mul1, mul2
def replace(x, y1, y2): def replace(x, y1, y2):
...@@ -126,11 +126,71 @@ def generate_simplify_inference_v2(): ...@@ -126,11 +126,71 @@ def generate_simplify_inference_v2():
op1 = ir.PassDesc.OP.transpose2 op1 = ir.PassDesc.OP.transpose2
op2 = ir.PassDesc.OP.transpose2 op2 = ir.PassDesc.OP.transpose2
# op2.Attr("axis").EQ(op1.Attr("axis")) # op2.Attr("axis").EQ(op1.Attr("axis"))
return op2(X=op1(X=x)) return op2(X=op1(X=x).Output("Out")).Output("Out")
return pattern, lambda x: x return pattern, lambda x: x
@ir.RegisterPass
def generate_layer_norm_fuse_pass():
def pattern(x, gamma, beta):
gamma.Attr("shape").Size().EQ(1)
gamma.Attr("shape")[0].EQ(x.Attr("shape")[-1])
beta.Attr("shape").EQ(gamma.Attr("shape"))
mean1 = ir.PassDesc.OP.reduce_mean(X=x)
mean1.SetAttr("dim", [-1])
mean1.SetAttr("reduce_all", False)
mean1.SetAttr("keep_dim", True)
ewsub = ir.PassDesc.OP.elementwise_sub(X=x, Y=mean1)
pow = ir.PassDesc.OP.pow(X=ewsub)
pow.SetAttr("factor", 2.0)
mean2 = ir.PassDesc.OP.reduce_mean(X=pow)
mean2.SetAttr("dim", [-1])
mean2.SetAttr("reduce_all", False)
mean2.SetAttr("keep_dim", True)
scale = ir.PassDesc.OP.scale(X=mean2)
sqrt = ir.PassDesc.OP.sqrt(X=scale)
ewdiv = ir.PassDesc.OP.elementwise_sub(X=ewsub, Y=sqrt)
ewmul = ir.PassDesc.OP.elementwise_mul(X=ewdiv, Y=gamma)
return ir.PassDesc.OP.elementwise_add(X=ewmul, Y=beta)
def replace(x, gamma, beta):
layer_norm = ir.PassDesc.OP.layer_norm(X=x, Scale=gamma, Bias=beta)
layer_norm.SetAttr("begin_norm_axis", x.Attr("shape").Size() - 1)
layer_norm.Attr("epsilon").MappedPattern(op="scale", name="bias")
layer_norm.SetAttr("is_test", True)
return layer_norm.Output("Y")
return pattern, replace
@ir.RegisterPass
def unimplemented_operand_exception():
def pattern(x, y):
return ir.PassDesc.OP.elementwise_add(X=x, Y=y)
def replace(x, y):
out = ir.PassDesc.OP.elementwise_add(X=x, Y=y)
out.SetAttr("axis", x.Attr("shape") - 1)
return out
return pattern, replace
@ir.RegisterPass
def unimplemented_operation_exception():
def pattern(x, y):
return ir.PassDesc.OP.elementwise_add(X=x, Y=y)
def replace(x, y):
out = ir.PassDesc.OP.elementwise_add(X=x, Y=y)
out.SetAttr("axis", x.Attr("shape").Size() + 1)
return out
return pattern, replace
def get_multi_pass_desc_from_str(s): def get_multi_pass_desc_from_str(s):
multi_pass_desc = ir.pass_desc_pb2.MultiPassDesc() multi_pass_desc = ir.pass_desc_pb2.MultiPassDesc()
multi_pass_desc.ParseFromString(s) multi_pass_desc.ParseFromString(s)
...@@ -151,12 +211,24 @@ class TestGeneratePass(unittest.TestCase): ...@@ -151,12 +211,24 @@ class TestGeneratePass(unittest.TestCase):
def test_has_attr(self): def test_has_attr(self):
self.assertFalse(hasattr(ir.PassDesc.OP, '__name__')) self.assertFalse(hasattr(ir.PassDesc.OP, '__name__'))
def test_exception(self):
paddle.enable_static()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [10, 10], "float32")
y = paddle.static.data("y", [10, 10], "float32")
paddle.add(x, y)
graph = core.Graph(program.desc)
with self.assertRaises(NotImplementedError):
core.get_pass("unimplemented_operand_exception").apply(graph)
with self.assertRaises(NotImplementedError):
core.get_pass("unimplemented_operation_exception").apply(graph)
def test_generate_fc_fuse(self): def test_generate_fc_fuse(self):
def _check_fc_fuse_pass(pass_desc, with_relu): def _check_fc_fuse_pass(pass_desc, with_relu):
pattern_op_dicts = self.convert_ops_to_op_dicts( pattern_op_dicts = self.convert_ops_to_op_dicts(pass_desc.pattern)
pass_desc.pattern.blocks[0].ops) replace_op_dicts = self.convert_ops_to_op_dicts(pass_desc.replace)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("mul", [])), 1) self.assertEqual(len(pattern_op_dicts.get("mul", [])), 1)
self.assertEqual( self.assertEqual(
len(pattern_op_dicts.get("elementwise_add", [])), 1) len(pattern_op_dicts.get("elementwise_add", [])), 1)
...@@ -166,10 +238,9 @@ class TestGeneratePass(unittest.TestCase): ...@@ -166,10 +238,9 @@ class TestGeneratePass(unittest.TestCase):
else: else:
pattern_op_num = 2 # ewadd, mul pattern_op_num = 2 # ewadd, mul
self.assertEqual(len(pass_desc.var_maps), 4) self.assertEqual(len(pass_desc.var_maps), 4)
self.assertEqual( self.assertEqual(len(pass_desc.pattern), pattern_op_num)
len(pass_desc.pattern.blocks[0].ops), pattern_op_num) self.assertEqual(len(pass_desc.replace), 1)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1) self.assertEqual(len(pass_desc.op_attr_maps), 1)
self.assertEqual(len(pass_desc.attr_maps), 1)
helper = ir.RegisterPassHelper(generate_fc_fuse()) helper = ir.RegisterPassHelper(generate_fc_fuse())
s = helper.SerializeMultiPassDesc() s = helper.SerializeMultiPassDesc()
...@@ -253,12 +324,10 @@ class TestGeneratePass(unittest.TestCase): ...@@ -253,12 +324,10 @@ class TestGeneratePass(unittest.TestCase):
self.assertEqual(len(multi_pass_desc.pass_descs), 1) self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0] pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 5) self.assertEqual(len(pass_desc.var_maps), 5)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) self.assertEqual(len(pass_desc.pattern), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4) self.assertEqual(len(pass_desc.replace), 4)
pattern_op_dicts = self.convert_ops_to_op_dicts( pattern_op_dicts = self.convert_ops_to_op_dicts(pass_desc.pattern)
pass_desc.pattern.blocks[0].ops) replace_op_dicts = self.convert_ops_to_op_dicts(pass_desc.replace)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2) self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2)
self.assertEqual(len(replace_op_dicts.get("concat", [])), 1) self.assertEqual(len(replace_op_dicts.get("concat", [])), 1)
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
...@@ -292,3 +361,33 @@ class TestGeneratePass(unittest.TestCase): ...@@ -292,3 +361,33 @@ class TestGeneratePass(unittest.TestCase):
def test_generate_simplify_inference(self): def test_generate_simplify_inference(self):
self.check_generate_simplify_inference("generate_simplify_inference_v1") self.check_generate_simplify_inference("generate_simplify_inference_v1")
self.check_generate_simplify_inference("generate_simplify_inference_v2") self.check_generate_simplify_inference("generate_simplify_inference_v2")
def test_generate_layer_norm_fuse_pass(self):
paddle.enable_static()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [3, 64, 120], "float32")
gamma = paddle.static.create_parameter(
shape=[120], dtype="float32", is_bias=True)
beta = paddle.static.create_parameter(
shape=[120], dtype="float32", is_bias=True)
x_sub_mean = x - paddle.mean(x, axis=-1, keepdim=True)
std_dev = paddle.mean(x_sub_mean.pow(2), axis=-1, keepdim=True)
lnorm = x_sub_mean - (std_dev + 1e-5).sqrt()
out = lnorm * gamma + beta
graph = core.Graph(program.desc)
before_node_nums = len(graph.nodes())
core.get_pass("generate_layer_norm_fuse_pass").apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums - 14)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {"x": np.random.random([3, 64, 120]).astype("float32")}
before_out = executor.run(program, feed=feed, fetch_list=[out.name])
after_out = executor.run(after_program,
feed=feed,
fetch_list=[out.name])
self.assertTrue(np.allclose(before_out, after_out))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册