Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e87436a5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e87436a5
编写于
8月 16, 2020
作者:
K
Kaipeng Deng
提交者:
GitHub
8月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
DistributedBatchSampler add num_replicas and rank. test=develop (#26315)
上级
241b44db
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
45 addition
and
6 deletion
+45
-6
python/paddle/incubate/hapi/distributed.py
python/paddle/incubate/hapi/distributed.py
+28
-3
python/paddle/incubate/hapi/tests/test_model.py
python/paddle/incubate/hapi/tests/test_model.py
+17
-3
未找到文件。
python/paddle/incubate/hapi/distributed.py
浏览文件 @
e87436a5
...
...
@@ -49,6 +49,13 @@ class DistributedBatchSampler(BatchSampler):
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
num_replicas(int, optional): porcess number in distributed training.
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
retrieved from :code:`paddle.fluid.dygraph.parallel.ParallenEnv`.
Default None.
rank(int, optional): the rank of the current process among :attr:`num_replicas`
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
:code:`paddle.fluid.dygraph.parallel.ParallenEnv`. Default None.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
...
...
@@ -84,7 +91,13 @@ class DistributedBatchSampler(BatchSampler):
break
"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
drop_last
=
False
):
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
False
,
drop_last
=
False
):
self
.
dataset
=
dataset
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
...
...
@@ -96,9 +109,21 @@ class DistributedBatchSampler(BatchSampler):
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
if
num_replicas
is
not
None
:
assert
isinstance
(
num_replicas
,
int
)
and
num_replicas
>
0
,
\
"num_replicas should be a positive integer"
self
.
nranks
=
num_replicas
else
:
self
.
nranks
=
ParallelEnv
().
nranks
if
rank
is
not
None
:
assert
isinstance
(
rank
,
int
)
and
rank
>=
0
,
\
"rank should be a non-negative integer"
self
.
local_rank
=
rank
else
:
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
drop_last
=
drop_last
self
.
nranks
=
ParallelEnv
().
nranks
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
...
...
python/paddle/incubate/hapi/tests/test_model.py
浏览文件 @
e87436a5
...
...
@@ -169,6 +169,12 @@ class TestModel(unittest.TestCase):
def
test_fit_static
(
self
):
self
.
fit
(
False
)
def
test_fit_dynamic_with_rank
(
self
):
self
.
fit
(
True
,
2
,
0
)
def
test_fit_static_with_rank
(
self
):
self
.
fit
(
False
,
2
,
0
)
def
test_evaluate_dygraph
(
self
):
self
.
evaluate
(
True
)
...
...
@@ -184,7 +190,7 @@ class TestModel(unittest.TestCase):
def
test_prepare_context
(
self
):
prepare_distributed_context
()
def
fit
(
self
,
dynamic
):
def
fit
(
self
,
dynamic
,
num_replicas
=
None
,
rank
=
None
):
fluid
.
enable_dygraph
(
self
.
device
)
if
dynamic
else
None
seed
=
333
fluid
.
default_startup_program
().
random_seed
=
seed
...
...
@@ -204,9 +210,17 @@ class TestModel(unittest.TestCase):
np
.
testing
.
assert_allclose
(
result
[
'acc'
],
self
.
acc1
)
train_sampler
=
DistributedBatchSampler
(
self
.
train_dataset
,
batch_size
=
64
,
shuffle
=
False
)
self
.
train_dataset
,
batch_size
=
64
,
shuffle
=
False
,
num_replicas
=
num_replicas
,
rank
=
rank
)
val_sampler
=
DistributedBatchSampler
(
self
.
val_dataset
,
batch_size
=
64
,
shuffle
=
False
)
self
.
val_dataset
,
batch_size
=
64
,
shuffle
=
False
,
num_replicas
=
num_replicas
,
rank
=
rank
)
train_loader
=
fluid
.
io
.
DataLoader
(
self
.
train_dataset
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录