{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "import numpy as np\n", "import re\n", "import logging\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sys.path.append('../ernie')\n", "sys.path.append('../')\n", "%env CUDA_VICIBLE_DEVICES=7\n", "# if CUDA_VICIBLE_DEVICES is changed, relaunch jupyter kernel to inform paddle" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import propeller.paddle as propeller\n", "import paddle\n", "import paddle.fluid as F\n", "import paddle.fluid.layers as L\n", "#import model defenition from original ERNIE\n", "from model.ernie import ErnieModel\n", "from tokenization import FullTokenizer\n", "from optimization import optimization\n", "from propeller import log\n", "log.setLevel(logging.DEBUG)\n", "\n", "if paddle.__version__ not in ['1.5.1', '1.5.2']:\n", " raise RuntimeError('propeller works in paddle1.5.1')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "# download pretrained model&config(ernie1.0) and xnli data\n", "mkdir ernie1.0_pretrained\n", "if [ ! -f ernie1.0_pretrained/ERNIE_stable-1.0.1.tar.gz ]\n", "then\n", " echo \"download model\"\n", " wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz -P ernie1.0_pretrained\n", "fi\n", "\n", "if [ ! -f task_data_zh.tgz ]\n", "then\n", " echo \"download data\"\n", " wget --no-check-certificate https://ernie.bj.bcebos.com/task_data_zh.tgz\n", "fi\n", "\n", "tar xzf ernie1.0_pretrained/ERNIE_stable-1.0.1.tar.gz -C ernie1.0_pretrained\n", "tar xzf task_data_zh.tgz" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#define basic training settings\n", "EPOCH=3\n", "BATCH=16\n", "LR=5e-3\n", "MAX_SEQLEN=128\n", "TASK_DATA='./task_data/'\n", "MODEL='./ernie1.0_pretrained/'\n", "OUTPUT_DIR='./output'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!rm -rf {OUTPUT_DIR}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#skip header, and reorganize train data into ./xnli_data \n", "!mkdir xnli_data\n", "!mkdir xnli_data/train\n", "!mkdir xnli_data/test\n", "!mkdir xnli_data/dev\n", "\n", "def remove_header_and_save(fname_in, fname_out):\n", " with open(fname_out, 'w') as fout:\n", " buf = open(fname_in).readlines()[1:]\n", " for i in buf:\n", " fout.write(i)\n", " return len(buf)\n", "train_data_size = remove_header_and_save(TASK_DATA + '/xnli/train.tsv', './xnli_data/train/part.0') \n", "dev_data_size = remove_header_and_save(TASK_DATA + '/xnli/dev.tsv', './xnli_data/dev/part.0') \n", "test_data_size = remove_header_and_save(TASK_DATA + '/xnli/test.tsv', './xnli_data/test/part.0') \n", "print(train_data_size)\n", "print(dev_data_size)\n", "print(test_data_size)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer = FullTokenizer(MODEL + 'vocab.txt')\n", "vocab = {j.strip().split('\\t')[0]: i for i, j in enumerate(open(MODEL + 'vocab.txt', encoding='utf8'))}\n", "\n", "print(tokenizer.tokenize('今天很热'))\n", "print(tokenizer.tokenize('coding in paddle is cool'))\n", "print(tokenizer.tokenize('[CLS]i have an pen')) # note: special token like [CLS], will be segmented, so please add these id after tokenization.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`propeller.data.FeatureColumns` defines the data schema in every data file.\n", "\n", "our data consist of 3 columns: seg_a, seg_b, label. with \"\\t\" as delemeter.\n", "\n", "`TextColumn` will do 3 things for you: \n", "\n", "1. tokenize input sentence with user-defined `tokenizer_func`\n", "2. vocab lookup\n", "3. serialize to protobuf bin file (optional)\n", "\n", "data file is organized into following patten:\n", "\n", "```script\n", "./xnli_data\n", "|-- dev\n", "| `-- part.0\n", "|-- test\n", "| `-- part.0\n", "|-- train\n", " `-- part.0\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "sep_id = vocab['[SEP]']\n", "cls_id = vocab['[CLS]']\n", "unk_id = vocab['[UNK]']\n", "\n", "label_map = {\n", " b\"contradictory\": 0,\n", " b\"contradiction\": 0,\n", " b\"entailment\": 1,\n", " b\"neutral\": 2,\n", "}\n", "def tokenizer_func(inputs):\n", " ret = tokenizer.tokenize(inputs) #`tokenize` will conver bytes to str, so we use a str vocab\n", " return ret\n", "\n", "feature_column = propeller.data.FeatureColumns([\n", " propeller.data.TextColumn('title', unk_id=unk_id, vocab_dict=vocab, tokenizer=tokenizer_func),\n", " propeller.data.TextColumn('comment', unk_id=unk_id, vocab_dict=vocab, tokenizer=tokenizer_func),\n", " propeller.data.LabelColumn('label', vocab_dict=label_map), #be careful, Columns deal with python3 bytes directly.\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## trian model in propeller can be defined in 2 ways:\n", "1. subclass of `propeller.train.Model` which implements:\n", " 1. `__init__` (hyper_param, mode, run_config)\n", " 2. `forward` (features) => (prediction)\n", " 3. `backword` (loss) => None\n", " 4. `loss` (predictoin) => (loss)\n", " 5. `metrics` (optional) (prediction) => (dict of propeller.Metrics)\n", " \n", "2. a callable takes following args:\n", " 1. features\n", " 2. param\n", " 3. mode\n", " 4. run_config(optional)\n", " \n", " and returns a propeller.ModelSpec\n", " \n", "we use the subclasss approch here" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ClassificationErnieModel(propeller.train.Model):\n", " def __init__(self, hparam, mode, run_config):\n", " self.hparam = hparam\n", " self.mode = mode\n", " self.run_config = run_config\n", "\n", " def forward(self, features):\n", " src_ids, sent_ids = features\n", " dtype = 'float16' if self.hparam['use_fp16'] else 'float32'\n", " zero = L.fill_constant([1], dtype='int64', value=0)\n", " input_mask = L.cast(L.equal(src_ids, zero), dtype) # assume pad id == 0\n", " #input_mask = L.unsqueeze(input_mask, axes=[2])\n", " d_shape = L.shape(src_ids)\n", " seqlen = d_shape[1]\n", " batch_size = d_shape[0]\n", " pos_ids = L.unsqueeze(L.range(0, seqlen, 1, dtype='int32'), axes=[0])\n", " pos_ids = L.expand(pos_ids, [batch_size, 1])\n", " pos_ids = L.unsqueeze(pos_ids, axes=[2])\n", " pos_ids = L.cast(pos_ids, 'int64')\n", " pos_ids.stop_gradient = True\n", " input_mask.stop_gradient = True\n", " task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment\n", " task_ids.stop_gradient = True\n", "\n", " ernie = ErnieModel(\n", " src_ids=src_ids,\n", " position_ids=pos_ids,\n", " sentence_ids=sent_ids,\n", " task_ids=task_ids,\n", " input_mask=input_mask,\n", " config=self.hparam,\n", " use_fp16=self.hparam['use_fp16']\n", " )\n", "\n", " cls_feats = ernie.get_pooled_output()\n", "\n", " cls_feats = L.dropout(\n", " x=cls_feats,\n", " dropout_prob=0.1,\n", " dropout_implementation=\"upscale_in_train\"\n", " )\n", "\n", " logits = L.fc(\n", " input=cls_feats,\n", " size=self.hparam['num_label'],\n", " param_attr=F.ParamAttr(\n", " name=\"cls_out_w\",\n", " initializer=F.initializer.TruncatedNormal(scale=0.02)),\n", " bias_attr=F.ParamAttr(\n", " name=\"cls_out_b\", initializer=F.initializer.Constant(0.))\n", " )\n", "\n", " propeller.summary.histogram('pred', logits)\n", "\n", " if self.mode is propeller.RunMode.PREDICT:\n", " probs = L.softmax(logits)\n", " return probs\n", " else:\n", " return logits\n", "\n", " def loss(self, predictions, labels):\n", " ce_loss, probs = L.softmax_with_cross_entropy(\n", " logits=predictions, label=labels, return_softmax=True)\n", " #L.Print(ce_loss, message='per_example_loss')\n", " loss = L.mean(x=ce_loss)\n", " return loss\n", "\n", " def backward(self, loss):\n", " scheduled_lr, loss_scale = optimization(\n", " loss=loss,\n", " warmup_steps=int(self.run_config.max_steps * self.hparam['warmup_proportion']),\n", " num_train_steps=self.run_config.max_steps,\n", " learning_rate=self.hparam['learning_rate'],\n", " train_program=F.default_main_program(),\n", " startup_prog=F.default_startup_program(),\n", " weight_decay=self.hparam['weight_decay'],\n", " scheduler=\"linear_warmup_decay\",)\n", " propeller.summary.scalar('lr', scheduled_lr)\n", "\n", " def metrics(self, predictions, label):\n", " predictions = L.argmax(predictions, axis=1)\n", " predictions = L.unsqueeze(predictions, axes=[1])\n", " acc = propeller.metrics.Acc(label, predictions)\n", " #auc = propeller.metrics.Auc(label, predictions)\n", " return {'acc': acc}\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define some utility function.\n", "\n", "def build_2_pair(seg_a, seg_b):\n", " token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0\n", " token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1\n", " sen_emb = np.concatenate([[cls_id], seg_a, [sep_id], seg_b, [sep_id]], 0)\n", " token_type_emb = np.concatenate([[0], token_type_a, [0], token_type_b, [1]], 0)\n", " #seqlen = sen_emb.shape[0]\n", " #deteministic truncate\n", " sen_emb = sen_emb[0: MAX_SEQLEN]\n", " token_type_emb = token_type_emb[0: MAX_SEQLEN]\n", " return sen_emb, token_type_emb\n", "\n", "def expand_dims(*args):\n", " func = lambda i: np.expand_dims(i, -1)\n", " ret = [func(i) for i in args]\n", " return ret\n", "\n", "def before_pad(seg_a, seg_b, label):\n", " sentence, segments = build_2_pair(seg_a, seg_b)\n", " return sentence, segments, label\n", "\n", "def after_pad(sentence, segments, label):\n", " sentence, segments, label = expand_dims(sentence, segments, label)\n", " return sentence, segments, label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# a `propeller.paddle.data.Dataset` is built from FeatureColumns\n", "\n", "train_ds = feature_column.build_dataset('train', use_gz=False, data_dir='./xnli_data/train', shuffle=True, repeat=True) \\\n", " .map(before_pad) \\\n", " .padded_batch(BATCH, (0, 0, 0)) \\\n", " .map(after_pad)\n", "\n", "dev_ds = feature_column.build_dataset('dev', use_gz=False, data_dir='./xnli_data/dev', shuffle=False, repeat=False) \\\n", " .map(before_pad) \\\n", " .padded_batch(BATCH, (0, 0, 0)) \\\n", " .map(after_pad)\n", "\n", "shapes = ([-1, MAX_SEQLEN, 1], [-1, MAX_SEQLEN, 1], [-1, 1])\n", "types = ('int64', 'int64', 'int64')\n", "train_ds.data_shapes = shapes\n", "train_ds.data_types = types\n", "dev_ds.data_shapes = shapes\n", "dev_ds.data_types = types\n", "\n", "warm_start_dir = MODEL + '/params'\n", "# only the encoder and embedding is loaded from pretrained model\n", "varname_to_warmstart = re.compile('^encoder.*w_0$|^encoder.*b_0$|^.*embedding$|^.*bias$|^.*scale$')\n", "ws = propeller.WarmStartSetting(\n", " predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)),\n", " from_dir=warm_start_dir\n", " )\n", "\n", "# propeller will export model of highest performance, the criteria is up to you. \n", "# here we pick the model with maximum evaluatoin accuracy.\n", "#`BestInferenceModelExporter` is used to export serveable models\n", "best_inference_exporter = propeller.train.exporter.BestInferenceModelExporter(\n", " os.path.join(OUTPUT_DIR, 'best'), \n", " cmp_fn=lambda old, new: new['eval']['acc'] > old['eval']['acc'])\n", "#`BestExporter` is used to export restartable checkpoint, so that we can restore from it and check test-set accuracy.\n", "best_exporter = propeller.train.exporter.BestExporter(\n", " os.path.join(OUTPUT_DIR, 'best_model'), \n", " cmp_fn=lambda old, new: new['eval']['acc'] > old['eval']['acc'])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#ERNIE1.0 config \n", "ernie_config = propeller.HParams(**json.loads(open(MODEL + '/ernie_config.json').read()))\n", "\n", "# default term in official config\n", "ernie_v2_config = propeller.HParams(**{\n", " \"sent_type_vocab_size\": None, \n", " \"use_task_id\": False,\n", " \"task_id\": 0,\n", "})\n", "\n", "# train schema\n", "train_config = propeller.HParams(**{ \n", " \"warmup_proportion\": 0.1,\n", " \"weight_decay\": 0.01,\n", " \"use_fp16\": 0,\n", " \"learning_rate\": 0.00005,\n", " \"num_label\": 3,\n", " \"batch_size\": 32\n", "})\n", "\n", "config = ernie_config.join(ernie_v2_config).join(train_config)\n", "\n", "run_config = propeller.RunConfig(\n", " model_dir=OUTPUT_DIR,\n", " max_steps=EPOCH * train_data_size / BATCH,\n", " skip_steps=10,\n", " eval_steps=1000,\n", " save_steps=1000,\n", " log_steps=10,\n", " max_ckpt=3\n", ")\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Finetune and Eval" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# `train_and_eval` takes key-word args only\n", "# we are now ready to train\n", "hooks = [propeller.train.TqdmNotebookProgressBarHook(run_config.max_steps)] # to show the progress bar, you need to `pip install tqdm ipywidgets`\n", "propeller.train_and_eval(\n", " model_class_or_model_fn=ClassificationErnieModel, #**careful**, you should pass a Class to `train_and_eval`, propeller will try to instantiate it.\n", " params=config, \n", " run_config=run_config, \n", " train_dataset=train_ds, \n", " eval_dataset=dev_ds, \n", " warm_start_setting=ws, \n", " exporters=[best_exporter, best_inference_exporter],\n", " train_hooks=hooks,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# after training you might want to check your model performance on test-set\n", "# let's do this via `propeller.predict`\n", "# keep in mind that model of best performace has been exported during thet `train_and_eval` phrase\n", "\n", "best_filename = [file for file in os.listdir(os.path.join(OUTPUT_DIR, 'best_model')) if 'model' in file][0]\n", "best_model_path = os.path.join(os.path.join(OUTPUT_DIR, 'best_model'), best_filename)\n", "true_label = [label_map[(line.strip().split(b'\\t')[-1])]for line in open('./xnli_data/test/part.0', 'rb')]\n", "\n", "def drop_label(sentence, segments, label): #we drop the label column here\n", " return sentence, segments\n", "\n", "test_ds = feature_column.build_dataset('test', use_gz=False, data_dir='./xnli_data/test', shuffle=False, repeat=False) \\\n", " .map(before_pad) \\\n", " .padded_batch(BATCH, (0, 0, 0)) \\\n", " .map(after_pad) \\\n", " .map(drop_label)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "result = []\n", "learner = propeller.Learner(ClassificationErnieModel, run_config, params=config, )\n", "for pred in learner.predict(test_ds, ckpt=-1):\n", " result.append(np.argmax(pred))\n", " \n", "result, true_label = np.array(result), np.array(true_label)\n", "\n", "test_acc = (result == true_label).sum() / len(true_label)\n", "print('test accuracy:%.5f' % test_acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Serving\n", "your model is now ready to serve! \n", "you can open up a server by propeller with \n", "```script\n", "python -m propeller.tools.start_server -m /path/to/saved/model -p 8888\n", "```\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }