diff --git a/README.md b/README.md index 649c62dfba48b2636a38f1bceefa8f4679aa0111..ce3081897e506544e3620e9e0c34f2b5ef5be10b 100644 --- a/README.md +++ b/README.md @@ -362,21 +362,31 @@ cls3: inference model saved at output_model/thirdrun/infer_model ## 进阶篇 本章节更深入的对paddlepalm的使用方法展开介绍,并提供一些提高使用效率的小技巧。 -### 训练终止条件与各类任务的预期训练步数 +### 训练终止条件与预期训练步数 -在默认情况下,每个训练step的各个任务被采样到的概率均等,若用户希望更改其中某些任务的采样概率(比如某些任务的训练集较小,希望减少对其采样的次数;或某些任务较难,希望被更多的训练),可以在全局配置文件中通过`mix_ratio`字段控制各个任务的采样概率。例如 +#### 训练终止条件 +在训练开始前,`Controller`会为所有每个目标任务计算出预期的训练步数。当某个目标任务的完成预期的训练步数后,`Controller`保存该任务的预测模型,而后继续按照设定的各任务的采样概率进行多任务训练。当所有目标任务均达到预期的训练步数后,多任务学习终止。需要注意的是,`Controller`不会为辅助任务计算预期训练步数,也不会为其保存预测模型,其仅仅起到“陪同目标任务训练”的作用,不会影响到多任务学习的终止与否。 + +#### 任务采样概率与预期训练步数 +此外,在默认情况下,每个训练step的各个任务被采样到的概率均等,若用户希望更改其中某些任务的采样概率(比如某些任务的训练集较小,希望减少对其采样的次数;或某些任务较难,希望被更多的训练),可以在全局配置文件中通过`mix_ratio`字段控制各个任务的采样概率。例如,我们有三个任务,其中mrqa任务为目标任务,其余为辅助任务,我们对其`mix_ratio`进行如下设定: ```yaml task_instance: mrqa, match4mrqa, mlm4mrqa mix_ratio: 1.0, 0.5, 0.5 ``` -上述设置表示`match4mrqa`和`mlm4mrqa`任务的期望被采样次数均为`mrqa`任务的一半。此时,在mrqa任务被设置为主任务的情况下(第一个目标任务即为主任务),若mrqa任务训练一个epoch要经历5000 steps,且全局配置文件中设置了num_epochs为2,则根据上述`mix_ratio`的设置,mrqa任务将被训练5000\*2\*1.0=10000个steps,而`match4mrqa`任务和`mlm4mrqa`任务都会被训练5000个steps**左右**。 +上述设置表示`match4mrqa`和`mlm4mrqa`任务的期望被采样次数均为`mrqa`任务的一半。此时,在mrqa任务被设置为目标任务的情况下,若mrqa任务训练一个epoch要经历5000 steps,且全局配置文件中设置了num_epochs为2,则根据上述`mix_ratio`的设置,mrqa任务将被训练5000\*2\*1.0=10000个steps,而`match4mrqa`任务和`mlm4mrqa`任务都会被训练5000个steps**左右**。 > 注意:若match4mrqa, mlm4mrqa被设置为辅助任务,则实际训练步数可能略多或略少于5000个steps。对于目标任务,则是精确的5000 steps。 +#### 多个目标任务时预期训练步数的计算 + +当存在多个目标任务时,`num_epochs`仅作用于**第一个设定的目标任务(称为“主任务(main task)”)**,而后根据`mix_ratio`的设定为其余目标任务和辅助任务计算出预期的训练步数。 + ### 模型保存与预测机制 +`Controller`会在训练过程 + ### 分布式训练与推理 @@ -391,17 +401,17 @@ for_cn: True 所有的内置reader,均支持以下字段 ```yaml -- vocab_path(REQUIRED): str类型。字典文件路径。 -- max_seq_len(REQUIRED): int类型。切词后的序列最大长度(即token ids的最大长度)。注意经过分词后,token ids的数量往往多于原始的单词数(e.g., 使用wordpiece tokenizer时)。 -- batch_size(REQUIRED): int类型。训练或预测时的批大小(每个step喂入神经网络的样本数)。 -- train_file(REQUIRED): str类型。训练集文件所在路径。仅进行预测时,该字段可不设置。 -- pred_file(REQUIRED): str类型。测试集文件所在路径。仅进行训练时,该字段可不设置。 - -- do_lower_case(OPTIONAL): bool类型,默认为False。是否将大写英文字母转换成小写。 -- shuffle(OPTIONAL): bool类型,默认为True。训练阶段打乱数据集样本的标志位,当置为True时,对数据集的样本进行全局打乱。注意,该标志位的设置不会影响预测阶段(预测阶段不会shuffle数据集)。 -- seed(OPTIONAL): int类型,默认为。 -- pred_batch_size(OPTIONAL): int类型。预测阶段的批大小,当该参数未设置时,预测阶段的批大小取决于`batch_size`字段的值。 -- print_first_n(OPTIONAL): int类型。打印数据集的前n条样本和对应的reader输出,默认为0。 +vocab_path(REQUIRED): str类型。字典文件路径。 +max_seq_len(REQUIRED): int类型。切词后的序列最大长度(即token ids的最大长度)。注意经过分词后,token ids的数量往往多于原始的单词数(e.g., 使用wordpiece tokenizer时)。 +batch_size(REQUIRED): int类型。训练或预测时的批大小(每个step喂入神经网络的样本数)。 +train_file(REQUIRED): str类型。训练集文件所在路径。仅进行预测时,该字段可不设置。 +pred_file(REQUIRED): str类型。测试集文件所在路径。仅进行训练时,该字段可不设置。 + +do_lower_case(OPTIONAL): bool类型,默认为False。是否将大写英文字母转换成小写。 +shuffle(OPTIONAL): bool类型,默认为True。训练阶段打乱数据集样本的标志位,当置为True时,对数据集的样本进行全局打乱。注意,该标志位的设置不会影响预测阶段(预测阶段不会shuffle数据集)。 +seed(OPTIONAL): int类型,默认为。 +pred_batch_size(OPTIONAL): int类型。预测阶段的批大小,当该参数未设置时,预测阶段的批大小取决于`batch_size`字段的值。 +print_first_n(OPTIONAL): int类型。打印数据集的前n条样本和对应的reader输出,默认为0。 ``` #### 文本分类数据集reader工具:cls