Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
6cc6540c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6cc6540c
编写于
11月 24, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add different seed for workers and replicas
上级
b542416d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
43 addition
and
5 deletion
+43
-5
ppcls/data/__init__.py
ppcls/data/__init__.py
+33
-2
ppcls/data/dataloader/pk_sampler.py
ppcls/data/dataloader/pk_sampler.py
+7
-3
ppcls/engine/engine.py
ppcls/engine/engine.py
+3
-0
未找到文件。
ppcls/data/__init__.py
浏览文件 @
6cc6540c
...
...
@@ -14,8 +14,11 @@
import
inspect
import
copy
import
random
import
paddle
import
numpy
as
np
import
paddle.distributed
as
dist
from
functools
import
partial
from
paddle.io
import
DistributedBatchSampler
,
BatchSampler
,
DataLoader
from
ppcls.utils
import
logger
...
...
@@ -66,6 +69,22 @@ def create_operators(params, class_num=None):
return
ops
def
worker_init_fn
(
worker_id
:
int
,
num_workers
:
int
,
rank
:
int
,
seed
:
int
):
"""callback function on each worker subprocess after seeding and before data loading.
Args:
worker_id (int): Worker id in [0, num_workers - 1]
num_workers (int): Number of subprocesses to use for data loading.
rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`.
seed (int): Random seed
"""
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed
=
num_workers
*
rank
+
worker_id
+
seed
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
def
build_dataloader
(
config
,
mode
,
device
,
use_dali
=
False
,
seed
=
None
):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
...
...
@@ -82,6 +101,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
seed
=
seed
)
class_num
=
config
.
get
(
"class_num"
,
None
)
epochs
=
config
.
get
(
"epochs"
,
None
)
config_dataset
=
config
[
mode
][
'dataset'
]
config_dataset
=
copy
.
deepcopy
(
config_dataset
)
dataset_name
=
config_dataset
.
pop
(
'name'
)
...
...
@@ -103,6 +123,9 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
shuffle
=
config_sampler
[
"shuffle"
]
else
:
sampler_name
=
config_sampler
.
pop
(
"name"
)
sampler_argspec
=
inspect
.
getargspec
(
eval
(
sampler_name
).
__init__
).
args
if
"total_epochs"
in
sampler_argspec
:
config_sampler
.
update
({
"total_epochs"
:
epochs
})
batch_sampler
=
eval
(
sampler_name
)(
dataset
,
**
config_sampler
)
logger
.
debug
(
"build batch_sampler({}) success..."
.
format
(
batch_sampler
))
...
...
@@ -131,6 +154,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
num_workers
=
config_loader
[
"num_workers"
]
use_shared_memory
=
config_loader
[
"use_shared_memory"
]
init_fn
=
partial
(
worker_init_fn
,
num_workers
=
num_workers
,
rank
=
dist
.
get_rank
(),
seed
=
seed
)
if
seed
is
not
None
else
None
if
batch_sampler
is
None
:
data_loader
=
DataLoader
(
dataset
=
dataset
,
...
...
@@ -141,7 +170,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
,
collate_fn
=
batch_collate_fn
)
collate_fn
=
batch_collate_fn
,
worker_init_fn
=
init_fn
)
else
:
data_loader
=
DataLoader
(
dataset
=
dataset
,
...
...
@@ -150,7 +180,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
return_list
=
True
,
use_shared_memory
=
use_shared_memory
,
batch_sampler
=
batch_sampler
,
collate_fn
=
batch_collate_fn
)
collate_fn
=
batch_collate_fn
,
worker_init_fn
=
init_fn
)
logger
.
debug
(
"build data_loader({}) success..."
.
format
(
data_loader
))
return
data_loader
ppcls/data/dataloader/pk_sampler.py
浏览文件 @
6cc6540c
...
...
@@ -38,6 +38,7 @@ class PKSampler(DistributedBatchSampler):
ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
total_epochs (int, optional): total epochs. Defaults to 0.
"""
def
__init__
(
self
,
...
...
@@ -48,7 +49,8 @@ class PKSampler(DistributedBatchSampler):
drop_last
=
True
,
id_list
=
None
,
ratio
=
None
,
sample_method
=
"sample_avg_prob"
):
sample_method
=
"sample_avg_prob"
,
total_epochs
=
0
):
super
().
__init__
(
dataset
,
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
assert
batch_size
%
sample_per_id
==
0
,
\
...
...
@@ -58,6 +60,7 @@ class PKSampler(DistributedBatchSampler):
self
.
sample_per_id
=
sample_per_id
self
.
label_dict
=
defaultdict
(
list
)
self
.
sample_method
=
sample_method
self
.
total_epochs
=
total_epochs
for
idx
,
label
in
enumerate
(
self
.
dataset
.
labels
):
self
.
label_dict
[
label
].
append
(
idx
)
self
.
label_list
=
list
(
self
.
label_dict
)
...
...
@@ -98,8 +101,9 @@ class PKSampler(DistributedBatchSampler):
def
__iter__
(
self
):
# shuffle manually, same as DistributedBatchSampler.__iter__
if
self
.
shuffle
:
np
.
random
.
RandomState
(
self
.
epoch
+
dist
.
get_rank
()).
shuffle
(
self
.
label_list
)
rank
=
dist
.
get_rank
()
np
.
random
.
RandomState
(
rank
*
self
.
total_epochs
+
self
.
epoch
).
shuffle
(
self
.
label_list
)
self
.
epoch
+=
1
label_per_batch
=
self
.
batch_size
//
self
.
sample_per_id
...
...
ppcls/engine/engine.py
浏览文件 @
6cc6540c
...
...
@@ -119,6 +119,9 @@ class Engine(object):
#TODO(gaotingquan): support rec
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
self
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
self
.
config
[
"DataLoader"
].
update
({
"epochs"
:
self
.
config
[
"Global"
][
"epochs"
]
})
# build dataloader
if
self
.
mode
==
'train'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录