finetune_ner.py 15.5 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import re
import time
from random import random
from functools import reduce, partial

import numpy as np
import multiprocessing
import logging
import six
import re

import paddle
import paddle.fluid as F
import paddle.fluid.layers as L


from model.ernie import ErnieModel
from optimization import optimization
import utils.data

from propeller import log
log.setLevel(logging.DEBUG)
import propeller.paddle as propeller

class SequenceLabelErnieModel(propeller.train.Model):
    """propeller Model wraper for paddle-ERNIE """
    def __init__(self, hparam, mode, run_config):
        self.hparam = hparam
        self.mode = mode
        self.run_config = run_config
        self.num_label = len(hparam['label_list'])

    def forward(self, features):
        src_ids, sent_ids, input_seqlen = features
        zero = L.fill_constant([1], dtype='int64', value=0)
        input_mask = L.cast(L.equal(src_ids, zero), 'float32') # assume pad id == 0
        #input_mask = L.unsqueeze(input_mask, axes=[2])
        d_shape = L.shape(src_ids)
        seqlen = d_shape[1]
        batch_size = d_shape[0]
        pos_ids = L.unsqueeze(L.range(0, seqlen, 1, dtype='int32'), axes=[0])
        pos_ids = L.expand(pos_ids, [batch_size, 1])
        pos_ids = L.unsqueeze(pos_ids, axes=[2])
        pos_ids = L.cast(pos_ids, 'int64')
        pos_ids.stop_gradient = True
        input_mask.stop_gradient = True
        task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment
        task_ids.stop_gradient = True

        model = ErnieModel(
            src_ids=src_ids,
            position_ids=pos_ids,
            sentence_ids=sent_ids,
            task_ids=task_ids,
            input_mask=input_mask,
            config=self.hparam,
            use_fp16=self.hparam['use_fp16']
        )

        enc_out = model.get_sequence_output()
        logits = L.fc(
            input=enc_out,
            size=self.num_label,
            num_flatten_dims=2,
            param_attr= F.ParamAttr(
                name="cls_seq_label_out_w",
                initializer= F.initializer.TruncatedNormal(scale=0.02)),
            bias_attr=F.ParamAttr(
                name="cls_seq_label_out_b",
                initializer=F.initializer.Constant(0.)))

        propeller.summary.histogram('pred', logits)

        return logits, input_seqlen

    def loss(self, predictions, labels):
        logits, input_seqlen = predictions
        logits = L.flatten(logits, axis=2)
        labels = L.flatten(labels, axis=2)
        ce_loss, probs = L.softmax_with_cross_entropy(
            logits=logits, label=labels, return_softmax=True)
        loss = L.mean(x=ce_loss)
        return loss

    def backward(self, loss):
        scheduled_lr, _ = optimization(
            loss=loss,
            warmup_steps=int(self.run_config.max_steps * self.hparam['warmup_proportion']),
            num_train_steps=self.run_config.max_steps,
            learning_rate=self.hparam['learning_rate'],
            train_program=F.default_main_program(), 
            startup_prog=F.default_startup_program(),
            weight_decay=self.hparam['weight_decay'],
            scheduler="linear_warmup_decay",)
        propeller.summary.scalar('lr', scheduled_lr)

    def metrics(self, predictions, label):
        pred, seqlen = predictions
        pred = L.argmax(pred, axis=-1)
        pred = L.unsqueeze(pred, axes=[-1])
        f1 = propeller.metrics.ChunkF1(label, pred, seqlen, self.num_label)
        return {'f1': f1}

def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_size, max_seqlen, is_train):
    label_map = {v: i for i, v in enumerate(label_list)}
    no_entity_id = label_map['O']
C
chenxuyi 已提交
123
    delimiter = b''
C
chenxuyi 已提交
124 125 126 127 128 129 130 131 132 133

    def read_bio_data(filename):
        ds = propeller.data.Dataset.from_file(filename)
        iterable = iter(ds)
        def gen():
            buf, size = [], 0
            iterator = iter(ds)
            while 1:
                line = next(iterator)
                cols = line.rstrip(b'\n').split(b'\t')
C
chenxuyi 已提交
134 135
                tokens = cols[0].split(delimiter)
                labels = cols[1].split(delimiter)
C
chenxuyi 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
                if len(cols) != 2:
                    continue
                if len(tokens) != len(labels) or len(tokens) == 0:
                    continue
                yield [tokens, labels]

        return propeller.data.Dataset.from_generator_func(gen)

    def reseg_token_label(dataset):
        def gen():
            iterator = iter(dataset)
            while True:
                tokens, labels = next(iterator)
                assert len(tokens) == len(labels)
                ret_tokens = []
                ret_labels = []
                for token, label in zip(tokens, labels):
C
chenxuyi 已提交
153 154
                    sub_token = tokenizer(token)
                    label = label.decode('utf8')
C
chenxuyi 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
                    if len(sub_token) == 0:
                        continue
                    ret_tokens.extend(sub_token)
                    ret_labels.append(label)
                    if len(sub_token) < 2:
                        continue
                    sub_label = label
                    if label.startswith("B-"):
                        sub_label = "I-" + label[2:]
                    ret_labels.extend([sub_label] * (len(sub_token) - 1))

                assert len(ret_tokens) == len(ret_labels)
                yield ret_tokens, ret_labels

        ds = propeller.data.Dataset.from_generator_func(gen)
        return ds

    def convert_to_ids(dataset):
        def gen():
            iterator = iter(dataset)
            while True:
                tokens, labels = next(iterator)
                if len(tokens) > max_seqlen - 2:
                    tokens = tokens[: max_seqlen - 2]
                    labels = labels[: max_seqlen - 2]

                tokens = ['[CLS]'] + tokens + ['[SEP]']
C
chenxuyi 已提交
182
                token_ids = [vocab[t] for t in tokens]
C
chenxuyi 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
                label_ids = [no_entity_id] + [label_map[x] for x in labels] + [no_entity_id]
                token_type_ids = [0] * len(token_ids)
                input_seqlen = len(token_ids)

                token_ids = np.array(token_ids, dtype=np.int64)
                label_ids = np.array(label_ids, dtype=np.int64)
                token_type_ids = np.array(token_type_ids, dtype=np.int64)
                input_seqlen = np.array(input_seqlen, dtype=np.int64)

                yield token_ids, token_type_ids, input_seqlen, label_ids

        ds = propeller.data.Dataset.from_generator_func(gen)
        return ds

    def after(*features):
        return utils.data.expand_dims(*features)

    dataset = propeller.data.Dataset.from_list(input_files)
    if is_train:
        dataset = dataset.repeat().shuffle(buffer_size=len(input_files))
    dataset = dataset.interleave(map_fn=read_bio_data, cycle_length=len(input_files), block_length=1)
    if is_train:
        dataset = dataset.shuffle(buffer_size=100)
    dataset = reseg_token_label(dataset)
    dataset = convert_to_ids(dataset)
    dataset = dataset.padded_batch(batch_size).map(after)
    dataset.name = name
    return dataset


def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seqlen):
C
chenxuyi 已提交
214
    delimiter = b''
C
chenxuyi 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234

    def stdin_gen():
        if six.PY3:
            source = sys.stdin.buffer 
        else:
            source = sys.stdin
        while True:
            line = source.readline()
            if len(line) == 0:
                break
            yield line,

    def read_bio_data(ds):
        iterable = iter(ds)
        def gen():
            buf, size = [], 0
            iterator = iter(ds)
            while 1:
                line, = next(iterator)
                cols = line.rstrip(b'\n').split(b'\t')
C
chenxuyi 已提交
235
                tokens = cols[0].split(delimiter)
C
chenxuyi 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249
                if len(cols) != 1:
                    continue
                if len(tokens) == 0:
                    continue
                yield tokens, 
        return propeller.data.Dataset.from_generator_func(gen)

    def reseg_token_label(dataset):
        def gen():
            iterator = iter(dataset)
            while True:
                tokens, = next(iterator)
                ret_tokens = []
                for token in tokens:
C
chenxuyi 已提交
250
                    sub_token = tokenizer(token)
C
chenxuyi 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
                    if len(sub_token) == 0:
                        continue
                    ret_tokens.extend(sub_token)
                    if len(sub_token) < 2:
                        continue
                yield ret_tokens, 
        ds = propeller.data.Dataset.from_generator_func(gen)
        return ds

    def convert_to_ids(dataset):
        def gen():
            iterator = iter(dataset)
            while True:
                tokens, = next(iterator)
                if len(tokens) > max_seqlen - 2:
                    tokens = tokens[: max_seqlen - 2]

                tokens = ['[CLS]'] + tokens + ['[SEP]']
C
chenxuyi 已提交
269
                token_ids = [vocab[t] for t in tokens]
C
chenxuyi 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
                token_type_ids = [0] * len(token_ids)
                input_seqlen = len(token_ids)

                token_ids = np.array(token_ids, dtype=np.int64)
                token_type_ids = np.array(token_type_ids, dtype=np.int64)
                input_seqlen = np.array(input_seqlen, dtype=np.int64)
                yield token_ids, token_type_ids, input_seqlen

        ds = propeller.data.Dataset.from_generator_func(gen)
        return ds

    def after(*features):
        return utils.data.expand_dims(*features)

    dataset = propeller.data.Dataset.from_generator_func(stdin_gen)
    dataset = read_bio_data(dataset)
    dataset = reseg_token_label(dataset)
    dataset = convert_to_ids(dataset)
    dataset = dataset.padded_batch(batch_size).map(after)
    dataset.name = name
    return dataset


if __name__ == '__main__':
    parser = propeller.ArgumentParser('NER model with ERNIE')
    parser.add_argument('--max_seqlen', type=int, default=128)
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--vocab_file', type=str, required=True)
    parser.add_argument('--do_predict', action='store_true')
C
chenxuyi 已提交
299
    parser.add_argument('--use_sentence_piece_vocab', action='store_true')
C
chenxuyi 已提交
300 301 302 303 304
    parser.add_argument('--warm_start_from', type=str)
    args = parser.parse_args()
    run_config = propeller.parse_runconfig(args)
    hparams = propeller.parse_hparam(args)

C
chenxuyi 已提交
305 306 307

    vocab = {j.strip().split('\t')[0]: i for i, j in enumerate(open(args.vocab_file, 'r', encoding='utf8'))}
    tokenizer = utils.data.CharTokenizer(vocab, sentencepiece_style_vocab=args.use_sentence_piece_vocab)
C
chenxuyi 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
    sep_id = vocab['[SEP]']
    cls_id = vocab['[CLS]']
    unk_id = vocab['[UNK]']
    pad_id = vocab['[PAD]']

    label_list = ['B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'O']
    hparams['label_list'] = label_list

    if not args.do_predict:
        train_data_dir = os.path.join(args.data_dir, 'train')
        train_input_files = [os.path.join(train_data_dir, filename) for filename in os.listdir(train_data_dir)]
        dev_data_dir = os.path.join(args.data_dir, 'dev')
        dev_input_files = [os.path.join(dev_data_dir, filename) for filename in os.listdir(dev_data_dir)]
        test_data_dir = os.path.join(args.data_dir, 'test')
        test_input_files = [os.path.join(test_data_dir, filename) for filename in os.listdir(test_data_dir)]

        train_ds = make_sequence_label_dataset(name='train', 
                                               input_files=train_input_files, 
                                               label_list=label_list, 
                                               tokenizer=tokenizer, 
                                               batch_size=hparams.batch_size, 
                                               max_seqlen=args.max_seqlen,
                                               is_train=True)
        dev_ds = make_sequence_label_dataset(name='dev',
                                             input_files=dev_input_files,
                                             label_list=label_list,
                                             tokenizer=tokenizer,
                                             batch_size=hparams.batch_size,
                                             max_seqlen=args.max_seqlen,
                                             is_train=False)
        test_ds = make_sequence_label_dataset(name='test',
                                              input_files=test_input_files,
                                              label_list=label_list,
                                              tokenizer=tokenizer,
                                              batch_size=hparams.batch_size,
                                              max_seqlen=args.max_seqlen,
                                              is_train=False)

        shapes = ([-1, args.max_seqlen, 1], [-1, args.max_seqlen, 1], [-1, 1], [-1, args.max_seqlen, 1]) 
        types = ('int64', 'int64', 'int64', 'int64')

        train_ds.data_shapes = shapes
        train_ds.data_types = types
        dev_ds.data_shapes = shapes
        dev_ds.data_types = types
        test_ds.data_shapes = shapes
        test_ds.data_types = types

        varname_to_warmstart = re.compile(r'^encoder.*[wb]_0$|^.*embedding$|^.*bias$|^.*scale$|^pooled_fc.[wb]_0$')
        warm_start_dir = args.warm_start_from
        ws = propeller.WarmStartSetting(
                predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)),
                from_dir=warm_start_dir
            )

C
chenxuyi 已提交
363
        best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
C
chenxuyi 已提交
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
        propeller.train.train_and_eval(
                model_class_or_model_fn=SequenceLabelErnieModel,
                params=hparams, 
                run_config=run_config, 
                train_dataset=train_ds, 
                eval_dataset={'dev': dev_ds, 'test': test_ds}, 
                warm_start_setting=ws, 
                exporters=[best_exporter])

        for k in best_exporter._best['dev'].keys():
            if 'loss' in k:
                continue
            dev_v = best_exporter._best['dev'][k]
            test_v = best_exporter._best['test'][k]
            print('dev_%s\t%.5f\ntest_%s\t%.5f' % (k, dev_v, k, test_v))
    else:
        predict_ds = make_sequence_label_dataset_from_stdin(name='pred', 
                                               tokenizer=tokenizer, 
                                               batch_size=hparams.batch_size, 
                                               max_seqlen=args.max_seqlen)

        shapes = ([-1, args.max_seqlen, 1], [-1, args.max_seqlen, 1], [-1, 1]) 
        types = ('int64', 'int64', 'int64')

        predict_ds.data_shapes = shapes
        predict_ds.data_types = types

        rev_label_map = {i: v for i, v in enumerate(label_list)}
        learner = propeller.Learner(SequenceLabelErnieModel, run_config, hparams)
        for pred, _  in learner.predict(predict_ds, ckpt=-1):
            pred_str = ' '.join([rev_label_map[idx] for idx in np.argmax(pred, 1).tolist()])
            print(pred_str)