提交 99fae95e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/parampack): add user-defined key to pack params

GitOrigin-RevId: 7d51dcae23734cf6b9ef00710ff3c3989c4e1fe0
上级 0668b343
......@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Iterable, Optional
from typing import Callable, Iterable, Optional, Tuple
import numpy as np
......@@ -35,16 +35,18 @@ class ParamPack(Module):
nr_ignore_first: int = 8,
max_size_per_group: int = 10,
max_nr_params_per_group: int = 100,
group_func: Callable = lambda name, param: 0,
):
super().__init__()
self._model = model
self._nr_ignore_first = nr_ignore_first
self._max_size_per_group = max_size_per_group
self._max_nr_params_per_group = max_nr_params_per_group
self._group_func = group_func
self._grouped_params = []
self._packed_params = []
params = model.parameters()
params = model.named_parameters()
self._pack_params(params)
def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]:
......@@ -52,20 +54,33 @@ class ParamPack(Module):
if requires_grad is None or param.requires_grad == requires_grad:
yield param
def _pack_params(self, params: Iterable[Parameter]):
def named_parameters(
self, requires_grad: Optional[bool] = None
) -> Iterable[Tuple[str, Parameter]]:
for idx, param in enumerate(self._packed_params):
if requires_grad is None or param.requires_grad == requires_grad:
yield "packed_param_" + str(idx), param
def _pack_params(self, params: Iterable[Tuple[str, Parameter]]):
groups = collections.defaultdict(list)
ignored = 0
param_id = 0
for param in params:
for name, param in params:
if self._nr_ignore_first > ignored:
ignored += 1
self._grouped_params.append([{"shape": param.shape, "id": param_id}])
param.pack_group_key = self._group_func(name, param)
self._packed_params.append(param)
else:
key = (param.dtype, param.device, param.requires_grad)
key = (
param.dtype,
param.device,
param.requires_grad,
self._group_func(name, param),
)
groups[key].append({"tensor": param, "id": param_id})
param_id += 1
for (dtype, device, requires_grad) in groups.keys():
for (dtype, device, requires_grad, group_key) in groups.keys():
dtype_sz = np.dtype(dtype).itemsize
align = device.mem_align
if align < dtype_sz:
......@@ -74,7 +89,7 @@ class ParamPack(Module):
assert align % dtype_sz == 0
align //= dtype_sz
group = groups[(dtype, device, requires_grad)]
group = groups[(dtype, device, requires_grad, group_key)]
while group:
aligned_pos = []
offset = 0
......@@ -98,6 +113,7 @@ class ParamPack(Module):
group = group[idx:]
if idx == 1:
# ignore param packs with only one item
params[0]["tensor"].pack_group_key = group_key
self._packed_params.append(params[0]["tensor"])
self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
......@@ -114,6 +130,7 @@ class ParamPack(Module):
dtype=dtype,
requires_grad=requires_grad,
)
new_param.pack_group_key = group_key
self._packed_params.append(new_param)
self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params]
......
......@@ -257,3 +257,18 @@ def test_correctness_parampack():
pred1 = infer1(data).numpy()
pred2 = infer2(data).numpy()
assert np.allclose(pred1, pred2)
def test_parampack_group_func():
net = XORNet()
net = ParamPack(
net,
nr_ignore_first=1,
max_size_per_group=10,
max_nr_params_per_group=100,
group_func=lambda n, p: "weight" in n,
)
for p in net.parameters(requires_grad=True):
assert p.pack_group_key is not None
for n, p in net.named_parameters(requires_grad=True):
assert p.pack_group_key is not None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册