未验证 提交 d585017b 编写于 作者: L lilong12 提交者: GitHub

add a demo to show how to use user-defined data for training (#30)

* add custome_reader.py

* update README.md
上级 4810142b
......@@ -31,6 +31,7 @@
* [Base64格式图像数据预处理](#Base64格式图像数据预处理)
* [混合精度训练](#混合精度训练)
* [自定义模型](#自定义模型)
* [自定义训练数据](#自定义训练数据)
* [预训练模型和性能](#预训练模型和性能)
* [预训练模型](#预训练模型)
* [训练性能](#训练性能)
......@@ -625,6 +626,98 @@ build_network方法的输入如下:
build_network方法返回用户自定义组网的输出变量。
### 自定义训练数据
默认地,我们假设用户的训练数据目录组织如下:
```shell script
train_data/
|-- images
`-- label.txt
```
其中,images目录中存放用户训练数据,label.txt文件记录用户训练数据中每幅图像的地址和对应的类别标签。
当用户的训练数据按照其它自定义格式组织时,可以按照下面的步骤使用自定义训练数据:
1. 定义reader函数(生成器),该函数对用户数据进行预处理(如裁剪),并使用yield生成数据样本;
* 数据样本的格式为形如(data, label)的元组,其中data为解码和预处理后的图像数据,label为该图像的类别标签。
2. 使用paddle.batch封装reader生成器,得到新的生成器batched_reader;
3. 将batched_reader赋值给plsc.Entry类示例的train_reader成员。
为了便于描述,我们仍然假设用户训练数据组织结构如下:
```shell script
train_data/
|-- images
`-- label.txt
```
定义样本生成器的代码如下所示(reader.py):
```python
import random
import os
from PIL import Image
def arc_train(data_dir):
label_file = os.path.join(data_dir, 'label.txt')
train_image_list = None
with open(label_file, 'r') as f:
train_image_list = f.readlines()
train_image_list = get_train_image_list(data_dir)
def reader():
for j in range(len(train_image_list)):
path, label = train_image_list[j]
path = os.path.join(data_dir, path)
img = Image.open(path)
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
yield img, label
return reader
```
使用用户自定义训练数据的训练代码如下:
```python
import argparse
import paddle
from plsc import Entry
import reader
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir",
type=str,
default="./data",
help="Directory for datasets.")
args = parser.parse_args()
def main():
global args
ins = Entry()
ins.set_dataset_dir(args.data_dir)
train_reader = reader.arc_train(args.data_dir)
# Batch the above samples;
batched_train_reader = paddle.batch(train_reader,
ins.train_batch_size)
# Set the reader to use during training to the above batch reader.
ins.train_reader = batched_train_reader
ins.train()
if __name__ == "__main__":
main()
```
更多详情请参考[示例代码](./demo/custom_reader.py)
## 预训练模型和性能
### 预训练模型
......
# This demo shows how to use user-defined training dataset.
# The following steps are needed to use user-defined training datasets:
# 1. Build a reader, which preprocess images and yield a sample in the
# format (data, label) each time, where data is the decoded image data;
# 2. Batch the above samples;
# 3. Set the reader to use during training to the above batch reader.
import argparse
import paddle
from plsc import Entry
from plsc.utils import jpeg_reader as reader
parser = argparse.ArgumentParser()
parser.add_argument("--model_save_dir",
type=str,
default="./saved_model",
help="Directory to save models.")
parser.add_argument("--data_dir",
type=str,
default="./data",
help="Directory for datasets.")
parser.add_argument("--num_epochs",
type=int,
default=2,
help="Number of epochs to run.")
parser.add_argument("--loss_type",
type=str,
default='arcface',
help="Loss type to use.")
args = parser.parse_args()
def main():
global args
ins = Entry()
ins.set_model_save_dir(args.model_save_dir)
ins.set_dataset_dir(args.data_dir)
ins.set_train_epochs(args.num_epochs)
ins.set_loss_type(args.loss_type)
# 1. Build a reader, which yield a sample in the format (data, label)
# each time, where data is the decoded image data;
train_reader = reader.arc_train(args.data_dir,
ins.num_classes)
# 2. Batch the above samples;
batched_train_reader = paddle.batch(train_reader,
ins.train_batch_size)
# 3. Set the reader to use during training to the above batch reader.
ins.train_reader = batched_train_reader
ins.train()
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册