提交 f077a529 编写于 作者: M Megvii Engine Team

fix(mge/param_pack): release old parameters

GitOrigin-RevId: 40a1f044e9fb43147b4369049167e19196de0047
上级 48822557
......@@ -35,4 +35,8 @@ class Parameter(Tensor):
def shape(self):
r"""Return shape of parameter.
"""
return self._symvar.imm_shape
if self._Tensor__val is not None:
return self._Tensor__val.shape
elif self._Tensor__sym is not None:
return self._Tensor__sym.imm_shape
return None
......@@ -56,7 +56,7 @@ class ParamPack(Module):
for param in params:
if self._nr_ignore_first > ignored:
ignored += 1
self._grouped_params.append([{"tensor": param, "id": param_id}])
self._grouped_params.append([{"shape": param.shape, "id": param_id}])
self._packed_params.append(param)
else:
key = (param.dtype, param.device, param.requires_grad)
......@@ -96,7 +96,9 @@ class ParamPack(Module):
if idx == 1:
# ignore param packs with only one item
self._packed_params.append(params[0]["tensor"])
self._grouped_params.append(params)
self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
)
continue
packed_value = np.zeros((offset,), dtype=dtype)
......@@ -110,7 +112,9 @@ class ParamPack(Module):
requires_grad=requires_grad,
)
self._packed_params.append(new_param)
self._grouped_params.append(params)
self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params]
)
def forward(self, *args, **kwargs):
replace_param = dict()
......@@ -120,7 +124,7 @@ class ParamPack(Module):
if len(grouped_params) == 1:
continue
split = param_pack_split(
packed_param._symvar, [i["tensor"].shape for i in grouped_params]
packed_param._symvar, [i["shape"] for i in grouped_params]
)
split = [
Parameter(Tensor(i, requires_grad=packed_param.requires_grad))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册