未验证 提交 5a3de29d 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim][NewIR] change decomposition to return new vars in New IR (#56391)

* change prim forward in New IR

* polish code

* polish code

* test case
上级 d38dde68
......@@ -45,7 +45,7 @@ def _prepare_python_api_arguments(op):
return tuple(api_arguments)
def _check_op_results(op_name, orig_outs, new_outs):
def _check_op_results(op_name, orig_outs, new_outs, orig_vars, dst_vars):
"""
Check whether the replaced outputs are consistent with origin outputs.
......@@ -53,6 +53,8 @@ def _check_op_results(op_name, orig_outs, new_outs):
op_name (str): The name of operator.
orig_outs (tuple): The outputs of original operator.
new_outs (tuple): The outputs of replaced operator.
orig_vars (dict): Origin variables of original block.
dst_vars (list): Corresponding replaced variables of Origin variables.
"""
assert len(orig_outs) == len(new_outs), (
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
......@@ -73,6 +75,8 @@ def _check_op_results(op_name, orig_outs, new_outs):
# to keep same as phi op definition, orig_out may receive None
continue
elif new_out is not None:
if orig_out in orig_vars.keys():
dst_vars[orig_vars[orig_out]] = new_out
orig_dtype = orig_out.dtype
new_dtype = new_out.dtype
orig_shape = orig_out.shape
......@@ -96,6 +100,7 @@ def _check_op_results(op_name, orig_outs, new_outs):
def decompose(
program,
src_vars,
blacklist=frozenset(),
whitelist=frozenset(),
):
......@@ -108,10 +113,17 @@ def decompose(
The finally set that will be decomposed is:
(block.ops & ops have decomposite rule & whitelist) - blacklist
Note:
All variables must be contained inside the given program.
Args:
program (Program): The program to be processed.
src_vars (list[OpResult]): In program, once some operator is decomposed, its vars will be replaced by new ones. This argument means some vars will be used later and corresponding vars will be returned for later usage.
blacklist (frozenset): The Operators that will be exclude when decomposed into primitives.
whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives.
Returns:
dst_vars (list): A list contains all vars which replace origin ones in src_vars.
"""
if not isinstance(program, Program):
raise TypeError(f"Expect type Program, but got type {type(program)}.")
......@@ -140,25 +152,43 @@ def decompose(
op_filter = lambda x: x.name() in whitelist
else:
op_filter = lambda x: True
dst_vars = [None] * len(src_vars)
dst_vars_dct = {}
for idx, item in enumerate(src_vars):
if not isinstance(item, ir.OpResult):
raise TypeError(
f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {src_vars}."
)
dst_vars_dct[item] = idx
with ir.core.program_guard(program):
_decompose_subgraph(
block,
dst_vars_dct,
dst_vars,
op_filter,
)
for item in dst_vars:
if not isinstance(item, ir.OpResult):
raise TypeError(
f"Each var in dst_vars should map corresponding var in src_vars, but got type {type(item)} in {dst_vars}."
)
logging.debug(
"Decompose composite forward ops finish: {}".format(
core.prim_config["composite_ops_record"]
)
)
return dst_vars
def _decompose_subgraph(block, op_filter):
def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):
"""
The operators in block wich satisfy the filter conditon will be decomposed into primitives.
Args:
block (Block|Sequence[Block]): The blocks of program to be processed.
op_filter (function): The filter to specify which ops to be processed.
orig_vars (dict): Origin variables of original block.
dst_vars (list): Corresponding replaced variables of Origin variables.
"""
if isinstance(block, Block):
......@@ -176,7 +206,9 @@ def _decompose_subgraph(block, op_filter):
new_outs = _build_tensor_tuple(decom_rule(*input_args))
# Todo: To cover such case: some outputs are no longer needed after decomposition.
_check_op_results(op_name, orig_outs, new_outs)
_check_op_results(
op_name, orig_outs, new_outs, orig_vars, dst_vars
)
op.replace_all_uses_with(new_outs)
block.remove_op(op)
......@@ -184,7 +216,7 @@ def _decompose_subgraph(block, op_filter):
elif isinstance(block, typing.Sequence):
for item in block:
_decompose_subgraph(item, op_filter)
_decompose_subgraph(item, orig_vars, dst_vars, op_filter)
return
raise TypeError(
f"Expect type Block or Sequence of Block, but got type {type(block)}"
......
......@@ -41,8 +41,14 @@ def get_ir_program():
class TestBuildOp(unittest.TestCase):
def test_build_op(self):
newir_program = get_ir_program()
y = newir_program.block().ops[-2].results()
orig_shape = y[0].shape
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
decompose(newir_program)
y_new = decompose(newir_program, y)
new_shape = y_new[0].shape
assert (
orig_shape == new_shape
), f"Original shape {orig_shape} is not equal to new shape {new_shape}"
op_name_list = [op.name() for op in newir_program.block().ops]
self.assertEqual(
op_name_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册