From f077a5292ced684517837dc945d2d30f7186d79b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 3 Apr 2020 18:35:28 +0800 Subject: [PATCH] fix(mge/param_pack): release old parameters GitOrigin-RevId: 40a1f044e9fb43147b4369049167e19196de0047 --- python_module/megengine/core/tensor_nn.py | 6 +++++- python_module/megengine/module/parampack.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python_module/megengine/core/tensor_nn.py b/python_module/megengine/core/tensor_nn.py index be5d27824..9c25df715 100644 --- a/python_module/megengine/core/tensor_nn.py +++ b/python_module/megengine/core/tensor_nn.py @@ -35,4 +35,8 @@ class Parameter(Tensor): def shape(self): r"""Return shape of parameter. """ - return self._symvar.imm_shape + if self._Tensor__val is not None: + return self._Tensor__val.shape + elif self._Tensor__sym is not None: + return self._Tensor__sym.imm_shape + return None diff --git a/python_module/megengine/module/parampack.py b/python_module/megengine/module/parampack.py index 6c19ea404..01b2c5555 100644 --- a/python_module/megengine/module/parampack.py +++ b/python_module/megengine/module/parampack.py @@ -56,7 +56,7 @@ class ParamPack(Module): for param in params: if self._nr_ignore_first > ignored: ignored += 1 - self._grouped_params.append([{"tensor": param, "id": param_id}]) + self._grouped_params.append([{"shape": param.shape, "id": param_id}]) self._packed_params.append(param) else: key = (param.dtype, param.device, param.requires_grad) @@ -96,7 +96,9 @@ class ParamPack(Module): if idx == 1: # ignore param packs with only one item self._packed_params.append(params[0]["tensor"]) - self._grouped_params.append(params) + self._grouped_params.append( + [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] + ) continue packed_value = np.zeros((offset,), dtype=dtype) @@ -110,7 +112,9 @@ class ParamPack(Module): requires_grad=requires_grad, ) self._packed_params.append(new_param) - self._grouped_params.append(params) + self._grouped_params.append( + [{"shape": i["tensor"].shape, "id": i["id"]} for i in params] + ) def forward(self, *args, **kwargs): replace_param = dict() @@ -120,7 +124,7 @@ class ParamPack(Module): if len(grouped_params) == 1: continue split = param_pack_split( - packed_param._symvar, [i["tensor"].shape for i in grouped_params] + packed_param._symvar, [i["shape"] for i in grouped_params] ) split = [ Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) -- GitLab