Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
8c22455e
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c22455e
编写于
6月 17, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update train.py
上级
394a34e0
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
1 addition
and
111 deletion
+1
-111
dygraph/train.py
dygraph/train.py
+1
-3
dygraph/utils/distributed.py
dygraph/utils/distributed.py
+0
-108
未找到文件。
dygraph/train.py
浏览文件 @
8c22455e
...
@@ -15,11 +15,10 @@
...
@@ -15,11 +15,10 @@
import
argparse
import
argparse
import
os
import
os
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.io
import
DataLoader
from
paddle.fluid.io
import
DataLoader
from
paddle.incubate.hapi.distributed
import
DistributedBatchSampler
from
datasets
import
OpticDiscSeg
from
datasets
import
OpticDiscSeg
import
transforms
as
T
import
transforms
as
T
...
@@ -27,7 +26,6 @@ import models
...
@@ -27,7 +26,6 @@ import models
import
utils.logging
as
logging
import
utils.logging
as
logging
from
utils
import
get_environ_info
from
utils
import
get_environ_info
from
utils
import
load_pretrained_model
from
utils
import
load_pretrained_model
from
utils
import
DistributedBatchSampler
from
val
import
evaluate
from
val
import
evaluate
...
...
dygraph/utils/distributed.py
已删除
100644 → 0
浏览文件 @
394a34e0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dataloader
import
BatchSampler
_parallel_context_initialized
=
False
class
DistributedBatchSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
data_source: this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
drop_last
=
False
):
self
.
dataset
=
dataset
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer"
self
.
batch_size
=
batch_size
assert
isinstance
(
shuffle
,
bool
),
\
"shuffle should be a boolean value"
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
self
.
drop_last
=
drop_last
self
.
nranks
=
ParallelEnv
().
nranks
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
def
__iter__
(
self
):
num_samples
=
len
(
self
.
dataset
)
indices
=
np
.
arange
(
num_samples
).
tolist
()
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
if
self
.
shuffle
:
np
.
random
.
shuffle
(
indices
)
# subsample
def
_get_indices_by_batch_size
(
indices
):
subsampled_indices
=
[]
last_batch_size
=
self
.
total_size
%
(
self
.
batch_size
*
self
.
nranks
)
assert
last_batch_size
%
self
.
nranks
==
0
last_local_batch_size
=
last_batch_size
//
self
.
nranks
for
i
in
range
(
self
.
local_rank
*
self
.
batch_size
,
len
(
indices
)
-
last_batch_size
,
self
.
batch_size
*
self
.
nranks
):
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
subsampled_indices
.
extend
(
indices
[
self
.
local_rank
*
last_local_batch_size
:
(
self
.
local_rank
+
1
)
*
last_local_batch_size
])
return
subsampled_indices
if
self
.
nranks
>
1
:
indices
=
_get_indices_by_batch_size
(
indices
)
assert
len
(
indices
)
==
self
.
num_samples
_sample_iter
=
iter
(
indices
)
batch_indices
=
[]
for
idx
in
_sample_iter
:
batch_indices
.
append
(
idx
)
if
len
(
batch_indices
)
==
self
.
batch_size
:
yield
batch_indices
batch_indices
=
[]
if
not
self
.
drop_last
and
len
(
batch_indices
)
>
0
:
yield
batch_indices
def
__len__
(
self
):
num_samples
=
self
.
num_samples
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录