diff --git a/README.md b/README.md index c0835af327fddfaa877329329f5e3164f86059cc..30d2ae20052c2f084d43128a974cad1d76c0fc84 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ * [Base64格式图像数据预处理](#Base64格式图像数据预处理) * [混合精度训练](#混合精度训练) * [自定义模型](#自定义模型) + * [自定义训练数据](#自定义训练数据) * [预训练模型和性能](#预训练模型和性能) * [预训练模型](#预训练模型) * [训练性能](#训练性能) @@ -625,6 +626,98 @@ build_network方法的输入如下: build_network方法返回用户自定义组网的输出变量。 +### 自定义训练数据 + +默认地,我们假设用户的训练数据目录组织如下: + +```shell script +train_data/ +|-- images +`-- label.txt +``` + +其中,images目录中存放用户训练数据,label.txt文件记录用户训练数据中每幅图像的地址和对应的类别标签。 + +当用户的训练数据按照其它自定义格式组织时,可以按照下面的步骤使用自定义训练数据: + +1. 定义reader函数(生成器),该函数对用户数据进行预处理(如裁剪),并使用yield生成数据样本; + * 数据样本的格式为形如(data, label)的元组,其中data为解码和预处理后的图像数据,label为该图像的类别标签。 +2. 使用paddle.batch封装reader生成器,得到新的生成器batched_reader; +3. 将batched_reader赋值给plsc.Entry类示例的train_reader成员。 + +为了便于描述,我们仍然假设用户训练数据组织结构如下: + +```shell script +train_data/ +|-- images +`-- label.txt +``` + +定义样本生成器的代码如下所示(reader.py): + +```python +import random +import os +from PIL import Image + +def arc_train(data_dir): + label_file = os.path.join(data_dir, 'label.txt') + train_image_list = None + with open(label_file, 'r') as f: + train_image_list = f.readlines() + train_image_list = get_train_image_list(data_dir) + + def reader(): + for j in range(len(train_image_list)): + path, label = train_image_list[j] + path = os.path.join(data_dir, path) + img = Image.open(path) + if random.randint(0, 1) == 1: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if img.mode != 'RGB': + img = img.convert('RGB') + img = np.array(img).astype('float32').transpose((2, 0, 1)) + yield img, label + + return reader +``` + +使用用户自定义训练数据的训练代码如下: + +```python +import argparse +import paddle +from plsc import Entry +import reader + +parser = argparse.ArgumentParser() +parser.add_argument("--data_dir", + type=str, + default="./data", + help="Directory for datasets.") +args = parser.parse_args() + + +def main(): + global args + ins = Entry() + ins.set_dataset_dir(args.data_dir) + train_reader = reader.arc_train(args.data_dir) + # Batch the above samples; + batched_train_reader = paddle.batch(train_reader, + ins.train_batch_size) + # Set the reader to use during training to the above batch reader. + ins.train_reader = batched_train_reader + + ins.train() + + +if __name__ == "__main__": + main() +``` + +更多详情请参考[示例代码](./demo/custom_reader.py) + ## 预训练模型和性能 ### 预训练模型 diff --git a/demo/custom_reader.py b/demo/custom_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7f5447d86acdc8526f432ffb4eb062a4e620fb --- /dev/null +++ b/demo/custom_reader.py @@ -0,0 +1,55 @@ +# This demo shows how to use user-defined training dataset. +# The following steps are needed to use user-defined training datasets: +# 1. Build a reader, which preprocess images and yield a sample in the +# format (data, label) each time, where data is the decoded image data; +# 2. Batch the above samples; +# 3. Set the reader to use during training to the above batch reader. + +import argparse + +import paddle +from plsc import Entry +from plsc.utils import jpeg_reader as reader + +parser = argparse.ArgumentParser() +parser.add_argument("--model_save_dir", + type=str, + default="./saved_model", + help="Directory to save models.") +parser.add_argument("--data_dir", + type=str, + default="./data", + help="Directory for datasets.") +parser.add_argument("--num_epochs", + type=int, + default=2, + help="Number of epochs to run.") +parser.add_argument("--loss_type", + type=str, + default='arcface', + help="Loss type to use.") +args = parser.parse_args() + + +def main(): + global args + ins = Entry() + ins.set_model_save_dir(args.model_save_dir) + ins.set_dataset_dir(args.data_dir) + ins.set_train_epochs(args.num_epochs) + ins.set_loss_type(args.loss_type) + # 1. Build a reader, which yield a sample in the format (data, label) + # each time, where data is the decoded image data; + train_reader = reader.arc_train(args.data_dir, + ins.num_classes) + # 2. Batch the above samples; + batched_train_reader = paddle.batch(train_reader, + ins.train_batch_size) + # 3. Set the reader to use during training to the above batch reader. + ins.train_reader = batched_train_reader + + ins.train() + + +if __name__ == "__main__": + main()