提交 99fae95e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/parampack): add user-defined key to pack params

GitOrigin-RevId: 7d51dcae23734cf6b9ef00710ff3c3989c4e1fe0
上级 0668b343
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
from typing import Iterable, Optional from typing import Callable, Iterable, Optional, Tuple
import numpy as np import numpy as np
...@@ -35,16 +35,18 @@ class ParamPack(Module): ...@@ -35,16 +35,18 @@ class ParamPack(Module):
nr_ignore_first: int = 8, nr_ignore_first: int = 8,
max_size_per_group: int = 10, max_size_per_group: int = 10,
max_nr_params_per_group: int = 100, max_nr_params_per_group: int = 100,
group_func: Callable = lambda name, param: 0,
): ):
super().__init__() super().__init__()
self._model = model self._model = model
self._nr_ignore_first = nr_ignore_first self._nr_ignore_first = nr_ignore_first
self._max_size_per_group = max_size_per_group self._max_size_per_group = max_size_per_group
self._max_nr_params_per_group = max_nr_params_per_group self._max_nr_params_per_group = max_nr_params_per_group
self._group_func = group_func
self._grouped_params = [] self._grouped_params = []
self._packed_params = [] self._packed_params = []
params = model.parameters() params = model.named_parameters()
self._pack_params(params) self._pack_params(params)
def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]:
...@@ -52,20 +54,33 @@ class ParamPack(Module): ...@@ -52,20 +54,33 @@ class ParamPack(Module):
if requires_grad is None or param.requires_grad == requires_grad: if requires_grad is None or param.requires_grad == requires_grad:
yield param yield param
def _pack_params(self, params: Iterable[Parameter]): def named_parameters(
self, requires_grad: Optional[bool] = None
) -> Iterable[Tuple[str, Parameter]]:
for idx, param in enumerate(self._packed_params):
if requires_grad is None or param.requires_grad == requires_grad:
yield "packed_param_" + str(idx), param
def _pack_params(self, params: Iterable[Tuple[str, Parameter]]):
groups = collections.defaultdict(list) groups = collections.defaultdict(list)
ignored = 0 ignored = 0
param_id = 0 param_id = 0
for param in params: for name, param in params:
if self._nr_ignore_first > ignored: if self._nr_ignore_first > ignored:
ignored += 1 ignored += 1
self._grouped_params.append([{"shape": param.shape, "id": param_id}]) self._grouped_params.append([{"shape": param.shape, "id": param_id}])
param.pack_group_key = self._group_func(name, param)
self._packed_params.append(param) self._packed_params.append(param)
else: else:
key = (param.dtype, param.device, param.requires_grad) key = (
param.dtype,
param.device,
param.requires_grad,
self._group_func(name, param),
)
groups[key].append({"tensor": param, "id": param_id}) groups[key].append({"tensor": param, "id": param_id})
param_id += 1 param_id += 1
for (dtype, device, requires_grad) in groups.keys(): for (dtype, device, requires_grad, group_key) in groups.keys():
dtype_sz = np.dtype(dtype).itemsize dtype_sz = np.dtype(dtype).itemsize
align = device.mem_align align = device.mem_align
if align < dtype_sz: if align < dtype_sz:
...@@ -74,7 +89,7 @@ class ParamPack(Module): ...@@ -74,7 +89,7 @@ class ParamPack(Module):
assert align % dtype_sz == 0 assert align % dtype_sz == 0
align //= dtype_sz align //= dtype_sz
group = groups[(dtype, device, requires_grad)] group = groups[(dtype, device, requires_grad, group_key)]
while group: while group:
aligned_pos = [] aligned_pos = []
offset = 0 offset = 0
...@@ -98,6 +113,7 @@ class ParamPack(Module): ...@@ -98,6 +113,7 @@ class ParamPack(Module):
group = group[idx:] group = group[idx:]
if idx == 1: if idx == 1:
# ignore param packs with only one item # ignore param packs with only one item
params[0]["tensor"].pack_group_key = group_key
self._packed_params.append(params[0]["tensor"]) self._packed_params.append(params[0]["tensor"])
self._grouped_params.append( self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
...@@ -114,6 +130,7 @@ class ParamPack(Module): ...@@ -114,6 +130,7 @@ class ParamPack(Module):
dtype=dtype, dtype=dtype,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
new_param.pack_group_key = group_key
self._packed_params.append(new_param) self._packed_params.append(new_param)
self._grouped_params.append( self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params] [{"shape": i["tensor"].shape, "id": i["id"]} for i in params]
......
...@@ -257,3 +257,18 @@ def test_correctness_parampack(): ...@@ -257,3 +257,18 @@ def test_correctness_parampack():
pred1 = infer1(data).numpy() pred1 = infer1(data).numpy()
pred2 = infer2(data).numpy() pred2 = infer2(data).numpy()
assert np.allclose(pred1, pred2) assert np.allclose(pred1, pred2)
def test_parampack_group_func():
net = XORNet()
net = ParamPack(
net,
nr_ignore_first=1,
max_size_per_group=10,
max_nr_params_per_group=100,
group_func=lambda n, p: "weight" in n,
)
for p in net.parameters(requires_grad=True):
assert p.pack_group_key is not None
for n, p in net.named_parameters(requires_grad=True):
assert p.pack_group_key is not None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册