未验证 提交 ac0a2e50 编写于 作者: A Aurelius84 提交者: GitHub

[LazyInit]Support LazyGuard and Polish interface code (#45462)

上级 edc9952c
......@@ -36,7 +36,7 @@ from .framework import disable_static # noqa: F401
from .framework import enable_static # noqa: F401
from .framework import in_dynamic_mode # noqa: F401
from .fluid.dataset import * # noqa: F401
from .fluid.lazy_init import LazyInit # noqa: F401
from .fluid.lazy_init import LazyGuard # noqa: F401
from .framework.dtype import dtype as dtype # noqa: F401
from .framework.dtype import uint8 # noqa: F401
......@@ -417,7 +417,7 @@ __all__ = [ # noqa
'cumprod',
'logcumsumexp',
'logit',
'LazyInit',
'LazyGuard',
'sign',
'is_empty',
'equal',
......
......@@ -38,6 +38,7 @@ from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_, in_dygraph_mode
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.core import VarDesc
from paddle.fluid.dygraph import no_grad
......@@ -1731,6 +1732,18 @@ class Layer(object):
self._dtype = dtype
return self
def _startup_program(self):
"""
Return starup program containing initialization operations of all parameters.
NOTE(dev): This is a very low level API and only for inner developer.
"""
startup_program = Program()
for param in self.parameters():
param._create_init_op(startup_program.global_block())
return startup_program
# [aliases] Compatible with old method names
set_dict = set_state_dict
load_dict = set_state_dict
......@@ -6794,18 +6794,19 @@ class EagerParamBase(_core_eager_eagertensor):
self.need_clip = kwargs.get('need_clip', True)
self.is_distributed = kwargs.get('is_distributed', False)
# self.block = default_main_program().global_block()
self.init_func = None
# hook functions for lazy initialization
self._init_func = None
self._init_op_creator = None
def set_init_func(self, obj):
self.init_func = obj
self._init_func = obj
@dygraph_only
def initialize(self):
assert self.init_func is not None, "Required self.init_func is not None, but received None."
self.init_func()
assert self._init_func is not None, "Required self._init_func is not None, but received None."
self._init_func()
# clear function handle to release resource
self.init_func = None
self._init_func = None
@property
def trainable(self):
......@@ -6820,6 +6821,13 @@ class EagerParamBase(_core_eager_eagertensor):
"The type of trainable MUST be bool, but the type is ",
type(trainable))
def _create_init_op(self, block):
"""
Call init_op_creator function to create initializer operation in block.
"""
assert self._init_op_creator is not None, "Required self._init_op_creator is not None, but received None."
self._init_op_creator(block)
def __str__(self):
"""
Convert a EagerParamBase object to a readable string.
......
......@@ -19,7 +19,7 @@ import functools
from . import framework
from . import core
from .framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, default_main_program, _current_expected_place
from .lazy_init import lazy_guard
from .lazy_init import lazy_init_helper
from .framework import program_guard
import numpy as np
from .core import VarDesc
......@@ -52,7 +52,7 @@ class Initializer(object):
pass
def __call__(self, param, block=None):
if not lazy_guard().state:
if not lazy_init_helper().state:
return self.forward(param, block)
return self._lazy_init(param, block)
......@@ -63,18 +63,21 @@ class Initializer(object):
raise NotImplementedError()
def _lazy_init(self, param, block=None):
# Apply lazy initialization
"""
Apply lazy initialization
"""
assert in_dygraph_mode()
new_block = lazy_guard().startup_program.global_block()
new_var = param._to_static_var(True, block=new_block)
def init_op_creator(forward, param, block):
new_var = param._to_static_var(True, block=block)
# Record initializer operator
with lazy_guard():
self.forward(new_var, new_block)
lazy_guard().enable(clear_cache=False)
with lazy_init_helper():
forward(new_var, block)
# Add hook function for initializing param in dygraph mode
func = functools.partial(self.forward, param, block)
param.set_init_func(func)
param.set_init_func(functools.partial(self.forward, param, block))
param._init_op_creator = functools.partial(init_op_creator,
self.forward, param)
return param
......
......@@ -14,23 +14,21 @@
from . import framework
__all__ = ["LazyInit"]
__all__ = ["LazyGuard"]
class LazyGuard(object):
class LazyInitHelper(object):
"""
Guard Context to trigger switching mode between dygraph and static mode,
A Helper Context to trigger switching mode between dygraph and static mode,
and holds the startup program resource.
"""
def __init__(self):
self._init_program()
self._state = False
self._tracer = None
self._in_guard = False
def enable(self, clear_cache=True):
def enable(self):
"""
Switch into lazy mode.
......@@ -42,9 +40,6 @@ class LazyGuard(object):
), "LazyInit.enable() is only available in dygraph mode."
self._state = True
if clear_cache:
self._init_program()
def disable(self):
"""
Exit from lazy mode.
......@@ -55,15 +50,12 @@ class LazyGuard(object):
return
self._state = False
def _init_program(self):
self.startup_program = framework.Program()
def __enter__(self):
"""
Switch into lazy mode and set _dygraph_tracer_ with None to convert
dygraph mode into static mode.
"""
self.enable(clear_cache=True)
self.enable()
if self._in_guard: return
self._tracer = framework._dygraph_tracer_
framework._dygraph_tracer_ = None
......@@ -85,26 +77,22 @@ class LazyGuard(object):
return self._state
_lazy_guard = LazyGuard()
_lazy_init_helper = LazyInitHelper()
def lazy_guard():
global _lazy_guard
return _lazy_guard
def lazy_init_helper():
global _lazy_init_helper
return _lazy_init_helper
class LazyInit(object):
class LazyGuard(object):
"""
LazyInit is a wrapper interface for nn.Layer, it forwards the construct
LazyGuard is a wrapper interface for nn.Layer, it forwards the construct
process of user defined Layer. Meanwhile, it provides necessary API to
trigger EagerParamBase Lazy Initialization and get startup Program.
"""
def __init__(self, class_obj=None):
self.class_obj = class_obj
self.clear_cache = True
def __call__(self, *args, **kwargs):
def __enter__(self):
"""
Construct instance from class_obj by Lazy Initializing parameters.
......@@ -112,43 +100,16 @@ class LazyInit(object):
.. code-block:: python
from paddle import LazyInit
from paddle import LazyGuard
from paddle.nn import Linear
with LazyGuard():
fc = LazyInit(Linear)(10, 10)
for param in fc.parameters():
param.initialize()
"""
assert isinstance(
self.class_obj, type
), "Required class_obj must be a class type, but received %s." % self.class_obj
global _lazy_guard
_lazy_guard.enable(self.clear_cache)
# construct Layer instance
with framework.program_guard(framework.Program()):
instance = self.class_obj(*args, **kwargs)
_lazy_guard.disable()
# set @property dynamically to visit startup_program
instance.startup_program = _lazy_guard.startup_program
return instance
@staticmethod
def startup_program():
"""
A static method to get startup program for the latest Layer.
Examples:
lazy_init_helper().enable()
.. code-block:: python
from paddle import LazyInit
from paddle.nn import Linear
fc = LazyInit(Linear)(10, 10)
print(LazyInit.startup_program())
"""
return _lazy_guard.startup_program
def __exit__(self, *args, **kwargs):
lazy_init_helper().disable()
......@@ -15,8 +15,8 @@
import paddle
import unittest
import numpy as np
from paddle import LazyInit
from paddle.nn import Linear
from paddle import LazyGuard
from paddle.nn import Linear, Layer
from paddle.nn.initializer import *
from paddle.fluid import unique_name
......@@ -47,11 +47,13 @@ class TestInitializerBase(unittest.TestCase):
unique_name.dygraph_parameter_name_checker._name_set = set()
def test_wrapper(self):
fc = LazyInit(Linear)(10,
with LazyGuard():
fc = Linear(10,
10,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr)
program = fc.startup_program
program = fc._startup_program()
print(program)
self.check_program(program)
def check_program(self, program):
......@@ -64,7 +66,8 @@ class TestInitializerBase(unittest.TestCase):
class TestDygraphLazy(TestInitializerBase):
def test_wrapper(self):
fc = LazyInit(Linear)(10,
with LazyGuard():
fc = Linear(10,
10,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr)
......@@ -89,6 +92,66 @@ class TestDygraphLazy(TestInitializerBase):
np.ones([10], dtype=np.float32) * 0.3)
class NestModel(Layer):
def __init__(self, base_model):
super(NestModel, self).__init__()
self.base_model = base_model
self.fc = Linear(10, 10)
def forward(self, x):
x = self.base_model(x)
x = self.fc(x)
return x
class TestNestModelLazy(TestInitializerBase):
def test_wrapper(self):
with LazyGuard():
base_model = Linear(10,
10,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr)
nest_model = NestModel(base_model)
self.check_data(nest_model)
self.check_program(nest_model)
def check_data(self, model):
x = paddle.randn([2, 10])
# weight and bias have no memory
with self.assertRaises(RuntimeError):
out = model(x)
for param in model.parameters():
param.initialize()
out = model(x)
self.assertEqual(out.shape, [2, 10])
np.testing.assert_allclose(model.base_model.weight.numpy(),
np.ones([10, 10], dtype=np.float32) * 0.6)
np.testing.assert_allclose(model.base_model.bias.numpy(),
np.ones([10], dtype=np.float32) * 0.3)
def check_program(self, model):
# verify nest_model startup_program
whole_program = model._startup_program()
self.assertEqual(whole_program.block(0).var("weight").shape, (10, 10))
self.assertEqual(whole_program.block(0).var("bias").shape, (10, ))
ops = [op.type for op in whole_program.block(0).ops]
init_ops = self.init_ops + ['uniform_random', 'fill_constant']
self.assertEqual(ops, init_ops)
# verify base_model startup_program
sub_program = model.base_model._startup_program()
self.assertEqual(sub_program.block(0).var("weight").shape, (10, 10))
self.assertEqual(sub_program.block(0).var("bias").shape, (10, ))
ops = [op.type for op in sub_program.block(0).ops]
self.assertEqual(ops, self.init_ops)
class TestUniform(TestInitializerBase):
def set_initializer(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册