未验证 提交 c549c6b9 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2static ] Add ParameterRecorder to support science training cases. (#49459)

* [Dy2static] Add ParameterRecorder

* filter by shape(tensor)==0

* fix code by review

* fix random failed in CI. (especially coverage)

* fix bugs

* remove API changes to avoid static CI approval
上级 04e24e58
......@@ -154,6 +154,18 @@ def _convert_into_variable(tensor):
new_var = tensor._to_static_var(
to_parameter=False, persistable=is_persistable
)
# add param into parameter recorder to collect all the params used in this program.
if new_var.persistable is True:
# TODO(@xiongkun): 0d-tensor may be affected at present,
# but there is no particularly good method to identify whether 0d-tensor
# is used as buffer or "drop_out_state" in LSTM buffer variable.
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
ProgramTranslator.get_instance()._params_recorder.add(
tensor.block.program, tensor
)
return new_var
else:
return tensor
......
......@@ -420,42 +420,6 @@ class TestJitSaveInCompiletime(TestErrorBase):
self._test_raise_new_exception()
# # Situation 4: NotImplementedError
class TestSuggestionErrorInRuntime(TestErrorBase):
def set_func(self):
self.func = func_suggestion_error_in_runtime
def set_input(self):
self.input = paddle.to_tensor([2.0])
def set_exception_type(self):
self.exception_type = ValueError
def set_message(self):
self.expected_message = [
'File "{}", line 118, in forward'.format(self.filepath),
'return self.inner_net.forward(x)',
'File "{}", line 127, in forward'.format(self.filepath),
'def forward(self, x):',
'out = paddle.matmul(self.w, x)',
'<--- HERE',
'return out',
'Revise suggestion:',
'Please ensure all your sublayers are inheritted from nn.Layer.',
'Please ensure there is no tensor created explicitly depended on external data, we suggest to register it as buffer tensor. See',
]
def set_func_call(self):
# NOTE: self.func(self.input) is the StaticLayer().__call__(self.input)
self.func_call = lambda: self.func(self.input)
def test_error(self):
for disable_new_error in [0, 1]:
self._test_raise_new_exception(disable_new_error)
@paddle.jit.to_static
def func_ker_error(x):
d = {'x': x}
......
......@@ -442,5 +442,41 @@ class TestRemoveCommentInDy2St(unittest.TestCase):
self.assertEqual('#' not in code_string, True)
class Obj:
def __init__(self):
pass
def func(self, x):
return x + 1
obj = Obj()
class Net2:
def __init__(self):
super(Net2, self).__init__()
self.layer1 = paddle.nn.Linear(10, 10)
def forward(self, data):
@paddle.jit.to_static
def func(ins, x, loss_fn):
x = ins.layer1(x)
return loss_fn(x)
def func1(x):
return func(self, x, obj.func)
return func1(data)
class TestParameterRecorder(unittest.TestCase):
def test_recorder(self):
"""function calls nn.Layer case."""
net = Net()
x = paddle.randn([5, 10])
out = net.forward(x)
if __name__ == '__main__':
unittest.main()
......@@ -308,6 +308,10 @@ class PartialProgramLayer:
program = self._create_forward_backward_train_amp_program()
return program
@LazyInitialized
def _empty_backward_program_for_eval(self):
return paddle.static.Program()
@LazyInitialized
def _train_pure_fp16_forward_backward_program(self):
program = self._create_forward_backward_train_pure_fp16_program()
......@@ -363,7 +367,16 @@ class PartialProgramLayer:
program = self._train_forward_backward_program
return program[1]
else:
return paddle.static.Program()
"""
Can't just return paddle.static.Program(), because self.backward_program is a property,
whenever we call this method, a tmp Program() object is created and is gc immediatly
after executed the following line in PartialProgramLayer.__call__.
>>> self.backward_program.desc.block(0),
When we access RunProgramAPI, it's possible to get an invalid backward_program address.
"""
return self._empty_backward_program_for_eval
@LazyInitialized
def _train_program_id(self):
......
......@@ -57,6 +57,16 @@ __all__ = []
MAX_TRACED_PROGRAM_COUNT = 10
def synchronized(func):
func.__lock__ = threading.Lock()
def lock_func(*args, **kwargs):
with func.__lock__:
return func(*args, **kwargs)
return lock_func
class FunctionCache:
"""
Caches the transformed functions to avoid redundant conversions of the same function.
......@@ -969,12 +979,7 @@ class ConcreteProgram:
[class_instance] + list(static_inputs)
)
# 2. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = _extract_indeed_params_buffers(
class_instance
)
# 3. Builds program only once and returns the output Variables.
# 2. Builds program only once and returns the output Variables.
with param_guard(
get_parameters(class_instance, False)
), param_guard(get_buffers(class_instance, False)):
......@@ -994,6 +999,17 @@ class ConcreteProgram:
error_data.raise_new_exception()
raise
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
# 3. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = (
ProgramTranslator.get_instance()._params_recorder.pop(
main_program
)
)
if outputs is not None:
need_wrap_into_list = (
not isinstance(outputs, (tuple, list))
......@@ -1026,6 +1042,34 @@ def _extract_indeed_params_buffers(class_instance):
return params + buffers
class ParametersRecorder:
def __init__(self):
self.params_dict = {}
@synchronized
def add(self, program, param):
"""use the default_program as key, append param the parameter list."""
key = self._program_hash(program)
if key not in self.params_dict:
self.params_dict[key] = set()
params = self.params_dict[key]
params.add(param)
def pop(self, program):
params = self.params_dict.get(self._program_hash(program))
if params is None:
return []
del self.params_dict[self._program_hash(program)]
return list(params)
def _program_hash(self, program):
"""
because program is not deleted while calling from_func_spec.
so it's ok to use id(program)
"""
return id(program)
class ProgramCache:
"""
Wrapper class for the program functions defined by dygraph function.
......@@ -1098,16 +1142,6 @@ class ProgramCache:
return [cp for key, (cp, _) in self._caches.items()]
def synchronized(func):
func.__lock__ = threading.Lock()
def lock_func(*args, **kwargs):
with func.__lock__:
return func(*args, **kwargs)
return lock_func
class ProgramTranslator:
"""
Class to translate dygraph function into static graph function. The object
......@@ -1159,6 +1193,7 @@ class ProgramTranslator:
return
self._initialized = True
self._program_cache = ProgramCache()
self._params_recorder = ParametersRecorder()
self.enable_to_static = True
def enable(self, enable_to_static):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册