提交 ecc842f1 编写于 作者: X xiongkun 提交者: Xiaoxu Chen

Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>
上级 d0c80f43
...@@ -143,6 +143,19 @@ class ProgramInfo: ...@@ -143,6 +143,19 @@ class ProgramInfo:
return self.programs[key], self.op_size[key] return self.programs[key], self.op_size[key]
class PartialProgramLayerHook:
def before_append_backward(self, partial_program_layer, forward_program):
...
def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
...
def after_infer(self, partial_program_layer, infer_program):
...
class PartialProgramLayer: class PartialProgramLayer:
""" """
PartialProgramLayer wraps all the ops from layers decorated by `@to_static` PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
...@@ -182,6 +195,7 @@ class PartialProgramLayer: ...@@ -182,6 +195,7 @@ class PartialProgramLayer:
# Set default mode to train # Set default mode to train
self.training = True self.training = True
self._infer_info = ProgramInfo() self._infer_info = ProgramInfo()
self._backward_start_index_map = {}
custom_white_list, custom_black_list = None, None custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
...@@ -195,6 +209,7 @@ class PartialProgramLayer: ...@@ -195,6 +209,7 @@ class PartialProgramLayer:
# program_id -> list(scope) # program_id -> list(scope)
self._scope_cache = {} self._scope_cache = {}
self._hooker = None
def __call__(self, inputs): def __call__(self, inputs):
""" """
...@@ -218,6 +233,9 @@ class PartialProgramLayer: ...@@ -218,6 +233,9 @@ class PartialProgramLayer:
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
def set_hooker(self, hooker):
self._hooker = hooker
def _get_scope(self, program_id=None, use_scope_cache=False): def _get_scope(self, program_id=None, use_scope_cache=False):
if use_scope_cache: if use_scope_cache:
if program_id not in self._scope_cache: if program_id not in self._scope_cache:
...@@ -242,7 +260,12 @@ class PartialProgramLayer: ...@@ -242,7 +260,12 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_program(self, is_infer_mode=False): def _create_program(self, is_infer_mode=False):
if is_infer_mode: if is_infer_mode:
return self._origin_main_program.clone(for_test=is_infer_mode) infer_program = self._origin_main_program.clone(
for_test=is_infer_mode
)
if self._hooker:
infer_program = self._hooker.after_infer(self, infer_program)
return infer_program
else: else:
train_program = self._append_backward_desc( train_program = self._append_backward_desc(
self._origin_main_program self._origin_main_program
...@@ -609,6 +632,8 @@ class PartialProgramLayer: ...@@ -609,6 +632,8 @@ class PartialProgramLayer:
def _append_backward_desc(self, main_program): def _append_backward_desc(self, main_program):
# make sure all status of is_test are False in train mode. # make sure all status of is_test are False in train mode.
program = _change_is_test_status(main_program.clone(), is_test=False) program = _change_is_test_status(main_program.clone(), is_test=False)
if self._hooker:
program = self._hooker.before_append_backward(self, program)
targets = [] targets = []
for out in self._outputs.tolist(): for out in self._outputs.tolist():
if isinstance(out, framework.Variable): if isinstance(out, framework.Variable):
...@@ -618,10 +643,16 @@ class PartialProgramLayer: ...@@ -618,10 +643,16 @@ class PartialProgramLayer:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled() core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
start_idx = (
start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist()) len(main_program.block(0).ops) + len(self._outputs.tolist()) + 1
)
self.prepare_gradient_aggregation(start_idx, main_program, program) if self._hooker:
program, start_idx = self._hooker.after_append_backward(
self, program, start_idx
)
# self._backward_start_index_map[self._hash_with_id(program, self)]
# TODO: prim make this complicate
self.prepare_gradient_aggregation(start_idx, main_program, program)
return program return program
...@@ -701,6 +732,11 @@ class PartialProgramLayer: ...@@ -701,6 +732,11 @@ class PartialProgramLayer:
'program_id', 'program_id',
self.program_id, self.program_id,
] ]
print(self.forward_program)
print(self.backward_program)
print(self.program_id)
if self.training: if self.training:
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like # NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get # `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
......
...@@ -19,7 +19,6 @@ import threading ...@@ -19,7 +19,6 @@ import threading
import warnings import warnings
import weakref import weakref
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import _non_static_mode, core, framework from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
...@@ -39,7 +38,7 @@ from .origin_info import ( ...@@ -39,7 +38,7 @@ from .origin_info import (
create_and_update_origin_info_map, create_and_update_origin_info_map,
update_op_callstack_with_origin_info, update_op_callstack_with_origin_info,
) )
from .partial_program import partial_program_from from .partial_program import PartialProgramLayerHook, partial_program_from
from .utils import ( from .utils import (
ALREADY_D2S, ALREADY_D2S,
ast_to_func, ast_to_func,
...@@ -1182,26 +1181,45 @@ class ProgramCache: ...@@ -1182,26 +1181,45 @@ class ProgramCache:
) )
) )
custom_vjps = set() class PrimHooker(PartialProgramLayerHook):
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled(): def __init__(self):
custom_vjps = { custom_vjps = set()
op.type if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
for op in concrete_program.main_program.block(0).ops custom_vjps = {
if core.has_comp_grad_op_maker(op.type) op.type
} for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
if core._is_fwd_prim_enabled(): }
if not _in_amp_guard() and not _in_pure_fp16_guard(): self.custom_vjps = custom_vjps
_to_prim( self.custom_vjps = {"softmax"}
concrete_program.main_program.blocks, exclude=custom_vjps
def before_append_backward(
self, partial_program_layer, forward_program
):
if core._is_fwd_prim_enabled():
to_prim(forward_program.block(0), self.custom_vjps)
return forward_program
def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
) )
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
to_prim(whole_program.block(0))
new_start_index = (
len(whole_program.block(0).ops) - backward_length
)
return whole_program, new_start_index
partial_program = partial_program_from(concrete_program) def after_infer(self, partial_program_layer, infer_program):
if core._is_fwd_prim_enabled():
if core._is_fwd_prim_enabled() and len(custom_vjps) != 0: to_prim(infer_program.block(0))
if not _in_amp_guard() and not _in_pure_fp16_guard(): return infer_program
_to_prim(partial_program.forward_program.blocks)
partial_program = partial_program_from(concrete_program)
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program return concrete_program, partial_program
...@@ -1675,8 +1693,8 @@ def enable_to_static(enable_to_static_bool): ...@@ -1675,8 +1693,8 @@ def enable_to_static(enable_to_static_bool):
@switch_to_static_graph @switch_to_static_graph
def _to_prim(blocks, exclude=frozenset()): def to_prim(blocks, exclude=frozenset()):
# TODO(Aurelius84): Fix this cycle import problem # TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi from paddle.incubate.autograd import primapi
primapi.to_prim(blocks, exclude=exclude) primapi.to_prim(blocks, exclude)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册