Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
b6ba314d
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b6ba314d
编写于
5月 09, 2020
作者:
B
Bai Yifan
提交者:
GitHub
5月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-Pick]Refine reader config (#267)
* refine reader config
上级
a82c9df4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
24 addition
and
50 deletion
+24
-50
demo/darts/README.md
demo/darts/README.md
+1
-0
demo/darts/reader.py
demo/darts/reader.py
+5
-37
demo/darts/search.py
demo/darts/search.py
+2
-2
demo/darts/train.py
demo/darts/train.py
+5
-6
docs/zh_cn/api_cn/darts.rst
docs/zh_cn/api_cn/darts.rst
+4
-1
paddleslim/nas/darts/train_search.py
paddleslim/nas/darts/train_search.py
+7
-4
未找到文件。
demo/darts/README.md
浏览文件 @
b6ba314d
...
...
@@ -42,6 +42,7 @@ python search.py # DARTS一阶近似搜索方法
python search.py
--unrolled
=
True
# DARTS的二阶近似搜索方法
python search.py
--method
=
'PC-DARTS'
--batch_size
=
256
--learning_rate
=
0.1
--arch_learning_rate
=
6e-4
--epochs_no_archopt
=
15
# PC-DARTS搜索方法
```
如果您使用的是docker环境,请确保共享内存足够使用多进程的dataloader,如果碰到共享内存问题,请设置
`--use_multiprocess=False`
也可以使用多卡进行模型结构搜索,以4卡为例(GPU id: 0-3), 启动命令如下:
...
...
demo/darts/reader.py
浏览文件 @
b6ba314d
...
...
@@ -140,32 +140,10 @@ def train_search(batch_size, train_portion, is_shuffle, args):
split_point
=
int
(
np
.
floor
(
train_portion
*
len
(
datasets
)))
train_datasets
=
datasets
[:
split_point
]
val_datasets
=
datasets
[
split_point
:]
train_readers
=
[]
val_readers
=
[]
n
=
int
(
math
.
ceil
(
len
(
train_datasets
)
//
args
.
num_workers
)
)
if
args
.
use_multiprocess
else
len
(
train_datasets
)
train_datasets_lists
=
[
train_datasets
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
train_datasets
),
n
)
reader
=
[
reader_generator
(
train_datasets
,
batch_size
,
True
,
True
,
args
),
reader_generator
(
val_datasets
,
batch_size
,
True
,
True
,
args
)
]
val_datasets_lists
=
[
val_datasets
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
val_datasets
),
n
)
]
for
pid
in
range
(
len
(
train_datasets_lists
)):
train_readers
.
append
(
reader_generator
(
train_datasets_lists
[
pid
],
batch_size
,
True
,
True
,
args
))
val_readers
.
append
(
reader_generator
(
val_datasets_lists
[
pid
],
batch_size
,
True
,
True
,
args
))
if
args
.
use_multiprocess
:
reader
=
[
paddle
.
reader
.
multiprocess_reader
(
train_readers
,
False
),
paddle
.
reader
.
multiprocess_reader
(
val_readers
,
False
)
]
else
:
reader
=
[
train_readers
[
0
],
val_readers
[
0
]]
return
reader
...
...
@@ -174,18 +152,8 @@ def train_valid(batch_size, is_train, is_shuffle, args):
datasets
=
cifar10_reader
(
paddle
.
dataset
.
common
.
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
),
name
,
is_shuffle
,
args
)
n
=
int
(
math
.
ceil
(
len
(
datasets
)
//
args
.
num_workers
))
if
args
.
use_multiprocess
else
len
(
datasets
)
datasets_lists
=
[
datasets
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
datasets
),
n
)]
multi_readers
=
[]
for
pid
in
range
(
len
(
datasets_lists
)):
multi_readers
.
append
(
reader_generator
(
datasets_lists
[
pid
],
batch_size
,
is_train
,
is_shuffle
,
args
))
if
args
.
use_multiprocess
:
reader
=
paddle
.
reader
.
multiprocess_reader
(
multi_readers
,
False
)
else
:
reader
=
multi_readers
[
0
]
reader
=
reader_generator
(
datasets
,
batch_size
,
is_train
,
is_shuffle
,
args
)
return
reader
...
...
demo/darts/search.py
浏览文件 @
b6ba314d
...
...
@@ -35,8 +35,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg
(
'log_freq'
,
int
,
50
,
"Log frequency."
)
add_arg
(
'use_multiprocess'
,
bool
,
False
,
"Whether use multiprocess reader."
)
add_arg
(
'num_workers'
,
int
,
4
,
"The multiprocess reader number."
)
add_arg
(
'use_multiprocess'
,
bool
,
True
,
"Whether use multiprocess reader."
)
add_arg
(
'data'
,
str
,
'dataset/cifar10'
,
"The dir of dataset."
)
add_arg
(
'batch_size'
,
int
,
64
,
"Minibatch size."
)
add_arg
(
'learning_rate'
,
float
,
0.025
,
"The start learning rate."
)
...
...
@@ -88,6 +87,7 @@ def main(args):
unrolled
=
args
.
unrolled
,
num_epochs
=
args
.
epochs
,
epochs_no_archopt
=
args
.
epochs_no_archopt
,
use_multiprocess
=
args
.
use_multiprocess
,
use_data_parallel
=
args
.
use_data_parallel
,
save_dir
=
args
.
model_save_dir
,
log_freq
=
args
.
log_freq
)
...
...
demo/darts/train.py
浏览文件 @
b6ba314d
...
...
@@ -39,8 +39,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'use_multiprocess'
,
bool
,
False
,
"Whether use multiprocess reader."
)
add_arg
(
'num_workers'
,
int
,
4
,
"The multiprocess reader number."
)
add_arg
(
'use_multiprocess'
,
bool
,
True
,
"Whether use multiprocess reader."
)
add_arg
(
'data'
,
str
,
'dataset/cifar10'
,
"The dir of dataset."
)
add_arg
(
'batch_size'
,
int
,
96
,
"Minibatch size."
)
add_arg
(
'learning_rate'
,
float
,
0.025
,
"The start learning rate."
)
...
...
@@ -170,17 +169,17 @@ def main(args):
model
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
model
,
strategy
)
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
102
4
,
capacity
=
6
4
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
,
use_multiprocess
=
True
)
use_multiprocess
=
args
.
use_multiprocess
)
valid_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
102
4
,
capacity
=
6
4
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
,
use_multiprocess
=
True
)
use_multiprocess
=
args
.
use_multiprocess
)
train_reader
=
reader
.
train_valid
(
batch_size
=
args
.
batch_size
,
...
...
docs/zh_cn/api_cn/darts.rst
浏览文件 @
b6ba314d
...
...
@@ -4,7 +4,7 @@
DARTSearch
---------
..
py
:
class
::
paddleslim
.
nas
.
DARTSearch
(
model
,
train_reader
,
valid_reader
,
place
,
learning_rate
=
0.025
,
batchsize
=
64
,
num_imgs
=
50000
,
arch_learning_rate
=
3e-4
,
unrolled
=
False
,
num_epochs
=
50
,
epochs_no_archopt
=
0
,
use_data_parallel
=
False
,
save_dir
=
'./'
,
log_freq
=
50
)
..
py
:
class
::
paddleslim
.
nas
.
DARTSearch
(
model
,
train_reader
,
valid_reader
,
place
,
learning_rate
=
0.025
,
batchsize
=
64
,
num_imgs
=
50000
,
arch_learning_rate
=
3e-4
,
unrolled
=
False
,
num_epochs
=
50
,
epochs_no_archopt
=
0
,
use_
multiprocess
=
False
,
use_
data_parallel
=
False
,
save_dir
=
'./'
,
log_freq
=
50
)
`
源代码
<
https
://
github
.
com
/
PaddlePaddle
/
PaddleSlim
/
blob
/
release
/
1.1.0
/
paddleslim
/
nas
/
darts
/
train_search
.
py
>`
_
...
...
@@ -18,11 +18,14 @@ DARTSearch
-
**
place
**
(
fluid
.
CPUPlace
()|
fluid
.
CUDAPlace
(
N
))-
该参数表示程序运行在何种设备上,这里的
N
为
GPU
对应的
ID
-
**
learning_rate
**
(
float
)-
模型参数的初始学习率。默认值:
0.025
。
-
**
batchsize
**
(
int
)-
搜索过程数据的批大小。默认值:
64
。
-
**
num_imgs
**
(
int
)-
数据集总样本数。默认值:
50000
。
-
**
arch_learning_rate
**
(
float
)-
架构参数的学习率。默认值:
3e-4
。
-
**
unrolled
**
(
bool
)-
是否使用二阶搜索算法。默认值:
False
。
-
**
num_epochs
**
(
int
)-
搜索训练的轮数。默认值:
50
。
-
**
epochs_no_archopt
**
(
int
)-
跳过前若干轮的模型架构参数优化。默认值:
0
。
-
**
use_multiprocess
**
(
bool
)-
是否使用多进程的
dataloader
。默认值:
False
。
-
**
use_data_parallel
**
(
bool
)-
是否使用数据并行的多卡训练。默认值:
False
。
-
**
save_dir
**
(
str
)-
模型参数保存目录。默认值:
'./'
。
-
**
log_freq
**
(
int
)-
每多少步输出一条
log
。默认值:
50
。
...
...
paddleslim/nas/darts/train_search.py
浏览文件 @
b6ba314d
...
...
@@ -59,6 +59,7 @@ class DARTSearch(object):
unrolled(bool): Use one-step unrolled validation loss. Default: False.
num_epochs(int): Epoch number. Default: 50.
epochs_no_archopt(int): Epochs skip architecture optimize at begining. Default: 0.
use_multiprocess(bool): Whether to use multiprocess in dataloader. Default: False.
use_data_parallel(bool): Whether to use data parallel mode. Default: False.
log_freq(int): Log frequency. Default: 50.
...
...
@@ -76,6 +77,7 @@ class DARTSearch(object):
unrolled
=
False
,
num_epochs
=
50
,
epochs_no_archopt
=
0
,
use_multiprocess
=
False
,
use_data_parallel
=
False
,
save_dir
=
'./'
,
log_freq
=
50
):
...
...
@@ -90,6 +92,7 @@ class DARTSearch(object):
self
.
unrolled
=
unrolled
self
.
epochs_no_archopt
=
epochs_no_archopt
self
.
num_epochs
=
num_epochs
self
.
use_multiprocess
=
use_multiprocess
self
.
use_data_parallel
=
use_data_parallel
self
.
save_dir
=
save_dir
self
.
log_freq
=
log_freq
...
...
@@ -207,17 +210,17 @@ class DARTSearch(object):
self
.
valid_reader
)
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
102
4
,
capacity
=
6
4
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
,
use_multiprocess
=
True
)
use_multiprocess
=
self
.
use_multiprocess
)
valid_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
102
4
,
capacity
=
6
4
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
True
,
use_multiprocess
=
True
)
use_multiprocess
=
self
.
use_multiprocess
)
train_loader
.
set_batch_generator
(
self
.
train_reader
,
places
=
self
.
place
)
valid_loader
.
set_batch_generator
(
self
.
valid_reader
,
places
=
self
.
place
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录