未验证 提交 a4e0f666 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] fix loss of composite rule (#52120)

* fix_prim

* fix bug

* add note

* fix logic

* fix

* add note

* fix check

* fix bug

* fix bug

* fix bug

* add debug

* fix check

* fix bug

* sync print log

* fix test case

* change default

* change test case time
上级 49461a02
...@@ -217,7 +217,13 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -217,7 +217,13 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only @framework.static_only
def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): def to_prim(
blocks,
blacklist=frozenset(),
whitelist=frozenset(),
start_idx=-1,
backward_length=-1,
):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops. """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
The operators in blacklist will be excluded from program when lowering into primitives, and only the The operators in blacklist will be excluded from program when lowering into primitives, and only the
operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means
...@@ -229,6 +235,8 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): ...@@ -229,6 +235,8 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
Args: Args:
blacklist(frozenset): The Operators that will be exclude when lowering into primitives. blacklist(frozenset): The Operators that will be exclude when lowering into primitives.
whitelist(frozenset): Only the operators in whitelist will be lowering into primitives. whitelist(frozenset): Only the operators in whitelist will be lowering into primitives.
start_idx(int): If start_idx exceeds -1, ops[start_idx:] will be processed. Default: -1.
backward_length(int): If backward_length exceeds -1, ops[:-backward_length] will be processed. Default: -1.
""" """
if not core._is_fwd_prim_enabled(): if not core._is_fwd_prim_enabled():
return return
...@@ -258,7 +266,7 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): ...@@ -258,7 +266,7 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
blacklist = prim_config["forward_blacklist"] | blacklist blacklist = prim_config["forward_blacklist"] | blacklist
with framework.program_guard(main_program): with framework.program_guard(main_program):
print("Lowering composite forward ops begin...") print("Lowering composite forward ops begin...", flush=True)
if len(blacklist) > 0 and len(whitelist) > 0: if len(blacklist) > 0 and len(whitelist) > 0:
filter_ = lambda x: x.type in whitelist and x.type not in blacklist filter_ = lambda x: x.type in whitelist and x.type not in blacklist
...@@ -268,6 +276,13 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): ...@@ -268,6 +276,13 @@ def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()):
filter_ = lambda x: x.type in whitelist filter_ = lambda x: x.type in whitelist
else: else:
filter_ = lambda x: True filter_ = lambda x: True
primx._lower_composite(blocks, filter_) primx._lower_composite(
blocks,
filter_,
start_idx=start_idx,
backward_length=backward_length,
)
replace_ops = prim_config["composite_ops_record"] replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}") print(
f"Lowering composite forward ops finish: {replace_ops}", flush=True
)
...@@ -551,7 +551,10 @@ def _lower(block, reverse, blacklist): ...@@ -551,7 +551,10 @@ def _lower(block, reverse, blacklist):
def _lower_composite( def _lower_composite(
block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True block,
filter_: typing.Callable[[framework.Operator], bool] = lambda x: True,
start_idx=-1,
backward_length=-1,
): ):
"""The operators in block wich satisfy the filter conditon will be decomposite into primitives.""" """The operators in block wich satisfy the filter conditon will be decomposite into primitives."""
...@@ -602,13 +605,41 @@ def _lower_composite( ...@@ -602,13 +605,41 @@ def _lower_composite(
none_vars_to_remove = set() none_vars_to_remove = set()
change = None change = None
# Only process required sliced block
# If given start_idx, only ops[start_idx:] will be processed.
# If given backward_length, only ops[:-backward_length] will be processed.
# Note, start_idx and backward_length cannot be both given, because the length of non-processed part must be kept unchanged.
length = len(block.ops)
idx_list = range(length)
assert (
-1 <= backward_length <= length
), f'expect -1 <= backward_length <= {length}, but got backward_length: {backward_length}'
assert (
-1 <= start_idx <= length
), f'expect -1 <= start_idx <= {length}, but got start_idx: {start_idx}'
assert not (
backward_length > -1 and start_idx > -1
), f'got start_idx: {start_idx} and backward_length: {backward_length}'
if backward_length > -1:
idx_list = range(length - backward_length)
if start_idx > -1:
idx_list = range(start_idx, length)
# Step2: Process all ops in the target block # Step2: Process all ops in the target block
for op_idx in range(len(block.ops)): for op_idx in range(length):
op = block.ops[op_idx] op = block.ops[op_idx]
ops_to_remove.append(op_idx) ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and filter_(op):
change = True
op_name = op.type op_name = op.type
do_comp = (
(lookup_fn(op_name) is not None)
and filter_(op)
and op_idx in idx_list
)
if do_comp:
change = True
prim_config["composite_ops_record"].add(op_name) prim_config["composite_ops_record"].add(op_name)
input_args = prepare_python_api_arguments(op) input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table) bind(input_args, to_bind, value_table)
...@@ -686,12 +717,22 @@ def _lower_composite( ...@@ -686,12 +717,22 @@ def _lower_composite(
# composite ops may contain other composite ops, thus, call _lower_composite again. # composite ops may contain other composite ops, thus, call _lower_composite again.
if change: if change:
_lower_composite(block, filter_) _lower_composite(
block,
filter_,
start_idx=start_idx,
backward_length=backward_length,
)
return return
elif isinstance(block, typing.Sequence): elif isinstance(block, typing.Sequence):
for item in block: for item in block:
_lower_composite(item, filter_) _lower_composite(
item,
filter_,
start_idx=start_idx,
backward_length=backward_length,
)
return return
else: else:
raise TypeError raise TypeError
......
...@@ -1266,8 +1266,12 @@ class PrimHooker(PartialProgramLayerHook): ...@@ -1266,8 +1266,12 @@ class PrimHooker(PartialProgramLayerHook):
def after_append_backward(self, whole_program, backward_start_idx): def after_append_backward(self, whole_program, backward_start_idx):
backward_length = len(whole_program.block(0).ops) - backward_start_idx backward_length = len(whole_program.block(0).ops) - backward_start_idx
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
_to_prim(whole_program.blocks, whitelist=self.custom_vjps) # only process backward part of block
_to_prim(whole_program.blocks, backward_length=backward_length)
new_start_index = len(whole_program.block(0).ops) - backward_length new_start_index = len(whole_program.block(0).ops) - backward_length
if backward_length > 0:
# only process forward part of block
_to_prim(whole_program.blocks, start_idx=new_start_index)
return whole_program, new_start_index return whole_program, new_start_index
def after_infer(self, infer_program): def after_infer(self, infer_program):
...@@ -1693,9 +1697,21 @@ def enable_to_static(enable_to_static_bool): ...@@ -1693,9 +1697,21 @@ def enable_to_static(enable_to_static_bool):
@switch_to_static_graph @switch_to_static_graph
def _to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): def _to_prim(
blocks,
blacklist=frozenset(),
whitelist=frozenset(),
start_idx=-1,
backward_length=-1,
):
"""Swith to static graph and call to_prim.""" """Swith to static graph and call to_prim."""
# TODO(Aurelius84): Fix this cycle import problem # TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi from paddle.incubate.autograd import primapi
primapi.to_prim(blocks, blacklist=blacklist, whitelist=whitelist) primapi.to_prim(
blocks,
blacklist=blacklist,
whitelist=whitelist,
start_idx=start_idx,
backward_length=backward_length,
)
...@@ -1477,7 +1477,12 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): ...@@ -1477,7 +1477,12 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
): ):
op = program_desc.block(0).op(i) op = program_desc.block(0).op(i)
if op.type() == 'fill_any_like': # If prim forward op, fill_any_like will be decomposite as fill_constant.
if core._is_fwd_prim_enabled():
target = ('fill_any_like', 'fill_constant')
else:
target = 'fill_any_like'
if op.type() in target:
var_name = op.output('Out')[0] var_name = op.output('Out')[0]
names.append(var_name) names.append(var_name)
return names return names
......
...@@ -216,7 +216,10 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None): ...@@ -216,7 +216,10 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
primapi.to_prim(blocks) primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops] fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that batch_norm is splitted into small ops # Ensure that batch_norm is splitted into small ops
assert 'batch_norm' not in fwd_ops_new assert (
'batch_norm' not in fwd_ops_new
and 'reduce_mean' not in fwd_ops_new
)
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
......
...@@ -8,7 +8,7 @@ foreach(TEST_OP ${TEST_OPS}) ...@@ -8,7 +8,7 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach() endforeach()
set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 800) set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 850)
set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500) set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500)
if(WITH_CINN) if(WITH_CINN)
......
...@@ -55,7 +55,7 @@ class TestCustomVJP(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestCustomVJP(unittest.TestCase):
'elementwise_mul', 'elementwise_mul',
'scale', 'scale',
'cast', 'cast',
'fill_any_like', 'fill_constant',
'cast', 'cast',
'elementwise_mul', 'elementwise_mul',
'scale', 'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册