Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
0d362bb9
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0d362bb9
编写于
6月 15, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dataloader
上级
9452bf66
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
181 addition
and
29 deletion
+181
-29
dygraph/datasets/optic_disc_seg.py
dygraph/datasets/optic_disc_seg.py
+3
-7
dygraph/train.py
dygraph/train.py
+40
-22
dygraph/utils/__init__.py
dygraph/utils/__init__.py
+1
-0
dygraph/utils/distributed.py
dygraph/utils/distributed.py
+137
-0
未找到文件。
dygraph/datasets/optic_disc_seg.py
浏览文件 @
0d362bb9
...
...
@@ -31,13 +31,11 @@ class OpticDiscSeg(Dataset):
train_list
=
None
,
val_list
=
None
,
test_list
=
None
,
shuffle
=
'False'
,
transforms
=
None
,
mode
=
'train'
,
transform
=
None
,
download
=
True
):
self
.
data_dir
=
data_dir
self
.
shuffle
=
shuffle
self
.
transform
=
transform
self
.
transforms
=
transforms
self
.
file_list
=
list
()
if
mode
.
lower
()
not
in
[
'train'
,
'eval'
,
'test'
]:
...
...
@@ -45,7 +43,7 @@ class OpticDiscSeg(Dataset):
"mode should be 'train', 'eval' or 'test', but got {}."
.
format
(
mode
))
if
transform
is
None
:
if
self
.
transforms
is
None
:
raise
Exception
(
"transform is necessary, but it is None."
)
self
.
data_dir
=
data_dir
...
...
@@ -83,8 +81,6 @@ class OpticDiscSeg(Dataset):
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
items
[
0
])
grt_path
=
os
.
path
.
join
(
self
.
data_dir
,
items
[
1
])
self
.
file_list
.
append
([
image_path
,
grt_path
])
if
shuffle
:
random
.
shuffle
(
self
.
file_list
)
def
__getitem__
(
self
,
idx
):
print
(
idx
)
...
...
dygraph/train.py
浏览文件 @
0d362bb9
...
...
@@ -18,13 +18,16 @@ import os
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.io
import
DataLoader
from
datasets
import
Dataset
from
datasets
import
OpticDiscSeg
,
Dataset
import
transforms
as
T
import
models
import
utils.logging
as
logging
from
utils
import
get_environ_info
from
utils
import
load_pretrained_model
from
utils
import
DistributedBatchSampler
from
val
import
evaluate
...
...
@@ -95,12 +98,18 @@ def parse_args():
help
=
'The directory for saving the model snapshot'
,
type
=
str
,
default
=
'./output'
)
parser
.
add_argument
(
'--num_workers'
,
dest
=
'num_workers'
,
help
=
'Num workers for data loader'
,
type
=
int
,
default
=
0
)
return
parser
.
parse_args
()
def
train
(
model
,
train_dataset
,
places
=
None
,
eval_dataset
=
None
,
optimizer
=
None
,
save_dir
=
'output'
,
...
...
@@ -108,7 +117,8 @@ def train(model,
batch_size
=
2
,
pretrained_model
=
None
,
save_interval_epochs
=
1
,
num_classes
=
None
):
num_classes
=
None
,
num_workers
=
8
):
if
not
os
.
path
.
isdir
(
save_dir
):
if
os
.
path
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
...
...
@@ -116,12 +126,22 @@ def train(model,
load_pretrained_model
(
model
,
pretrained_model
)
data_generator
=
train_dataset
.
generator
(
batch_size
=
batch_size
,
drop_last
=
True
)
num_steps_each_epoch
=
train_dataset
.
num_samples
//
args
.
batch_size
batch_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
batch_sampler
,
places
=
places
,
num_workers
=
num_workers
,
return_list
=
True
,
)
num_steps_each_epoch
=
len
(
train_dataset
)
//
batch_size
for
epoch
in
range
(
num_epochs
):
for
step
,
data
in
enumerate
(
data_generator
()
):
for
step
,
data
in
enumerate
(
loader
):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
2
]
for
d
in
data
]).
astype
(
'int64'
)
images
=
to_variable
(
images
)
...
...
@@ -156,6 +176,11 @@ def train(model,
def
main
(
args
):
env_info
=
get_environ_info
()
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'place'
]
==
'gpu'
and
fluid
.
is_compiled_with_cuda
()
\
else
fluid
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
places
):
# Creat dataset reader
train_transforms
=
T
.
Compose
([
...
...
@@ -163,13 +188,8 @@ def main(args):
T
.
RandomHorizontalFlip
(),
T
.
Normalize
()
])
train_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
train_list
,
transforms
=
train_transforms
,
num_workers
=
'auto'
,
buffer_size
=
100
,
parallel_method
=
'thread'
,
shuffle
=
True
)
train_dataset
=
OpticDiscSeg
(
transforms
=
train_transforms
,
mode
=
'train'
)
if
args
.
val_list
is
not
None
:
eval_transforms
=
T
.
Compose
(
[
T
.
Resize
(
args
.
input_size
),
...
...
@@ -186,7 +206,7 @@ def main(args):
model
=
models
.
UNet
(
num_classes
=
args
.
num_classes
,
ignore_index
=
255
)
# Creat optimizer
num_steps_each_epoch
=
train_dataset
.
num_samples
//
args
.
batch_size
num_steps_each_epoch
=
len
(
train_dataset
)
//
args
.
batch_size
decay_step
=
args
.
num_epochs
*
num_steps_each_epoch
lr_decay
=
fluid
.
layers
.
polynomial_decay
(
args
.
learning_rate
,
decay_step
,
...
...
@@ -200,21 +220,19 @@ def main(args):
train
(
model
,
train_dataset
,
eval_dataset
,
optimizer
,
places
=
places
,
eval_dataset
=
eval_dataset
,
optimizer
=
optimizer
,
save_dir
=
args
.
save_dir
,
num_epochs
=
args
.
num_epochs
,
batch_size
=
args
.
batch_size
,
pretrained_model
=
args
.
pretrained_model
,
save_interval_epochs
=
args
.
save_interval_epochs
,
num_classes
=
args
.
num_classes
)
num_classes
=
args
.
num_classes
,
num_workers
=
args
.
num_workers
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
env_info
=
get_environ_info
()
if
env_info
[
'place'
]
==
'cpu'
:
places
=
fluid
.
CPUPlace
()
else
:
places
=
fluid
.
CUDAPlace
(
0
)
print
(
args
)
main
(
args
)
dygraph/utils/__init__.py
浏览文件 @
0d362bb9
...
...
@@ -16,3 +16,4 @@ from . import logging
from
.
import
download
from
.metrics
import
ConfusionMatrix
from
.utils
import
*
from
.distributed
import
DistributedBatchSampler
dygraph/utils/distributed.py
0 → 100644
浏览文件 @
0d362bb9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dataloader
import
BatchSampler
_parallel_context_initialized
=
False
class
DistributedBatchSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
data_source: this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from hapi.datasets import MNIST
from hapi.distributed import DistributedBatchSampler
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True):
super(MnistDataset, self).__init__(mode=mode)
self.return_label = return_label
def __getitem__(self, idx):
img = np.reshape(self.images[idx], [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return img,
def __len__(self):
return len(self.images)
train_dataset = MnistDataset(mode='train')
dist_train_dataloader = DistributedBatchSampler(train_dataset, batch_size=64)
for data in dist_train_dataloader:
# do something
break
"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
drop_last
=
False
):
self
.
dataset
=
dataset
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer"
self
.
batch_size
=
batch_size
assert
isinstance
(
shuffle
,
bool
),
\
"shuffle should be a boolean value"
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
self
.
drop_last
=
drop_last
self
.
nranks
=
ParallelEnv
().
nranks
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
def
__iter__
(
self
):
num_samples
=
len
(
self
.
dataset
)
indices
=
np
.
arange
(
num_samples
).
tolist
()
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
if
self
.
shuffle
:
np
.
random
.
RandomState
.
shuffle
(
indices
)
# subsample
def
_get_indices_by_batch_size
(
indices
):
subsampled_indices
=
[]
last_batch_size
=
self
.
total_size
%
(
self
.
batch_size
*
self
.
nranks
)
assert
last_batch_size
%
self
.
nranks
==
0
last_local_batch_size
=
last_batch_size
//
self
.
nranks
for
i
in
range
(
self
.
local_rank
*
self
.
batch_size
,
len
(
indices
)
-
last_batch_size
,
self
.
batch_size
*
self
.
nranks
):
subsampled_indices
.
extend
(
indices
[
i
:
i
+
self
.
batch_size
])
indices
=
indices
[
len
(
indices
)
-
last_batch_size
:]
subsampled_indices
.
extend
(
indices
[
self
.
local_rank
*
last_local_batch_size
:(
self
.
local_rank
+
1
)
*
last_local_batch_size
])
return
subsampled_indices
if
self
.
nranks
>
1
:
indices
=
_get_indices_by_batch_size
(
indices
)
assert
len
(
indices
)
==
self
.
num_samples
_sample_iter
=
iter
(
indices
)
batch_indices
=
[]
for
idx
in
_sample_iter
:
batch_indices
.
append
(
idx
)
if
len
(
batch_indices
)
==
self
.
batch_size
:
yield
batch_indices
batch_indices
=
[]
if
not
self
.
drop_last
and
len
(
batch_indices
)
>
0
:
yield
batch_indices
def
__len__
(
self
):
num_samples
=
self
.
num_samples
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录