未验证 提交 a4e0f666 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] fix loss of composite rule (#52120)

* fix_prim

* fix bug

* add note

* fix logic

* fix

* add note

* fix check

* fix bug

* fix bug

* fix bug

* add debug

* fix check

* fix bug

* sync print log

* fix test case

* change default

* change test case time
上级 49461a02
......@@ -217,7 +217,13 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only
def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
def to_prim(
blocks,
blacklist=frozenset(),
whitelist=frozenset(),
start_idx=-1,
backward_length=-1,
):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
The operators in blacklist will be excluded from program when lowering into primitives, and only the
operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means
......@@ -229,6 +235,8 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
Args:
blacklist(frozenset): The Operators that will be exclude when lowering into primitives.
whitelist(frozenset): Only the operators in whitelist will be lowering into primitives.
start_idx(int): If start_idx exceeds -1, ops[start_idx:] will be processed. Default: -1.
backward_length(int): If backward_length exceeds -1, ops[:-backward_length] will be processed. Default: -1.
"""
if not core._is_fwd_prim_enabled():
return
......@@ -258,7 +266,7 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
blacklist = prim_config["forward_blacklist"] | blacklist
with framework.program_guard(main_program):
print("Lowering composite forward ops begin...")
print("Lowering composite forward ops begin...", flush=True)
if len(blacklist) > 0 and len(whitelist) > 0:
filter_ = lambda x: x.type in whitelist and x.type not in blacklist
......@@ -268,6 +276,13 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
filter_ = lambda x: x.type in whitelist
else:
filter_ = lambda x: True
primx._lower_composite(blocks, filter_)
primx._lower_composite(
blocks,
filter_,
start_idx=start_idx,
backward_length=backward_length,
)
replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}")
print(
f"Lowering composite forward ops finish: {replace_ops}", flush=True
)
......@@ -551,7 +551,10 @@ def _lower(block, reverse, blacklist):
def _lower_composite(
block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True
block,
filter_: typing.Callable[[framework.Operator], bool] = lambda x: True,
start_idx=-1,
backward_length=-1,
):
"""The operators in block wich satisfy the filter conditon will be decomposite into primitives."""
......@@ -602,13 +605,41 @@ def _lower_composite(
none_vars_to_remove = set()
change = None
# Only process required sliced block
# If given start_idx, only ops[start_idx:] will be processed.
# If given backward_length, only ops[:-backward_length] will be processed.
# Note, start_idx and backward_length cannot be both given, because the length of non-processed part must be kept unchanged.
length = len(block.ops)
idx_list = range(length)
assert (
-1 <= backward_length <= length
), f'expect -1 <= backward_length <= {length}, but got backward_length: {backward_length}'
assert (
-1 <= start_idx <= length
), f'expect -1 <= start_idx <= {length}, but got start_idx: {start_idx}'
assert not (
backward_length > -1 and start_idx > -1
), f'got start_idx: {start_idx} and backward_length: {backward_length}'
if backward_length > -1:
idx_list = range(length - backward_length)
if start_idx > -1:
idx_list = range(start_idx, length)
# Step2: Process all ops in the target block
for op_idx in range(len(block.ops)):
for op_idx in range(length):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and filter_(op):
op_name = op.type
do_comp = (
(lookup_fn(op_name) is not None)
and filter_(op)
and op_idx in idx_list
)
if do_comp:
change = True
op_name = op.type
prim_config["composite_ops_record"].add(op_name)
input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table)
......@@ -686,12 +717,22 @@ def _lower_composite(
# composite ops may contain other composite ops, thus, call _lower_composite again.
if change:
_lower_composite(block, filter_)
_lower_composite(
block,
filter_,
start_idx=start_idx,
backward_length=backward_length,
)
return
elif isinstance(block, typing.Sequence):
for item in block:
_lower_composite(item, filter_)
_lower_composite(
item,
filter_,
start_idx=start_idx,
backward_length=backward_length,
)
return
else:
raise TypeError
......
......@@ -1266,8 +1266,12 @@ class PrimHooker(PartialProgramLayerHook):
def after_append_backward(self, whole_program, backward_start_idx):
backward_length = len(whole_program.block(0).ops) - backward_start_idx
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
_to_prim(whole_program.blocks, whitelist=self.custom_vjps)
# only process backward part of block
_to_prim(whole_program.blocks, backward_length=backward_length)
new_start_index = len(whole_program.block(0).ops) - backward_length
if backward_length > 0:
# only process forward part of block
_to_prim(whole_program.blocks, start_idx=new_start_index)
return whole_program, new_start_index
def after_infer(self, infer_program):
......@@ -1693,9 +1697,21 @@ def enable_to_static(enable_to_static_bool):
@switch_to_static_graph
def _to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
def _to_prim(
blocks,
blacklist=frozenset(),
whitelist=frozenset(),
start_idx=-1,
backward_length=-1,
):
"""Swith to static graph and call to_prim."""
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi
primapi.to_prim(blocks, blacklist=blacklist, whitelist=whitelist)
primapi.to_prim(
blocks,
blacklist=blacklist,
whitelist=whitelist,
start_idx=start_idx,
backward_length=backward_length,
)
......@@ -1477,7 +1477,12 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
):
op = program_desc.block(0).op(i)
if op.type() == 'fill_any_like':
# If prim forward op, fill_any_like will be decomposite as fill_constant.
if core._is_fwd_prim_enabled():
target = ('fill_any_like', 'fill_constant')
else:
target = 'fill_any_like'
if op.type() in target:
var_name = op.output('Out')[0]
names.append(var_name)
return names
......
......@@ -216,7 +216,10 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that batch_norm is splitted into small ops
assert 'batch_norm' not in fwd_ops_new
assert (
'batch_norm' not in fwd_ops_new
and 'reduce_mean' not in fwd_ops_new
)
exe = paddle.static.Executor()
exe.run(startup_program)
......
......@@ -8,7 +8,7 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 800)
set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 850)
set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500)
if(WITH_CINN)
......
......@@ -55,7 +55,7 @@ class TestCustomVJP(unittest.TestCase):
'elementwise_mul',
'scale',
'cast',
'fill_any_like',
'fill_constant',
'cast',
'elementwise_mul',
'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册