未验证 提交 bab39eb2 编写于 作者: W wuhuanzhou 提交者: GitHub

Python support register pass via PassDesc (#35602)

PR主要功能:针对fusion等子图替换场景,支持Python侧开发并注册Pass。

背景
Pass是指输入一个深度学习计算图Graph,依照一定条件进行修改,输出修改后的Graph的过程;
当前PaddlePadle框架编写Pass代码存在以下问题:
用户需要手写Graph的条件匹配、在Graph上的修改代码;
对Graph操作需要深入底层框架代码,了解Graph的结构,并且知道相关Pass写法;
我们提出了针对fusion等子图替换类Pass的优化方案以支持用户在Python侧开发注册Pass,提升二次开发体验:
用户只需要输入匹配和替换的子图描述,由深度学习框架编写的代码来生成匹配和替换的逻辑,不需要用户对Graph进行匹配和替换操作;
API级别的替换,用户可以通过Paddle的Python API构造子图,从而不需要知道Graph的结构,也能写Paddle的Graph Pass代码
上级 29ef7cc9
...@@ -223,6 +223,7 @@ if(WITH_PYTHON) ...@@ -223,6 +223,7 @@ if(WITH_PYTHON)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto) py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto)
py_proto_compile(distributed_strategy_py_proto SRCS distributed_strategy.proto) py_proto_compile(distributed_strategy_py_proto SRCS distributed_strategy.proto)
py_proto_compile(pass_desc_py_proto SRCS pass_desc.proto)
#Generate an empty \ #Generate an empty \
#__init__.py to make framework_py_proto as a valid python module. #__init__.py to make framework_py_proto as a valid python module.
add_custom_target(fleet_proto_init ALL add_custom_target(fleet_proto_init ALL
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
import "framework.proto";
package paddle.framework.proto;
// Describes one subsitute subgraph.
message PassDesc {
message VarMap {
required string pattern_var = 1;
required string replace_var = 2;
}
message AttrMap {
required int32 pattern_op_idx = 1;
required int32 replace_op_idx = 2;
required string pattern_name = 3;
required string replace_name = 4;
}
required ProgramDesc pattern = 1;
required ProgramDesc replace = 2;
repeated VarMap var_maps = 3;
repeated AttrMap attr_maps = 4;
}
// A series of PassDesc.
message MultiPassDesc {
optional string pass_type = 1;
repeated PassDesc pass_descs = 2;
}
...@@ -12,10 +12,19 @@ ...@@ -12,10 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import copy import copy
from . import core import inspect
from .framework import _apply_pass from os import path
import paddle
from . import core, unique_name
from .framework import _apply_pass, OpProtoHolder
try:
from .proto import pass_desc_pb2
except ModuleNotFoundError:
import sys
sys.path.append(path.join(path.dirname(__file__), 'proto'))
from .proto import pass_desc_pb2
def get_data_vars(program): def get_data_vars(program):
...@@ -115,3 +124,234 @@ def apply_build_strategy(main_program, startup_program, build_strategy, ...@@ -115,3 +124,234 @@ def apply_build_strategy(main_program, startup_program, build_strategy,
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy._clear_finalized() build_strategy._clear_finalized()
return build_strategy return build_strategy
class RegisterPassHelper(object):
def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type
self._pass_pairs = pass_pairs
if isinstance(input_specs, dict):
self._input_specs = input_specs
def _get_args_from_func(self, func):
args = list()
arg_specs = inspect.getfullargspec(func)
for arg_name in arg_specs.args:
input_spec = self._input_specs.get(arg_name)
if isinstance(input_spec, paddle.static.InputSpec):
args.append(
paddle.static.data(arg_name, input_spec.shape,
input_spec.dtype))
elif isinstance(input_spec, paddle.ParamAttr):
args.append(paddle.ParamAttr(arg_name))
else:
args.append(paddle.static.data(arg_name, [-1]))
return args
def _func_to_program_desc(self, func, program_desc, is_replace=False):
vars = list()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
args = self._get_args_from_func(func)
for arg in args:
vars.append(arg.name)
outs = func(*args)
if not isinstance(outs, (list, tuple)):
outs = [outs]
for out in outs:
if isinstance(out, PassDesc.OpHelper):
for out in out.Outputs().values():
vars.extend(out)
elif isinstance(out, paddle.fluid.framework.Variable):
vars.append(out.name)
program_desc.ParseFromString(program.desc.serialize_to_string())
if is_replace:
attrs = list()
for op in program.current_block().ops:
if not isinstance(op, PassDesc.OpHelper):
continue
attrs.extend(op._attrs.values())
return vars, attrs
return vars
def SerializeMultiPassDesc(self):
switch_static_mode = paddle.in_dynamic_mode()
if switch_static_mode:
paddle.enable_static()
multi_pass_desc = pass_desc_pb2.MultiPassDesc()
multi_pass_desc.pass_type = self._pass_type
for (pattern, replace) in self._pass_pairs:
pass_desc = multi_pass_desc.pass_descs.add()
pattern_vars = self._func_to_program_desc(pattern,
pass_desc.pattern)
replace_vars, attrs = self._func_to_program_desc(
replace, pass_desc.replace, is_replace=True)
for (pattern_var, replace_var) in zip(pattern_vars, replace_vars):
var_map = pass_desc.var_maps.add()
var_map.pattern_var = pattern_var
var_map.replace_var = replace_var
pattern_op_idxs = dict()
for (idx, op) in enumerate(pass_desc.pattern.blocks[0].ops):
op_idxs = pattern_op_idxs.get(op.type)
if op_idxs:
op_idxs.append(idx)
else:
pattern_op_idxs[op.type] = [idx]
for attr in attrs:
attr_map = pass_desc.attr_maps.add()
attr_map.pattern_op_idx = pattern_op_idxs[
attr._pattern_op_type][attr._pattern_op_idx]
attr_map.replace_op_idx = attr._replace_op_idx
attr_map.pattern_name = attr._pattern_name
attr_map.replace_name = attr._replace_name
if switch_static_mode:
paddle.disable_static()
return multi_pass_desc.SerializeToString()
class PassDesc(object):
class AttrHelper(object):
def __init__(self, name, replace_op_idx):
self._pattern_op_type = None
self._pattern_op_idx = -1
self._replace_op_idx = replace_op_idx
self._pattern_name = name
self._replace_name = name
def ReusePattern(self, op, index=0, name=None):
if name:
self._pattern_name = name
self._pattern_op_type = op
self._pattern_op_idx = index
class OpHelper(object):
def __init__(self, type=None):
self._type = type
def __getattr__(self, name):
if self._type is not None:
raise AttributeError(
"type object 'OpHelper' has no attribute '{}'".format(name))
op = PassDesc.OpHelper(name)
op.Init()
return op
def __call__(self, *args, **kwargs):
for (in_name, in_args) in kwargs.items():
in_arg_names = list()
if isinstance(in_args, (list, tuple)):
if len(in_args) == 0:
raise ValueError(
"Input '{}' of operator '{}' cannot be empty.".
format(in_name, self._type))
else:
in_args = [in_args]
for in_arg in in_args:
if isinstance(in_arg, PassDesc.OpHelper):
in_arg_names.extend(in_arg.Output())
else:
in_arg_names.append(in_arg.name)
self._op_desc.set_input(in_name, in_arg_names)
return self
def Init(self):
block = paddle.static.default_main_program().current_block()
self._attrs = dict()
self._op_idx = len(block.ops)
self._op_desc = block.desc.append_op()
self._op_desc.set_type(self._type)
self._op_proto = OpProtoHolder.instance().get_op_proto(self._type)
block.ops.append(self)
def Attr(self, name):
attr = self._attrs.get(name)
if attr:
return attr
attr = PassDesc.AttrHelper(name, self._op_idx)
self._attrs[name] = attr
return attr
def SetAttr(self, name, value):
self._op_desc._set_attr(name, value)
def Output(self, name=None):
if name:
return self.Outputs()[name]
return list(self.Outputs().values())[0]
def Outputs(self):
outputs = self._op_desc.outputs()
if len(outputs) > 0:
return outputs
block = paddle.static.default_main_program().current_block()
for output_proto in self._op_proto.outputs:
name = unique_name.generate(self._type)
block.create_var(name=name)
self._op_desc.set_output(output_proto.name, [name])
return self._op_desc.outputs()
OP = OpHelper()
def RegisterPass(function=None, input_specs=None):
"""
The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function
as Pass type.
Args:
function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is None.
Returns:
callables: Callable pair(s).
Examples:
.. code-block:: python
import paddle
from paddle.fluid.ir import RegisterPass
@RegisterPass
def multi_add_to_addn():
def pattern(x, y, z):
return paddle.add(paddle.add(x, y), z)
def replace(x, y, z):
return paddle.add_n([x, y, z])
return pattern, replace
"""
def _is_pass_pair(check_pair):
if isinstance(check_pair, (list, tuple)):
if len(check_pair) == 2:
if all(map(inspect.isfunction, check_pair)):
return True
return False
def decorated(python_func):
pass_type = python_func.__name__
signature = inspect.signature(python_func)
if len(signature.parameters) > 0:
raise NotImplementedError(
"Pass function with parameter is not supported now.")
elif len(signature.parameters) == 0:
pass_pairs = python_func()
if _is_pass_pair(pass_pairs):
pass_pairs = [pass_pairs]
elif not all(map(_is_pass_pair, pass_pairs)):
raise ValueError(
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
return python_func
if inspect.isfunction(function):
return decorated(function)
return decorated
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.static import InputSpec
from paddle.fluid import ir
import numpy as np
# 0: ewadd(X=mul(X=x, Y=w), Y=b) => fc(Input=x, W=w, Bias=b)
# 1: relu(X=ewadd(X=mul(X=x, Y=w), Y=b)) => fc(Input=x, W=w, Bias=b)
@ir.RegisterPass
def generate_fc_fuse():
def create_pass_pair(with_relu):
def pattern(x, w, b):
mul = ir.PassDesc.OP.mul(X=x, Y=w)
ewadd = ir.PassDesc.OP.elementwise_add(X=mul, Y=b)
if with_relu:
return ir.PassDesc.OP.relu(X=ewadd)
else:
return ewadd
def replace(x, w, b):
fc = ir.PassDesc.OP.fc
fc.Attr("in_num_col_dims").ReusePattern(
"mul", name="x_num_col_dims")
if with_relu:
fc.SetAttr("activation_type", "relu")
return fc(Input=x, W=w, Bias=b)
return pattern, replace
return list(map(create_pass_pair, [True, False]))
# add(X=add(x, y), Y=z)z => add_n(X=[x, y, z])
@ir.RegisterPass
def generate_add_n():
def pattern(x, y, z):
return paddle.add(paddle.add(x, y), z)
def replace(x, y, z):
return paddle.add_n([x, y, z])
return pattern, replace
# mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2)))
@ir.RegisterPass(input_specs={
'x': InputSpec([1, 1]),
'y1': InputSpec([1, 1]),
'y2': InputSpec([1, 1])
})
def generate_combine_mul_v1():
def pattern(x, y1, y2):
mul1 = paddle.matmul(x, y1)
mul2 = paddle.matmul(x, y2)
return mul1, mul2
def replace(x, y1, y2):
concat_out = paddle.concat([y1, y2], axis=-1)
mul_out = paddle.matmul(x, concat_out)
out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[1])
out2 = paddle.slice(mul_out, axes=[1], starts=[1], ends=[2])
return out1, out2
return pattern, replace
@ir.RegisterPass
def generate_combine_mul_v2():
def pattern(x, y1, y2):
mul1 = ir.PassDesc.OP.matmul_v2(x, y1)
mul2 = ir.PassDesc.OP.matmul_v2(x, y2)
return mul1, mul2
def replace(x, y1, y2):
concat = ir.PassDesc.OP.concat(X=[y1, y2])
matmul = ir.PassDesc.OP.matmul_v2(X=x, Y=concat)
out1 = ir.PassDesc.OP.slice(Input=matmul)
out2 = ir.PassDesc.OP.slice(Input=matmul)
return out1, out2
return pattern, replace
# reshape(reshape(x)) => x
@ir.RegisterPass(input_specs={'x': InputSpec([-1, 16, 16, 16])})
def generate_simplify_inference():
def pattern(x):
transpose = paddle.transpose(x, [0, 3, 1, 2])
return paddle.transpose(transpose, [0, 3, 1, 2])
return pattern, lambda x: x
def get_multi_pass_desc_from_str(s):
multi_pass_desc = ir.pass_desc_pb2.MultiPassDesc()
multi_pass_desc.ParseFromString(s)
return multi_pass_desc
class TestGeneratePass(unittest.TestCase):
def convert_ops_to_op_dicts(self, ops):
op_dicts = dict()
for op in ops:
op_list = op_dicts.get(op.type)
if isinstance(op_list, list):
op_list.append(op)
else:
op_dicts[op.type] = [op]
return op_dicts
def test_generate_fc_fuse(self):
def _check_fc_fuse_pass(pass_desc, with_relu):
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("mul", [])), 1)
self.assertEqual(
len(pattern_op_dicts.get("elementwise_add", [])), 1)
if with_relu:
self.assertEqual(len(pattern_op_dicts.get("relu", [])), 1)
pattern_op_num = 3 # relu, ewadd, mul
else:
pattern_op_num = 2 # ewadd, mul
self.assertEqual(len(pass_desc.var_maps), 4)
self.assertEqual(
len(pass_desc.pattern.blocks[0].ops), pattern_op_num)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1)
self.assertEqual(len(pass_desc.attr_maps), 1)
helper = ir.RegisterPassHelper(generate_fc_fuse())
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 2)
_check_fc_fuse_pass(multi_pass_desc.pass_descs[0], True)
_check_fc_fuse_pass(multi_pass_desc.pass_descs[1], False)
def test_generate_add_n(self):
helper = ir.RegisterPassHelper([generate_add_n()])
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 4)
self.assertEqual(len(pass_desc.attr_maps), 0)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 1)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("elementwise_add", [])), 2)
self.assertEqual(len(replace_op_dicts.get("sum", [])), 1)
def test_generate_combine_mul_v1(self):
input_specs = {
'x': InputSpec([1, 1]),
'y1': InputSpec([1, 1]),
'y2': InputSpec([1, 1])
}
helper = ir.RegisterPassHelper(
[generate_combine_mul_v1()], input_specs=input_specs)
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 5)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2)
self.assertEqual(len(replace_op_dicts.get("concat", [])), 1)
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
self.assertEqual(len(replace_op_dicts.get("slice", [])), 2)
def test_generate_combine_mul_v2(self):
helper = ir.RegisterPassHelper([generate_combine_mul_v2()])
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 5)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2)
self.assertEqual(len(replace_op_dicts.get("concat", [])), 1)
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
self.assertEqual(len(replace_op_dicts.get("slice", [])), 2)
def test_generate_simplify_inference(self):
input_specs = {'x': InputSpec([-1, 16, 16, 16])}
helper = ir.RegisterPassHelper(
[generate_simplify_inference()], input_specs=input_specs)
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 2)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 0)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("transpose2", [])), 2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册