From 523916faa0cd4d0c4392e3b4c16fe56e19ffaf89 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Tue, 8 Aug 2023 14:07:59 +0800 Subject: [PATCH] [Prim][NewIR] Support forward decomposition in new IR (#55480) * Support Prim Forward in New IR * Fix test case * polish code * fix code * polish format * format code --- python/paddle/__init__.py | 1 + python/paddle/decomposition/__init__.py | 16 ++ python/paddle/decomposition/decomp.py | 191 ++++++++++++++++++++++ python/paddle/decomposition/primitives.py | 69 ++++++++ python/paddle/decomposition/register.py | 73 +++++++++ python/paddle/decomposition/rules.py | 35 ++++ python/setup.py.in | 1 + setup.py | 1 + test/prim/CMakeLists.txt | 1 + test/prim/new_ir_prim/CMakeLists.txt | 10 ++ test/prim/new_ir_prim/test_decomp_op.py | 63 +++++++ 11 files changed, 461 insertions(+) create mode 100644 python/paddle/decomposition/__init__.py create mode 100644 python/paddle/decomposition/decomp.py create mode 100644 python/paddle/decomposition/primitives.py create mode 100644 python/paddle/decomposition/register.py create mode 100644 python/paddle/decomposition/rules.py create mode 100644 test/prim/new_ir_prim/CMakeLists.txt create mode 100644 test/prim/new_ir_prim/test_decomp_op.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index a2f560598d1..103a996443e 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -73,6 +73,7 @@ import paddle.regularizer # noqa: F401 import paddle.incubate # noqa: F401 import paddle.autograd # noqa: F401 import paddle.device # noqa: F401 +import paddle.decomposition # noqa: F401 import paddle.jit # noqa: F401 import paddle.amp # noqa: F401 diff --git a/python/paddle/decomposition/__init__.py b/python/paddle/decomposition/__init__.py new file mode 100644 index 00000000000..2ae777c19eb --- /dev/null +++ b/python/paddle/decomposition/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .decomp import decompose # noqa: F401 +from . import rules # noqa: F401 diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py new file mode 100644 index 00000000000..47f1d05bbb0 --- /dev/null +++ b/python/paddle/decomposition/decomp.py @@ -0,0 +1,191 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import typing + +from paddle import ir +from paddle.fluid.libpaddle.ir import Block, Program +from paddle.framework import core + +from . import register + + +def _build_tensor_tuple(xs): + if isinstance(xs, ir.OpResult): + return (xs,) + elif isinstance(xs, typing.Sequence): + return tuple(xs) + return TypeError(f"Type {type(xs)} is not supported") + + +def _prepare_python_api_arguments(op): + """ + For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs. + + Args: + op (Operator): The target operator. + """ + op_inputs = [x.source() for x in op.operands()] + op_attrs_dict = op.attrs() + op_attrs_name = op.get_attr_names() + op_attrs = [op_attrs_dict[x] for x in op_attrs_name] + api_arguments = op_inputs + op_attrs + return tuple(api_arguments) + + +def _check_op_results(op_name, orig_outs, new_outs): + """ + Check whether the replaced outputs are consistent with origin outputs. + + Args: + op_name (str): The name of operator. + orig_outs (tuple): The outputs of original operator. + new_outs (tuple): The outputs of replaced operator. + """ + assert len(orig_outs) == len(new_outs), ( + f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' + f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' + ) + + for orig_out, new_out in zip( + orig_outs, + new_outs, + ): + if (orig_out is None or new_out is None) and ( + op_name not in core.ops_contain_none + ): + raise ValueError( + f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}" + ) + if orig_out is None: + # to keep same as phi op definition, orig_out may receive None + continue + elif new_out is not None: + orig_dtype = orig_out.dtype + new_dtype = new_out.dtype + orig_shape = orig_out.shape + new_shape = new_out.shape + assert orig_dtype == new_dtype, ( + f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, ' + f'but orig_out dtype={orig_dtype} and new_out dtype={new_dtype}' + ) + assert ( + -1 not in new_shape + ), f'when replace origin op {op_name} with composite rule, composite out shape has -1.' + assert orig_shape == new_shape, ( + f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' + f'but orig_out shape={orig_shape} and new_out shape={new_shape}' + ) + assert not (orig_out is None) ^ ( + new_out is None + ), "orig_out and new_out should match." + return + + +def decompose( + program, + blacklist=frozenset(), + whitelist=frozenset(), +): + """ + Search nonbasic ops which have be registered composite rules and replace them with primitive ops. + The operators in blacklist will be excluded from program when decomposed into primitives, and only the + operators in whitelist will be decomposed. The priority of blacklist is higher than whitelist, it means + an operator both in blacklist and whitelist will not be decomposed. + + The finally set that will be decomposed is: + (block.ops & ops have decomposite rule & whitelist) - blacklist + + Args: + program (Program): The program to be processed. + blacklist (frozenset): The Operators that will be exclude when decomposed into primitives. + whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives. + """ + if not isinstance(program, Program): + raise TypeError(f"Expect type Program, but got type {type(program)}.") + block = program.block() + + if not isinstance(blacklist, (set, frozenset)): + raise TypeError( + f'Expected type of blacklisst is set|frozenset, but got {type(blacklist)}.' + ) + if not isinstance(whitelist, (set, frozenset)): + raise TypeError( + f'Expected type of whiltelist is set|frozenset, but got {type(whitelist)}.' + ) + + blacklist = core.prim_config["forward_blacklist"] | blacklist + + logging.debug("Decompose composite forward ops begin...") + + if len(blacklist) > 0 and len(whitelist) > 0: + op_filter = ( + lambda x: x.name() in whitelist and x.name() not in blacklist + ) + elif len(blacklist) > 0 and len(whitelist) == 0: + op_filter = lambda x: x.name() not in blacklist + elif len(blacklist) == 0 and len(whitelist) > 0: + op_filter = lambda x: x.name() in whitelist + else: + op_filter = lambda x: True + with ir.core.program_guard(program): + _decompose_subgraph( + block, + op_filter, + ) + logging.debug( + "Decompose composite forward ops finish: {}".format( + core.prim_config["composite_ops_record"] + ) + ) + + +def _decompose_subgraph(block, op_filter): + """ + The operators in block wich satisfy the filter conditon will be decomposed into primitives. + + Args: + block (Block|Sequence[Block]): The blocks of program to be processed. + op_filter (function): The filter to specify which ops to be processed. + """ + + if isinstance(block, Block): + ops_list = block.get_ops() + for op in ops_list: + op_name = op.name() + decom_rule = register.get_decomp_rule(op_name) + lower = decom_rule and op_filter(op) + + if lower: + core.prim_config["composite_ops_record"].add(op_name) + input_args = _prepare_python_api_arguments(op) + ir.set_insertion_point(op) + orig_outs = op.results() + new_outs = _build_tensor_tuple(decom_rule(*input_args)) + + # Todo: To cover such case: some outputs are no longer needed after decomposition. + _check_op_results(op_name, orig_outs, new_outs) + + op.replace_all_uses_with(new_outs) + block.remove_op(op) + return + + elif isinstance(block, typing.Sequence): + for item in block: + _decompose_subgraph(item, op_filter) + return + raise TypeError( + f"Expect type Block or Sequence of Block, but got type {type(block)}" + ) diff --git a/python/paddle/decomposition/primitives.py b/python/paddle/decomposition/primitives.py new file mode 100644 index 00000000000..c2c6fcb08da --- /dev/null +++ b/python/paddle/decomposition/primitives.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.tensor import abs # noqa: F401 +from paddle.tensor import acos # noqa: F401 +from paddle.tensor import acosh # noqa: F401 +from paddle.tensor import add # noqa: F401 +from paddle.tensor import asin # noqa: F401 +from paddle.tensor import asinh # noqa: F401 +from paddle.tensor import atan # noqa: F401 +from paddle.tensor import atanh # noqa: F401 +from paddle.tensor import broadcast_shape # noqa: F401 +from paddle.tensor import broadcast_to # noqa: F401 +from paddle.tensor import concat # noqa: F401 +from paddle.tensor import cos # noqa: F401 +from paddle.tensor import cosh # noqa: F401 +from paddle.tensor import cumprod # noqa: F401 +from paddle.tensor import cumsum # noqa: F401 +from paddle.tensor import digamma # noqa: F401 +from paddle.tensor import divide # noqa: F401 +from paddle.tensor import erf # noqa: F401 +from paddle.tensor import erfinv # noqa: F401 +from paddle.tensor import exp # noqa: F401 +from paddle.tensor import expm1 # noqa: F401 +from paddle.tensor import fill_constant # noqa: F401 +from paddle.tensor import full # noqa: F401 +from paddle.tensor import gather # noqa: F401 +from paddle.tensor import greater_equal # noqa: F401 +from paddle.tensor import lgamma # noqa: F401 +from paddle.tensor import log # noqa: F401 +from paddle.tensor import log1p # noqa: F401 +from paddle.tensor import logcumsumexp # noqa: F401 +from paddle.tensor import logit # noqa: F401 +from paddle.tensor import logsumexp # noqa: F401 +from paddle.tensor import max # noqa: F401 +from paddle.tensor import min # noqa: F401 +from paddle.tensor import multiply # noqa: F401 +from paddle.tensor import ones # noqa: F401 +from paddle.tensor import pow # noqa: F401 +from paddle.tensor import prod # noqa: F401 +from paddle.tensor import reshape # noqa: F401 +from paddle.tensor import rsqrt # noqa: F401 +from paddle.tensor import sign # noqa: F401 +from paddle.tensor import sin # noqa: F401 +from paddle.tensor import sinh # noqa: F401 +from paddle.tensor import sqrt # noqa: F401 +from paddle.tensor import subtract # noqa: F401 +from paddle.tensor import sum # noqa: F401 +from paddle.tensor import tan # noqa: F401 +from paddle.tensor import tanh # noqa: F401 +from paddle.tensor import tile # noqa: F401 +from paddle.tensor import uniform # noqa: F401 +from paddle.tensor import zeros # noqa: F401 +from paddle.tensor.creation import assign # noqa: F401 +from paddle.tensor.creation import zeros_like # noqa: F401 +from paddle.tensor.manipulation import cast # noqa: F401 +from paddle.tensor.math import maximum # noqa: F401 +from paddle.tensor.math import minimum # noqa: F401 diff --git a/python/paddle/decomposition/register.py b/python/paddle/decomposition/register.py new file mode 100644 index 00000000000..ba8adc54f65 --- /dev/null +++ b/python/paddle/decomposition/register.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + + +class Registry: + """A general registry object.""" + + __slots__ = ['name', 'rules'] + + def __init__(self, name): + self.name = name + self.rules = {} + + def register(self, op_type, rule): + assert isinstance(op_type, str) + assert inspect.isfunction(rule) + assert ( + op_type not in self.rules + ), f'name "{op_type}" should not be registered before.' + self.rules[op_type] = rule + + def lookup(self, op_type): + return self.rules.get(op_type) + + +_decomposition_ops = Registry('decomposition') + + +def register_decomp(op_type): + """ + Decorator for registering the lower function for an original op into sequence of primitive ops. + + Args: + op_type(str): The op name + + Returns: + wrapper: Inner wrapper function + + Examples: + .. code-block:: python + @register_decomp('softmax') + def softmax(x, axis): + molecular = exp(x) + denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape) + res = divide(molecular, denominator) + return res + + """ + if not isinstance(op_type, str): + raise TypeError(f'op_type must be str, but got {type(op_type)}.') + + def wrapper(f): + _decomposition_ops.register(op_type, f) + return f + + return wrapper + + +def get_decomp_rule(op_type): + _lowerrule = _decomposition_ops.lookup(op_type) + return _lowerrule diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py new file mode 100644 index 00000000000..ec8959cc960 --- /dev/null +++ b/python/paddle/decomposition/rules.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .primitives import * # noqa: F403 +from .register import register_decomp + + +@register_decomp('pd.mean') +def mean(x, axis, keepdim): + """define composite rule of op mean""" + x_shape = x.shape + axes = axis or tuple(range(0, len(x_shape))) + axes = (axes,) if isinstance(axes, int) else axes + sum_x = sum(x, axis=axes, keepdim=keepdim) + value_to_fill = 1 + for axis in axes: + value_to_fill *= x_shape[axis] + norm = fill_constant( + shape=[], + value=value_to_fill, + dtype=sum_x.dtype, + ) + res = divide(sum_x, norm) + return res diff --git a/python/setup.py.in b/python/setup.py.in index d1a6388a976..a0e665526dc 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -499,6 +499,7 @@ packages=['paddle', 'paddle.geometric.message_passing', 'paddle.geometric.sampling', 'paddle.ir', + 'paddle.decomposition', ] with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: diff --git a/setup.py b/setup.py index dda195d8aad..f2b1df02d8c 100644 --- a/setup.py +++ b/setup.py @@ -1497,6 +1497,7 @@ def get_setup_parameters(): 'paddle.geometric.message_passing', 'paddle.geometric.sampling', 'paddle.ir', + 'paddle.decomposition', ] paddle_bins = '' diff --git a/test/prim/CMakeLists.txt b/test/prim/CMakeLists.txt index b0f037f9a14..867a7552763 100644 --- a/test/prim/CMakeLists.txt +++ b/test/prim/CMakeLists.txt @@ -12,3 +12,4 @@ add_subdirectory(prim) add_subdirectory(model) add_subdirectory(composite_ops) add_subdirectory(process) +add_subdirectory(new_ir_prim) diff --git a/test/prim/new_ir_prim/CMakeLists.txt b/test/prim/new_ir_prim/CMakeLists.txt new file mode 100644 index 00000000000..393bc869d9b --- /dev/null +++ b/test/prim/new_ir_prim/CMakeLists.txt @@ -0,0 +1,10 @@ +file( + GLOB TEST_INTERP_CASES + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") + +foreach(target ${TEST_INTERP_CASES}) + py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 + FLAGS_enable_new_ir_in_executor=true) +endforeach() diff --git a/test/prim/new_ir_prim/test_decomp_op.py b/test/prim/new_ir_prim/test_decomp_op.py new file mode 100644 index 00000000000..f56b68f2317 --- /dev/null +++ b/test/prim/new_ir_prim/test_decomp_op.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle import ir +from paddle.decomposition import decompose + +paddle.enable_static() + + +def get_ir_program(): + x = paddle.randn([4, 4]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + x_s.stop_gradient = False + y_s = paddle.matmul(x_s, x_s) + y_s = paddle.add(x_s, y_s) + y_s = paddle.mean(y_s) + y_s = paddle.tanh(y_s) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TestBuildOp(unittest.TestCase): + def test_build_op(self): + newir_program = get_ir_program() + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + decompose(newir_program) + op_name_list = [op.name() for op in newir_program.block().get_ops()] + self.assertEqual( + op_name_list, + [ + 'builtin.get_parameter', + 'pd.matmul', + 'pd.add', + 'pd.full_int_array', + 'pd.sum', + 'pd.full', + 'pd.divide', + 'pd.tanh', + ], + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab