未验证 提交 44a32fbd 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] bug fix: Lazy initialize bugs (#50785)

上级 016f5ecb
...@@ -7230,7 +7230,7 @@ class EagerParamBase(_core_eager_eagertensor): ...@@ -7230,7 +7230,7 @@ class EagerParamBase(_core_eager_eagertensor):
assert ( assert (
self._init_func is not None self._init_func is not None
), "Required self._init_func is not None, but received None." ), "Required self._init_func is not None, but received None."
self._init_func() self._init_func(self, None)
# clear function handle to release resource # clear function handle to release resource
self._init_func = None self._init_func = None
...@@ -7255,7 +7255,7 @@ class EagerParamBase(_core_eager_eagertensor): ...@@ -7255,7 +7255,7 @@ class EagerParamBase(_core_eager_eagertensor):
assert ( assert (
self._init_op_creator is not None self._init_op_creator is not None
), "Required self._init_op_creator is not None, but received None." ), "Required self._init_op_creator is not None, but received None."
self._init_op_creator(block) self._init_op_creator(self, block)
def __str__(self): def __str__(self):
""" """
...@@ -7307,6 +7307,8 @@ class EagerParamBase(_core_eager_eagertensor): ...@@ -7307,6 +7307,8 @@ class EagerParamBase(_core_eager_eagertensor):
new_param = EagerParamBase(self.shape, self.dtype, **state) new_param = EagerParamBase(self.shape, self.dtype, **state)
memo[id(self)] = new_param memo[id(self)] = new_param
new_param.copy_(self, True) new_param.copy_(self, True)
new_param._init_func = self._init_func
new_param._init_op_creator = self._init_op_creator
return new_param return new_param
def _copy_to(self, device, blocking): def _copy_to(self, device, blocking):
......
...@@ -58,9 +58,9 @@ class Initializer: ...@@ -58,9 +58,9 @@ class Initializer:
forward(new_var, block) forward(new_var, block)
# Add hook function for initializing param in dygraph mode # Add hook function for initializing param in dygraph mode
param.set_init_func(functools.partial(self.forward, param, block)) param.set_init_func(functools.partial(self.forward))
param._init_op_creator = functools.partial( param._init_op_creator = functools.partial(
init_op_creator, self.forward, param init_op_creator, self.forward
) )
return param return param
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册