Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
牧羊zove
fcos
提交
5f2a8263
F
fcos
项目概览
牧羊zove
/
fcos
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
F
fcos
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
5f2a8263
编写于
1月 25, 2019
作者:
W
wat3rBro
提交者:
Francisco Massa
1月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use all_gather to gather results from all gpus (#383)
上级
9b53d15c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
82 addition
and
109 deletion
+82
-109
maskrcnn_benchmark/engine/inference.py
maskrcnn_benchmark/engine/inference.py
+2
-2
maskrcnn_benchmark/utils/comm.py
maskrcnn_benchmark/utils/comm.py
+80
-107
未找到文件。
maskrcnn_benchmark/engine/inference.py
浏览文件 @
5f2a8263
...
...
@@ -9,7 +9,7 @@ from tqdm import tqdm
from
maskrcnn_benchmark.data.datasets.evaluation
import
evaluate
from
..utils.comm
import
is_main_process
from
..utils.comm
import
scatter
_gather
from
..utils.comm
import
all
_gather
from
..utils.comm
import
synchronize
...
...
@@ -30,7 +30,7 @@ def compute_on_dataset(model, data_loader, device):
def
_accumulate_predictions_from_multiple_gpus
(
predictions_per_gpu
):
all_predictions
=
scatter
_gather
(
predictions_per_gpu
)
all_predictions
=
all
_gather
(
predictions_per_gpu
)
if
not
is_main_process
():
return
# merge the list of dicts
...
...
maskrcnn_benchmark/utils/comm.py
浏览文件 @
5f2a8263
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import
os
import
pickle
import
tempfile
import
time
import
torch
import
torch.distributed
as
dist
def
get_world_size
():
if
not
torch
.
distributed
.
is_available
():
if
not
dist
.
is_available
():
return
1
if
not
torch
.
distributed
.
is_initialized
():
if
not
dist
.
is_initialized
():
return
1
return
torch
.
distributed
.
get_world_size
()
return
dist
.
get_world_size
()
def
get_rank
():
if
not
torch
.
distributed
.
is_available
():
if
not
dist
.
is_available
():
return
0
if
not
torch
.
distributed
.
is_initialized
():
if
not
dist
.
is_initialized
():
return
0
return
torch
.
distributed
.
get_rank
()
return
dist
.
get_rank
()
def
is_main_process
():
if
not
torch
.
distributed
.
is_available
():
return
True
if
not
torch
.
distributed
.
is_initialized
():
return
True
return
torch
.
distributed
.
get_rank
()
==
0
return
get_rank
()
==
0
def
synchronize
():
"""
Helper function to synchronize
between multiple
processes when
Helper function to synchronize
(barrier) among all
processes when
using distributed training
"""
if
not
torch
.
distributed
.
is_available
():
if
not
dist
.
is_available
():
return
if
not
torch
.
distributed
.
is_initialized
():
if
not
dist
.
is_initialized
():
return
world_size
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
if
world_size
==
1
:
return
...
...
@@ -55,7 +49,7 @@ def synchronize():
tensor
=
torch
.
tensor
(
0
,
device
=
"cuda"
)
else
:
tensor
=
torch
.
tensor
(
1
,
device
=
"cuda"
)
torch
.
distributed
.
broadcast
(
tensor
,
r
)
dist
.
broadcast
(
tensor
,
r
)
while
tensor
.
item
()
==
1
:
time
.
sleep
(
1
)
...
...
@@ -64,94 +58,73 @@ def synchronize():
_send_and_wait
(
1
)
def
_encode
(
encoded_data
,
data
):
# gets a byte representation for the data
encoded_bytes
=
pickle
.
dumps
(
data
)
# convert this byte string into a byte tensor
storage
=
torch
.
ByteStorage
.
from_buffer
(
encoded_bytes
)
tensor
=
torch
.
ByteTensor
(
storage
).
to
(
"cuda"
)
# encoding: first byte is the size and then rest is the data
s
=
tensor
.
numel
()
assert
s
<=
255
,
"Can't encode data greater than 255 bytes"
# put the encoded data in encoded_data
encoded_data
[
0
]
=
s
encoded_data
[
1
:
(
s
+
1
)]
=
tensor
def
_decode
(
encoded_data
):
size
=
encoded_data
[
0
]
encoded_tensor
=
encoded_data
[
1
:
(
size
+
1
)].
to
(
"cpu"
)
return
pickle
.
loads
(
bytearray
(
encoded_tensor
.
tolist
()))
# TODO try to use tensor in shared-memory instead of serializing to disk
# this involves getting the all_gather to work
def
scatter_gather
(
data
):
def
all_gather
(
data
):
"""
This function gathers data from multiple processes, and returns them
in a list, as they were obtained from each process.
This function is useful for retrieving data from multiple processes,
when launching the code with torch.distributed.launch
Note: this function is slow and should not be used in tight loops, i.e.,
do not use it in the training loop.
Arguments:
data: the object to be gathered from multiple processes.
It must be serializable
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
result (list): a list with as many elements as there are processes,
where each element i in the list corresponds to the data that was
gathered from the process of rank i.
list[data]: list of data gathered from each rank
"""
# strategy: the main process creates a temporary directory, and communicates
# the location of the temporary directory to all other processes.
# each process will then serialize the data to the folder defined by
# the main process, and then the main process reads all of the serialized
# files and returns them in a list
if
not
torch
.
distributed
.
is_available
():
return
[
data
]
if
not
torch
.
distributed
.
is_initialized
():
world_size
=
get_world_size
()
if
world_size
==
1
:
return
[
data
]
synchronize
()
# get rank of the current process
rank
=
torch
.
distributed
.
get_rank
()
# the data to communicate should be small
data_to_communicate
=
torch
.
empty
(
256
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
if
rank
==
0
:
# manually creates a temporary directory, that needs to be cleaned
# afterwards
tmp_dir
=
tempfile
.
mkdtemp
()
_encode
(
data_to_communicate
,
tmp_dir
)
synchronize
()
# the main process (rank=0) communicates the data to all processes
torch
.
distributed
.
broadcast
(
data_to_communicate
,
0
)
# get the data that was communicated
tmp_dir
=
_decode
(
data_to_communicate
)
# each process serializes to a different file
file_template
=
"file{}.pth"
tmp_file
=
os
.
path
.
join
(
tmp_dir
,
file_template
.
format
(
rank
))
torch
.
save
(
data
,
tmp_file
)
# synchronize before loading the data
synchronize
()
# only the master process returns the data
if
rank
==
0
:
data_list
=
[]
world_size
=
torch
.
distributed
.
get_world_size
()
for
r
in
range
(
world_size
):
file_path
=
os
.
path
.
join
(
tmp_dir
,
file_template
.
format
(
r
))
d
=
torch
.
load
(
file_path
)
data_list
.
append
(
d
)
# cleanup
os
.
remove
(
file_path
)
# cleanup
os
.
rmdir
(
tmp_dir
)
return
data_list
# serialized to a Tensor
buffer
=
pickle
.
dumps
(
data
)
storage
=
torch
.
ByteStorage
.
from_buffer
(
buffer
)
tensor
=
torch
.
ByteTensor
(
storage
).
to
(
"cuda"
)
# obtain Tensor size of each rank
local_size
=
torch
.
IntTensor
([
tensor
.
numel
()]).
to
(
"cuda"
)
size_list
=
[
torch
.
IntTensor
([
0
]).
to
(
"cuda"
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
size_list
,
local_size
)
size_list
=
[
int
(
size
.
item
())
for
size
in
size_list
]
max_size
=
max
(
size_list
)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list
=
[]
for
_
in
size_list
:
tensor_list
.
append
(
torch
.
ByteTensor
(
size
=
(
max_size
,)).
to
(
"cuda"
))
if
local_size
!=
max_size
:
padding
=
torch
.
ByteTensor
(
size
=
(
max_size
-
local_size
,)).
to
(
"cuda"
)
tensor
=
torch
.
cat
((
tensor
,
padding
),
dim
=
0
)
dist
.
all_gather
(
tensor_list
,
tensor
)
data_list
=
[]
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
buffer
=
tensor
.
cpu
().
numpy
().
tobytes
()[:
size
]
data_list
.
append
(
pickle
.
loads
(
buffer
))
return
data_list
def
reduce_dict
(
input_dict
,
average
=
True
):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size
=
get_world_size
()
if
world_size
<
2
:
return
input_dict
with
torch
.
no_grad
():
names
=
[]
values
=
[]
# sort the keys so that they are consistent across processes
for
k
in
sorted
(
input_dict
.
keys
()):
names
.
append
(
k
)
values
.
append
(
input_dict
[
k
])
values
=
torch
.
stack
(
values
,
dim
=
0
)
dist
.
reduce
(
values
,
dst
=
0
)
if
dist
.
get_rank
()
==
0
and
average
:
# only main process gets accumulated, so only divide by
# world_size in this case
values
/=
world_size
reduced_dict
=
{
k
:
v
for
k
,
v
in
zip
(
names
,
values
)}
return
reduced_dict
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录