{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. ERNIE-M模型简介\n", "\n", "\n", "[ERNIE-M](https://arxiv.org/abs/2012.15674) 是百度提出的一种多语言语言模型。原文提出了一种新的训练方法,让模型能够将多种语言的表示与单语语料库对齐,以克服平行语料库大小对模型性能的限制。原文的主要想法是将回译机制整合到预训练的流程中,在单语语料库上生成伪平行句对,以便学习不同语言之间的语义对齐,从而增强跨语言模型的语义建模。实验结果表明,ERNIE-M 优于现有的跨语言模型,并在各种跨语言下游任务中提供了最新的 SOTA 结果。\n", "\n", "本项目是 ERNIE-M 的 PaddlePaddle 动态图实现, 包含模型训练,模型验证等内容。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 模型效果及应用场景\n", "\n", "### 2.1 自然语言推断任务\n", "\n", "#### 2.1.1 数据集\n", "\n", "XNLI 是 MNLI 的子集,并且已被翻译成14种不同的语言(包含一些较低资源语言)。与 MNLI 一样,目标是预测文本蕴含(句子 A 是否暗示/矛盾/都不是句子 B )。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. 模型如何使用\n", "\n", "## 3.1 模型训练\n", "\n", "* 安装PaddleNLP包\n", "\n", "```shell\n", "pip install --upgrade paddlenlp\n", "pip install datasets\n", "```\n", "\n", "* 下载\n", "\n", "```shell\n", "# 下载脚本文件(从gitee上更快)\n", "wget https://gitee.com/paddlepaddle/PaddleNLP/blob/develop/model_zoo/ernie-m/run_classifier.py\n", "```\n", "\n", "* 单卡训练\n", "\n", "```shell\n", "python run_classifier.py \\\n", " --task_type cross-lingual-transfer \\\n", " --batch_size 16 \\\n", " --model_name_or_path ernie-m-base \\\n", " --save_steps 12272 \\\n", " --output_dir output\n", "```\n", "\n", "* 多卡训练\n", "\n", "```shell\n", "python -m paddle.distributed.launch --gpus 0,1 --log_dir output run_classifier.py \\\n", " --task_type cross-lingual-transfer \\\n", " --batch_size 16 \\\n", " --model_name_or_path ernie-m-base \\\n", " --save_steps 12272 \\\n", " --output_dir output\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.2 通用参数释义\n", "\n", "- `task_type` 表示了自然语言推断任务的类型,目前支持的类型为:\"cross-lingual-transfer\", \"translate-train-all\"\n", " ,分别表示在英文数据集上训练并在所有15种语言数据集上测试、在所有15种语言数据集上训练和测试。\n", "- `model_name_or_path` 指示了 Fine-tuning 使用的具体预训练模型以及预训练时使用的tokenizer,目前支持的预训练模型有:\"ernie-m-base\", \"ernie-m-large\"\n", " 。若模型相关内容保存在本地,这里也可以提供相应目录地址,例如:\"./checkpoint/model_xx/\"。\n", "- `output_dir` 表示模型保存路径。\n", "- `max_seq_length` 表示最大句子长度,超过该长度将被截断,不足该长度的将会进行 padding。\n", "- `learning_rate` 表示基础学习率大小,将于 learning rate scheduler 产生的值相乘作为当前学习率。\n", "- `num_train_epochs` 表示训练轮数。\n", "- `logging_steps` 表示日志打印间隔步数。\n", "- `save_steps` 表示模型保存及评估间隔步数。\n", "- `batch_size` 表示每次迭代**每张**卡上的样本数目。\n", "- `weight_decay` 表示AdamW的权重衰减系数。\n", "- `layerwise_decay` 表示 AdamW with Layerwise decay 的逐层衰减系数。\n", "- `warmup_proportion` 表示学习率warmup系数。\n", "- `max_steps` 表示最大训练步数。若训练`num_train_epochs`轮包含的训练步数大于该值,则达到`max_steps`后就提前结束。\n", "- `seed` 表示随机数种子。\n", "- `device` 表示训练使用的设备, 'gpu'表示使用 GPU, 'xpu'表示使用百度昆仑卡, 'cpu'表示使用 CPU。\n", "- `use_amp` 表示是否启用自动混合精度训练。\n", "- `scale_loss` 表示自动混合精度训练的参数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 模型原理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "原文提出两种方法建模各种语言间的对齐关系:\n", "\n", "- **Cross-Attention Masked Language Modeling(CAMLM)**: 该算法在少量双语语料上捕捉语言间的对齐信息。其需要在不利用源句子上下文的情况下,通过目标句子还原被掩盖的词语,使模型初步建模了语言间的对齐关系。\n", "- **Back-Translation masked language modeling(BTMLM)**: 该方法基于回译机制从单语语料中学习语言间的对齐关系。通过CAMLM 生成伪平行语料,然后让模型学习生成的伪平行句子,使模型可以利用单语语料更好地建模语义对齐关系。\n", "\n", "\n", "![](https://foruda.gitee.com/images/1668157409546826003/f78cb949_5218658.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 相关论文以及引用信息\n", "\n", "```bibtex\n", "@article{Ouyang2021ERNIEMEM,\n", " title={ERNIE-M: Enhanced Multilingual Representation by Aligning Cross-lingual Semantics with Monolingual Corpora},\n", " author={Xuan Ouyang and Shuohuan Wang and Chao Pang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},\n", " journal={ArXiv},\n", " year={2021},\n", " volume={abs/2012.15674}\n", "}\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.13 ('model_center')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "de1ffcbce2b3061b5001e2c22f3a27594f323d4a49b789ebdbef6534581834bd" } } }, "nbformat": 4, "nbformat_minor": 2 }