diff --git a/docs/tutorial/how_to_load_data.md b/docs/tutorial/how_to_load_data.md index b11065807da2175f557e755f3ddb6ca3108aef8f..9ea7b0e2380a3b3d44b58caf05739696c3340c41 100644 --- a/docs/tutorial/how_to_load_data.md +++ b/docs/tutorial/how_to_load_data.md @@ -30,7 +30,38 @@ text_a label 1.接电源没有几分钟,电源适配器热的不行. 2.摄像头用不起来. 3.机盖的钢琴漆,手不能摸,一摸一个印. 4.硬盘分区不好办. 0 ``` - +### 自定义数据加载 +加载文本类自定义数据集,用户仅需要继承基类BaseNLPDatast,修改数据集存放地址以及类别即可。具体使用如下: + +**NOTE:** +* 数据集文件编码格式建议为utf8格式。 +* 如果相应的数据集文件没有上述的列说明,如train.tsv文件没有第一行的`text_a label`,则train_file_with_header=False。 +* 如果您还有预测数据(没有文本类别),可以将预测数据存放在predict.tsv文件,文件格式和train.tsv类似。去掉label一列即可。 +* 分类任务中,数据集的label必须从0开始计数 + + +```python +from paddlehub.dataset.base_nlp_dataset import BaseNLPDataset +class DemoDataset(BaseNLPDataset): + """DemoDataset""" + def __init__(self): + # 数据集存放位置 + self.dataset_dir = "path/to/dataset" + super(DemoDataset, self).__init__( + base_path=self.dataset_dir, + train_file="train.tsv", + dev_file="dev.tsv", + test_file="test.tsv", + # 如果还有预测数据(不需要文本类别label),可以放在predict.tsv + predict_file="predict.tsv", + train_file_with_header=True, + dev_file_with_header=True, + test_file_with_header=True, + predict_file_with_header=True, + # 数据集类别集合 + label_list=["0", "1"]) +dataset = DemoDataset() +``` ## 二、CV类任务如何自定义数据 @@ -71,3 +102,32 @@ label_list.txt内容如下: cat dog ``` + +### 自定义数据加载 + +加载图像类自定义数据集,用户仅需要继承基类BaseCVDatast,修改数据集存放地址即可。具体使用如下: + +**NOTE:** +* 数据集文件编码格式建议为utf8格式。 +* dataset_dir为数据集实际路径,需要填写全路径,以下示例以`/test/data`为例。 +* 训练/验证/测试集的数据列表文件中的图片路径需要相对于dataset_dir的相对路径,例如图片的实际位置为`/test/data/dog/dog1.jpg`。base_path为`/test/data`,则文件中填写的路径应该为`dog/dog1.jpg`。 +* 如果您还有预测数据(没有文本类别),可以将预测数据存放在predict_list.txt文件,文件格式和train_list.txt类似。去掉label一列即可 +* 如果您的数据集类别较少,可以不用定义label_list.txt,可以选择定义label_list=["数据集所有类别"]。 +* 分类任务中,数据集的label必须从0开始计数 + + ```python +from paddlehub.dataset.base_cv_dataset import BaseCVDataset +class DemoDataset(BaseCVDataset): + def __init__(self): + # 数据集存放位置 + self.dataset_dir = "/test/data" + super(DemoDataset, self).__init__( + base_path=self.dataset_dir, + train_list_file="train_list.txt", + validate_list_file="validate_list.txt", + test_list_file="test_list.txt", + predict_file="predict_list.txt", + label_list_file="label_list.txt", + # label_list=["数据集所有类别"]) +dataset = DemoDataset() +```