diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 38dbd591baf9d7127f50aee28600522864e623b3..1ba95c7f5b28b33382515661c3e839a0d18b552f 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -217,7 +217,13 @@ def grad(outputs, inputs, grad_outputs=None): @framework.static_only -def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): +def to_prim( + blocks, + blacklist=frozenset(), + whitelist=frozenset(), + start_idx=-1, + backward_length=-1, +): """Search nonbasic ops which have be registered composite rules and replace them with primitive ops. The operators in blacklist will be excluded from program when lowering into primitives, and only the operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means @@ -229,6 +235,8 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): Args: blacklist(frozenset): The Operators that will be exclude when lowering into primitives. whitelist(frozenset): Only the operators in whitelist will be lowering into primitives. + start_idx(int): If start_idx exceeds -1, ops[start_idx:] will be processed. Default: -1. + backward_length(int): If backward_length exceeds -1, ops[:-backward_length] will be processed. Default: -1. """ if not core._is_fwd_prim_enabled(): return @@ -258,7 +266,7 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): blacklist = prim_config["forward_blacklist"] | blacklist with framework.program_guard(main_program): - print("Lowering composite forward ops begin...") + print("Lowering composite forward ops begin...", flush=True) if len(blacklist) > 0 and len(whitelist) > 0: filter_ = lambda x: x.type in whitelist and x.type not in blacklist @@ -268,6 +276,13 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): filter_ = lambda x: x.type in whitelist else: filter_ = lambda x: True - primx._lower_composite(blocks, filter_) + primx._lower_composite( + blocks, + filter_, + start_idx=start_idx, + backward_length=backward_length, + ) replace_ops = prim_config["composite_ops_record"] - print(f"Lowering composite forward ops finish: {replace_ops}") + print( + f"Lowering composite forward ops finish: {replace_ops}", flush=True + ) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 352101a7fd1f6ce33aac4a48226d8254ad47e87d..ce3e8914adc43a039d8afce3b96fedbf6b345ff3 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -551,7 +551,10 @@ def _lower(block, reverse, blacklist): def _lower_composite( - block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True + block, + filter_: typing.Callable[[framework.Operator], bool] = lambda x: True, + start_idx=-1, + backward_length=-1, ): """The operators in block wich satisfy the filter conditon will be decomposite into primitives.""" @@ -602,13 +605,41 @@ def _lower_composite( none_vars_to_remove = set() change = None + + # Only process required sliced block + # If given start_idx, only ops[start_idx:] will be processed. + # If given backward_length, only ops[:-backward_length] will be processed. + # Note, start_idx and backward_length cannot be both given, because the length of non-processed part must be kept unchanged. + length = len(block.ops) + idx_list = range(length) + assert ( + -1 <= backward_length <= length + ), f'expect -1 <= backward_length <= {length}, but got backward_length: {backward_length}' + assert ( + -1 <= start_idx <= length + ), f'expect -1 <= start_idx <= {length}, but got start_idx: {start_idx}' + assert not ( + backward_length > -1 and start_idx > -1 + ), f'got start_idx: {start_idx} and backward_length: {backward_length}' + if backward_length > -1: + idx_list = range(length - backward_length) + if start_idx > -1: + idx_list = range(start_idx, length) + # Step2: Process all ops in the target block - for op_idx in range(len(block.ops)): + for op_idx in range(length): op = block.ops[op_idx] ops_to_remove.append(op_idx) - if lookup_fn(op.type) is not None and filter_(op): + + op_name = op.type + do_comp = ( + (lookup_fn(op_name) is not None) + and filter_(op) + and op_idx in idx_list + ) + + if do_comp: change = True - op_name = op.type prim_config["composite_ops_record"].add(op_name) input_args = prepare_python_api_arguments(op) bind(input_args, to_bind, value_table) @@ -686,12 +717,22 @@ def _lower_composite( # composite ops may contain other composite ops, thus, call _lower_composite again. if change: - _lower_composite(block, filter_) + _lower_composite( + block, + filter_, + start_idx=start_idx, + backward_length=backward_length, + ) return elif isinstance(block, typing.Sequence): for item in block: - _lower_composite(item, filter_) + _lower_composite( + item, + filter_, + start_idx=start_idx, + backward_length=backward_length, + ) return else: raise TypeError diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 942f1323c62da49006a8eba3700a7fe8dcf9d447..86d2075ab101d88de0e7990263a68ebcfb373d14 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1266,8 +1266,12 @@ class PrimHooker(PartialProgramLayerHook): def after_append_backward(self, whole_program, backward_start_idx): backward_length = len(whole_program.block(0).ops) - backward_start_idx if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: - _to_prim(whole_program.blocks, whitelist=self.custom_vjps) + # only process backward part of block + _to_prim(whole_program.blocks, backward_length=backward_length) new_start_index = len(whole_program.block(0).ops) - backward_length + if backward_length > 0: + # only process forward part of block + _to_prim(whole_program.blocks, start_idx=new_start_index) return whole_program, new_start_index def after_infer(self, infer_program): @@ -1693,9 +1697,21 @@ def enable_to_static(enable_to_static_bool): @switch_to_static_graph -def _to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): +def _to_prim( + blocks, + blacklist=frozenset(), + whitelist=frozenset(), + start_idx=-1, + backward_length=-1, +): """Swith to static graph and call to_prim.""" # TODO(Aurelius84): Fix this cycle import problem from paddle.incubate.autograd import primapi - primapi.to_prim(blocks, blacklist=blacklist, whitelist=whitelist) + primapi.to_prim( + blocks, + blacklist=blacklist, + whitelist=whitelist, + start_idx=start_idx, + backward_length=backward_length, + ) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 845e556a22a12609a488bc39ee849462b89fc85e..2f1a661a2f1d072fed1ba8be47a4ad5de4f57cac 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1477,7 +1477,12 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), ): op = program_desc.block(0).op(i) - if op.type() == 'fill_any_like': + # If prim forward op, fill_any_like will be decomposite as fill_constant. + if core._is_fwd_prim_enabled(): + target = ('fill_any_like', 'fill_constant') + else: + target = 'fill_any_like' + if op.type() in target: var_name = op.output('Out')[0] names.append(var_name) return names diff --git a/test/prim/composite_ops/test_composite_batch_norm.py b/test/prim/composite_ops/test_composite_batch_norm.py index 450e42a1a7b85854f2d20d94e76bfc2a7df1999e..39278663540fbf0594edbebaf745c836194903bd 100644 --- a/test/prim/composite_ops/test_composite_batch_norm.py +++ b/test/prim/composite_ops/test_composite_batch_norm.py @@ -216,7 +216,10 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None): primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that batch_norm is splitted into small ops - assert 'batch_norm' not in fwd_ops_new + assert ( + 'batch_norm' not in fwd_ops_new + and 'reduce_mean' not in fwd_ops_new + ) exe = paddle.static.Executor() exe.run(startup_program) diff --git a/test/prim/model/CMakeLists.txt b/test/prim/model/CMakeLists.txt index 9aab235837718921ca5c91e7c3e1e4992aae2277..53a109cf41d7a58a819fc33462d5ea12a9044c4a 100644 --- a/test/prim/model/CMakeLists.txt +++ b/test/prim/model/CMakeLists.txt @@ -8,7 +8,7 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() -set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 800) +set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 850) set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500) if(WITH_CINN) diff --git a/test/prim/test_comp_custom_vjp.py b/test/prim/test_comp_custom_vjp.py index 981a41caee3b26a97e1047e1c215271a6b8d4320..c658a3cda224b89a1b88c0c443c37c7363c48176 100644 --- a/test/prim/test_comp_custom_vjp.py +++ b/test/prim/test_comp_custom_vjp.py @@ -55,7 +55,7 @@ class TestCustomVJP(unittest.TestCase): 'elementwise_mul', 'scale', 'cast', - 'fill_any_like', + 'fill_constant', 'cast', 'elementwise_mul', 'scale',