Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
99fae95e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
99fae95e
编写于
4月 20, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/parampack): add user-defined key to pack params
GitOrigin-RevId: 7d51dcae23734cf6b9ef00710ff3c3989c4e1fe0
上级
0668b343
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
39 addition
and
7 deletion
+39
-7
python_module/megengine/module/parampack.py
python_module/megengine/module/parampack.py
+24
-7
python_module/test/integration/test_parampack.py
python_module/test/integration/test_parampack.py
+15
-0
未找到文件。
python_module/megengine/module/parampack.py
浏览文件 @
99fae95e
...
...
@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
from
typing
import
Iterable
,
Optional
from
typing
import
Callable
,
Iterable
,
Optional
,
Tuple
import
numpy
as
np
...
...
@@ -35,16 +35,18 @@ class ParamPack(Module):
nr_ignore_first
:
int
=
8
,
max_size_per_group
:
int
=
10
,
max_nr_params_per_group
:
int
=
100
,
group_func
:
Callable
=
lambda
name
,
param
:
0
,
):
super
().
__init__
()
self
.
_model
=
model
self
.
_nr_ignore_first
=
nr_ignore_first
self
.
_max_size_per_group
=
max_size_per_group
self
.
_max_nr_params_per_group
=
max_nr_params_per_group
self
.
_group_func
=
group_func
self
.
_grouped_params
=
[]
self
.
_packed_params
=
[]
params
=
model
.
parameters
()
params
=
model
.
named_
parameters
()
self
.
_pack_params
(
params
)
def
parameters
(
self
,
requires_grad
:
Optional
[
bool
]
=
None
)
->
Iterable
[
Parameter
]:
...
...
@@ -52,20 +54,33 @@ class ParamPack(Module):
if
requires_grad
is
None
or
param
.
requires_grad
==
requires_grad
:
yield
param
def
_pack_params
(
self
,
params
:
Iterable
[
Parameter
]):
def
named_parameters
(
self
,
requires_grad
:
Optional
[
bool
]
=
None
)
->
Iterable
[
Tuple
[
str
,
Parameter
]]:
for
idx
,
param
in
enumerate
(
self
.
_packed_params
):
if
requires_grad
is
None
or
param
.
requires_grad
==
requires_grad
:
yield
"packed_param_"
+
str
(
idx
),
param
def
_pack_params
(
self
,
params
:
Iterable
[
Tuple
[
str
,
Parameter
]]):
groups
=
collections
.
defaultdict
(
list
)
ignored
=
0
param_id
=
0
for
param
in
params
:
for
name
,
param
in
params
:
if
self
.
_nr_ignore_first
>
ignored
:
ignored
+=
1
self
.
_grouped_params
.
append
([{
"shape"
:
param
.
shape
,
"id"
:
param_id
}])
param
.
pack_group_key
=
self
.
_group_func
(
name
,
param
)
self
.
_packed_params
.
append
(
param
)
else
:
key
=
(
param
.
dtype
,
param
.
device
,
param
.
requires_grad
)
key
=
(
param
.
dtype
,
param
.
device
,
param
.
requires_grad
,
self
.
_group_func
(
name
,
param
),
)
groups
[
key
].
append
({
"tensor"
:
param
,
"id"
:
param_id
})
param_id
+=
1
for
(
dtype
,
device
,
requires_grad
)
in
groups
.
keys
():
for
(
dtype
,
device
,
requires_grad
,
group_key
)
in
groups
.
keys
():
dtype_sz
=
np
.
dtype
(
dtype
).
itemsize
align
=
device
.
mem_align
if
align
<
dtype_sz
:
...
...
@@ -74,7 +89,7 @@ class ParamPack(Module):
assert
align
%
dtype_sz
==
0
align
//=
dtype_sz
group
=
groups
[(
dtype
,
device
,
requires_grad
)]
group
=
groups
[(
dtype
,
device
,
requires_grad
,
group_key
)]
while
group
:
aligned_pos
=
[]
offset
=
0
...
...
@@ -98,6 +113,7 @@ class ParamPack(Module):
group
=
group
[
idx
:]
if
idx
==
1
:
# ignore param packs with only one item
params
[
0
][
"tensor"
].
pack_group_key
=
group_key
self
.
_packed_params
.
append
(
params
[
0
][
"tensor"
])
self
.
_grouped_params
.
append
(
[{
"shape"
:
params
[
0
][
"tensor"
].
shape
,
"id"
:
params
[
0
][
"id"
]}]
...
...
@@ -114,6 +130,7 @@ class ParamPack(Module):
dtype
=
dtype
,
requires_grad
=
requires_grad
,
)
new_param
.
pack_group_key
=
group_key
self
.
_packed_params
.
append
(
new_param
)
self
.
_grouped_params
.
append
(
[{
"shape"
:
i
[
"tensor"
].
shape
,
"id"
:
i
[
"id"
]}
for
i
in
params
]
...
...
python_module/test/integration/test_parampack.py
浏览文件 @
99fae95e
...
...
@@ -257,3 +257,18 @@ def test_correctness_parampack():
pred1
=
infer1
(
data
).
numpy
()
pred2
=
infer2
(
data
).
numpy
()
assert
np
.
allclose
(
pred1
,
pred2
)
def
test_parampack_group_func
():
net
=
XORNet
()
net
=
ParamPack
(
net
,
nr_ignore_first
=
1
,
max_size_per_group
=
10
,
max_nr_params_per_group
=
100
,
group_func
=
lambda
n
,
p
:
"weight"
in
n
,
)
for
p
in
net
.
parameters
(
requires_grad
=
True
):
assert
p
.
pack_group_key
is
not
None
for
n
,
p
in
net
.
named_parameters
(
requires_grad
=
True
):
assert
p
.
pack_group_key
is
not
None
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录