decomp.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import typing

from paddle import ir
from paddle.fluid.libpaddle.ir import Block, Program
from paddle.framework import core

from . import register


def _build_tensor_tuple(xs):
    if isinstance(xs, ir.OpResult):
        return (xs,)
    elif isinstance(xs, typing.Sequence):
        return tuple(xs)
30
    return TypeError(f"Type {type(xs)} is not supported.")
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47


def _prepare_python_api_arguments(op):
    """
    For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs.

    Args:
    op (Operator): The target operator.
    """
    op_inputs = [x.source() for x in op.operands()]
    op_attrs_dict = op.attrs()
    op_attrs_name = op.get_attr_names()
    op_attrs = [op_attrs_dict[x] for x in op_attrs_name]
    api_arguments = op_inputs + op_attrs
    return tuple(api_arguments)


48
def _check_op_results(op_name, orig_outs, new_outs, orig_vars, dst_vars):
49 50 51 52 53 54 55
    """
    Check whether the replaced outputs are consistent with origin outputs.

    Args:
    op_name (str): The name of operator.
    orig_outs (tuple): The outputs of original operator.
    new_outs (tuple): The outputs of replaced operator.
56 57
    orig_vars (dict): Origin variables of original block.
    dst_vars (list): Corresponding replaced variables of Origin variables.
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    """
    assert len(orig_outs) == len(new_outs), (
        f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
        f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
    )

    for orig_out, new_out in zip(
        orig_outs,
        new_outs,
    ):
        if (orig_out is None or new_out is None) and (
            op_name not in core.ops_contain_none
        ):
            raise ValueError(
                f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}"
            )
        if orig_out is None:
            # to keep same as phi op definition, orig_out may receive None
            continue
        elif new_out is not None:
78 79
            if orig_out in orig_vars.keys():
                dst_vars[orig_vars[orig_out]] = new_out
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
            orig_dtype = orig_out.dtype
            new_dtype = new_out.dtype
            orig_shape = orig_out.shape
            new_shape = new_out.shape
            assert orig_dtype == new_dtype, (
                f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
                f'but orig_out dtype={orig_dtype} and new_out dtype={new_dtype}'
            )
            assert (
                -1 not in new_shape
            ), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
            assert orig_shape == new_shape, (
                f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
                f'but orig_out shape={orig_shape} and new_out shape={new_shape}'
            )
            assert not (orig_out is None) ^ (
                new_out is None
            ), "orig_out and new_out should match."
        return


def decompose(
    program,
103
    src_vars,
104 105 106 107 108 109 110 111 112 113 114 115
    blacklist=frozenset(),
    whitelist=frozenset(),
):
    """
    Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
    The operators in blacklist will be excluded from program when decomposed into primitives, and only the
    operators in whitelist will be decomposed. The priority of blacklist is higher than whitelist, it means
    an operator both in blacklist and whitelist will not be decomposed.

    The finally set that will be decomposed is:
        (block.ops & ops have decomposite rule & whitelist) - blacklist

116 117 118
    Note:
        All variables must be contained inside the given program.

119 120
    Args:
        program (Program): The program to be processed.
121
        src_vars (list[OpResult]): In program, once some operator is decomposed, its vars will be replaced by new ones. This argument means some vars will be used later and corresponding vars will be returned for later usage.
122 123
        blacklist (frozenset): The Operators that will be exclude when decomposed into primitives.
        whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives.
124 125 126

    Returns:
        dst_vars (list): A list contains all vars which replace origin ones in src_vars.
127
    """
128 129
    if not core._is_fwd_prim_enabled():
        return src_vars
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    if not isinstance(program, Program):
        raise TypeError(f"Expect type Program, but got type {type(program)}.")
    block = program.block()

    if not isinstance(blacklist, (set, frozenset)):
        raise TypeError(
            f'Expected type of blacklisst is set|frozenset, but got {type(blacklist)}.'
        )
    if not isinstance(whitelist, (set, frozenset)):
        raise TypeError(
            f'Expected type of whiltelist is set|frozenset, but got {type(whitelist)}.'
        )

    blacklist = core.prim_config["forward_blacklist"] | blacklist

    logging.debug("Decompose composite forward ops begin...")

    if len(blacklist) > 0 and len(whitelist) > 0:
        op_filter = (
            lambda x: x.name() in whitelist and x.name() not in blacklist
        )
    elif len(blacklist) > 0 and len(whitelist) == 0:
        op_filter = lambda x: x.name() not in blacklist
    elif len(blacklist) == 0 and len(whitelist) > 0:
        op_filter = lambda x: x.name() in whitelist
    else:
        op_filter = lambda x: True
157 158 159 160 161 162 163 164
    dst_vars = [None] * len(src_vars)
    dst_vars_dct = {}
    for idx, item in enumerate(src_vars):
        if not isinstance(item, ir.OpResult):
            raise TypeError(
                f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {src_vars}."
            )
        dst_vars_dct[item] = idx
165 166 167
    with ir.core.program_guard(program):
        _decompose_subgraph(
            block,
168 169
            dst_vars_dct,
            dst_vars,
170 171
            op_filter,
        )
172 173 174 175 176
    for item in dst_vars:
        if not isinstance(item, ir.OpResult):
            raise TypeError(
                f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {dst_vars}."
            )
177 178 179 180 181
    logging.debug(
        "Decompose composite forward ops finish: {}".format(
            core.prim_config["composite_ops_record"]
        )
    )
182
    return dst_vars
183 184


185
def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):
186 187 188 189 190 191
    """
    The operators in block wich satisfy the filter conditon will be decomposed into primitives.

    Args:
        block (Block|Sequence[Block]): The blocks of program to be processed.
        op_filter (function): The filter to specify which ops to be processed.
192 193
        orig_vars (dict): Origin variables of original block.
        dst_vars (list): Corresponding replaced variables of Origin variables.
194 195 196
    """

    if isinstance(block, Block):
197
        ops_list = block.ops
198 199 200 201 202 203 204 205 206 207 208 209 210
        for op in ops_list:
            op_name = op.name()
            decom_rule = register.get_decomp_rule(op_name)
            lower = decom_rule and op_filter(op)

            if lower:
                core.prim_config["composite_ops_record"].add(op_name)
                input_args = _prepare_python_api_arguments(op)
                ir.set_insertion_point(op)
                orig_outs = op.results()
                new_outs = _build_tensor_tuple(decom_rule(*input_args))

                # Todo: To cover such case: some outputs are no longer needed after decomposition.
211 212 213
                _check_op_results(
                    op_name, orig_outs, new_outs, orig_vars, dst_vars
                )
214 215 216 217 218 219 220

                op.replace_all_uses_with(new_outs)
                block.remove_op(op)
        return

    elif isinstance(block, typing.Sequence):
        for item in block:
221
            _decompose_subgraph(item, orig_vars, dst_vars, op_filter)
222 223 224 225
        return
    raise TypeError(
        f"Expect type Block or Sequence of Block, but got type {type(block)}"
    )