Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
98460f58
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
98460f58
编写于
9月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
chore(mge/misc): rename APIs that should not be public
GitOrigin-RevId: b72517bfe9c4b4e3efde5badc95fdf89b0ea48cd
上级
217999b1
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
122 addition
and
123 deletion
+122
-123
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+9
-9
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+5
-5
imperative/python/megengine/data/dataset/vision/coco.py
imperative/python/megengine/data/dataset/vision/coco.py
+2
-2
imperative/python/megengine/data/dataset/vision/mnist.py
imperative/python/megengine/data/dataset/vision/mnist.py
+6
-6
imperative/python/megengine/data/transform/vision/_functional.py
...ive/python/megengine/data/transform/vision/_functional.py
+0
-0
imperative/python/megengine/data/transform/vision/transform.py
...ative/python/megengine/data/transform/vision/transform.py
+1
-1
imperative/python/megengine/distributed/__init__.py
imperative/python/megengine/distributed/__init__.py
+1
-2
imperative/python/megengine/distributed/group.py
imperative/python/megengine/distributed/group.py
+5
-5
imperative/python/megengine/distributed/helper.py
imperative/python/megengine/distributed/helper.py
+6
-6
imperative/python/megengine/distributed/launcher.py
imperative/python/megengine/distributed/launcher.py
+1
-1
imperative/python/megengine/distributed/server.py
imperative/python/megengine/distributed/server.py
+9
-9
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+39
-39
imperative/python/megengine/functional/tensor_cache.py
imperative/python/megengine/functional/tensor_cache.py
+2
-2
imperative/python/megengine/hub/fetcher.py
imperative/python/megengine/hub/fetcher.py
+3
-3
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+3
-3
imperative/python/megengine/logger.py
imperative/python/megengine/logger.py
+14
-14
imperative/python/megengine/module/rnn.py
imperative/python/megengine/module/rnn.py
+8
-8
imperative/python/megengine/quantization/observer.py
imperative/python/megengine/quantization/observer.py
+1
-1
imperative/python/megengine/serialization.py
imperative/python/megengine/serialization.py
+2
-2
imperative/python/megengine/utils/module_stats.py
imperative/python/megengine/utils/module_stats.py
+2
-2
imperative/python/test/unit/core/test_util.py
imperative/python/test/unit/core/test_util.py
+2
-2
imperative/python/test/unit/distributed/test_distributed.py
imperative/python/test/unit/distributed/test_distributed.py
+1
-1
未找到文件。
imperative/python/megengine/autodiff/grad_manager.py
浏览文件 @
98460f58
...
...
@@ -19,11 +19,11 @@ logger = get_logger(__name__)
backwarding_grad_manager
=
None
def
get_backwarding_grad_manager
():
def
_
get_backwarding_grad_manager
():
return
backwarding_grad_manager
class
AttachSpec
:
class
_
AttachSpec
:
__slots__
=
"tensor"
,
"callbacks"
...
...
@@ -118,7 +118,7 @@ class GradManager:
"""
def
__init__
(
self
):
self
.
_attach_specs
=
{}
# id(Tensor) -> AttachSpec
self
.
_attach_specs
=
{}
# id(Tensor) ->
_
AttachSpec
self
.
_recording
=
False
self
.
_grad
=
None
self
.
_after_backward_callback
=
[]
...
...
@@ -200,7 +200,7 @@ class GradManager:
if
self
is
not
None
:
del
self
.
_attach_specs
[
key
]
spec
=
AttachSpec
()
spec
=
_
AttachSpec
()
spec
.
tensor
=
weakref
.
ref
(
tensor
,
deleter
)
spec
.
callbacks
=
[]
return
spec
...
...
@@ -354,22 +354,22 @@ class GradManager:
def
__or__
(
self
,
other
):
if
isinstance
(
other
,
GradManager
):
return
GradManagerGroup
([
self
,
other
])
return
_
GradManagerGroup
([
self
,
other
])
return
NotImplemented
__ror__
=
__or__
class
GradManagerGroup
:
class
_
GradManagerGroup
:
def
__init__
(
self
,
gms
)
->
None
:
self
.
_gms
=
list
(
gms
)
def
merge_with
(
self
,
other
):
if
isinstance
(
other
,
GradManager
):
other
=
GradManagerGroup
([
other
])
elif
not
isinstance
(
other
,
GradManagerGroup
):
other
=
_
GradManagerGroup
([
other
])
elif
not
isinstance
(
other
,
_
GradManagerGroup
):
return
NotImplemented
return
GradManagerGroup
([
*
self
.
_gms
,
*
other
.
_gms
])
return
_
GradManagerGroup
([
*
self
.
_gms
,
*
other
.
_gms
])
__or__
=
merge_with
__ror__
=
merge_with
...
...
imperative/python/megengine/data/dataloader.py
浏览文件 @
98460f58
...
...
@@ -35,7 +35,7 @@ logger = get_logger(__name__)
GLOBAL_TIMEOUT
=
5
def
raise_timeout_error
():
def
_
raise_timeout_error
():
raise
RuntimeError
(
"dataloader timeout"
)
...
...
@@ -95,7 +95,7 @@ class DataLoader:
collator
:
Collator
=
None
,
num_workers
:
int
=
0
,
timeout
:
int
=
0
,
timeout_event
:
Callable
=
raise_timeout_error
,
timeout_event
:
Callable
=
_
raise_timeout_error
,
divide
:
bool
=
False
,
preload
:
bool
=
False
,
):
...
...
@@ -188,7 +188,7 @@ class DataLoader:
return
len
(
self
.
sampler
)
class
PreLoader
:
class
_
PreLoader
:
def
__init__
(
self
,
preload
):
if
preload
:
self
.
default_device
=
get_default_device
()
...
...
@@ -237,7 +237,7 @@ class PreLoader:
return
out
class
_BaseMapDataLoaderIter
(
PreLoader
):
class
_BaseMapDataLoaderIter
(
_
PreLoader
):
def
__init__
(
self
,
loader
,
preload
):
super
().
__init__
(
preload
)
self
.
dataset
=
loader
.
dataset
...
...
@@ -454,7 +454,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
self
.
_shutdown
()
class
_BaseStreamDataLoaderIter
(
PreLoader
):
class
_BaseStreamDataLoaderIter
(
_
PreLoader
):
def
__init__
(
self
,
loader
,
preload
):
super
().
__init__
(
preload
)
self
.
dataset
=
loader
.
dataset
...
...
imperative/python/megengine/data/dataset/vision/coco.py
浏览文件 @
98460f58
...
...
@@ -21,7 +21,7 @@ def _count_visible_keypoints(anno):
return
sum
(
sum
(
1
for
v
in
ann
[
"keypoints"
][
2
::
3
]
if
v
>
0
)
for
ann
in
anno
)
def
has_valid_annotation
(
anno
,
order
):
def
_
has_valid_annotation
(
anno
,
order
):
# if it"s empty, there is no annotation
if
len
(
anno
)
==
0
:
return
False
...
...
@@ -101,7 +101,7 @@ class COCO(VisionDataset):
anno
=
[
obj
for
obj
in
anno
if
obj
[
"bbox"
][
2
]
>
0
and
obj
[
"bbox"
][
3
]
>
0
]
if
has_valid_annotation
(
anno
,
order
):
if
_
has_valid_annotation
(
anno
,
order
):
ids
.
append
(
img_id
)
self
.
img_to_anns
[
img_id
]
=
anno
else
:
...
...
imperative/python/megengine/data/dataset/vision/mnist.py
浏览文件 @
98460f58
...
...
@@ -140,17 +140,17 @@ class MNIST(VisionDataset):
# load raw files and transform them into meta data and datasets Tuple(np.array)
logger
.
info
(
"process the raw files of %s set..."
,
"train"
if
train
else
"test"
)
if
train
:
meta_data_images
,
images
=
parse_idx3
(
meta_data_images
,
images
=
_
parse_idx3
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
0
])
)
meta_data_labels
,
labels
=
parse_idx1
(
meta_data_labels
,
labels
=
_
parse_idx1
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
1
])
)
else
:
meta_data_images
,
images
=
parse_idx3
(
meta_data_images
,
images
=
_
parse_idx3
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
2
])
)
meta_data_labels
,
labels
=
parse_idx1
(
meta_data_labels
,
labels
=
_
parse_idx1
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_file_name
[
3
])
)
...
...
@@ -161,7 +161,7 @@ class MNIST(VisionDataset):
self
.
arrays
=
(
images
,
labels
.
astype
(
np
.
int32
))
def
parse_idx3
(
idx3_file
):
def
_
parse_idx3
(
idx3_file
):
# parse idx3 file to meta data and data in numpy array (images)
logger
.
debug
(
"parse idx3 file %s ..."
,
idx3_file
)
assert
idx3_file
.
endswith
(
".gz"
)
...
...
@@ -187,7 +187,7 @@ def parse_idx3(idx3_file):
return
meta_data
,
images
def
parse_idx1
(
idx1_file
):
def
_
parse_idx1
(
idx1_file
):
# parse idx1 file to meta data and data in numpy array (labels)
logger
.
debug
(
"parse idx1 file %s ..."
,
idx1_file
)
assert
idx1_file
.
endswith
(
".gz"
)
...
...
imperative/python/megengine/data/transform/vision/functional.py
→
imperative/python/megengine/data/transform/vision/
_
functional.py
浏览文件 @
98460f58
文件已移动
imperative/python/megengine/data/transform/vision/transform.py
浏览文件 @
98460f58
...
...
@@ -7,7 +7,7 @@ import cv2
import
numpy
as
np
from
megengine.data.transform
import
Transform
from
megengine.data.transform.vision
import
functional
as
F
from
megengine.data.transform.vision
import
_
functional
as
F
__all__
=
[
"VisionTransform"
,
...
...
imperative/python/megengine/distributed/__init__.py
浏览文件 @
98460f58
...
...
@@ -2,7 +2,6 @@
from
mprop
import
mproperty
from
..core._imperative_rt.core2
import
group_end
,
group_start
from
.
import
group
from
.group
import
(
WORLD
,
Group
,
...
...
@@ -20,7 +19,7 @@ from .group import (
)
from
.helper
import
bcast_list_
,
make_allreduce_cb
,
synchronized
from
.launcher
import
launcher
from
.server
import
Client
,
Server
from
.server
import
Server
@
mproperty
...
...
imperative/python/megengine/distributed/group.py
浏览文件 @
98460f58
...
...
@@ -7,10 +7,10 @@ from mprop import mproperty
from
..device
import
_sh
,
set_default_device
,
what_is_xpu
from
..random
import
seed
from
.server
import
Client
,
Server
from
.server
import
Server
,
_Client
class
StaticData
:
class
_
StaticData
:
server
=
None
client
=
None
master_ip
=
None
...
...
@@ -139,13 +139,13 @@ def init_process_group(
global
_sd
assert
_sd
is
None
,
"init_process_group should be called only once"
_sd
=
StaticData
()
_sd
=
_
StaticData
()
assert
world_size
>
1
assert
rank
>=
0
and
rank
<
world_size
assert
port
>
0
_sd
.
client
=
Client
(
master_ip
,
port
)
_sd
.
client
=
_
Client
(
master_ip
,
port
)
_sd
.
master_ip
=
master_ip
_sd
.
py_server_port
=
port
_sd
.
mm_server_port
=
_sd
.
client
.
get_mm_server_port
()
...
...
@@ -225,7 +225,7 @@ def get_mm_server_addr() -> Tuple[str, int]:
return
_sd
.
master_ip
,
_sd
.
mm_server_port
def
get_client
()
->
Client
:
def
get_client
()
->
_
Client
:
r
"""Get client of python XML RPC server."""
assert
_sd
is
not
None
,
"please call init_process_group first"
return
_sd
.
client
...
...
imperative/python/megengine/distributed/helper.py
浏览文件 @
98460f58
...
...
@@ -7,7 +7,7 @@ from weakref import WeakSet
import
numpy
as
np
from
megengine.autodiff.grad_manager
import
GradManager
,
get_backwarding_grad_manager
from
megengine.autodiff.grad_manager
import
GradManager
,
_
get_backwarding_grad_manager
from
..core._imperative_rt.core2
import
apply
from
..core.ops.builtin
import
ParamPackConcat
,
ParamPackSplit
...
...
@@ -78,7 +78,7 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
return
apply
(
op
,
*
inps
,
offsets
)[
0
]
def
get_offsets
(
shapes
):
def
_
get_offsets
(
shapes
):
offsets
=
[]
offset
=
0
for
shape
in
shapes
:
...
...
@@ -108,7 +108,7 @@ def _check_enable_p2p():
def
pack_allreduce_split
(
pack_list
,
shapes
,
group
,
reduce_method
):
offsets_val
=
get_offsets
(
shapes
)
offsets_val
=
_
get_offsets
(
shapes
)
offsets
=
Tensor
(
offsets_val
)
packed_grads
=
param_pack_concat
(
pack_list
,
offsets
,
offsets_val
)
...
...
@@ -119,7 +119,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method):
return
grads
class
TensorFuture
(
Future
):
class
_
TensorFuture
(
Future
):
def
device
(
self
):
raise
"Sorry, this tensor is not ready"
...
...
@@ -234,13 +234,13 @@ class AllreduceCallback:
self
.
_packing_size
[
dtype
]
=
0
def
__call__
(
self
,
param
,
grad
):
gm
=
get_backwarding_grad_manager
()
gm
=
_
get_backwarding_grad_manager
()
assert
isinstance
(
gm
,
GradManager
)
if
gm
not
in
self
.
_marked_gm
:
gm
.
_register_after_backward_callback
(
self
.
_flush
)
self
.
_marked_gm
.
add
(
gm
)
self
.
_params
.
append
(
param
)
self
.
_futures_dict
[
param
]
=
TensorFuture
(
ack
=
False
)
self
.
_futures_dict
[
param
]
=
_
TensorFuture
(
ack
=
False
)
self
.
_gradients_dict
[
param
]
=
grad
self
.
_grad_origin_device
[
param
]
=
str
(
grad
.
device
)
...
...
imperative/python/megengine/distributed/launcher.py
浏览文件 @
98460f58
...
...
@@ -10,7 +10,7 @@ from ..device import get_device_count
from
..logger
import
get_logger
from
.group
import
_set_machine_ranks
,
group_barrier
,
init_process_group
from
.helper
import
_check_device_initialized
,
_check_interpreter_status
from
.server
import
Client
,
Server
from
.server
import
Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN
=
(
"subprocess exited with code 0 but did not return a value"
...
...
imperative/python/megengine/distributed/server.py
浏览文件 @
98460f58
...
...
@@ -12,7 +12,7 @@ from ..core._imperative_rt.utils import create_mm_server
from
..utils.future
import
Future
class
Methods
:
class
_
Methods
:
r
"""Distributed Server Method.
Used for exchange information between distributed nodes.
...
...
@@ -149,7 +149,7 @@ class Methods:
return
ret
class
ThreadXMLRPCServer
(
ThreadingMixIn
,
SimpleXMLRPCServer
):
class
_
ThreadXMLRPCServer
(
ThreadingMixIn
,
SimpleXMLRPCServer
):
pass
...
...
@@ -163,10 +163,10 @@ def _start_server(py_server_port, queue):
"""
try
:
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
server
=
ThreadXMLRPCServer
(
server
=
_
ThreadXMLRPCServer
(
(
"0.0.0.0"
,
py_server_port
),
logRequests
=
False
,
allow_none
=
True
)
server
.
register_instance
(
Methods
(
mm_server_port
))
server
.
register_instance
(
_
Methods
(
mm_server_port
))
_
,
py_server_port
=
server
.
server_address
queue
.
put
((
py_server_port
,
mm_server_port
))
server
.
serve_forever
()
...
...
@@ -196,7 +196,7 @@ class Server:
self
.
proc
.
terminate
()
class
Client
:
class
_
Client
:
r
"""Distributed Client for distributed training.
Args:
...
...
@@ -298,10 +298,10 @@ class Client:
return
self
.
proxy
.
bcast_val
(
val
,
key
,
size
)
def
main
(
port
=
0
,
verbose
=
True
):
def
_
main
(
port
=
0
,
verbose
=
True
):
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
server
=
ThreadXMLRPCServer
((
"0.0.0.0"
,
port
),
logRequests
=
verbose
)
server
.
register_instance
(
Methods
(
mm_server_port
))
server
=
_
ThreadXMLRPCServer
((
"0.0.0.0"
,
port
),
logRequests
=
verbose
)
server
.
register_instance
(
_
Methods
(
mm_server_port
))
_
,
port
=
server
.
server_address
print
(
"serving on port"
,
port
)
server
.
serve_forever
()
...
...
@@ -314,4 +314,4 @@ if __name__ == "__main__":
ap
.
add_argument
(
"-p"
,
"--port"
,
type
=
int
,
default
=
0
)
ap
.
add_argument
(
"-v"
,
"--verbose"
,
type
=
bool
,
default
=
True
)
args
=
ap
.
parse_args
()
main
(
port
=
args
.
port
,
verbose
=
args
.
verbose
)
_
main
(
port
=
args
.
port
,
verbose
=
args
.
verbose
)
imperative/python/megengine/functional/nn.py
浏览文件 @
98460f58
...
...
@@ -91,7 +91,7 @@ __all__ = [
]
def
expand_hw
(
x
):
def
_
expand_hw
(
x
):
# judge int is 5 times faster than judge Sequence
if
isinstance
(
x
,
int
):
return
x
,
x
...
...
@@ -100,7 +100,7 @@ def expand_hw(x):
return
int
(
x
),
int
(
x
)
def
expand_dhw
(
x
):
def
_
expand_dhw
(
x
):
if
isinstance
(
x
,
int
):
return
x
,
x
,
x
if
isinstance
(
x
,
Sequence
):
...
...
@@ -242,9 +242,9 @@ def conv2d(
or
conv_mode
.
name
==
"CROSS_CORRELATION"
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
expand_hw
(
dilation
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
pad_h
,
pad_w
=
_
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
_
expand_hw
(
dilation
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
...
...
@@ -304,9 +304,9 @@ def conv3d(
D
,
H
,
W
=
0
,
1
,
2
pad
=
expand_dhw
(
padding
)
stride
=
expand_dhw
(
stride
)
dilate
=
expand_dhw
(
dilation
)
pad
=
_
expand_dhw
(
padding
)
stride
=
_
expand_dhw
(
stride
)
dilate
=
_
expand_dhw
(
dilation
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
op
=
builtin
.
Convolution3D
(
...
...
@@ -374,10 +374,10 @@ def conv_transpose2d(
or
conv_mode
.
name
==
"CROSS_CORRELATION"
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
output_pad_h
,
output_pad_w
=
expand_hw
(
output_padding
)
dilate_h
,
dilate_w
=
expand_hw
(
dilation
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
pad_h
,
pad_w
=
_
expand_hw
(
padding
)
output_pad_h
,
output_pad_w
=
_
expand_hw
(
output_padding
)
dilate_h
,
dilate_w
=
_
expand_hw
(
dilation
)
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
...
...
@@ -475,9 +475,9 @@ def deformable_conv2d(
offset
=
offset
.
astype
(
"float32"
)
mask
=
mask
.
astype
(
"float32"
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
expand_hw
(
dilation
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
pad_h
,
pad_w
=
_
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
_
expand_hw
(
dilation
)
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
...
...
@@ -529,9 +529,9 @@ def local_conv2d(
or
conv_mode
.
name
==
"CROSS_CORRELATION"
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
expand_hw
(
dilation
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
pad_h
,
pad_w
=
_
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
_
expand_hw
(
dilation
)
# local conv only support "dense" mode, but weight could contain group dimension.
op
=
builtin
.
GroupLocal
(
...
...
@@ -585,10 +585,10 @@ def conv_transpose3d(
output tensor.
"""
D
,
H
,
W
=
0
,
1
,
2
pad
=
expand_dhw
(
padding
)
stride
=
expand_dhw
(
stride
)
dilate
=
expand_dhw
(
dilation
)
output_padding
=
expand_dhw
(
output_padding
)
pad
=
_
expand_dhw
(
padding
)
stride
=
_
expand_dhw
(
stride
)
dilate
=
_
expand_dhw
(
dilation
)
output_padding
=
_
expand_dhw
(
output_padding
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
op
=
builtin
.
Convolution3DBackwardData
(
...
...
@@ -667,9 +667,9 @@ def max_pool2d(
"""
if
stride
is
None
:
stride
=
kernel_size
window_h
,
window_w
=
expand_hw
(
kernel_size
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
padding_h
,
padding_w
=
expand_hw
(
padding
)
window_h
,
window_w
=
_
expand_hw
(
kernel_size
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
padding_h
,
padding_w
=
_
expand_hw
(
padding
)
op
=
builtin
.
Pooling
(
window_h
=
window_h
,
...
...
@@ -717,9 +717,9 @@ def avg_pool2d(
"""
if
stride
is
None
:
stride
=
kernel_size
window_h
,
window_w
=
expand_hw
(
kernel_size
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
padding_h
,
padding_w
=
expand_hw
(
padding
)
window_h
,
window_w
=
_
expand_hw
(
kernel_size
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
padding_h
,
padding_w
=
_
expand_hw
(
padding
)
op
=
builtin
.
Pooling
(
window_h
=
window_h
,
...
...
@@ -1708,10 +1708,10 @@ def sliding_window(
stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1
"""
padding_h
,
padding_w
=
expand_hw
(
padding
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
dilation_h
,
dilation_w
=
expand_hw
(
dilation
)
window_h
,
window_w
=
expand_hw
(
kernel_size
)
padding_h
,
padding_w
=
_
expand_hw
(
padding
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
dilation_h
,
dilation_w
=
_
expand_hw
(
dilation
)
window_h
,
window_w
=
_
expand_hw
(
kernel_size
)
op
=
builtin
.
Images2Neibs
(
pad_h
=
padding_h
,
...
...
@@ -1747,11 +1747,11 @@ def sliding_window_transpose(
stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1
"""
output_h
,
output_w
=
expand_hw
(
output_size
)
padding_h
,
padding_w
=
expand_hw
(
padding
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
dilation_h
,
dilation_w
=
expand_hw
(
dilation
)
window_h
,
window_w
=
expand_hw
(
kernel_size
)
output_h
,
output_w
=
_
expand_hw
(
output_size
)
padding_h
,
padding_w
=
_
expand_hw
(
padding
)
stride_h
,
stride_w
=
_
expand_hw
(
stride
)
dilation_h
,
dilation_w
=
_
expand_hw
(
dilation
)
window_h
,
window_w
=
_
expand_hw
(
kernel_size
)
expected_h
=
(
output_h
+
2
*
padding_h
-
dilation_h
*
(
window_h
-
1
)
-
1
...
...
@@ -1904,7 +1904,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order):
return
layerPixelShuffle
def
layerPixelShuffle_traceable
(
inp
,
upscale_factor
):
def
_
layerPixelShuffle_traceable
(
inp
,
upscale_factor
):
assert
upscale_factor
>
0
,
"upscale_factor should larger than 0"
assert
inp
.
ndim
>=
3
,
"the input dimension of pixel_shuffle should be larger than 3"
assert
(
...
...
@@ -1955,7 +1955,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
:param upscale_factor: upscale factor of pixel_shuffle.
:return: output tensor.
"""
return
pixel_shuffle_cpp
(
inp
,
upscale_factor
,
layerPixelShuffle_traceable
)
return
pixel_shuffle_cpp
(
inp
,
upscale_factor
,
_
layerPixelShuffle_traceable
)
from
.quantized
import
conv_bias_activation
# isort:skip
...
...
imperative/python/megengine/functional/tensor_cache.py
浏览文件 @
98460f58
from
..core._imperative_rt.core2
import
Const
from
..jit.tracing
import
is_tracing
from
..jit.tracing
import
_
is_tracing
small_tensor_cache
=
{}
def
_get_scalar_tensor_with_value
(
value
,
dtype
=
None
,
device
=
None
):
global
small_tensor_cache
if
is_tracing
():
if
_
is_tracing
():
ret
=
Const
(
value
,
dtype
,
device
)
else
:
cache_key
=
(
value
,
dtype
,
device
)
...
...
imperative/python/megengine/hub/fetcher.py
浏览文件 @
98460f58
...
...
@@ -36,7 +36,7 @@ pattern = re.compile(
)
class
RepoFetcherBase
:
class
_
RepoFetcherBase
:
@
classmethod
def
fetch
(
cls
,
...
...
@@ -84,7 +84,7 @@ class RepoFetcherBase:
return
hashlib
.
sha1
(
repo_dir
.
encode
()).
hexdigest
()[:
16
]
class
GitSSHFetcher
(
RepoFetcherBase
):
class
GitSSHFetcher
(
_
RepoFetcherBase
):
@
classmethod
@
synchronized
def
fetch
(
...
...
@@ -193,7 +193,7 @@ class GitSSHFetcher(RepoFetcherBase):
)
class
GitHTTPSFetcher
(
RepoFetcherBase
):
class
GitHTTPSFetcher
(
_
RepoFetcherBase
):
@
classmethod
@
synchronized
def
fetch
(
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
98460f58
...
...
@@ -49,7 +49,7 @@ active_trace = None
skip_tracing
=
False
def
is_tracing
():
def
_
is_tracing
():
if
active_trace
is
None
:
return
False
else
:
...
...
@@ -73,7 +73,7 @@ def exclude_from_trace():
skip_tracing
=
False
def
array_comparator
(
lhs
,
rhs
):
def
_
array_comparator
(
lhs
,
rhs
):
return
np
.
all
(
lhs
==
rhs
)
...
...
@@ -184,7 +184,7 @@ class trace:
self
.
_trace
.
no_exec
=
record_only
self
.
_trace
.
options_visitor
=
apply_options
self
.
_trace
.
profile
=
profiling
self
.
_trace
.
array_comparator
=
array_comparator
self
.
_trace
.
array_comparator
=
_
array_comparator
self
.
_trace
.
record_input_shapes
=
_input_node_use_static_shape
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
...
...
imperative/python/megengine/logger.py
浏览文件 @
98460f58
...
...
@@ -18,10 +18,10 @@ def set_log_file(fout, mode="a"):
"""
if
isinstance
(
fout
,
str
):
fout
=
open
(
fout
,
mode
)
MegEngineLogFormatter
.
log_fout
=
fout
_
MegEngineLogFormatter
.
log_fout
=
fout
class
MegEngineLogFormatter
(
logging
.
Formatter
):
class
_
MegEngineLogFormatter
(
logging
.
Formatter
):
log_fout
=
None
date_full
=
"[%(asctime)s %(lineno)d@%(filename)s:%(name)s] "
date
=
"%(asctime)s "
...
...
@@ -71,7 +71,7 @@ class MegEngineLogFormatter(logging.Formatter):
if
self
.
log_fout
:
self
.
__set_fmt
(
self
.
date_full
+
mtxt
+
self
.
msg
)
formatted
=
super
(
MegEngineLogFormatter
,
self
).
format
(
record
)
formatted
=
super
(
_
MegEngineLogFormatter
,
self
).
format
(
record
)
nr_line
=
formatted
.
count
(
"
\n
"
)
+
1
if
nr_line
>=
self
.
max_lines
:
head
,
body
=
formatted
.
split
(
"
\n
"
,
1
)
...
...
@@ -88,7 +88,7 @@ class MegEngineLogFormatter(logging.Formatter):
self
.
log_fout
.
flush
()
self
.
__set_fmt
(
self
.
_color_date
(
self
.
date
)
+
mcl
(
mtxt
+
self
.
msg
))
formatted
=
super
(
MegEngineLogFormatter
,
self
).
format
(
record
)
formatted
=
super
(
_
MegEngineLogFormatter
,
self
).
format
(
record
)
if
record
.
exc_text
or
record
.
exc_info
:
# handle exception format
...
...
@@ -125,7 +125,7 @@ class MegEngineLogFormatter(logging.Formatter):
self
.
_style
.
_fmt
=
fmt
def
get_logger
(
name
=
None
,
formatter
=
MegEngineLogFormatter
):
def
get_logger
(
name
=
None
,
formatter
=
_
MegEngineLogFormatter
):
r
"""Gets megengine logger with given name."""
logger
=
logging
.
getLogger
(
name
)
...
...
@@ -167,16 +167,16 @@ try:
from
.core._imperative_rt.utils
import
Logger
as
_imperative_rt_logger
class
MegBrainLogFormatter
(
MegEngineLogFormatter
):
class
_MegBrainLogFormatter
(
_
MegEngineLogFormatter
):
date
=
"%(asctime)s[mgb] "
def
_color_date
(
self
,
msg
):
return
"
\x1b
[33m{}
\x1b
[0m"
.
format
(
msg
)
_megbrain_logger
=
get_logger
(
"megbrain"
,
MegBrainLogFormatter
)
_megbrain_logger
=
get_logger
(
"megbrain"
,
_
MegBrainLogFormatter
)
_imperative_rt_logger
.
set_log_handler
(
_megbrain_logger
)
def
set_mgb_log_level
(
level
):
def
_
set_mgb_log_level
(
level
):
r
"""Sets megbrain log level
Args:
...
...
@@ -200,30 +200,30 @@ try:
)
return
rst
set_mgb_log_level
(
_default_level
)
_
set_mgb_log_level
(
_default_level
)
except
ImportError
as
exc
:
def
set_mgb_log_level
(
level
):
def
_
set_mgb_log_level
(
level
):
raise
NotImplementedError
(
"imperative_rt has not been imported"
)
@
contextlib
.
contextmanager
def
replace_mgb_log_level
(
level
):
def
_
replace_mgb_log_level
(
level
):
r
"""Replaces megbrain log level in a block and restore after exiting.
Args:
level: new log level
"""
old
=
set_mgb_log_level
(
level
)
old
=
_
set_mgb_log_level
(
level
)
try
:
yield
finally
:
set_mgb_log_level
(
old
)
_
set_mgb_log_level
(
old
)
def
enable_debug_log
():
r
"""Sets logging level to debug for all components."""
set_log_level
(
logging
.
DEBUG
)
set_mgb_log_level
(
logging
.
DEBUG
)
_
set_mgb_log_level
(
logging
.
DEBUG
)
imperative/python/megengine/module/rnn.py
浏览文件 @
98460f58
...
...
@@ -15,12 +15,12 @@ from . import init
from
.module
import
Module
class
RNNCellBase
(
Module
):
class
_
RNNCellBase
(
Module
):
def
__init__
(
self
,
input_size
:
int
,
hidden_size
:
int
,
bias
:
bool
,
num_chunks
:
int
,
)
->
None
:
# num_chunks indicates the number of gates
super
(
RNNCellBase
,
self
).
__init__
()
super
(
_
RNNCellBase
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
...
...
@@ -57,7 +57,7 @@ class RNNCellBase(Module):
raise
NotImplementedError
(
"forward not implemented !"
)
class
RNNCell
(
RNNCellBase
):
class
RNNCell
(
_
RNNCellBase
):
r
"""An Elman RNN cell with tanh or ReLU non-linearity.
...
...
@@ -135,7 +135,7 @@ class RNNCell(RNNCellBase):
)[
0
]
class
LSTMCell
(
RNNCellBase
):
class
LSTMCell
(
_
RNNCellBase
):
r
"""A long short-term memory (LSTM) cell.
...
...
@@ -216,7 +216,7 @@ class LSTMCell(RNNCellBase):
)[:
2
]
class
RNNBase
(
Module
):
class
_
RNNBase
(
Module
):
def
__init__
(
self
,
input_size
:
int
,
...
...
@@ -228,7 +228,7 @@ class RNNBase(Module):
bidirectional
:
bool
=
False
,
proj_size
:
int
=
0
,
)
->
None
:
super
(
RNNBase
,
self
).
__init__
()
super
(
_
RNNBase
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
self
.
num_layers
=
num_layers
...
...
@@ -323,7 +323,7 @@ class RNNBase(Module):
return
output
,
h
class
RNN
(
RNNBase
):
class
RNN
(
_
RNNBase
):
r
"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
input sequence.
...
...
@@ -453,7 +453,7 @@ class RNN(RNNBase):
return
output
,
h
class
LSTM
(
RNNBase
):
class
LSTM
(
_
RNNBase
):
r
"""Applies a multi-layer long short-term memory LSTM to an input
sequence.
...
...
imperative/python/megengine/quantization/observer.py
浏览文件 @
98460f58
...
...
@@ -7,7 +7,7 @@ import numpy as np
from
..
import
functional
as
F
from
..core.tensor.dtype
import
QuantDtypeMeta
,
_builtin_quant_dtypes
from
..distributed
import
WORLD
,
get_rank
,
is_distributed
from
..distributed
import
WORLD
,
is_distributed
from
..functional.distributed
import
all_reduce_max
,
all_reduce_min
from
..logger
import
get_logger
from
..module
import
Module
...
...
imperative/python/megengine/serialization.py
浏览文件 @
98460f58
...
...
@@ -27,7 +27,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL):
pickle_module
.
dump
(
obj
,
f
,
pickle_protocol
)
class
dmap
:
class
_
dmap
:
def
__init__
(
self
,
map_location
):
self
.
map_location
=
map_location
...
...
@@ -101,5 +101,5 @@ def load(f, map_location=None, pickle_module=pickle):
map_location
=
_get_callable_map_location
(
map_location
)
# callable map_location
with
dmap
(
map_location
)
as
dm
:
with
_
dmap
(
map_location
)
as
dm
:
return
pickle_module
.
load
(
f
)
imperative/python/megengine/utils/module_stats.py
浏览文件 @
98460f58
...
...
@@ -11,11 +11,11 @@ from .. import functional as F
from
..
import
get_logger
from
..
import
module
as
M
from
..core.tensor.dtype
import
get_dtype_bit
from
..logger
import
MegEngineLogFormatter
from
..logger
import
_
MegEngineLogFormatter
from
.module_utils
import
set_module_mode_safe
try
:
MegEngineLogFormatter
.
max_lines
=
float
(
"inf"
)
_
MegEngineLogFormatter
.
max_lines
=
float
(
"inf"
)
except
AttributeError
as
e
:
raise
ValueError
(
"set logger max lines failed"
)
...
...
imperative/python/test/unit/core/test_util.py
浏览文件 @
98460f58
...
...
@@ -2,14 +2,14 @@
import
logging
from
megengine.core._imperative_rt
import
Logger
from
megengine.logger
import
_imperative_rt_logger
,
set_mgb_log_level
from
megengine.logger
import
_imperative_rt_logger
,
_
set_mgb_log_level
def
test_logger
():
orig_level
=
Logger
().
set_log_level
(
Logger
.
LogLevel
.
Debug
)
assert
Logger
().
set_log_level
(
Logger
.
LogLevel
.
Debug
)
==
Logger
.
LogLevel
.
Debug
Logger
().
set_log_level
(
orig_level
)
orig_level
=
set_mgb_log_level
(
logging
.
DEBUG
)
orig_level
=
_
set_mgb_log_level
(
logging
.
DEBUG
)
assert
(
_imperative_rt_logger
.
set_log_level
(
Logger
.
LogLevel
.
Debug
)
==
Logger
.
LogLevel
.
Debug
...
...
imperative/python/test/unit/distributed/test_distributed.py
浏览文件 @
98460f58
...
...
@@ -50,7 +50,7 @@ def test_init_process_group(backend):
assert
mm_server_addr
[
0
]
==
"localhost"
assert
mm_server_addr
[
1
]
>
0
assert
isinstance
(
dist
.
get_client
(),
dist
.
Client
)
assert
isinstance
(
dist
.
get_client
(),
dist
.
server
.
_
Client
)
procs
=
[]
for
rank
in
range
(
world_size
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录