module.py 15.7 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import shutil
from copy import deepcopy

import numpy as np
K
kinghuin 已提交
21 22 23
import paddle
import paddle.nn as nn
from paddle.io import DataLoader
W
wuzewu 已提交
24 25 26
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo
K
kinghuin 已提交
27 28 29 30
from paddlenlp.datasets import MapDataset
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.metrics import Rouge1, Rouge2
from paddlenlp.transformers import ErnieTokenizer, ErnieForGeneration, LinearDecayWithWarmup
W
wuzewu 已提交
31

K
kinghuin 已提交
32 33 34
from .encode import convert_example, after_padding
from .decode import post_process, beam_search_infilling
from .model import StackModel
W
wuzewu 已提交
35 36 37 38


@moduleinfo(
    name="ernie_gen",
K
kinghuin 已提交
39
    version="1.1.0",
W
wuzewu 已提交
40 41 42 43 44 45
    summary="ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning.",
    author="baidu",
    author_email="",
    type="nlp/text_generation",
)
class ErnieGen(hub.Module):
W
wuzewu 已提交
46
    def __init__(self):
W
wuzewu 已提交
47 48 49
        """
        initialize with the necessary elements
        """
K
kinghuin 已提交
50 51
        self.tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
        self.rev_dict = self.tokenizer.vocab.idx_to_token
W
wuzewu 已提交
52 53 54 55 56 57
        self.rev_lookup = np.vectorize(lambda i: self.rev_dict[i])
        self._model = None

    @property
    def model(self):
        if not self._model:
K
kinghuin 已提交
58
            self._model = ErnieForGeneration.from_pretrained("ernie-1.0")
W
wuzewu 已提交
59 60 61
        return self._model

    def finetune(
W
wuzewu 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
            self,
            train_path,
            dev_path=None,
            save_dir="ernie_gen_result",
            init_ckpt_path=None,
            use_gpu=True,
            max_steps=500,
            batch_size=8,
            max_encode_len=50,
            max_decode_len=50,
            learning_rate=5e-5,
            warmup_proportion=0.1,
            weight_decay=0.1,
            noise_prob=0,
            label_smooth=0,
            beam_width=5,
            length_penalty=1.0,
            log_interval=100,
            save_interval=200,
W
wuzewu 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    ):
        """
        finetune with the specified dataset.

        Args:
            train_path(str): the train dataset path.
            dev_path(str): the dev dataset path.
            save_dir(str): the model params and dev dataset predict result save path.
            init_ckpt_path(str): incremental training load path.
            use_gpu(bool): use gpu or not.
            max_steps(int): max training steps.
            batch_size(int): the batch size.
            max_encode_len(int): the max encode length.
            max_decode_len(int): the max decode length.
            learning_rate(float): the learning rate.
            warmup_proportion(float): the warmup proportion.
            weight_decay(float): the weight decay magnitude.
            noise_prob(float): the nosie probability. see the ernie gen paper for details.
            label_smooth(float): the label smooth magnitude.
            beam_width(int): the beam size during evaluating the dev dataset.
            length_penalty(float): the length penalty during evaluating the dev dataset.
            log_interval(int): the log interval.
            save_interval(int): the save interval. dev set will be evaluated after saving.

        Return:
            result(dict): A Dictionary of shape::
                {
                    last_save_path(str): last model save path.
                    last_ppl(float): last model ppl.
                }
        """
W
wuzewu 已提交
112
        paddle.disable_static()
K
kinghuin 已提交
113 114 115 116 117 118 119 120 121
        paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')

        if init_ckpt_path is not None:
            logger.info('loading checkpoint from %s' % init_ckpt_path)
            sd = paddle.load(init_ckpt_path)
            self.model.set_state_dict(sd)

        train_dataset = self._load_dataset(train_path)
        attn_id = self.tokenizer.vocab['[MASK]']
W
wuzewu 已提交
122 123 124 125 126 127 128
        trans_func = convert_example(
            tokenizer=self.tokenizer,
            attn_id=attn_id,
            tgt_type_id=1,
            max_encode_len=max_encode_len,
            max_decode_len=max_decode_len,
            noise_prob=noise_prob)
K
kinghuin 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141

        train_dataset = train_dataset.map(trans_func)
        train_batch_sampler = paddle.io.BatchSampler(train_dataset, batch_size=batch_size, shuffle=True)
        batchify_fn = lambda samples, fn=Tuple(
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # src_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # src_pids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id),  # src_tids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_pids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id),  # tgt_tids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # attn_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_labels
        ): after_padding(fn(samples))
W
wuzewu 已提交
142 143 144 145 146 147
        train_data_loader = DataLoader(
            dataset=train_dataset,
            batch_sampler=train_batch_sampler,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
K
kinghuin 已提交
148 149 150 151

        if dev_path:
            dev_dataset = self._load_dataset(dev_path)
            dev_dataset = dev_dataset.map(trans_func)
W
wuzewu 已提交
152 153
            dev_data_loader = DataLoader(
                dataset=dev_dataset, batch_size=batch_size, collate_fn=batchify_fn, num_workers=0, return_list=True)
K
kinghuin 已提交
154 155 156 157 158 159 160

        label_num = self.model.word_emb.weight.shape[0]
        train_model = StackModel(self.model)
        lr_scheduler = LinearDecayWithWarmup(learning_rate, max_steps, warmup_proportion)
        # Generate parameter names needed to perform weight decay.
        # All bias and LayerNorm parameters are excluded.
        decay_params = [p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
W
wuzewu 已提交
161 162 163 164 165 166
        optimizer = paddle.optimizer.AdamW(
            learning_rate=lr_scheduler,
            parameters=self.model.parameters(),
            weight_decay=weight_decay,
            grad_clip=nn.ClipGradByGlobalNorm(1.0),
            apply_decay_param_fun=lambda x: x in decay_params)
K
kinghuin 已提交
167 168 169 170 171 172 173 174 175 176

        rouge1 = Rouge1()
        rouge2 = Rouge2()
        global_step = 1
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir)
        while True:
            for batch in train_data_loader:
                (src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
                 mask_attn_2_srctgtattn, tgt_labels, _) = batch
W
wuzewu 已提交
177
                if label_smooth > 0.:
W
wuzewu 已提交
178 179
                    tgt_labels = nn.functional.label_smooth(
                        nn.functional.one_hot(tgt_labels, label_num), epsilon=label_smooth)
K
kinghuin 已提交
180 181 182 183

                tgt_pos = paddle.nonzero(attn_ids == attn_id)
                loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src,
                                   mask_tgt_2_srctgt, mask_attn_2_srctgtattn, tgt_labels, tgt_pos)
W
wuzewu 已提交
184 185

                loss.backward()
K
kinghuin 已提交
186 187 188
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()
W
wuzewu 已提交
189

K
kinghuin 已提交
190
                if global_step % log_interval == 0 and paddle.distributed.get_rank() == 0:
W
wuzewu 已提交
191 192
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
W
wuzewu 已提交
193 194
                    logger.info('[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' % (global_step, max_steps, loss_np,
                                                                                       ppl, lr_scheduler.get_lr()))
K
kinghuin 已提交
195
                if save_dir and global_step % save_interval == 0 and global_step > 0:
W
wuzewu 已提交
196 197
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
K
kinghuin 已提交
198
                    save_name = "step_%s_ppl_%.5f.params" % (global_step, ppl)
W
wuzewu 已提交
199 200
                    save_path = os.path.join(save_dir, save_name)
                    logger.info("save the model in %s" % save_path)
K
kinghuin 已提交
201
                    paddle.save(self.model.state_dict(), save_path)
W
wuzewu 已提交
202 203

                    if dev_path:
K
kinghuin 已提交
204 205 206 207
                        self._evaluate(self.model, dev_data_loader, self.tokenizer, rouge1, rouge2, attn_id,
                                       max_decode_len, max_encode_len, beam_width, length_penalty)

                if global_step >= max_steps:
W
wuzewu 已提交
208
                    break
K
kinghuin 已提交
209
                global_step += 1
W
wuzewu 已提交
210

K
kinghuin 已提交
211 212
            if global_step >= max_steps:
                break
W
wuzewu 已提交
213

K
kinghuin 已提交
214 215 216
        if global_step % save_interval != 0:
            loss_np = loss.numpy()
            ppl = np.exp(loss_np)
W
wuzewu 已提交
217 218
            logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' % (global_step, loss_np, ppl,
                                                                                lr_scheduler.get_lr()))
K
kinghuin 已提交
219 220 221 222 223 224 225 226 227
            if save_dir:
                save_name = "step_%s_ppl_%.5f.pdparams" % (global_step, ppl)
                save_path = os.path.join(save_dir, save_name)
                logger.info("save the model in %s" % save_path)
                paddle.save(self.model.state_dict(), save_path)

                if dev_path:
                    self._evaluate(self.model, dev_data_loader, self.tokenizer, rouge1, rouge2, attn_id, max_decode_len,
                                   max_encode_len, beam_width, length_penalty)
W
wuzewu 已提交
228

K
kinghuin 已提交
229 230 231 232
        result = {
            "last_save_path": "%s" % save_path,
            "last_ppl": ppl[0],
        }
W
wuzewu 已提交
233

K
kinghuin 已提交
234
        return result
W
wuzewu 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

    def export(self,
               params_path,
               module_name,
               author,
               max_encode_len=50,
               max_decode_len=50,
               version="1.0.0",
               summary="",
               author_email="",
               export_path="."):
        """
        export the model saved in the params_path to a hub module.

        Args:
            params_path(str): the model params save path.
            module_name(str): the module name.
            author(str): the author name.
            max_encode_len(int): the max encode length.
            max_decode_len(int): the max decode length.
            version(str): the version information.
            summary(str): the module brief introduction.
            author_email(str): the author email address.
            export_path(str): the module export path.
        """
        if not os.path.exists(params_path):
            raise FileNotFoundError("The path %s does not exist." % params_path)
        export_module_path = os.path.join(export_path, module_name)
        if not os.path.exists(export_module_path):
            os.makedirs(export_module_path)
        logger.info("Begin export the model save in %s ..." % params_path)

        assets_path = os.path.join(self.directory, "template", "assets")
        init_path = os.path.join(self.directory, "template", "__init__.py")
K
kinghuin 已提交
269
        decode_path = os.path.join(self.directory, "template", "decode.py")
W
wuzewu 已提交
270 271 272 273 274
        module_temp_path = os.path.join(self.directory, "template", "module.temp")

        export_assets_path = os.path.join(export_module_path, "assets")
        export_params_path = os.path.join(export_module_path, "assets", "ernie_gen.pdparams")
        export_init_path = os.path.join(export_module_path, "__init__.py")
K
kinghuin 已提交
275
        export_decode_path = os.path.join(export_module_path, "decode.py")
W
wuzewu 已提交
276

K
kinghuin 已提交
277 278
        if not os.path.exists(export_assets_path):
            os.makedirs(export_assets_path)
W
wuzewu 已提交
279 280
        shutil.copyfile(init_path, export_init_path)
        shutil.copyfile(params_path, export_params_path)
K
kinghuin 已提交
281
        shutil.copyfile(decode_path, export_decode_path)
W
wuzewu 已提交
282 283 284 285 286 287 288 289 290 291

        module_path = os.path.join(export_module_path, "module.py")
        with open(module_temp_path, encoding="utf8") as ftemp, open(module_path, "w") as fmodule:
            content = ftemp.read().replace(r"{module_name}", module_name).replace(r"{author}", author).replace(
                r"{version}", version).replace(r"{summary}", summary).replace(r"{author_email}", author_email).replace(
                    r"{max_encode_len}", str(max_encode_len)).replace(r"{max_decode_len}", str(max_decode_len))
            fmodule.write(content)

        logger.info("The module has exported to %s" % os.path.abspath(export_module_path))

K
kinghuin 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    def _evaluate(self, model, data_loader, tokenizer, rouge1, rouge2, attn_id, max_decode_len, max_encode_len,
                  beam_width, length_penalty):
        model.eval()

        vocab = tokenizer.vocab
        eos_id = vocab[tokenizer.sep_token]
        sos_id = vocab[tokenizer.cls_token]
        pad_id = vocab[tokenizer.pad_token]
        unk_id = vocab[tokenizer.unk_token]
        vocab_size = len(vocab)
        evaluated_sentences_ids = []
        reference_sentences_ids = []
        logger.info("Evaluating...")
        for data in data_loader:
            (src_ids, src_tids, src_pids, _, _, _, _, _, _, _, _, raw_tgt_labels) = data  # never use target when infer
            # Use greedy_search_infilling or beam_search_infilling to get predictions
W
wuzewu 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
            output_ids = beam_search_infilling(
                model,
                src_ids,
                src_tids,
                eos_id=eos_id,
                sos_id=sos_id,
                attn_id=attn_id,
                pad_id=pad_id,
                unk_id=unk_id,
                vocab_size=vocab_size,
                max_decode_len=max_decode_len,
                max_encode_len=max_encode_len,
                beam_width=beam_width,
                length_penalty=length_penalty,
                tgt_type_id=1)
K
kinghuin 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359

            for ids in output_ids.tolist():
                if eos_id in ids:
                    ids = ids[:ids.index(eos_id)]
                evaluated_sentences_ids.append(ids[0])

            for ids in raw_tgt_labels.numpy().tolist():
                ids = ids[:ids.index(eos_id)]
                reference_sentences_ids.append(ids)

        score1 = rouge1.score(evaluated_sentences_ids, reference_sentences_ids)
        score2 = rouge2.score(evaluated_sentences_ids, reference_sentences_ids)

        logger.info("Rouge-1: %.5f ,Rouge-2: %.5f" % (score1 * 100, score2 * 100))

        evaluated_sentences = []
        reference_sentences = []
        for ids in reference_sentences_ids[:3]:
            reference_sentences.append(''.join(map(post_process, vocab.to_tokens(ids))))
        for ids in evaluated_sentences_ids[:3]:
            evaluated_sentences.append(''.join(map(post_process, vocab.to_tokens(ids))))
        logger.debug(reference_sentences)
        logger.debug(evaluated_sentences)

        model.train()

    def _load_dataset(self, datafiles):
        def read(data_path):
            with open(data_path, 'r', encoding='utf-8') as fp:
                for line in fp.readlines():
                    order, words, labels = line.strip('\n').split('\t')
                    yield {'tokens': words, 'labels': labels}

        if isinstance(datafiles, str):
            return MapDataset(list(read(datafiles)))
        elif isinstance(datafiles, list) or isinstance(datafiles, tuple):
            return [MapDataset(list(read(datafile))) for datafile in datafiles]
W
wuzewu 已提交
360 361 362 363


if __name__ == "__main__":
    module = ErnieGen()
W
wuzewu 已提交
364 365 366 367 368 369 370
    result = module.finetune(
        train_path='test_data/train.txt',
        dev_path='test_data/dev.txt',
        max_steps=30,
        batch_size=2,
        log_interval=10,
        save_interval=20)
W
wuzewu 已提交
371
    module.export(params_path=result['last_save_path'], module_name="ernie_gen_test", author="test")