From a4e0f666fa2700cc50435c4b67e2eaef5fd1b3c8 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Thu, 30 Mar 2023 22:41:33 +0800 Subject: [PATCH] [Prim] fix loss of composite rule (#52120) * fix_prim * fix bug * add note * fix logic * fix * add note * fix check * fix bug * fix bug * fix bug * add debug * fix check * fix bug * sync print log * fix test case * change default * change test case time --- python/paddle/incubate/autograd/primapi.py | 23 ++++++-- python/paddle/incubate/autograd/primx.py | 53 ++++++++++++++++--- .../jit/dy2static/program_translator.py | 22 ++++++-- python/paddle/jit/dy2static/utils.py | 7 ++- .../test_composite_batch_norm.py | 5 +- test/prim/model/CMakeLists.txt | 2 +- test/prim/test_comp_custom_vjp.py | 2 +- 7 files changed, 97 insertions(+), 17 deletions(-) diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 38dbd591baf..1ba95c7f5b2 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -217,7 +217,13 @@ def grad(outputs, inputs, grad_outputs=None): @framework.static_only -def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): +def to_prim( + blocks, + blacklist=frozenset(), + whitelist=frozenset(), + start_idx=-1, + backward_length=-1, +): """Search nonbasic ops which have be registered composite rules and replace them with primitive ops. The operators in blacklist will be excluded from program when lowering into primitives, and only the operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means @@ -229,6 +235,8 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): Args: blacklist(frozenset): The Operators that will be exclude when lowering into primitives. whitelist(frozenset): Only the operators in whitelist will be lowering into primitives. + start_idx(int): If start_idx exceeds -1, ops[start_idx:] will be processed. Default: -1. + backward_length(int): If backward_length exceeds -1, ops[:-backward_length] will be processed. Default: -1. """ if not core._is_fwd_prim_enabled(): return @@ -258,7 +266,7 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): blacklist = prim_config["forward_blacklist"] | blacklist with framework.program_guard(main_program): - print("Lowering composite forward ops begin...") + print("Lowering composite forward ops begin...", flush=True) if len(blacklist) > 0 and len(whitelist) > 0: filter_ = lambda x: x.type in whitelist and x.type not in blacklist @@ -268,6 +276,13 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): filter_ = lambda x: x.type in whitelist else: filter_ = lambda x: True - primx._lower_composite(blocks, filter_) + primx._lower_composite( + blocks, + filter_, + start_idx=start_idx, + backward_length=backward_length, + ) replace_ops = prim_config["composite_ops_record"] - print(f"Lowering composite forward ops finish: {replace_ops}") + print( + f"Lowering composite forward ops finish: {replace_ops}", flush=True + ) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 352101a7fd1..ce3e8914adc 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -551,7 +551,10 @@ def _lower(block, reverse, blacklist): def _lower_composite( - block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True + block, + filter_: typing.Callable[[framework.Operator], bool] = lambda x: True, + start_idx=-1, + backward_length=-1, ): """The operators in block wich satisfy the filter conditon will be decomposite into primitives.""" @@ -602,13 +605,41 @@ def _lower_composite( none_vars_to_remove = set() change = None + + # Only process required sliced block + # If given start_idx, only ops[start_idx:] will be processed. + # If given backward_length, only ops[:-backward_length] will be processed. + # Note, start_idx and backward_length cannot be both given, because the length of non-processed part must be kept unchanged. + length = len(block.ops) + idx_list = range(length) + assert ( + -1 <= backward_length <= length + ), f'expect -1 <= backward_length <= {length}, but got backward_length: {backward_length}' + assert ( + -1 <= start_idx <= length + ), f'expect -1 <= start_idx <= {length}, but got start_idx: {start_idx}' + assert not ( + backward_length > -1 and start_idx > -1 + ), f'got start_idx: {start_idx} and backward_length: {backward_length}' + if backward_length > -1: + idx_list = range(length - backward_length) + if start_idx > -1: + idx_list = range(start_idx, length) + # Step2: Process all ops in the target block - for op_idx in range(len(block.ops)): + for op_idx in range(length): op = block.ops[op_idx] ops_to_remove.append(op_idx) - if lookup_fn(op.type) is not None and filter_(op): + + op_name = op.type + do_comp = ( + (lookup_fn(op_name) is not None) + and filter_(op) + and op_idx in idx_list + ) + + if do_comp: change = True - op_name = op.type prim_config["composite_ops_record"].add(op_name) input_args = prepare_python_api_arguments(op) bind(input_args, to_bind, value_table) @@ -686,12 +717,22 @@ def _lower_composite( # composite ops may contain other composite ops, thus, call _lower_composite again. if change: - _lower_composite(block, filter_) + _lower_composite( + block, + filter_, + start_idx=start_idx, + backward_length=backward_length, + ) return elif isinstance(block, typing.Sequence): for item in block: - _lower_composite(item, filter_) + _lower_composite( + item, + filter_, + start_idx=start_idx, + backward_length=backward_length, + ) return else: raise TypeError diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 942f1323c62..86d2075ab10 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1266,8 +1266,12 @@ class PrimHooker(PartialProgramLayerHook): def after_append_backward(self, whole_program, backward_start_idx): backward_length = len(whole_program.block(0).ops) - backward_start_idx if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: - _to_prim(whole_program.blocks, whitelist=self.custom_vjps) + # only process backward part of block + _to_prim(whole_program.blocks, backward_length=backward_length) new_start_index = len(whole_program.block(0).ops) - backward_length + if backward_length > 0: + # only process forward part of block + _to_prim(whole_program.blocks, start_idx=new_start_index) return whole_program, new_start_index def after_infer(self, infer_program): @@ -1693,9 +1697,21 @@ def enable_to_static(enable_to_static_bool): @switch_to_static_graph -def _to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): +def _to_prim( + blocks, + blacklist=frozenset(), + whitelist=frozenset(), + start_idx=-1, + backward_length=-1, +): """Swith to static graph and call to_prim.""" # TODO(Aurelius84): Fix this cycle import problem from paddle.incubate.autograd import primapi - primapi.to_prim(blocks, blacklist=blacklist, whitelist=whitelist) + primapi.to_prim( + blocks, + blacklist=blacklist, + whitelist=whitelist, + start_idx=start_idx, + backward_length=backward_length, + ) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 845e556a22a..2f1a661a2f1 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1477,7 +1477,12 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), ): op = program_desc.block(0).op(i) - if op.type() == 'fill_any_like': + # If prim forward op, fill_any_like will be decomposite as fill_constant. + if core._is_fwd_prim_enabled(): + target = ('fill_any_like', 'fill_constant') + else: + target = 'fill_any_like' + if op.type() in target: var_name = op.output('Out')[0] names.append(var_name) return names diff --git a/test/prim/composite_ops/test_composite_batch_norm.py b/test/prim/composite_ops/test_composite_batch_norm.py index 450e42a1a7b..39278663540 100644 --- a/test/prim/composite_ops/test_composite_batch_norm.py +++ b/test/prim/composite_ops/test_composite_batch_norm.py @@ -216,7 +216,10 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None): primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that batch_norm is splitted into small ops - assert 'batch_norm' not in fwd_ops_new + assert ( + 'batch_norm' not in fwd_ops_new + and 'reduce_mean' not in fwd_ops_new + ) exe = paddle.static.Executor() exe.run(startup_program) diff --git a/test/prim/model/CMakeLists.txt b/test/prim/model/CMakeLists.txt index 9aab2358377..53a109cf41d 100644 --- a/test/prim/model/CMakeLists.txt +++ b/test/prim/model/CMakeLists.txt @@ -8,7 +8,7 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() -set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 800) +set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 850) set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500) if(WITH_CINN) diff --git a/test/prim/test_comp_custom_vjp.py b/test/prim/test_comp_custom_vjp.py index 981a41caee3..c658a3cda22 100644 --- a/test/prim/test_comp_custom_vjp.py +++ b/test/prim/test_comp_custom_vjp.py @@ -55,7 +55,7 @@ class TestCustomVJP(unittest.TestCase): 'elementwise_mul', 'scale', 'cast', - 'fill_any_like', + 'fill_constant', 'cast', 'elementwise_mul', 'scale', -- GitLab