From 99fae95e02d8979db062c386195be6bc35aec2a5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Apr 2020 16:58:21 +0800 Subject: [PATCH] feat(mge/parampack): add user-defined key to pack params GitOrigin-RevId: 7d51dcae23734cf6b9ef00710ff3c3989c4e1fe0 --- python_module/megengine/module/parampack.py | 31 ++++++++++++++----- .../test/integration/test_parampack.py | 15 +++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/python_module/megengine/module/parampack.py b/python_module/megengine/module/parampack.py index b6025022a..c020a41d5 100644 --- a/python_module/megengine/module/parampack.py +++ b/python_module/megengine/module/parampack.py @@ -7,7 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections -from typing import Iterable, Optional +from typing import Callable, Iterable, Optional, Tuple import numpy as np @@ -35,16 +35,18 @@ class ParamPack(Module): nr_ignore_first: int = 8, max_size_per_group: int = 10, max_nr_params_per_group: int = 100, + group_func: Callable = lambda name, param: 0, ): super().__init__() self._model = model self._nr_ignore_first = nr_ignore_first self._max_size_per_group = max_size_per_group self._max_nr_params_per_group = max_nr_params_per_group + self._group_func = group_func self._grouped_params = [] self._packed_params = [] - params = model.parameters() + params = model.named_parameters() self._pack_params(params) def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: @@ -52,20 +54,33 @@ class ParamPack(Module): if requires_grad is None or param.requires_grad == requires_grad: yield param - def _pack_params(self, params: Iterable[Parameter]): + def named_parameters( + self, requires_grad: Optional[bool] = None + ) -> Iterable[Tuple[str, Parameter]]: + for idx, param in enumerate(self._packed_params): + if requires_grad is None or param.requires_grad == requires_grad: + yield "packed_param_" + str(idx), param + + def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): groups = collections.defaultdict(list) ignored = 0 param_id = 0 - for param in params: + for name, param in params: if self._nr_ignore_first > ignored: ignored += 1 self._grouped_params.append([{"shape": param.shape, "id": param_id}]) + param.pack_group_key = self._group_func(name, param) self._packed_params.append(param) else: - key = (param.dtype, param.device, param.requires_grad) + key = ( + param.dtype, + param.device, + param.requires_grad, + self._group_func(name, param), + ) groups[key].append({"tensor": param, "id": param_id}) param_id += 1 - for (dtype, device, requires_grad) in groups.keys(): + for (dtype, device, requires_grad, group_key) in groups.keys(): dtype_sz = np.dtype(dtype).itemsize align = device.mem_align if align < dtype_sz: @@ -74,7 +89,7 @@ class ParamPack(Module): assert align % dtype_sz == 0 align //= dtype_sz - group = groups[(dtype, device, requires_grad)] + group = groups[(dtype, device, requires_grad, group_key)] while group: aligned_pos = [] offset = 0 @@ -98,6 +113,7 @@ class ParamPack(Module): group = group[idx:] if idx == 1: # ignore param packs with only one item + params[0]["tensor"].pack_group_key = group_key self._packed_params.append(params[0]["tensor"]) self._grouped_params.append( [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] @@ -114,6 +130,7 @@ class ParamPack(Module): dtype=dtype, requires_grad=requires_grad, ) + new_param.pack_group_key = group_key self._packed_params.append(new_param) self._grouped_params.append( [{"shape": i["tensor"].shape, "id": i["id"]} for i in params] diff --git a/python_module/test/integration/test_parampack.py b/python_module/test/integration/test_parampack.py index 6b73c9f88..c9acc47c8 100644 --- a/python_module/test/integration/test_parampack.py +++ b/python_module/test/integration/test_parampack.py @@ -257,3 +257,18 @@ def test_correctness_parampack(): pred1 = infer1(data).numpy() pred2 = infer2(data).numpy() assert np.allclose(pred1, pred2) + + +def test_parampack_group_func(): + net = XORNet() + net = ParamPack( + net, + nr_ignore_first=1, + max_size_per_group=10, + max_nr_params_per_group=100, + group_func=lambda n, p: "weight" in n, + ) + for p in net.parameters(requires_grad=True): + assert p.pack_group_key is not None + for n, p in net.named_parameters(requires_grad=True): + assert p.pack_group_key is not None -- GitLab