未验证 提交 523916fa 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim][NewIR] Support forward decomposition in new IR (#55480)

* Support Prim Forward in New IR

* Fix test case

* polish code

* fix code

* polish format

* format code
上级 a5ba0b65
......@@ -73,6 +73,7 @@ import paddle.regularizer # noqa: F401
import paddle.incubate # noqa: F401
import paddle.autograd # noqa: F401
import paddle.device # noqa: F401
import paddle.decomposition # noqa: F401
import paddle.jit # noqa: F401
import paddle.amp # noqa: F401
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .decomp import decompose # noqa: F401
from . import rules # noqa: F401
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import typing
from paddle import ir
from paddle.fluid.libpaddle.ir import Block, Program
from paddle.framework import core
from . import register
def _build_tensor_tuple(xs):
if isinstance(xs, ir.OpResult):
return (xs,)
elif isinstance(xs, typing.Sequence):
return tuple(xs)
return TypeError(f"Type {type(xs)} is not supported")
def _prepare_python_api_arguments(op):
"""
For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs.
Args:
op (Operator): The target operator.
"""
op_inputs = [x.source() for x in op.operands()]
op_attrs_dict = op.attrs()
op_attrs_name = op.get_attr_names()
op_attrs = [op_attrs_dict[x] for x in op_attrs_name]
api_arguments = op_inputs + op_attrs
return tuple(api_arguments)
def _check_op_results(op_name, orig_outs, new_outs):
"""
Check whether the replaced outputs are consistent with origin outputs.
Args:
op_name (str): The name of operator.
orig_outs (tuple): The outputs of original operator.
new_outs (tuple): The outputs of replaced operator.
"""
assert len(orig_outs) == len(new_outs), (
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
)
for orig_out, new_out in zip(
orig_outs,
new_outs,
):
if (orig_out is None or new_out is None) and (
op_name not in core.ops_contain_none
):
raise ValueError(
f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}"
)
if orig_out is None:
# to keep same as phi op definition, orig_out may receive None
continue
elif new_out is not None:
orig_dtype = orig_out.dtype
new_dtype = new_out.dtype
orig_shape = orig_out.shape
new_shape = new_out.shape
assert orig_dtype == new_dtype, (
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
f'but orig_out dtype={orig_dtype} and new_out dtype={new_dtype}'
)
assert (
-1 not in new_shape
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
assert orig_shape == new_shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out shape={orig_shape} and new_out shape={new_shape}'
)
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
return
def decompose(
program,
blacklist=frozenset(),
whitelist=frozenset(),
):
"""
Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
The operators in blacklist will be excluded from program when decomposed into primitives, and only the
operators in whitelist will be decomposed. The priority of blacklist is higher than whitelist, it means
an operator both in blacklist and whitelist will not be decomposed.
The finally set that will be decomposed is:
(block.ops & ops have decomposite rule & whitelist) - blacklist
Args:
program (Program): The program to be processed.
blacklist (frozenset): The Operators that will be exclude when decomposed into primitives.
whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives.
"""
if not isinstance(program, Program):
raise TypeError(f"Expect type Program, but got type {type(program)}.")
block = program.block()
if not isinstance(blacklist, (set, frozenset)):
raise TypeError(
f'Expected type of blacklisst is set|frozenset, but got {type(blacklist)}.'
)
if not isinstance(whitelist, (set, frozenset)):
raise TypeError(
f'Expected type of whiltelist is set|frozenset, but got {type(whitelist)}.'
)
blacklist = core.prim_config["forward_blacklist"] | blacklist
logging.debug("Decompose composite forward ops begin...")
if len(blacklist) > 0 and len(whitelist) > 0:
op_filter = (
lambda x: x.name() in whitelist and x.name() not in blacklist
)
elif len(blacklist) > 0 and len(whitelist) == 0:
op_filter = lambda x: x.name() not in blacklist
elif len(blacklist) == 0 and len(whitelist) > 0:
op_filter = lambda x: x.name() in whitelist
else:
op_filter = lambda x: True
with ir.core.program_guard(program):
_decompose_subgraph(
block,
op_filter,
)
logging.debug(
"Decompose composite forward ops finish: {}".format(
core.prim_config["composite_ops_record"]
)
)
def _decompose_subgraph(block, op_filter):
"""
The operators in block wich satisfy the filter conditon will be decomposed into primitives.
Args:
block (Block|Sequence[Block]): The blocks of program to be processed.
op_filter (function): The filter to specify which ops to be processed.
"""
if isinstance(block, Block):
ops_list = block.get_ops()
for op in ops_list:
op_name = op.name()
decom_rule = register.get_decomp_rule(op_name)
lower = decom_rule and op_filter(op)
if lower:
core.prim_config["composite_ops_record"].add(op_name)
input_args = _prepare_python_api_arguments(op)
ir.set_insertion_point(op)
orig_outs = op.results()
new_outs = _build_tensor_tuple(decom_rule(*input_args))
# Todo: To cover such case: some outputs are no longer needed after decomposition.
_check_op_results(op_name, orig_outs, new_outs)
op.replace_all_uses_with(new_outs)
block.remove_op(op)
return
elif isinstance(block, typing.Sequence):
for item in block:
_decompose_subgraph(item, op_filter)
return
raise TypeError(
f"Expect type Block or Sequence of Block, but got type {type(block)}"
)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.tensor import abs # noqa: F401
from paddle.tensor import acos # noqa: F401
from paddle.tensor import acosh # noqa: F401
from paddle.tensor import add # noqa: F401
from paddle.tensor import asin # noqa: F401
from paddle.tensor import asinh # noqa: F401
from paddle.tensor import atan # noqa: F401
from paddle.tensor import atanh # noqa: F401
from paddle.tensor import broadcast_shape # noqa: F401
from paddle.tensor import broadcast_to # noqa: F401
from paddle.tensor import concat # noqa: F401
from paddle.tensor import cos # noqa: F401
from paddle.tensor import cosh # noqa: F401
from paddle.tensor import cumprod # noqa: F401
from paddle.tensor import cumsum # noqa: F401
from paddle.tensor import digamma # noqa: F401
from paddle.tensor import divide # noqa: F401
from paddle.tensor import erf # noqa: F401
from paddle.tensor import erfinv # noqa: F401
from paddle.tensor import exp # noqa: F401
from paddle.tensor import expm1 # noqa: F401
from paddle.tensor import fill_constant # noqa: F401
from paddle.tensor import full # noqa: F401
from paddle.tensor import gather # noqa: F401
from paddle.tensor import greater_equal # noqa: F401
from paddle.tensor import lgamma # noqa: F401
from paddle.tensor import log # noqa: F401
from paddle.tensor import log1p # noqa: F401
from paddle.tensor import logcumsumexp # noqa: F401
from paddle.tensor import logit # noqa: F401
from paddle.tensor import logsumexp # noqa: F401
from paddle.tensor import max # noqa: F401
from paddle.tensor import min # noqa: F401
from paddle.tensor import multiply # noqa: F401
from paddle.tensor import ones # noqa: F401
from paddle.tensor import pow # noqa: F401
from paddle.tensor import prod # noqa: F401
from paddle.tensor import reshape # noqa: F401
from paddle.tensor import rsqrt # noqa: F401
from paddle.tensor import sign # noqa: F401
from paddle.tensor import sin # noqa: F401
from paddle.tensor import sinh # noqa: F401
from paddle.tensor import sqrt # noqa: F401
from paddle.tensor import subtract # noqa: F401
from paddle.tensor import sum # noqa: F401
from paddle.tensor import tan # noqa: F401
from paddle.tensor import tanh # noqa: F401
from paddle.tensor import tile # noqa: F401
from paddle.tensor import uniform # noqa: F401
from paddle.tensor import zeros # noqa: F401
from paddle.tensor.creation import assign # noqa: F401
from paddle.tensor.creation import zeros_like # noqa: F401
from paddle.tensor.manipulation import cast # noqa: F401
from paddle.tensor.math import maximum # noqa: F401
from paddle.tensor.math import minimum # noqa: F401
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
class Registry:
"""A general registry object."""
__slots__ = ['name', 'rules']
def __init__(self, name):
self.name = name
self.rules = {}
def register(self, op_type, rule):
assert isinstance(op_type, str)
assert inspect.isfunction(rule)
assert (
op_type not in self.rules
), f'name "{op_type}" should not be registered before.'
self.rules[op_type] = rule
def lookup(self, op_type):
return self.rules.get(op_type)
_decomposition_ops = Registry('decomposition')
def register_decomp(op_type):
"""
Decorator for registering the lower function for an original op into sequence of primitive ops.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@register_decomp('softmax')
def softmax(x, axis):
molecular = exp(x)
denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape)
res = divide(molecular, denominator)
return res
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
_decomposition_ops.register(op_type, f)
return f
return wrapper
def get_decomp_rule(op_type):
_lowerrule = _decomposition_ops.lookup(op_type)
return _lowerrule
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .primitives import * # noqa: F403
from .register import register_decomp
@register_decomp('pd.mean')
def mean(x, axis, keepdim):
"""define composite rule of op mean"""
x_shape = x.shape
axes = axis or tuple(range(0, len(x_shape)))
axes = (axes,) if isinstance(axes, int) else axes
sum_x = sum(x, axis=axes, keepdim=keepdim)
value_to_fill = 1
for axis in axes:
value_to_fill *= x_shape[axis]
norm = fill_constant(
shape=[],
value=value_to_fill,
dtype=sum_x.dtype,
)
res = divide(sum_x, norm)
return res
......@@ -499,6 +499,7 @@ packages=['paddle',
'paddle.geometric.message_passing',
'paddle.geometric.sampling',
'paddle.ir',
'paddle.decomposition',
]
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
......
......@@ -1497,6 +1497,7 @@ def get_setup_parameters():
'paddle.geometric.message_passing',
'paddle.geometric.sampling',
'paddle.ir',
'paddle.decomposition',
]
paddle_bins = ''
......
......@@ -12,3 +12,4 @@ add_subdirectory(prim)
add_subdirectory(model)
add_subdirectory(composite_ops)
add_subdirectory(process)
add_subdirectory(new_ir_prim)
file(
GLOB TEST_INTERP_CASES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")
foreach(target ${TEST_INTERP_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
FLAGS_enable_new_ir_in_executor=true)
endforeach()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle import ir
from paddle.decomposition import decompose
paddle.enable_static()
def get_ir_program():
x = paddle.randn([4, 4])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
y_s = paddle.matmul(x_s, x_s)
y_s = paddle.add(x_s, y_s)
y_s = paddle.mean(y_s)
y_s = paddle.tanh(y_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
class TestBuildOp(unittest.TestCase):
def test_build_op(self):
newir_program = get_ir_program()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
decompose(newir_program)
op_name_list = [op.name() for op in newir_program.block().get_ops()]
self.assertEqual(
op_name_list,
[
'builtin.get_parameter',
'pd.matmul',
'pd.add',
'pd.full_int_array',
'pd.sum',
'pd.full',
'pd.divide',
'pd.tanh',
],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册