提交 ecc842f1 编写于 作者: X xiongkun 提交者: Xiaoxu Chen

Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>
上级 d0c80f43
......@@ -143,6 +143,19 @@ class ProgramInfo:
return self.programs[key], self.op_size[key]
class PartialProgramLayerHook:
def before_append_backward(self, partial_program_layer, forward_program):
...
def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
...
def after_infer(self, partial_program_layer, infer_program):
...
class PartialProgramLayer:
"""
PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
......@@ -182,6 +195,7 @@ class PartialProgramLayer:
# Set default mode to train
self.training = True
self._infer_info = ProgramInfo()
self._backward_start_index_map = {}
custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer()
......@@ -195,6 +209,7 @@ class PartialProgramLayer:
# program_id -> list(scope)
self._scope_cache = {}
self._hooker = None
def __call__(self, inputs):
"""
......@@ -218,6 +233,9 @@ class PartialProgramLayer:
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def set_hooker(self, hooker):
self._hooker = hooker
def _get_scope(self, program_id=None, use_scope_cache=False):
if use_scope_cache:
if program_id not in self._scope_cache:
......@@ -242,7 +260,12 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_program(self, is_infer_mode=False):
if is_infer_mode:
return self._origin_main_program.clone(for_test=is_infer_mode)
infer_program = self._origin_main_program.clone(
for_test=is_infer_mode
)
if self._hooker:
infer_program = self._hooker.after_infer(self, infer_program)
return infer_program
else:
train_program = self._append_backward_desc(
self._origin_main_program
......@@ -609,6 +632,8 @@ class PartialProgramLayer:
def _append_backward_desc(self, main_program):
# make sure all status of is_test are False in train mode.
program = _change_is_test_status(main_program.clone(), is_test=False)
if self._hooker:
program = self._hooker.before_append_backward(self, program)
targets = []
for out in self._outputs.tolist():
if isinstance(out, framework.Variable):
......@@ -618,9 +643,15 @@ class PartialProgramLayer:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist())
start_idx = (
len(main_program.block(0).ops) + len(self._outputs.tolist()) + 1
)
if self._hooker:
program, start_idx = self._hooker.after_append_backward(
self, program, start_idx
)
# self._backward_start_index_map[self._hash_with_id(program, self)]
# TODO: prim make this complicate
self.prepare_gradient_aggregation(start_idx, main_program, program)
return program
......@@ -701,6 +732,11 @@ class PartialProgramLayer:
'program_id',
self.program_id,
]
print(self.forward_program)
print(self.backward_program)
print(self.program_id)
if self.training:
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
......
......@@ -19,7 +19,6 @@ import threading
import warnings
import weakref
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers
......@@ -39,7 +38,7 @@ from .origin_info import (
create_and_update_origin_info_map,
update_op_callstack_with_origin_info,
)
from .partial_program import partial_program_from
from .partial_program import PartialProgramLayerHook, partial_program_from
from .utils import (
ALREADY_D2S,
ast_to_func,
......@@ -1182,6 +1181,8 @@ class ProgramCache:
)
)
class PrimHooker(PartialProgramLayerHook):
def __init__(self):
custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
custom_vjps = {
......@@ -1189,19 +1190,36 @@ class ProgramCache:
for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type)
}
self.custom_vjps = custom_vjps
self.custom_vjps = {"softmax"}
def before_append_backward(
self, partial_program_layer, forward_program
):
if core._is_fwd_prim_enabled():
if not _in_amp_guard() and not _in_pure_fp16_guard():
_to_prim(
concrete_program.main_program.blocks, exclude=custom_vjps
)
to_prim(forward_program.block(0), self.custom_vjps)
return forward_program
partial_program = partial_program_from(concrete_program)
def after_append_backward(
self, partial_program_layer, whole_program, backward_start_idx
):
backward_length = (
len(whole_program.block(0).ops) - backward_start_idx
)
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
to_prim(whole_program.block(0))
new_start_index = (
len(whole_program.block(0).ops) - backward_length
)
return whole_program, new_start_index
if core._is_fwd_prim_enabled() and len(custom_vjps) != 0:
if not _in_amp_guard() and not _in_pure_fp16_guard():
_to_prim(partial_program.forward_program.blocks)
def after_infer(self, partial_program_layer, infer_program):
if core._is_fwd_prim_enabled():
to_prim(infer_program.block(0))
return infer_program
partial_program = partial_program_from(concrete_program)
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program
......@@ -1675,8 +1693,8 @@ def enable_to_static(enable_to_static_bool):
@switch_to_static_graph
def _to_prim(blocks, exclude=frozenset()):
def to_prim(blocks, exclude=frozenset()):
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd import primapi
primapi.to_prim(blocks, exclude=exclude)
primapi.to_prim(blocks, exclude)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册