Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
62b0c6cd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
62b0c6cd
编写于
3月 30, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
style(mge): apply format.sh
GitOrigin-RevId: a900b1bb6e8b6dde7a6737bfb8df9db6e79b45ce
上级
fc6aa12e
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
113 addition
and
88 deletion
+113
-88
python_module/megengine/core/tensor_nn.py
python_module/megengine/core/tensor_nn.py
+1
-1
python_module/megengine/data/dataset/vision/coco.py
python_module/megengine/data/dataset/vision/coco.py
+1
-1
python_module/megengine/data/dataset/vision/imagenet.py
python_module/megengine/data/dataset/vision/imagenet.py
+24
-13
python_module/megengine/data/dataset/vision/utils.py
python_module/megengine/data/dataset/vision/utils.py
+1
-1
python_module/megengine/functional/elemwise.py
python_module/megengine/functional/elemwise.py
+5
-5
python_module/megengine/module/__init__.py
python_module/megengine/module/__init__.py
+1
-1
python_module/megengine/module/init.py
python_module/megengine/module/init.py
+1
-1
python_module/megengine/module/module.py
python_module/megengine/module/module.py
+7
-7
python_module/megengine/module/parampack.py
python_module/megengine/module/parampack.py
+32
-24
python_module/test/integration/test_parampack.py
python_module/test/integration/test_parampack.py
+13
-15
python_module/test/unit/functional/test_elemwise.py
python_module/test/unit/functional/test_elemwise.py
+19
-13
python_module/test/unit/jit/test_jit.py
python_module/test/unit/jit/test_jit.py
+8
-6
未找到文件。
python_module/megengine/core/tensor_nn.py
浏览文件 @
62b0c6cd
...
@@ -30,7 +30,7 @@ class Parameter(Tensor):
...
@@ -30,7 +30,7 @@ class Parameter(Tensor):
else
:
else
:
t
=
tensor
(
value
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
t
=
tensor
(
value
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
self
.
__dict__
.
update
(
t
.
__dict__
)
self
.
__dict__
.
update
(
t
.
__dict__
)
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
r
"""Return shape of parameter.
r
"""Return shape of parameter.
...
...
python_module/megengine/data/dataset/vision/coco.py
浏览文件 @
62b0c6cd
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
#
#
# Copyright (c) 2018 Facebook
# Copyright (c) 2018 Facebook
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
from
collections
import
OrderedDict
,
defaultdict
import
json
import
json
import
os
import
os
from
collections
import
OrderedDict
,
defaultdict
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
...
python_module/megengine/data/dataset/vision/imagenet.py
浏览文件 @
62b0c6cd
...
@@ -87,7 +87,7 @@ class ImageNet(ImageFolder):
...
@@ -87,7 +87,7 @@ class ImageNet(ImageFolder):
if
not
os
.
path
.
exists
(
self
.
root
):
if
not
os
.
path
.
exists
(
self
.
root
):
raise
FileNotFoundError
(
"dir %s does not exist"
%
self
.
root
)
raise
FileNotFoundError
(
"dir %s does not exist"
%
self
.
root
)
self
.
devkit_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
default_devkit_dir
)
self
.
devkit_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
default_devkit_dir
)
if
not
os
.
path
.
exists
(
self
.
devkit_dir
):
if
not
os
.
path
.
exists
(
self
.
devkit_dir
):
...
@@ -159,8 +159,14 @@ class ImageNet(ImageFolder):
...
@@ -159,8 +159,14 @@ class ImageNet(ImageFolder):
classes
=
[
tuple
(
clss
.
split
(
", "
))
for
clss
in
classes
]
classes
=
[
tuple
(
clss
.
split
(
", "
))
for
clss
in
classes
]
idx_to_wnid
=
{
idx
:
wnid
for
idx
,
wnid
in
zip
(
idcs
,
wnids
)}
idx_to_wnid
=
{
idx
:
wnid
for
idx
,
wnid
in
zip
(
idcs
,
wnids
)}
wnid_to_classes
=
{
wnid
:
clss
for
wnid
,
clss
in
zip
(
wnids
,
classes
)}
wnid_to_classes
=
{
wnid
:
clss
for
wnid
,
clss
in
zip
(
wnids
,
classes
)}
logger
.
info
(
"saving cached meta file to %s"
,
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
))
logger
.
info
(
save
((
idx_to_wnid
,
wnid_to_classes
),
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
))
"saving cached meta file to %s"
,
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
),
)
save
(
(
idx_to_wnid
,
wnid_to_classes
),
os
.
path
.
join
(
self
.
devkit_dir
,
"meta.pkl"
),
)
return
idx_to_wnid
,
wnid_to_classes
return
idx_to_wnid
,
wnid_to_classes
def
check_raw_file
(
self
)
->
bool
:
def
check_raw_file
(
self
)
->
bool
:
...
@@ -177,7 +183,10 @@ class ImageNet(ImageFolder):
...
@@ -177,7 +183,10 @@ class ImageNet(ImageFolder):
val_wnids
=
[
id2wnid
[
idx
]
for
idx
in
val_idcs
]
val_wnids
=
[
id2wnid
[
idx
]
for
idx
in
val_idcs
]
val_images
=
sorted
(
val_images
=
sorted
(
[
os
.
path
.
join
(
self
.
target_folder
,
image
)
for
image
in
os
.
listdir
(
self
.
target_folder
)]
[
os
.
path
.
join
(
self
.
target_folder
,
image
)
for
image
in
os
.
listdir
(
self
.
target_folder
)
]
)
)
logger
.
debug
(
"mkdir for val set wnids"
)
logger
.
debug
(
"mkdir for val set wnids"
)
...
@@ -198,23 +207,24 @@ class ImageNet(ImageFolder):
...
@@ -198,23 +207,24 @@ class ImageNet(ImageFolder):
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"val"
]
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"val"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum valid tar file {} .."
.
format
(
raw_file
))
logger
.
info
(
"checksum valid tar file {} .."
.
format
(
raw_file
))
assert
calculate_md5
(
raw_file
)
==
checksum
,
\
assert
(
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
calculate_md5
(
raw_file
)
==
checksum
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
logger
.
info
(
"extract valid tar file.. this may take 10-20 minutes"
)
logger
.
info
(
"extract valid tar file.. this may take 10-20 minutes"
)
untar
(
os
.
path
.
join
(
self
.
root
,
raw_file
),
self
.
target_folder
)
untar
(
os
.
path
.
join
(
self
.
root
,
raw_file
),
self
.
target_folder
)
self
.
_organize_val_data
()
self
.
_organize_val_data
()
def
_prepare_train
(
self
):
def
_prepare_train
(
self
):
assert
self
.
train
assert
self
.
train
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"train"
]
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"train"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum train tar file {} .."
.
format
(
raw_file
))
logger
.
info
(
"checksum train tar file {} .."
.
format
(
raw_file
))
assert
calculate_md5
(
raw_file
)
==
checksum
,
\
assert
(
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
calculate_md5
(
raw_file
)
==
checksum
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
logger
.
info
(
"extract train tar file.. this may take several hours"
)
logger
.
info
(
"extract train tar file.. this may take several hours"
)
untar
(
untar
(
os
.
path
.
join
(
self
.
root
,
raw_file
),
os
.
path
.
join
(
self
.
root
,
raw_file
),
self
.
target_folder
,
self
.
target_folder
,
)
)
paths
=
[
paths
=
[
os
.
path
.
join
(
self
.
target_folder
,
child_dir
)
os
.
path
.
join
(
self
.
target_folder
,
child_dir
)
...
@@ -227,7 +237,8 @@ class ImageNet(ImageFolder):
...
@@ -227,7 +237,8 @@ class ImageNet(ImageFolder):
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"devkit"
]
raw_filename
,
checksum
=
self
.
raw_file_meta
[
"devkit"
]
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
raw_file
=
os
.
path
.
join
(
self
.
root
,
raw_filename
)
logger
.
info
(
"checksum devkit tar file {} .."
.
format
(
raw_file
))
logger
.
info
(
"checksum devkit tar file {} .."
.
format
(
raw_file
))
assert
calculate_md5
(
raw_file
)
==
checksum
,
\
assert
(
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
calculate_md5
(
raw_file
)
==
checksum
),
"checksum mismatch, {} may be damaged"
.
format
(
raw_file
)
logger
.
info
(
"extract devkit file.."
)
logger
.
info
(
"extract devkit file.."
)
untargz
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_meta
[
"devkit"
][
0
]))
untargz
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_meta
[
"devkit"
][
0
]))
python_module/megengine/data/dataset/vision/utils.py
浏览文件 @
62b0c6cd
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
hashlib
import
hashlib
import
tarfile
import
os
import
os
import
tarfile
from
....distributed.util
import
is_distributed
from
....distributed.util
import
is_distributed
from
....logger
import
get_logger
from
....logger
import
get_logger
...
...
python_module/megengine/functional/elemwise.py
浏览文件 @
62b0c6cd
...
@@ -46,16 +46,16 @@ __all__ = [
...
@@ -46,16 +46,16 @@ __all__ = [
def
_elemwise
(
mode
):
# DONT export
def
_elemwise
(
mode
):
# DONT export
"""Decorator helps to wrap megbrain element-wise oprs"""
"""Decorator helps to wrap megbrain element-wise oprs"""
def
elemwise_decorator
(
func
):
def
elemwise_decorator
(
func
):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
@
wrap_io_tensor
@
wrap_io_tensor
def
elemwise_func
(
*
inputs
)
->
Tensor
:
def
elemwise_func
(
*
inputs
)
->
Tensor
:
if
all
(
isinstance
(
i
,
(
int
,
float
))
for
i
in
inputs
):
if
all
(
isinstance
(
i
,
(
int
,
float
))
for
i
in
inputs
):
device
,
comp_graph
=
_use_default_if_none
(
None
,
None
)
device
,
comp_graph
=
_use_default_if_none
(
None
,
None
)
ret
=
mgb
.
opr
.
elemwise
(
*
inputs
,
ret
=
mgb
.
opr
.
elemwise
(
mode
=
mode
,
*
inputs
,
mode
=
mode
,
comp_node
=
device
,
comp_graph
=
comp_graph
comp_node
=
device
,
)
comp_graph
=
comp_graph
)
return
ret
.
inferred_value
[
0
]
return
ret
.
inferred_value
[
0
]
return
mgb
.
opr
.
elemwise
(
*
inputs
,
mode
=
mode
)
return
mgb
.
opr
.
elemwise
(
*
inputs
,
mode
=
mode
)
...
...
python_module/megengine/module/__init__.py
浏览文件 @
62b0c6cd
...
@@ -14,6 +14,6 @@ from .embedding import Embedding
...
@@ -14,6 +14,6 @@ from .embedding import Embedding
from
.identity
import
Identity
from
.identity
import
Identity
from
.linear
import
Linear
from
.linear
import
Linear
from
.module
import
Module
from
.module
import
Module
from
.parampack
import
ParamPack
from
.pooling
import
AvgPool2d
,
MaxPool2d
from
.pooling
import
AvgPool2d
,
MaxPool2d
from
.sequential
import
Sequential
from
.sequential
import
Sequential
from
.parampack
import
ParamPack
python_module/megengine/module/init.py
浏览文件 @
62b0c6cd
...
@@ -12,7 +12,7 @@ from typing import Optional, Tuple, Union
...
@@ -12,7 +12,7 @@ from typing import Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
from
..core
import
Tensor
,
Graph
from
..core
import
Graph
,
Tensor
from
..random
import
gaussian
,
uniform
from
..random
import
gaussian
,
uniform
...
...
python_module/megengine/module/module.py
浏览文件 @
62b0c6cd
...
@@ -168,10 +168,9 @@ class Module(metaclass=ABCMeta):
...
@@ -168,10 +168,9 @@ class Module(metaclass=ABCMeta):
"""
"""
yield
from
self
.
_flatten
(
predicate
=
_is_buffer
,
recursive
=
recursive
)
yield
from
self
.
_flatten
(
predicate
=
_is_buffer
,
recursive
=
recursive
)
def
replace_param
(
self
,
def
replace_param
(
params
:
dict
,
self
,
params
:
dict
,
start_pos
:
int
,
seen
:
Optional
[
Set
[
int
]]
=
None
start_pos
:
int
,
):
seen
:
Optional
[
Set
[
int
]]
=
None
):
offset
=
0
offset
=
0
if
seen
is
None
:
if
seen
is
None
:
seen
=
set
([
id
(
self
)])
seen
=
set
([
id
(
self
)])
...
@@ -183,12 +182,13 @@ class Module(metaclass=ABCMeta):
...
@@ -183,12 +182,13 @@ class Module(metaclass=ABCMeta):
seen
.
add
(
hash_id
)
seen
.
add
(
hash_id
)
if
isinstance
(
module_dict
[
key
],
Parameter
):
if
isinstance
(
module_dict
[
key
],
Parameter
):
if
start_pos
+
offset
in
params
:
if
start_pos
+
offset
in
params
:
assert
module_dict
[
key
].
shape
==
params
[
start_pos
+
assert
module_dict
[
key
].
shape
==
params
[
start_pos
+
offset
].
shape
offset
].
shape
module_dict
[
key
]
=
params
[
start_pos
+
offset
]
module_dict
[
key
]
=
params
[
start_pos
+
offset
]
offset
+=
1
offset
+=
1
if
isinstance
(
module_dict
[
key
],
Module
):
if
isinstance
(
module_dict
[
key
],
Module
):
offset
+=
module_dict
[
key
].
replace_param
(
params
,
start_pos
+
offset
,
seen
)
offset
+=
module_dict
[
key
].
replace_param
(
params
,
start_pos
+
offset
,
seen
)
return
offset
return
offset
def
named_buffers
(
def
named_buffers
(
...
...
python_module/megengine/module/parampack.py
浏览文件 @
62b0c6cd
...
@@ -8,11 +8,12 @@
...
@@ -8,11 +8,12 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
collections
from
typing
import
Iterable
,
Optional
from
typing
import
Iterable
,
Optional
import
numpy
as
np
import
numpy
as
np
from
.._internal.opr
import
param_pack_split
from
..core
import
Parameter
,
Tensor
from
..core
import
Parameter
,
Tensor
from
.module
import
Module
from
.module
import
Module
from
.._internal.opr
import
param_pack_split
class
ParamPack
(
Module
):
class
ParamPack
(
Module
):
...
@@ -24,11 +25,14 @@ class ParamPack(Module):
...
@@ -24,11 +25,14 @@ class ParamPack(Module):
:param max_nr_params_per_group: upper bound of the number of parameters of each group.
:param max_nr_params_per_group: upper bound of the number of parameters of each group.
"""
"""
def
__init__
(
self
,
model
:
Module
,
def
__init__
(
nr_ignore_first
:
int
=
8
,
self
,
max_size_per_group
:
int
=
10
,
model
:
Module
,
max_nr_params_per_group
:
int
=
100
):
nr_ignore_first
:
int
=
8
,
max_size_per_group
:
int
=
10
,
max_nr_params_per_group
:
int
=
100
,
):
super
().
__init__
()
super
().
__init__
()
self
.
_model
=
model
self
.
_model
=
model
self
.
_nr_ignore_first
=
nr_ignore_first
self
.
_nr_ignore_first
=
nr_ignore_first
...
@@ -52,11 +56,11 @@ class ParamPack(Module):
...
@@ -52,11 +56,11 @@ class ParamPack(Module):
for
param
in
params
:
for
param
in
params
:
if
self
.
_nr_ignore_first
>
ignored
:
if
self
.
_nr_ignore_first
>
ignored
:
ignored
+=
1
ignored
+=
1
self
.
_grouped_params
.
append
([{
'tensor'
:
param
,
'id'
:
param_id
}])
self
.
_grouped_params
.
append
([{
"tensor"
:
param
,
"id"
:
param_id
}])
self
.
_packed_params
.
append
(
param
)
self
.
_packed_params
.
append
(
param
)
else
:
else
:
key
=
(
param
.
dtype
,
param
.
device
,
param
.
requires_grad
)
key
=
(
param
.
dtype
,
param
.
device
,
param
.
requires_grad
)
groups
[
key
].
append
({
'tensor'
:
param
,
'id'
:
param_id
})
groups
[
key
].
append
({
"tensor"
:
param
,
"id"
:
param_id
})
param_id
+=
1
param_id
+=
1
for
(
dtype
,
device
,
requires_grad
)
in
groups
.
keys
():
for
(
dtype
,
device
,
requires_grad
)
in
groups
.
keys
():
dtype_sz
=
np
.
dtype
(
dtype
).
itemsize
dtype_sz
=
np
.
dtype
(
dtype
).
itemsize
...
@@ -75,33 +79,36 @@ class ParamPack(Module):
...
@@ -75,33 +79,36 @@ class ParamPack(Module):
idx
=
0
idx
=
0
while
idx
<
len
(
group
):
while
idx
<
len
(
group
):
param
=
group
[
idx
]
param
=
group
[
idx
]
assert
param
[
'tensor'
].
device
==
device
assert
param
[
"tensor"
].
device
==
device
padding
=
(
align
-
(
offset
&
(
align
-
1
)))
&
(
align
-
1
)
padding
=
(
align
-
(
offset
&
(
align
-
1
)))
&
(
align
-
1
)
offset
+=
padding
offset
+=
padding
aligned_pos
.
append
(
offset
)
aligned_pos
.
append
(
offset
)
params
.
append
(
param
)
params
.
append
(
param
)
offset
+=
int
(
np
.
prod
(
param
[
'tensor'
].
shape
))
offset
+=
int
(
np
.
prod
(
param
[
"tensor"
].
shape
))
idx
+=
1
idx
+=
1
if
(
offset
*
dtype_sz
>=
if
(
self
.
_max_size_per_group
*
1024
*
1024
offset
*
dtype_sz
>=
self
.
_max_size_per_group
*
1024
*
1024
or
idx
>=
self
.
_max_nr_params_per_group
):
or
idx
>=
self
.
_max_nr_params_per_group
):
break
break
group
=
group
[
idx
:]
group
=
group
[
idx
:]
if
idx
==
1
:
if
idx
==
1
:
# ignore param packs with only one item
# ignore param packs with only one item
self
.
_packed_params
.
append
(
params
[
0
][
'tensor'
])
self
.
_packed_params
.
append
(
params
[
0
][
"tensor"
])
self
.
_grouped_params
.
append
(
params
)
self
.
_grouped_params
.
append
(
params
)
continue
continue
packed_value
=
np
.
zeros
((
offset
,
),
dtype
=
dtype
)
packed_value
=
np
.
zeros
((
offset
,),
dtype
=
dtype
)
for
param
,
pos
in
zip
(
params
,
aligned_pos
):
for
param
,
pos
in
zip
(
params
,
aligned_pos
):
val
=
param
[
'tensor'
].
numpy
()
val
=
param
[
"tensor"
].
numpy
()
packed_value
[
pos
:
pos
+
val
.
size
]
=
val
.
flatten
()
packed_value
[
pos
:
pos
+
val
.
size
]
=
val
.
flatten
()
new_param
=
Parameter
(
value
=
packed_value
,
new_param
=
Parameter
(
device
=
device
,
value
=
packed_value
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
dtype
=
dtype
,
requires_grad
=
requires_grad
,
)
self
.
_packed_params
.
append
(
new_param
)
self
.
_packed_params
.
append
(
new_param
)
self
.
_grouped_params
.
append
(
params
)
self
.
_grouped_params
.
append
(
params
)
...
@@ -112,14 +119,15 @@ class ParamPack(Module):
...
@@ -112,14 +119,15 @@ class ParamPack(Module):
grouped_params
=
self
.
_grouped_params
[
i
]
grouped_params
=
self
.
_grouped_params
[
i
]
if
len
(
grouped_params
)
==
1
:
if
len
(
grouped_params
)
==
1
:
continue
continue
split
=
param_pack_split
(
packed_param
.
_symvar
,
split
=
param_pack_split
(
[
i
[
'tensor'
].
shape
for
i
in
grouped_params
])
packed_param
.
_symvar
,
[
i
[
"tensor"
].
shape
for
i
in
grouped_params
]
)
split
=
[
split
=
[
Parameter
(
Tensor
(
i
,
requires_grad
=
packed_param
.
requires_grad
))
Parameter
(
Tensor
(
i
,
requires_grad
=
packed_param
.
requires_grad
))
for
i
in
split
for
i
in
split
]
]
for
j
in
range
(
len
(
split
)):
for
j
in
range
(
len
(
split
)):
replace_param
[
grouped_params
[
j
][
'id'
]]
=
split
[
j
]
replace_param
[
grouped_params
[
j
][
"id"
]]
=
split
[
j
]
self
.
_model
.
replace_param
(
replace_param
,
0
)
self
.
_model
.
replace_param
(
replace_param
,
0
)
return
self
.
_model
.
forward
(
*
args
,
**
kwargs
)
return
self
.
_model
.
forward
(
*
args
,
**
kwargs
)
python_module/test/integration/test_parampack.py
浏览文件 @
62b0c6cd
...
@@ -75,10 +75,9 @@ class XORNet(Module):
...
@@ -75,10 +75,9 @@ class XORNet(Module):
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_static_graph_parampack
():
def
test_static_graph_parampack
():
net
=
XORNet
()
net
=
XORNet
()
net
=
ParamPack
(
net
,
net
=
ParamPack
(
nr_ignore_first
=
0
,
net
,
nr_ignore_first
=
0
,
max_size_per_group
=
10
,
max_nr_params_per_group
=
100
max_size_per_group
=
10
,
)
max_nr_params_per_group
=
100
)
opt
=
SGD
(
opt
=
SGD
(
net
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
net
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
)
...
@@ -110,12 +109,11 @@ def test_static_graph_parampack():
...
@@ -110,12 +109,11 @@ def test_static_graph_parampack():
pred
=
infer
(
data
).
numpy
()
pred
=
infer
(
data
).
numpy
()
assert
calculate_precision
(
data
,
pred
)
>
0.95
,
"Test precision must be high enough"
assert
calculate_precision
(
data
,
pred
)
>
0.95
,
"Test precision must be high enough"
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_nopack_parampack
():
def
test_nopack_parampack
():
net
=
XORNet
()
net
=
XORNet
()
net
=
ParamPack
(
net
,
net
=
ParamPack
(
net
,
max_size_per_group
=
0
,
max_nr_params_per_group
=
0
)
max_size_per_group
=
0
,
max_nr_params_per_group
=
0
)
opt
=
SGD
(
opt
=
SGD
(
net
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
net
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
)
...
@@ -146,13 +144,13 @@ def test_nopack_parampack():
...
@@ -146,13 +144,13 @@ def test_nopack_parampack():
pred
=
infer
(
data
).
numpy
()
pred
=
infer
(
data
).
numpy
()
assert
calculate_precision
(
data
,
pred
)
>
0.95
,
"Test precision must be high enough"
assert
calculate_precision
(
data
,
pred
)
>
0.95
,
"Test precision must be high enough"
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_dynamic_graph_parampack
():
def
test_dynamic_graph_parampack
():
net
=
XORNet
()
net
=
XORNet
()
net
=
ParamPack
(
net
,
net
=
ParamPack
(
nr_ignore_first
=
0
,
net
,
nr_ignore_first
=
0
,
max_size_per_group
=
10
,
max_nr_params_per_group
=
100
max_size_per_group
=
10
,
)
max_nr_params_per_group
=
100
)
opt
=
SGD
(
opt
=
SGD
(
net
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
net
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
)
...
@@ -184,6 +182,7 @@ def test_dynamic_graph_parampack():
...
@@ -184,6 +182,7 @@ def test_dynamic_graph_parampack():
pred
=
infer
(
data
).
numpy
()
pred
=
infer
(
data
).
numpy
()
assert
calculate_precision
(
data
,
pred
)
>
0.95
,
"Test precision must be high enough"
assert
calculate_precision
(
data
,
pred
)
>
0.95
,
"Test precision must be high enough"
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_correctness_parampack
():
def
test_correctness_parampack
():
net1
=
XORNet
()
net1
=
XORNet
()
...
@@ -192,10 +191,9 @@ def test_correctness_parampack():
...
@@ -192,10 +191,9 @@ def test_correctness_parampack():
params2
=
net2
.
parameters
()
params2
=
net2
.
parameters
()
for
param1
,
param2
in
zip
(
params1
,
params2
):
for
param1
,
param2
in
zip
(
params1
,
params2
):
param1
.
set_value
(
param2
.
numpy
())
param1
.
set_value
(
param2
.
numpy
())
net1
=
ParamPack
(
net1
,
net1
=
ParamPack
(
nr_ignore_first
=
0
,
net1
,
nr_ignore_first
=
0
,
max_size_per_group
=
10
,
max_nr_params_per_group
=
100
max_size_per_group
=
10
,
)
max_nr_params_per_group
=
100
)
opt1
=
SGD
(
opt1
=
SGD
(
net1
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
net1
.
parameters
(
requires_grad
=
True
),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
)
...
...
python_module/test/unit/functional/test_elemwise.py
浏览文件 @
62b0c6cd
...
@@ -10,31 +10,37 @@ import numpy as np
...
@@ -10,31 +10,37 @@ import numpy as np
import
megengine.functional
as
F
import
megengine.functional
as
F
from
megengine
import
tensor
from
megengine
import
tensor
from
megengine.test
import
assertTensorClose
from
megengine.test
import
assertTensorClose
def
test_abs
():
def
test_abs
():
assertTensorClose
(
assertTensorClose
(
F
.
abs
(
tensor
([
-
3.
,
-
4.
,
-
5.
])).
numpy
(),
F
.
abs
(
tensor
([
-
3.0
,
-
4.0
,
-
5.0
])).
numpy
(),
np
.
abs
(
np
.
array
([
-
3.
,
-
4.
,
-
5.
],
dtype
=
np
.
float32
)))
np
.
abs
(
np
.
array
([
-
3.0
,
-
4.0
,
-
5.0
],
dtype
=
np
.
float32
)),
)
assertTensorClose
(
F
.
abs
(
-
3.
),
np
.
abs
(
np
.
float32
(
-
3.
)))
assertTensorClose
(
F
.
abs
(
-
3.
0
),
np
.
abs
(
np
.
float32
(
-
3.0
)))
def
test_multiply
():
def
test_multiply
():
assertTensorClose
(
F
.
multiply
(
-
3.
,
-
4.
),
assertTensorClose
(
np
.
multiply
(
np
.
float32
(
-
3.
),
np
.
float32
(
-
4.
)))
F
.
multiply
(
-
3.0
,
-
4.0
),
np
.
multiply
(
np
.
float32
(
-
3.0
),
np
.
float32
(
-
4.0
))
)
assertTensorClose
(
assertTensorClose
(
F
.
multiply
(
tensor
([
3.
,
4.
]),
4.
).
numpy
(),
F
.
multiply
(
tensor
([
3.0
,
4.0
]),
4.0
).
numpy
(),
np
.
multiply
(
np
.
array
([
3.
,
4.
],
dtype
=
np
.
float32
),
4.
))
np
.
multiply
(
np
.
array
([
3.0
,
4.0
],
dtype
=
np
.
float32
),
4.0
),
)
assertTensorClose
(
assertTensorClose
(
F
.
multiply
(
4.
,
tensor
([
3.
,
4.
])).
numpy
(),
F
.
multiply
(
4.0
,
tensor
([
3.0
,
4.0
])).
numpy
(),
np
.
multiply
(
4.
,
np
.
array
([
3.
,
4.
],
dtype
=
np
.
float32
)))
np
.
multiply
(
4.0
,
np
.
array
([
3.0
,
4.0
],
dtype
=
np
.
float32
)),
)
assertTensorClose
(
assertTensorClose
(
F
.
multiply
(
tensor
([
3.
,
4.
]),
tensor
([
3.
,
4.
])).
numpy
(),
F
.
multiply
(
tensor
([
3.0
,
4.0
]),
tensor
([
3.0
,
4.0
])).
numpy
(),
np
.
multiply
(
np
.
array
([
3.
,
4.
],
dtype
=
np
.
float32
),
np
.
multiply
(
np
.
array
([
3.
,
4.
],
dtype
=
np
.
float32
)))
np
.
array
([
3.0
,
4.0
],
dtype
=
np
.
float32
),
np
.
array
([
3.0
,
4.0
],
dtype
=
np
.
float32
),
),
)
python_module/test/unit/jit/test_jit.py
浏览文件 @
62b0c6cd
...
@@ -15,10 +15,10 @@ import pytest
...
@@ -15,10 +15,10 @@ import pytest
import
megengine
as
mge
import
megengine
as
mge
import
megengine._internal
as
mgb
import
megengine._internal
as
mgb
import
megengine.module
as
M
from
megengine
import
jit
,
tensor
from
megengine
import
jit
,
tensor
from
megengine.core.tensor
import
Tensor
from
megengine.core.tensor
import
Tensor
from
megengine.test
import
assertTensorClose
from
megengine.test
import
assertTensorClose
import
megengine.module
as
M
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
@@ -158,13 +158,14 @@ def test_shape_infer():
...
@@ -158,13 +158,14 @@ def test_shape_infer():
def
test_dump_bn_fused
():
def
test_dump_bn_fused
():
class
ConvBNReLU
(
M
.
Sequential
):
class
ConvBNReLU
(
M
.
Sequential
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
ConvBNReLU
,
self
).
__init__
(
super
(
ConvBNReLU
,
self
).
__init__
(
M
.
Conv2d
(
3
,
4
,
3
,
1
,
1
,
groups
=
1
,
bias
=
False
),
M
.
Conv2d
(
3
,
4
,
3
,
1
,
1
,
groups
=
1
,
bias
=
False
),
M
.
BatchNorm2d
(
4
),
M
.
BatchNorm2d
(
4
),
M
.
ReLU
())
M
.
ReLU
(),
)
net
=
ConvBNReLU
()
net
=
ConvBNReLU
()
net
.
eval
()
net
.
eval
()
...
@@ -178,8 +179,9 @@ def test_dump_bn_fused():
...
@@ -178,8 +179,9 @@ def test_dump_bn_fused():
fun
.
dump
(
out
,
optimize_for_inference
=
True
)
fun
.
dump
(
out
,
optimize_for_inference
=
True
)
cg
,
_
,
outputs
=
mgb
.
load_comp_graph_from_file
(
out
)
cg
,
_
,
outputs
=
mgb
.
load_comp_graph_from_file
(
out
)
out
,
=
outputs
(
out
,)
=
outputs
inputs
=
mgb
.
cgtools
.
get_inputs
(
out
)
inputs
=
mgb
.
cgtools
.
get_inputs
(
out
)
assert
len
(
inputs
)
==
2
and
(
assert
len
(
inputs
)
==
2
and
(
mgb
.
cgtools
.
get_type
(
inputs
[
0
])
==
'MultipleDeviceTensorHolder'
and
mgb
.
cgtools
.
get_type
(
inputs
[
0
])
==
"MultipleDeviceTensorHolder"
mgb
.
cgtools
.
get_type
(
inputs
[
1
])
==
'ConvolutionForward'
)
and
mgb
.
cgtools
.
get_type
(
inputs
[
1
])
==
"ConvolutionForward"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录