Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PLSC
提交
d585017b
P
PLSC
项目概览
PaddlePaddle
/
PLSC
通知
10
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
5
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PLSC
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
5
Issue
5
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d585017b
编写于
1月 13, 2020
作者:
L
lilong12
提交者:
GitHub
1月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a demo to show how to use user-defined data for training (#30)
* add custome_reader.py * update README.md
上级
4810142b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
148 addition
and
0 deletion
+148
-0
README.md
README.md
+93
-0
demo/custom_reader.py
demo/custom_reader.py
+55
-0
未找到文件。
README.md
浏览文件 @
d585017b
...
...
@@ -31,6 +31,7 @@
*
[
Base64格式图像数据预处理
](
#Base64格式图像数据预处理
)
*
[
混合精度训练
](
#混合精度训练
)
*
[
自定义模型
](
#自定义模型
)
*
[
自定义训练数据
](
#自定义训练数据
)
*
[
预训练模型和性能
](
#预训练模型和性能
)
*
[
预训练模型
](
#预训练模型
)
*
[
训练性能
](
#训练性能
)
...
...
@@ -625,6 +626,98 @@ build_network方法的输入如下:
build_network方法返回用户自定义组网的输出变量。
### 自定义训练数据
默认地,我们假设用户的训练数据目录组织如下:
```
shell script
train_data/
|-- images
`-- label.txt
```
其中,images目录中存放用户训练数据,label.txt文件记录用户训练数据中每幅图像的地址和对应的类别标签。
当用户的训练数据按照其它自定义格式组织时,可以按照下面的步骤使用自定义训练数据:
1.
定义reader函数(生成器),该函数对用户数据进行预处理(如裁剪),并使用yield生成数据样本;
*
数据样本的格式为形如(data, label)的元组,其中data为解码和预处理后的图像数据,label为该图像的类别标签。
2.
使用paddle.batch封装reader生成器,得到新的生成器batched_reader;
3.
将batched_reader赋值给plsc.Entry类示例的train_reader成员。
为了便于描述,我们仍然假设用户训练数据组织结构如下:
```
shell script
train_data/
|-- images
`-- label.txt
```
定义样本生成器的代码如下所示(reader.py):
```
python
import
random
import
os
from
PIL
import
Image
def
arc_train
(
data_dir
):
label_file
=
os
.
path
.
join
(
data_dir
,
'label.txt'
)
train_image_list
=
None
with
open
(
label_file
,
'r'
)
as
f
:
train_image_list
=
f
.
readlines
()
train_image_list
=
get_train_image_list
(
data_dir
)
def
reader
():
for
j
in
range
(
len
(
train_image_list
)):
path
,
label
=
train_image_list
[
j
]
path
=
os
.
path
.
join
(
data_dir
,
path
)
img
=
Image
.
open
(
path
)
if
random
.
randint
(
0
,
1
)
==
1
:
img
=
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
if
img
.
mode
!=
'RGB'
:
img
=
img
.
convert
(
'RGB'
)
img
=
np
.
array
(
img
).
astype
(
'float32'
).
transpose
((
2
,
0
,
1
))
yield
img
,
label
return
reader
```
使用用户自定义训练数据的训练代码如下:
```
python
import
argparse
import
paddle
from
plsc
import
Entry
import
reader
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
default
=
"./data"
,
help
=
"Directory for datasets."
)
args
=
parser
.
parse_args
()
def
main
():
global
args
ins
=
Entry
()
ins
.
set_dataset_dir
(
args
.
data_dir
)
train_reader
=
reader
.
arc_train
(
args
.
data_dir
)
# Batch the above samples;
batched_train_reader
=
paddle
.
batch
(
train_reader
,
ins
.
train_batch_size
)
# Set the reader to use during training to the above batch reader.
ins
.
train_reader
=
batched_train_reader
ins
.
train
()
if
__name__
==
"__main__"
:
main
()
```
更多详情请参考
[
示例代码
](
./demo/custom_reader.py
)
## 预训练模型和性能
### 预训练模型
...
...
demo/custom_reader.py
0 → 100644
浏览文件 @
d585017b
# This demo shows how to use user-defined training dataset.
# The following steps are needed to use user-defined training datasets:
# 1. Build a reader, which preprocess images and yield a sample in the
# format (data, label) each time, where data is the decoded image data;
# 2. Batch the above samples;
# 3. Set the reader to use during training to the above batch reader.
import
argparse
import
paddle
from
plsc
import
Entry
from
plsc.utils
import
jpeg_reader
as
reader
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_save_dir"
,
type
=
str
,
default
=
"./saved_model"
,
help
=
"Directory to save models."
)
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
default
=
"./data"
,
help
=
"Directory for datasets."
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
2
,
help
=
"Number of epochs to run."
)
parser
.
add_argument
(
"--loss_type"
,
type
=
str
,
default
=
'arcface'
,
help
=
"Loss type to use."
)
args
=
parser
.
parse_args
()
def
main
():
global
args
ins
=
Entry
()
ins
.
set_model_save_dir
(
args
.
model_save_dir
)
ins
.
set_dataset_dir
(
args
.
data_dir
)
ins
.
set_train_epochs
(
args
.
num_epochs
)
ins
.
set_loss_type
(
args
.
loss_type
)
# 1. Build a reader, which yield a sample in the format (data, label)
# each time, where data is the decoded image data;
train_reader
=
reader
.
arc_train
(
args
.
data_dir
,
ins
.
num_classes
)
# 2. Batch the above samples;
batched_train_reader
=
paddle
.
batch
(
train_reader
,
ins
.
train_batch_size
)
# 3. Set the reader to use during training to the above batch reader.
ins
.
train_reader
=
batched_train_reader
ins
.
train
()
if
__name__
==
"__main__"
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录