README.md 5.5 KB
Newer Older
W
webyfdt 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# 交叉验证

- 交叉验证(Cross Validation) 是在机器学习建立模型和验证模型参数时常用的方法。顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和验证集。用训练集来训练模型,验证集来评估模型的好坏。在此基础上可以得到多组不同的训练集和测试集,某次训练集中的样本,在下次可能成为测试集中的样本,也就是所谓的交叉。
- 交叉验证用在数据量不是很充足的情况(比如数据量小于一万条),能够从有限的数据中获取尽可能多的有效信息。
- 交叉验证用于评估模型的预测性能,尤其是训练好的模型在新数据上的表现,能够一定程度上减小过拟合。

## 交叉验证常见方法

- 留出法交叉验证
  - 留出法 (Hold-Out Cross Validation) 是一种简单交叉验证,即针对原始数据集,通常分为训练集、验证集。训练集用于训练模型、验证集对于模型来说是未知数据,用于评估模型的泛化能力。比如我们随机的将样本数据分为两部分(80%的训练集,20%的验证集),然后用训练集来训练模型,验证集上验证模型及参数,最后选择损失函数评估最优的模型和参数。
- K折交叉验证(文心中使用这种方法),K折交叉验证 (k-fold Cross Validation) 过程如下所示:
  - 不重复抽样将原始数据随机分成k份;
  - 每次挑选其中1份作为验证集,剩余k-1份作为训练集用于训练模型;
  - 重复第2步k次,在每个训练集上训练后得到一个模型。用这个模型在相应的验证集上测试,计算并保存模型的评估指标。

![img](http://rte.weiyun.baidu.com/api/imageDownloadAddress?attachId=8e5371445df84ada880543a641dd8c1f)

## 文心中使用交叉验证

这里我们以BOW分类任务为例介绍文心中的交叉验证的使用方式,主要分为以下几个步骤:

1. 准备数据:数据集的准备与通用的BOW分类任务一致,不再赘述,详细信息请移步[实战演练:使用文心进行模型训练](https://ai.baidu.com/ai-doc/ERNIE-Ultimate/Ekmlrorrp)

W
webyfdt 已提交
24
2. 配置参数:以erniekit_appzoo/tasks/text_classification/examples/cls_bow_ch.json为例,交叉验证仅需要在dataset_reader部分进行配置,其余部分与普通的训练任务一致。dataset_reader中k_fold相关参数为K折交叉验证的配置:
W
webyfdt 已提交
25 26 27 28 29 30 31 32 33

   - num_fold表示需要将训练集拆成K折,num_fold的取值需要>=2;

   - num_use_split表示被拆成K折,并组成K个训练集+验证集的数据集合中,有几个用来训练模型(num_use_split的取值不能大于num_fold);

   - data_path表示需要被拆分的原始数据集路径,必须是目录;

   - data_path_split表示使用K折交叉验证拆分完成之后的数据存放路径,必须是个目录。

T
Tesla 已提交
34
```
W
webyfdt 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    	{
    	"dataset_reader": {
    	"k_fold": {
        "num_fold": 4,    ## 表示需要将训练集拆成K折(num_fold的取值需要>=2)
        "data_path": "./data/train_data",  ## 待拆分的原始数据集路径
        "data_path_split": "./data_tmp",  ## 表示使用K折交叉验证拆分完成之后的数据存放路径,模型训练时真正用到的数据集路径。
        "num_use_split": 4     ## 拆好的K折中,有几个数据集要用来做训练,其取值不能大于num_fold。
    	},
    	"train_reader": {
        ...
        "config": {
            "data_path": "./data/train_data",
            "shuffle": false,
            "batch_size": 8,
            "epoch": 3,
            "sampling_rate": 1.0
        }
    	}
    	 ...
    	},
    	"model": {
    	"type": "BowClassification",
    	"optimization": {
        "learning_rate": 2e-5
    	},
    	"vocab_size": 33261,
    	"num_labels": 2
    	},
    	 ...
    	}
```

W
webyfdt 已提交
67
3. 启动训练:使用交叉验证进行训练的启动脚本与普通训练任务不一样,启动脚本为**run_with_preprocess.py**,该脚本的位置在erniekit_appzoo/tools/run_preprocess/目录下,可以拷贝到erniekit_appzoo/tasks/text_classification目录下使用,入参为上一步骤配置好k_fold参数的json文件,具体如下所示:
W
webyfdt 已提交
68

T
Tesla 已提交
69
```
W
webyfdt 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
# BOW 中文文本分类模型
# 基于json实现预置网络训练。其调用了配置文件./examples/cls_bow_ch.json(添加了交叉验证配置的json)
python run_with_preprocess.py --param_path ./examples/cls_bow_ch.json
```

## 运行过程中产生的文件

运行过程中,使用**run_with_preprocess.py**启动的交叉验证任务会产生一些额外的目录,如下所示:

- data_tmp:通过K折交叉验证处理过后的数据,如图所示就是将原始数据分成了4组,每组包含train_data 和 dev_data,如下图所示:

![img](http://rte.weiyun.baidu.com/api/imageDownloadAddress?attachId=6b520e0e75c94b00bfb22d2a70f9cac6)

- json_tmp:交叉验证处理完成之后生成的新的待运行配置文件,如下图所示:

![img](http://rte.weiyun.baidu.com/api/imageDownloadAddress?attachId=5cfbc77da7a74f429dc5f6f6ffffb268)

- log:基于交run_with_preprocess.py运行的任务都会生成新的json配置,每个json对应一个独立的trainer,各个trainer按照顺序串行训练,所以日志会分别输出到对应编号的log中。如下图就是串行的4个trainer的日志。日志内容和单独运行run_trainer.py输出到test.log中的日志一样,如下图所示:

T
Tesla 已提交
89
![img](http://rte.weiyun.baidu.com/api/imageDownloadAddress?attachId=5c55226559ff4d63aa43720d5f1757bc)