From 5e140d8fbcaf63840c089fcd1430a2851a925d69 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Fri, 31 Mar 2023 19:51:06 +0800 Subject: [PATCH] Add overflow check in memory efficient attention implementation (#52191) (#52257) --- python/paddle/incubate/autograd/primx.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 5bc8e45260d..b9b7c90b17a 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -626,19 +626,31 @@ def _lower_composite( if start_idx > -1: idx_list = range(start_idx, length) + lower = lower_pre = False # Flag of routing to lower or copy branch # Step2: Process all ops in the target block for op_idx in range(length): op = block.ops[op_idx] ops_to_remove.append(op_idx) - op_name = op.type - do_comp = ( + + # NOTE: why need _sync_with_cpp here + # _sync_wich_cpp after every copied operator is very slow. + # However, _sync_wich_cpp only support continuous block currently. + # The lowering transformation will generate program which is + # crossed combination of copy block and lower block, such as + # op1(copy) -> op2(copy) -> op3(lower) -> op4(lower) -> op5(copy) -> op6(copy) + # It will cause _sync_wich_cpp error. + # So, _sync_with_cpp will be executed only once after every continuous copy block. + lower = ( (lookup_fn(op_name) is not None) and filter_(op) and op_idx in idx_list ) + if not lower_pre and lower: + block._sync_with_cpp() + lower_pre = lower - if do_comp: + if lower: change = True prim_config["composite_ops_record"].add(op_name) input_args = prepare_python_api_arguments(op) @@ -683,8 +695,8 @@ def _lower_composite( else: op_desc = block.desc.append_op() op_desc.copy_from(op.desc) - block._sync_with_cpp() + block._sync_with_cpp() # Step3: Do some post-processing work for op_idx in reversed(ops_to_remove): block.desc._remove_op(op_idx, op_idx + 1) -- GitLab