utils.py 4.3 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
"""
Arguments for configuration
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import argparse
import io
import sys
import random
import numpy as np
import os

import paddle
import paddle.fluid as fluid


def str2bool(v):
    """
    String to Boolean
    """
    # because argparse does not support to parse "true, False" as python
    # boolean directly
    return v.lower() in ("true", "t", "1")


class ArgumentGroup(object):
    """
    Argument Class
    """
    def __init__(self, parser, title, des):
        self._group = parser.add_argument_group(title=title, description=des)

    def add_arg(self, name, type, default, help, **kwargs):
        """
        Add argument
        """
        type = str2bool if type == bool else type
        self._group.add_argument(
            "--" + name,
            default=default,
            type=type,
            help=help + ' Default: %(default)s.',
            **kwargs)


def print_arguments(args):
    """
    Print Arguments
    """
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(six.iteritems(vars(args))):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


def init_checkpoint(exe, init_checkpoint_path, main_program):
    """
    Init CheckPoint
    """
    assert os.path.exists(
        init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path

    def existed_persitables(var):
        """
        If existed presitabels
        """
        if not fluid.io.is_persistable(var):
            return False
        return os.path.exists(os.path.join(init_checkpoint_path, var.name))

    fluid.io.load_vars(
        exe,
        init_checkpoint_path,
        main_program=main_program,
        predicate=existed_persitables)
    print("Load model from {}".format(init_checkpoint_path))

82 83
    
def data_reader(file_path, word_dict, num_examples, phrase, epoch, max_seq_len):
Y
Yibing Liu 已提交
84 85 86 87
    """
    Convert word sequence into slot
    """
    unk_id = len(word_dict)
88
    pad_id = 0
Y
Yibing Liu 已提交
89 90 91 92 93 94 95 96 97 98
    all_data = []
    with io.open(file_path, "r", encoding='utf8') as fin:
        for line in fin:
            if line.startswith('text_a'):
                continue
            cols = line.strip().split("\t")
            if len(cols) != 2:
                sys.stderr.write("[NOTICE] Error Format Line!")
                continue
            label = int(cols[1])
99 100 101 102 103 104 105 106 107 108
            wids = [word_dict[x] if x in word_dict else unk_id
                    for x in cols[0].split(" ")]
            seq_len = len(wids)
            if seq_len < max_seq_len:
                for i in range(max_seq_len - seq_len):
                    wids.append(pad_id)
            else:
                wids = wids[:max_seq_len]
                seq_len = max_seq_len
            all_data.append((wids, label, seq_len))
Y
Yibing Liu 已提交
109 110 111 112 113

    if phrase == "train":
        random.shuffle(all_data)

    num_examples[phrase] = len(all_data)
114
        
Y
Yibing Liu 已提交
115 116 117 118 119
    def reader():
        """
        Reader Function
        """
        for epoch_index in range(epoch):
120 121
            for doc, label, seq_len in all_data:
                yield doc, label, seq_len
Y
Yibing Liu 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    return reader

def load_vocab(file_path):
    """
    load the given vocabulary
    """
    vocab = {}
    with io.open(file_path, 'r', encoding='utf8') as f:
        wid = 0
        for line in f:
            if line.strip() not in vocab:
                vocab[line.strip()] = wid
                wid += 1
    vocab["<unk>"] = len(vocab)
    return vocab


def init_pretraining_params(exe,
                            pretraining_params_path,
                            main_program,
                            use_fp16=False):
    """load params of pretrained model, NOT including moment, learning_rate"""
    assert os.path.exists(pretraining_params_path
                          ), "[%s] cann't be found." % pretraining_params_path

    def _existed_params(var):
        if not isinstance(var, fluid.framework.Parameter):
            return False
        return os.path.exists(os.path.join(pretraining_params_path, var.name))

    fluid.io.load_vars(
        exe,
        pretraining_params_path,
        main_program=main_program,
        predicate=_existed_params)
    print("Load pretraining parameters from {}.".format(
        pretraining_params_path))