未验证 提交 ac0a2e50 编写于 作者: A Aurelius84 提交者: GitHub

[LazyInit]Support LazyGuard and Polish interface code (#45462)

上级 edc9952c
...@@ -36,7 +36,7 @@ from .framework import disable_static # noqa: F401 ...@@ -36,7 +36,7 @@ from .framework import disable_static # noqa: F401
from .framework import enable_static # noqa: F401 from .framework import enable_static # noqa: F401
from .framework import in_dynamic_mode # noqa: F401 from .framework import in_dynamic_mode # noqa: F401
from .fluid.dataset import * # noqa: F401 from .fluid.dataset import * # noqa: F401
from .fluid.lazy_init import LazyInit # noqa: F401 from .fluid.lazy_init import LazyGuard # noqa: F401
from .framework.dtype import dtype as dtype # noqa: F401 from .framework.dtype import dtype as dtype # noqa: F401
from .framework.dtype import uint8 # noqa: F401 from .framework.dtype import uint8 # noqa: F401
...@@ -417,7 +417,7 @@ __all__ = [ # noqa ...@@ -417,7 +417,7 @@ __all__ = [ # noqa
'cumprod', 'cumprod',
'logcumsumexp', 'logcumsumexp',
'logit', 'logit',
'LazyInit', 'LazyGuard',
'sign', 'sign',
'is_empty', 'is_empty',
'equal', 'equal',
......
...@@ -38,6 +38,7 @@ from paddle.fluid import framework ...@@ -38,6 +38,7 @@ from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_, in_dygraph_mode from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_, in_dygraph_mode
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.core import VarDesc from paddle.fluid.core import VarDesc
from paddle.fluid.dygraph import no_grad from paddle.fluid.dygraph import no_grad
...@@ -1731,6 +1732,18 @@ class Layer(object): ...@@ -1731,6 +1732,18 @@ class Layer(object):
self._dtype = dtype self._dtype = dtype
return self return self
def _startup_program(self):
"""
Return starup program containing initialization operations of all parameters.
NOTE(dev): This is a very low level API and only for inner developer.
"""
startup_program = Program()
for param in self.parameters():
param._create_init_op(startup_program.global_block())
return startup_program
# [aliases] Compatible with old method names # [aliases] Compatible with old method names
set_dict = set_state_dict set_dict = set_state_dict
load_dict = set_state_dict load_dict = set_state_dict
...@@ -6794,18 +6794,19 @@ class EagerParamBase(_core_eager_eagertensor): ...@@ -6794,18 +6794,19 @@ class EagerParamBase(_core_eager_eagertensor):
self.need_clip = kwargs.get('need_clip', True) self.need_clip = kwargs.get('need_clip', True)
self.is_distributed = kwargs.get('is_distributed', False) self.is_distributed = kwargs.get('is_distributed', False)
# self.block = default_main_program().global_block() # hook functions for lazy initialization
self.init_func = None self._init_func = None
self._init_op_creator = None
def set_init_func(self, obj): def set_init_func(self, obj):
self.init_func = obj self._init_func = obj
@dygraph_only @dygraph_only
def initialize(self): def initialize(self):
assert self.init_func is not None, "Required self.init_func is not None, but received None." assert self._init_func is not None, "Required self._init_func is not None, but received None."
self.init_func() self._init_func()
# clear function handle to release resource # clear function handle to release resource
self.init_func = None self._init_func = None
@property @property
def trainable(self): def trainable(self):
...@@ -6820,6 +6821,13 @@ class EagerParamBase(_core_eager_eagertensor): ...@@ -6820,6 +6821,13 @@ class EagerParamBase(_core_eager_eagertensor):
"The type of trainable MUST be bool, but the type is ", "The type of trainable MUST be bool, but the type is ",
type(trainable)) type(trainable))
def _create_init_op(self, block):
"""
Call init_op_creator function to create initializer operation in block.
"""
assert self._init_op_creator is not None, "Required self._init_op_creator is not None, but received None."
self._init_op_creator(block)
def __str__(self): def __str__(self):
""" """
Convert a EagerParamBase object to a readable string. Convert a EagerParamBase object to a readable string.
......
...@@ -19,7 +19,7 @@ import functools ...@@ -19,7 +19,7 @@ import functools
from . import framework from . import framework
from . import core from . import core
from .framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, default_main_program, _current_expected_place from .framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, default_main_program, _current_expected_place
from .lazy_init import lazy_guard from .lazy_init import lazy_init_helper
from .framework import program_guard from .framework import program_guard
import numpy as np import numpy as np
from .core import VarDesc from .core import VarDesc
...@@ -52,7 +52,7 @@ class Initializer(object): ...@@ -52,7 +52,7 @@ class Initializer(object):
pass pass
def __call__(self, param, block=None): def __call__(self, param, block=None):
if not lazy_guard().state: if not lazy_init_helper().state:
return self.forward(param, block) return self.forward(param, block)
return self._lazy_init(param, block) return self._lazy_init(param, block)
...@@ -63,18 +63,21 @@ class Initializer(object): ...@@ -63,18 +63,21 @@ class Initializer(object):
raise NotImplementedError() raise NotImplementedError()
def _lazy_init(self, param, block=None): def _lazy_init(self, param, block=None):
# Apply lazy initialization """
Apply lazy initialization
"""
assert in_dygraph_mode() assert in_dygraph_mode()
new_block = lazy_guard().startup_program.global_block()
new_var = param._to_static_var(True, block=new_block)
# Record initializer operator def init_op_creator(forward, param, block):
with lazy_guard(): new_var = param._to_static_var(True, block=block)
self.forward(new_var, new_block) # Record initializer operator
lazy_guard().enable(clear_cache=False) with lazy_init_helper():
forward(new_var, block)
# Add hook function for initializing param in dygraph mode # Add hook function for initializing param in dygraph mode
func = functools.partial(self.forward, param, block) param.set_init_func(functools.partial(self.forward, param, block))
param.set_init_func(func) param._init_op_creator = functools.partial(init_op_creator,
self.forward, param)
return param return param
......
...@@ -14,23 +14,21 @@ ...@@ -14,23 +14,21 @@
from . import framework from . import framework
__all__ = ["LazyInit"] __all__ = ["LazyGuard"]
class LazyGuard(object): class LazyInitHelper(object):
""" """
Guard Context to trigger switching mode between dygraph and static mode, A Helper Context to trigger switching mode between dygraph and static mode,
and holds the startup program resource. and holds the startup program resource.
""" """
def __init__(self): def __init__(self):
self._init_program()
self._state = False self._state = False
self._tracer = None self._tracer = None
self._in_guard = False self._in_guard = False
def enable(self, clear_cache=True): def enable(self):
""" """
Switch into lazy mode. Switch into lazy mode.
...@@ -42,9 +40,6 @@ class LazyGuard(object): ...@@ -42,9 +40,6 @@ class LazyGuard(object):
), "LazyInit.enable() is only available in dygraph mode." ), "LazyInit.enable() is only available in dygraph mode."
self._state = True self._state = True
if clear_cache:
self._init_program()
def disable(self): def disable(self):
""" """
Exit from lazy mode. Exit from lazy mode.
...@@ -55,15 +50,12 @@ class LazyGuard(object): ...@@ -55,15 +50,12 @@ class LazyGuard(object):
return return
self._state = False self._state = False
def _init_program(self):
self.startup_program = framework.Program()
def __enter__(self): def __enter__(self):
""" """
Switch into lazy mode and set _dygraph_tracer_ with None to convert Switch into lazy mode and set _dygraph_tracer_ with None to convert
dygraph mode into static mode. dygraph mode into static mode.
""" """
self.enable(clear_cache=True) self.enable()
if self._in_guard: return if self._in_guard: return
self._tracer = framework._dygraph_tracer_ self._tracer = framework._dygraph_tracer_
framework._dygraph_tracer_ = None framework._dygraph_tracer_ = None
...@@ -85,26 +77,22 @@ class LazyGuard(object): ...@@ -85,26 +77,22 @@ class LazyGuard(object):
return self._state return self._state
_lazy_guard = LazyGuard() _lazy_init_helper = LazyInitHelper()
def lazy_guard(): def lazy_init_helper():
global _lazy_guard global _lazy_init_helper
return _lazy_guard return _lazy_init_helper
class LazyInit(object): class LazyGuard(object):
""" """
LazyInit is a wrapper interface for nn.Layer, it forwards the construct LazyGuard is a wrapper interface for nn.Layer, it forwards the construct
process of user defined Layer. Meanwhile, it provides necessary API to process of user defined Layer. Meanwhile, it provides necessary API to
trigger EagerParamBase Lazy Initialization and get startup Program. trigger EagerParamBase Lazy Initialization and get startup Program.
""" """
def __init__(self, class_obj=None): def __enter__(self):
self.class_obj = class_obj
self.clear_cache = True
def __call__(self, *args, **kwargs):
""" """
Construct instance from class_obj by Lazy Initializing parameters. Construct instance from class_obj by Lazy Initializing parameters.
...@@ -112,43 +100,16 @@ class LazyInit(object): ...@@ -112,43 +100,16 @@ class LazyInit(object):
.. code-block:: python .. code-block:: python
from paddle import LazyInit from paddle import LazyGuard
from paddle.nn import Linear from paddle.nn import Linear
fc = LazyInit(Linear)(10, 10) with LazyGuard():
fc = LazyInit(Linear)(10, 10)
for param in fc.parameters(): for param in fc.parameters():
param.initialize() param.initialize()
""" """
assert isinstance( lazy_init_helper().enable()
self.class_obj, type
), "Required class_obj must be a class type, but received %s." % self.class_obj
global _lazy_guard
_lazy_guard.enable(self.clear_cache)
# construct Layer instance
with framework.program_guard(framework.Program()):
instance = self.class_obj(*args, **kwargs)
_lazy_guard.disable()
# set @property dynamically to visit startup_program
instance.startup_program = _lazy_guard.startup_program
return instance
@staticmethod
def startup_program():
"""
A static method to get startup program for the latest Layer.
Examples: def __exit__(self, *args, **kwargs):
lazy_init_helper().disable()
.. code-block:: python
from paddle import LazyInit
from paddle.nn import Linear
fc = LazyInit(Linear)(10, 10)
print(LazyInit.startup_program())
"""
return _lazy_guard.startup_program
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import paddle import paddle
import unittest import unittest
import numpy as np import numpy as np
from paddle import LazyInit from paddle import LazyGuard
from paddle.nn import Linear from paddle.nn import Linear, Layer
from paddle.nn.initializer import * from paddle.nn.initializer import *
from paddle.fluid import unique_name from paddle.fluid import unique_name
...@@ -47,11 +47,13 @@ class TestInitializerBase(unittest.TestCase): ...@@ -47,11 +47,13 @@ class TestInitializerBase(unittest.TestCase):
unique_name.dygraph_parameter_name_checker._name_set = set() unique_name.dygraph_parameter_name_checker._name_set = set()
def test_wrapper(self): def test_wrapper(self):
fc = LazyInit(Linear)(10, with LazyGuard():
10, fc = Linear(10,
weight_attr=self.weight_attr, 10,
bias_attr=self.bias_attr) weight_attr=self.weight_attr,
program = fc.startup_program bias_attr=self.bias_attr)
program = fc._startup_program()
print(program)
self.check_program(program) self.check_program(program)
def check_program(self, program): def check_program(self, program):
...@@ -64,10 +66,11 @@ class TestInitializerBase(unittest.TestCase): ...@@ -64,10 +66,11 @@ class TestInitializerBase(unittest.TestCase):
class TestDygraphLazy(TestInitializerBase): class TestDygraphLazy(TestInitializerBase):
def test_wrapper(self): def test_wrapper(self):
fc = LazyInit(Linear)(10, with LazyGuard():
10, fc = Linear(10,
weight_attr=self.weight_attr, 10,
bias_attr=self.bias_attr) weight_attr=self.weight_attr,
bias_attr=self.bias_attr)
self.check_data(fc) self.check_data(fc)
...@@ -89,6 +92,66 @@ class TestDygraphLazy(TestInitializerBase): ...@@ -89,6 +92,66 @@ class TestDygraphLazy(TestInitializerBase):
np.ones([10], dtype=np.float32) * 0.3) np.ones([10], dtype=np.float32) * 0.3)
class NestModel(Layer):
def __init__(self, base_model):
super(NestModel, self).__init__()
self.base_model = base_model
self.fc = Linear(10, 10)
def forward(self, x):
x = self.base_model(x)
x = self.fc(x)
return x
class TestNestModelLazy(TestInitializerBase):
def test_wrapper(self):
with LazyGuard():
base_model = Linear(10,
10,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr)
nest_model = NestModel(base_model)
self.check_data(nest_model)
self.check_program(nest_model)
def check_data(self, model):
x = paddle.randn([2, 10])
# weight and bias have no memory
with self.assertRaises(RuntimeError):
out = model(x)
for param in model.parameters():
param.initialize()
out = model(x)
self.assertEqual(out.shape, [2, 10])
np.testing.assert_allclose(model.base_model.weight.numpy(),
np.ones([10, 10], dtype=np.float32) * 0.6)
np.testing.assert_allclose(model.base_model.bias.numpy(),
np.ones([10], dtype=np.float32) * 0.3)
def check_program(self, model):
# verify nest_model startup_program
whole_program = model._startup_program()
self.assertEqual(whole_program.block(0).var("weight").shape, (10, 10))
self.assertEqual(whole_program.block(0).var("bias").shape, (10, ))
ops = [op.type for op in whole_program.block(0).ops]
init_ops = self.init_ops + ['uniform_random', 'fill_constant']
self.assertEqual(ops, init_ops)
# verify base_model startup_program
sub_program = model.base_model._startup_program()
self.assertEqual(sub_program.block(0).var("weight").shape, (10, 10))
self.assertEqual(sub_program.block(0).var("bias").shape, (10, ))
ops = [op.type for op in sub_program.block(0).ops]
self.assertEqual(ops, self.init_ops)
class TestUniform(TestInitializerBase): class TestUniform(TestInitializerBase):
def set_initializer(self): def set_initializer(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册