Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ee4ea7fd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4704
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ee4ea7fd
编写于
11月 25, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(distributed/test): make distributed test more stronger
GitOrigin-RevId: 085fd1dcfd3a80467e84ffe463b5dcedf615bd48
上级
3ecded74
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
77 addition
and
96 deletion
+77
-96
imperative/python/megengine/distributed/launcher.py
imperative/python/megengine/distributed/launcher.py
+7
-1
imperative/python/megengine/distributed/server.py
imperative/python/megengine/distributed/server.py
+16
-8
imperative/python/test/integration/test_dp_correctness.py
imperative/python/test/integration/test_dp_correctness.py
+3
-13
imperative/python/test/unit/core/test_autodiff.py
imperative/python/test/unit/core/test_autodiff.py
+8
-3
imperative/python/test/unit/distributed/test_distributed.py
imperative/python/test/unit/distributed/test_distributed.py
+13
-32
imperative/python/test/unit/functional/test_functional_distributed.py
...ython/test/unit/functional/test_functional_distributed.py
+22
-22
imperative/python/test/unit/module/test_batchnorm.py
imperative/python/test/unit/module/test_batchnorm.py
+1
-1
imperative/python/test/unit/quantization/test_observer.py
imperative/python/test/unit/quantization/test_observer.py
+7
-16
未找到文件。
imperative/python/megengine/distributed/launcher.py
浏览文件 @
ee4ea7fd
...
@@ -45,9 +45,15 @@ def launcher(func):
...
@@ -45,9 +45,15 @@ def launcher(func):
while
len
(
ranks
)
>
0
:
while
len
(
ranks
)
>
0
:
left
=
[]
left
=
[]
# check all processes in one second
time_to_wait
=
1.0
/
len
(
ranks
)
for
rank
in
ranks
:
for
rank
in
ranks
:
procs
[
rank
].
join
(
1
)
procs
[
rank
].
join
(
time_to_wait
)
code
=
procs
[
rank
].
exitcode
code
=
procs
[
rank
].
exitcode
# terminate processes if one of them has failed
if
code
!=
0
and
code
!=
None
:
for
i
in
ranks
:
procs
[
i
].
terminate
()
assert
(
assert
(
code
==
0
or
code
==
None
code
==
0
or
code
==
None
),
"subprocess {} exit with code {}"
.
format
(
rank
,
code
)
),
"subprocess {} exit with code {}"
.
format
(
rank
,
code
)
...
...
imperative/python/megengine/distributed/server.py
浏览文件 @
ee4ea7fd
...
@@ -133,18 +133,22 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
...
@@ -133,18 +133,22 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
pass
def
start_server
(
py_server_port
,
mm_server_port
,
queue
):
def
_
start_server
(
py_server_port
,
mm_server_port
,
queue
):
"""
"""
Start python distributed server and multiple machine server.
Start python distributed server and multiple machine server.
:param py_server_port: python server port.
:param py_server_port: python server port.
:param mm_server_port: multiple machine server port.
:param mm_server_port: multiple machine server port.
:param queue: server port will put in this queue, puts exception when process fails.
"""
"""
server
=
ThreadXMLRPCServer
((
"0.0.0.0"
,
py_server_port
),
logRequests
=
False
)
try
:
server
.
register_instance
(
Methods
(
mm_server_port
))
server
=
ThreadXMLRPCServer
((
"0.0.0.0"
,
py_server_port
),
logRequests
=
False
)
_
,
port
=
server
.
server_address
server
.
register_instance
(
Methods
(
mm_server_port
))
queue
.
put
(
port
)
_
,
port
=
server
.
server_address
server
.
serve_forever
()
queue
.
put
(
port
)
server
.
serve_forever
()
except
Exception
as
e
:
queue
.
put
(
e
)
class
Server
:
class
Server
:
...
@@ -159,10 +163,14 @@ class Server:
...
@@ -159,10 +163,14 @@ class Server:
self
.
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
self
.
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
q
=
Queue
()
q
=
Queue
()
self
.
proc
=
threading
.
Thread
(
self
.
proc
=
threading
.
Thread
(
target
=
start_server
,
args
=
(
port
,
self
.
mm_server_port
,
q
),
daemon
=
True
,
target
=
_
start_server
,
args
=
(
port
,
self
.
mm_server_port
,
q
),
daemon
=
True
,
)
)
self
.
proc
.
start
()
self
.
proc
.
start
()
self
.
py_server_port
=
q
.
get
()
ret
=
q
.
get
()
if
isinstance
(
ret
,
Exception
):
raise
ret
else
:
self
.
py_server_port
=
ret
class
Client
:
class
Client
:
...
...
imperative/python/test/integration/test_dp_correctness.py
浏览文件 @
ee4ea7fd
...
@@ -159,11 +159,9 @@ def run_test(
...
@@ -159,11 +159,9 @@ def run_test(
checkpoint
=
mge
.
load
(
model_path
)
checkpoint
=
mge
.
load
(
model_path
)
data
=
checkpoint
[
"data"
]
data
=
checkpoint
[
"data"
]
label
=
checkpoint
[
"label"
]
label
=
checkpoint
[
"label"
]
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
(
port
)
def
worker
(
rank
,
max_err
):
@
dist
.
launcher
dist
.
init_process_group
(
"localhost"
,
port
,
p_num
,
rank
,
rank
)
def
worker
(
max_err
):
net
=
MnistNet
(
has_bn
=
True
)
net
=
MnistNet
(
has_bn
=
True
)
net
.
load_state_dict
(
checkpoint
[
"net_init"
])
net
.
load_state_dict
(
checkpoint
[
"net_init"
])
lr
=
checkpoint
[
"sgd_lr"
]
lr
=
checkpoint
[
"sgd_lr"
]
...
@@ -194,15 +192,7 @@ def run_test(
...
@@ -194,15 +192,7 @@ def run_test(
else
:
else
:
np
.
testing
.
assert_allclose
(
param
[
1
],
param_ref
[
1
],
atol
=
max_err
)
np
.
testing
.
assert_allclose
(
param
[
1
],
param_ref
[
1
],
atol
=
max_err
)
procs
=
[]
worker
(
max_err
)
for
rank
in
range
(
p_num
):
p
=
mp
.
Process
(
target
=
worker
,
args
=
(
rank
,
max_err
,))
p
.
start
()
procs
.
append
(
p
)
for
p
in
procs
:
p
.
join
(
20
)
assert
p
.
exitcode
==
0
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
4
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
4
,
reason
=
"need more gpu device"
)
...
...
imperative/python/test/unit/core/test_autodiff.py
浏览文件 @
ee4ea7fd
...
@@ -23,6 +23,7 @@ from megengine.core.ops.builtin import Elemwise
...
@@ -23,6 +23,7 @@ from megengine.core.ops.builtin import Elemwise
from
megengine.core.tensor.raw_tensor
import
as_raw_tensor
from
megengine.core.tensor.raw_tensor
import
as_raw_tensor
from
megengine.core.tensor.tensor
import
Tensor
,
apply
from
megengine.core.tensor.tensor
import
Tensor
,
apply
from
megengine.core.tensor.tensor_wrapper
import
TensorWrapper
from
megengine.core.tensor.tensor_wrapper
import
TensorWrapper
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.functional.distributed
import
remote_recv
,
remote_send
from
megengine.functional.distributed
import
remote_recv
,
remote_send
...
@@ -53,15 +54,19 @@ def save_to(self, name="grad"):
...
@@ -53,15 +54,19 @@ def save_to(self, name="grad"):
return
callback
return
callback
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Darwin"
,
reason
=
"do not imp GPU mode at macos now"
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"windows disable MGB_ENABLE_OPR_MM"
platform
.
system
()
==
"Windows"
,
reason
=
"windows disable MGB_ENABLE_OPR_MM"
)
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
isolated_distributed
def
test_dist_grad
():
def
test_dist_grad
():
world_size
=
2
world_size
=
2
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker0
():
def
worker0
():
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
0
,
0
)
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
0
,
0
)
...
...
imperative/python/test/unit/distributed/test_distributed.py
浏览文件 @
ee4ea7fd
...
@@ -47,8 +47,8 @@ def _assert_q_val(q, val):
...
@@ -47,8 +47,8 @@ def _assert_q_val(q, val):
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_init_process_group
():
def
test_init_process_group
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
backend
):
def
worker
(
rank
,
backend
):
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
,
backend
)
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
,
backend
)
...
@@ -92,11 +92,10 @@ def test_init_process_group():
...
@@ -92,11 +92,10 @@ def test_init_process_group():
def
test_new_group
():
def
test_new_group
():
world_size
=
3
world_size
=
3
ranks
=
[
2
,
0
]
ranks
=
[
2
,
0
]
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
(
port
)
def
worker
(
rank
):
@
dist
.
launcher
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
)
def
worker
():
rank
=
dist
.
get_rank
()
if
rank
in
ranks
:
if
rank
in
ranks
:
group
=
dist
.
new_group
(
ranks
)
group
=
dist
.
new_group
(
ranks
)
assert
group
.
size
==
2
assert
group
.
size
==
2
...
@@ -104,15 +103,7 @@ def test_new_group():
...
@@ -104,15 +103,7 @@ def test_new_group():
assert
group
.
rank
==
ranks
.
index
(
rank
)
assert
group
.
rank
==
ranks
.
index
(
rank
)
assert
group
.
comp_node
==
"gpu{}:2"
.
format
(
rank
)
assert
group
.
comp_node
==
"gpu{}:2"
.
format
(
rank
)
procs
=
[]
worker
()
for
rank
in
range
(
world_size
):
p
=
mp
.
Process
(
target
=
worker
,
args
=
(
rank
,))
p
.
start
()
procs
.
append
(
p
)
for
p
in
procs
:
p
.
join
(
20
)
assert
p
.
exitcode
==
0
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
...
@@ -125,8 +116,8 @@ def test_new_group():
...
@@ -125,8 +116,8 @@ def test_new_group():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_group_barrier
():
def
test_group_barrier
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
q
):
def
worker
(
rank
,
q
):
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
)
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
)
...
@@ -161,8 +152,8 @@ def test_group_barrier():
...
@@ -161,8 +152,8 @@ def test_group_barrier():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_synchronized
():
def
test_synchronized
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
@
dist
.
synchronized
@
dist
.
synchronized
def
func
(
rank
,
q
):
def
func
(
rank
,
q
):
...
@@ -205,26 +196,16 @@ def test_synchronized():
...
@@ -205,26 +196,16 @@ def test_synchronized():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_user_set_get
():
def
test_user_set_get
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
(
port
)
def
worker
(
rank
):
@
dist
.
launcher
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
)
def
worker
():
# set in race condition
# set in race condition
dist
.
get_client
().
user_set
(
"foo"
,
1
)
dist
.
get_client
().
user_set
(
"foo"
,
1
)
# get in race condition
# get in race condition
ret
=
dist
.
get_client
().
user_get
(
"foo"
)
ret
=
dist
.
get_client
().
user_get
(
"foo"
)
assert
ret
==
1
assert
ret
==
1
procs
=
[]
worker
()
for
rank
in
range
(
world_size
):
p
=
mp
.
Process
(
target
=
worker
,
args
=
(
rank
,))
p
.
start
()
procs
.
append
(
p
)
for
p
in
procs
:
p
.
join
(
20
)
assert
p
.
exitcode
==
0
def
test_oprmm_hashable
():
def
test_oprmm_hashable
():
...
...
imperative/python/test/unit/functional/test_functional_distributed.py
浏览文件 @
ee4ea7fd
...
@@ -41,8 +41,8 @@ from megengine.functional.distributed import (
...
@@ -41,8 +41,8 @@ from megengine.functional.distributed import (
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_reduce_sum
():
def
test_reduce_sum
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -83,8 +83,8 @@ def test_reduce_sum():
...
@@ -83,8 +83,8 @@ def test_reduce_sum():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_broadcast
():
def
test_broadcast
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -121,8 +121,8 @@ def test_broadcast():
...
@@ -121,8 +121,8 @@ def test_broadcast():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_all_gather
():
def
test_all_gather
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -160,8 +160,8 @@ def test_all_gather():
...
@@ -160,8 +160,8 @@ def test_all_gather():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_reduce_scatter_sum
():
def
test_reduce_scatter_sum
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -199,8 +199,8 @@ def test_reduce_scatter_sum():
...
@@ -199,8 +199,8 @@ def test_reduce_scatter_sum():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_all_reduce_sum
():
def
test_all_reduce_sum
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -238,8 +238,8 @@ def test_all_reduce_sum():
...
@@ -238,8 +238,8 @@ def test_all_reduce_sum():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_all_reduce_max
():
def
test_all_reduce_max
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -277,8 +277,8 @@ def test_all_reduce_max():
...
@@ -277,8 +277,8 @@ def test_all_reduce_max():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_all_reduce_min
():
def
test_all_reduce_min
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -316,8 +316,8 @@ def test_all_reduce_min():
...
@@ -316,8 +316,8 @@ def test_all_reduce_min():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_gather
():
def
test_gather
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -358,8 +358,8 @@ def test_gather():
...
@@ -358,8 +358,8 @@ def test_gather():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_scatter
():
def
test_scatter
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -396,8 +396,8 @@ def test_scatter():
...
@@ -396,8 +396,8 @@ def test_scatter():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_all_to_all
():
def
test_all_to_all
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
expect
,
port
):
def
worker
(
rank
,
data
,
expect
,
port
):
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
if
mge
.
get_device_count
(
"gpu"
)
<
world_size
:
...
@@ -436,8 +436,8 @@ def test_all_to_all():
...
@@ -436,8 +436,8 @@ def test_all_to_all():
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_io_remote
():
def
test_io_remote
():
world_size
=
2
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
()
server
=
dist
.
Server
(
port
)
port
=
server
.
py_server_port
val
=
np
.
random
.
rand
(
4
,
5
).
astype
(
np
.
float32
)
val
=
np
.
random
.
rand
(
4
,
5
).
astype
(
np
.
float32
)
def
worker
(
rank
):
def
worker
(
rank
):
...
...
imperative/python/test/unit/module/test_batchnorm.py
浏览文件 @
ee4ea7fd
...
@@ -38,7 +38,7 @@ def test_syncbn():
...
@@ -38,7 +38,7 @@ def test_syncbn():
running_var
=
np
.
ones
((
1
,
nr_chan
,
1
,
1
),
dtype
=
np
.
float32
)
running_var
=
np
.
ones
((
1
,
nr_chan
,
1
,
1
),
dtype
=
np
.
float32
)
steps
=
4
steps
=
4
nr_ranks
=
2
nr_ranks
=
2
server
=
dist
.
Server
(
0
)
server
=
dist
.
Server
()
port
=
server
.
py_server_port
port
=
server
.
py_server_port
def
worker
(
rank
,
data
,
yv_expect
,
running_mean
,
running_var
):
def
worker
(
rank
,
data
,
yv_expect
,
running_mean
,
running_var
):
...
...
imperative/python/test/unit/quantization/test_observer.py
浏览文件 @
ee4ea7fd
...
@@ -28,25 +28,16 @@ def test_min_max_observer():
...
@@ -28,25 +28,16 @@ def test_min_max_observer():
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_sync_min_max_observer
():
def
test_sync_min_max_observer
():
x
=
np
.
random
.
rand
(
6
,
3
,
3
,
3
).
astype
(
"float32"
)
word_size
=
get_device_count_by_fork
(
"gpu"
)
x
=
np
.
random
.
rand
(
3
*
word_size
,
3
,
3
,
3
).
astype
(
"float32"
)
np_min
,
np_max
=
x
.
min
(),
x
.
max
()
np_min
,
np_max
=
x
.
min
(),
x
.
max
()
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
(
port
)
def
worker
(
rank
,
slc
):
@
dist
.
launcher
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
)
def
worker
():
rank
=
dist
.
get_rank
()
m
=
ob
.
SyncMinMaxObserver
()
m
=
ob
.
SyncMinMaxObserver
()
y
=
mge
.
tensor
(
x
[
slc
])
y
=
mge
.
tensor
(
x
[
rank
*
3
:
(
rank
+
1
)
*
3
])
m
(
y
)
m
(
y
)
assert
m
.
min_val
==
np_min
and
m
.
max_val
==
np_max
assert
m
.
min_val
==
np_min
and
m
.
max_val
==
np_max
procs
=
[]
worker
()
for
rank
in
range
(
world_size
):
slc
=
slice
(
rank
*
3
,
(
rank
+
1
)
*
3
)
p
=
mp
.
Process
(
target
=
worker
,
args
=
(
rank
,
slc
,),
daemon
=
True
)
p
.
start
()
procs
.
append
(
p
)
for
p
in
procs
:
p
.
join
(
20
)
assert
p
.
exitcode
==
0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录