未验证 提交 5e140d8f 编写于 作者: X Xiaoxu Chen 提交者: GitHub

Add overflow check in memory efficient attention implementation (#52191) (#52257)

上级 d05b73e4
......@@ -626,19 +626,31 @@ def _lower_composite(
if start_idx > -1:
idx_list = range(start_idx, length)
lower = lower_pre = False # Flag of routing to lower or copy branch
# Step2: Process all ops in the target block
for op_idx in range(length):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
op_name = op.type
do_comp = (
# NOTE: why need _sync_with_cpp here
# _sync_wich_cpp after every copied operator is very slow.
# However, _sync_wich_cpp only support continuous block currently.
# The lowering transformation will generate program which is
# crossed combination of copy block and lower block, such as
# op1(copy) -> op2(copy) -> op3(lower) -> op4(lower) -> op5(copy) -> op6(copy)
# It will cause _sync_wich_cpp error.
# So, _sync_with_cpp will be executed only once after every continuous copy block.
lower = (
(lookup_fn(op_name) is not None)
and filter_(op)
and op_idx in idx_list
)
if not lower_pre and lower:
block._sync_with_cpp()
lower_pre = lower
if do_comp:
if lower:
change = True
prim_config["composite_ops_record"].add(op_name)
input_args = prepare_python_api_arguments(op)
......@@ -683,8 +695,8 @@ def _lower_composite(
else:
op_desc = block.desc.append_op()
op_desc.copy_from(op.desc)
block._sync_with_cpp()
block._sync_with_cpp()
# Step3: Do some post-processing work
for op_idx in reversed(ops_to_remove):
block.desc._remove_op(op_idx, op_idx + 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册