Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
8c22455e
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c22455e
编写于
6月 17, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update train.py
上级
394a34e0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
1 addition
and
111 deletion
+1
-111
dygraph/train.py
dygraph/train.py
+1
-3
dygraph/utils/distributed.py
dygraph/utils/distributed.py
+0
-108
未找到文件。
dygraph/train.py
浏览文件 @
8c22455e
...
...
@@ -15,11 +15,10 @@
import
argparse
import
os
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.io
import
DataLoader
from
paddle.incubate.hapi.distributed
import
DistributedBatchSampler
from
datasets
import
OpticDiscSeg
import
transforms
as
T
...
...
@@ -27,7 +26,6 @@ import models
import
utils.logging
as
logging
from
utils
import
get_environ_info
from
utils
import
load_pretrained_model
from
utils
import
DistributedBatchSampler
from
val
import
evaluate
...
...
dygraph/utils/distributed.py
已删除
100644 → 0
浏览文件 @
394a34e0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dataloader
import
BatchSampler
_parallel_context_initialized
=
False
class
DistributedBatchSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
data_source: this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
drop_last
=
False
):
self
.
dataset
=
dataset
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer"
self
.
batch_size
=
batch_size
assert
isinstance
(
shuffle
,
bool
),
\
"shuffle should be a boolean value"
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
self
.
drop_last
=
drop_last
self
.
nranks
=
ParallelEnv
().
nranks
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
def
__iter__
(
self
):
num_samples
=
len
(
self
.
dataset
)
indices
=
np
.
arange
(
num_samples
).
tolist
()
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
if
self
.
shuffle
:
np
.
random
.
shuffle
(
indices
)
# subsample
def
_get_indices_by_batch_size
(
indices
):
subsampled_indices
=
[]
last_batch_size
=
self
.
total_size
%
(
self
.
batch_size
*
self
.
nranks
)
assert
last_batch_size
%
self
.
nranks
==
0
last_local_batch_size
=
last_batch_size
//
self
.
nranks
for
i
in
range
(
self
.
local_rank
*
self
.
batch_size
,
len
(
indices
)
-
last_batch_size
,
self
.
batch_size
*
self
.
nranks
):
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
subsampled_indices
.
extend
(
indices
[
self
.
local_rank
*
last_local_batch_size
:
(
self
.
local_rank
+
1
)
*
last_local_batch_size
])
return
subsampled_indices
if
self
.
nranks
>
1
:
indices
=
_get_indices_by_batch_size
(
indices
)
assert
len
(
indices
)
==
self
.
num_samples
_sample_iter
=
iter
(
indices
)
batch_indices
=
[]
for
idx
in
_sample_iter
:
batch_indices
.
append
(
idx
)
if
len
(
batch_indices
)
==
self
.
batch_size
:
yield
batch_indices
batch_indices
=
[]
if
not
self
.
drop_last
and
len
(
batch_indices
)
>
0
:
yield
batch_indices
def
__len__
(
self
):
num_samples
=
self
.
num_samples
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录