From 5a3de29d152f3dbe14fad62ab4bb34ef224574ae Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Mon, 21 Aug 2023 19:51:44 +0800 Subject: [PATCH] [Prim][NewIR] change decomposition to return new vars in New IR (#56391) * change prim forward in New IR * polish code * polish code * test case --- python/paddle/decomposition/decomp.py | 40 ++++++++++++++++++++++--- test/prim/new_ir_prim/test_decomp_op.py | 8 ++++- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index 9d8fa8a05b2..9bd288dacd4 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -45,7 +45,7 @@ def _prepare_python_api_arguments(op): return tuple(api_arguments) -def _check_op_results(op_name, orig_outs, new_outs): +def _check_op_results(op_name, orig_outs, new_outs, orig_vars, dst_vars): """ Check whether the replaced outputs are consistent with origin outputs. @@ -53,6 +53,8 @@ def _check_op_results(op_name, orig_outs, new_outs): op_name (str): The name of operator. orig_outs (tuple): The outputs of original operator. new_outs (tuple): The outputs of replaced operator. + orig_vars (dict): Origin variables of original block. + dst_vars (list): Corresponding replaced variables of Origin variables. """ assert len(orig_outs) == len(new_outs), ( f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' @@ -73,6 +75,8 @@ def _check_op_results(op_name, orig_outs, new_outs): # to keep same as phi op definition, orig_out may receive None continue elif new_out is not None: + if orig_out in orig_vars.keys(): + dst_vars[orig_vars[orig_out]] = new_out orig_dtype = orig_out.dtype new_dtype = new_out.dtype orig_shape = orig_out.shape @@ -96,6 +100,7 @@ def _check_op_results(op_name, orig_outs, new_outs): def decompose( program, + src_vars, blacklist=frozenset(), whitelist=frozenset(), ): @@ -108,10 +113,17 @@ def decompose( The finally set that will be decomposed is: (block.ops & ops have decomposite rule & whitelist) - blacklist + Note: + All variables must be contained inside the given program. + Args: program (Program): The program to be processed. + src_vars (list[OpResult]): In program, once some operator is decomposed, its vars will be replaced by new ones. This argument means some vars will be used later and corresponding vars will be returned for later usage. blacklist (frozenset): The Operators that will be exclude when decomposed into primitives. whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives. + + Returns: + dst_vars (list): A list contains all vars which replace origin ones in src_vars. """ if not isinstance(program, Program): raise TypeError(f"Expect type Program, but got type {type(program)}.") @@ -140,25 +152,43 @@ def decompose( op_filter = lambda x: x.name() in whitelist else: op_filter = lambda x: True + dst_vars = [None] * len(src_vars) + dst_vars_dct = {} + for idx, item in enumerate(src_vars): + if not isinstance(item, ir.OpResult): + raise TypeError( + f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {src_vars}." + ) + dst_vars_dct[item] = idx with ir.core.program_guard(program): _decompose_subgraph( block, + dst_vars_dct, + dst_vars, op_filter, ) + for item in dst_vars: + if not isinstance(item, ir.OpResult): + raise TypeError( + f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {dst_vars}." + ) logging.debug( "Decompose composite forward ops finish: {}".format( core.prim_config["composite_ops_record"] ) ) + return dst_vars -def _decompose_subgraph(block, op_filter): +def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): """ The operators in block wich satisfy the filter conditon will be decomposed into primitives. Args: block (Block|Sequence[Block]): The blocks of program to be processed. op_filter (function): The filter to specify which ops to be processed. + orig_vars (dict): Origin variables of original block. + dst_vars (list): Corresponding replaced variables of Origin variables. """ if isinstance(block, Block): @@ -176,7 +206,9 @@ def _decompose_subgraph(block, op_filter): new_outs = _build_tensor_tuple(decom_rule(*input_args)) # Todo: To cover such case: some outputs are no longer needed after decomposition. - _check_op_results(op_name, orig_outs, new_outs) + _check_op_results( + op_name, orig_outs, new_outs, orig_vars, dst_vars + ) op.replace_all_uses_with(new_outs) block.remove_op(op) @@ -184,7 +216,7 @@ def _decompose_subgraph(block, op_filter): elif isinstance(block, typing.Sequence): for item in block: - _decompose_subgraph(item, op_filter) + _decompose_subgraph(item, orig_vars, dst_vars, op_filter) return raise TypeError( f"Expect type Block or Sequence of Block, but got type {type(block)}" diff --git a/test/prim/new_ir_prim/test_decomp_op.py b/test/prim/new_ir_prim/test_decomp_op.py index ecf9c859079..f90e0fe2439 100644 --- a/test/prim/new_ir_prim/test_decomp_op.py +++ b/test/prim/new_ir_prim/test_decomp_op.py @@ -41,8 +41,14 @@ def get_ir_program(): class TestBuildOp(unittest.TestCase): def test_build_op(self): newir_program = get_ir_program() + y = newir_program.block().ops[-2].results() + orig_shape = y[0].shape paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) - decompose(newir_program) + y_new = decompose(newir_program, y) + new_shape = y_new[0].shape + assert ( + orig_shape == new_shape + ), f"Original shape {orig_shape} is not equal to new shape {new_shape}" op_name_list = [op.name() for op in newir_program.block().ops] self.assertEqual( op_name_list, -- GitLab