run.py 11.0 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10
import os
import sys
import argparse
import functools
from functools import partial
import numpy as np
import paddle
import paddle.nn as nn
from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.metric import Metric, Accuracy, Precision, Recall
C
Chang Xu 已提交
11 12
from paddlenlp.transformers import AutoModelForTokenClassification, AutoTokenizer

C
ceci3 已提交
13 14 15 16 17 18
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.metrics import Mcc, PearsonAndSpearman
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression.compressor import AutoCompression
C
Chang Xu 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40


def argsparser():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--config_path',
        type=str,
        default=None,
        help="path of compression strategy config.",
        required=True)
    parser.add_argument(
        '--save_dir',
        type=str,
        default='output',
        help="directory to save compressed model.")
    parser.add_argument(
        '--eval',
        type=bool,
        default=False,
        help="whether validate the model only.")
    return parser

C
ceci3 已提交
41 42 43 44 45 46 47 48

METRIC_CLASSES = {
    "cola": Mcc,
    "sst-2": Accuracy,
    "sts-b": PearsonAndSpearman,
    "mnli": Accuracy,
    "qnli": Accuracy,
    "rte": Accuracy,
C
Chang Xu 已提交
49 50 51 52 53 54 55
    "afqmc": Accuracy,
    "tnews": Accuracy,
    "iflytek": Accuracy,
    "ocnli": Accuracy,
    "cmnli": Accuracy,
    "cluewsc2020": Accuracy,
    "csl": Accuracy,
C
ceci3 已提交
56 57 58 59 60 61 62 63
}


def convert_example(example,
                    tokenizer,
                    label_list,
                    max_seq_length=512,
                    is_test=False):
C
Chang Xu 已提交
64
    assert global_config['dataset'] in [
C
ceci3 已提交
65 66
        'glue', 'clue'
    ], "This demo only supports for dataset glue or clue"
C
Chang Xu 已提交
67
    """Convert a glue example into necessary features."""
C
Chang Xu 已提交
68
    if global_config['dataset'] == 'glue':
C
Chang Xu 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81
        if not is_test:
            # `label_list == None` is for regression task
            label_dtype = "int64" if label_list else "float32"
            # Get the label
            label = example['labels']
            label = np.array([label], dtype=label_dtype)
        # Convert raw text to feature
        example = tokenizer(example['sentence'], max_seq_len=max_seq_length)

        if not is_test:
            return example['input_ids'], example['token_type_ids'], label
        else:
            return example['input_ids'], example['token_type_ids']
C
ceci3 已提交
82

C
Chang Xu 已提交
83
    else:  #if global_config['dataset'] == 'clue':
C
Chang Xu 已提交
84 85 86 87
        if not is_test:
            # `label_list == None` is for regression task
            label_dtype = "int64" if label_list else "float32"
            # Get the label
C
ceci3 已提交
88 89
            example['label'] = np.array(
                example["label"], dtype="int64").reshape((-1, 1))
C
Chang Xu 已提交
90 91 92 93 94 95 96 97 98 99
            label = example['label']
        # Convert raw text to feature
        if 'keyword' in example:  # CSL
            sentence1 = " ".join(example['keyword'])
            example = {
                'sentence1': sentence1,
                'sentence2': example['abst'],
                'label': example['label']
            }
        elif 'target' in example:  # wsc
C
ceci3 已提交
100 101 102 103
            text, query, pronoun, query_idx, pronoun_idx = example[
                'text'], example['target']['span1_text'], example['target'][
                    'span2_text'], example['target']['span1_index'], example[
                        'target']['span2_index']
C
Chang Xu 已提交
104
            text_list = list(text)
C
ceci3 已提交
105 106
            assert text[pronoun_idx:(pronoun_idx + len(
                pronoun))] == pronoun, "pronoun: {}".format(pronoun)
C
Chang Xu 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
            assert text[query_idx:(query_idx + len(query)
                                   )] == query, "query: {}".format(query)
            if pronoun_idx > query_idx:
                text_list.insert(query_idx, "_")
                text_list.insert(query_idx + len(query) + 1, "_")
                text_list.insert(pronoun_idx + 2, "[")
                text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
            else:
                text_list.insert(pronoun_idx, "[")
                text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
                text_list.insert(query_idx + 2, "_")
                text_list.insert(query_idx + len(query) + 2 + 1, "_")
            text = "".join(text_list)
            example['sentence'] = text
        if tokenizer is None:
            return example
        if 'sentence' in example:
            example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
        elif 'sentence1' in example:
            example = tokenizer(
                example['sentence1'],
                text_pair=example['sentence2'],
                max_seq_len=max_seq_length)
        if not is_test:
            return example['input_ids'], example['token_type_ids'], label
        else:
            return example['input_ids'], example['token_type_ids']

C
ceci3 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

def create_data_holder(task_name):
    """
    Define the input data holder for the glue task.
    """
    input_ids = paddle.static.data(
        name="input_ids", shape=[-1, -1], dtype="int64")
    token_type_ids = paddle.static.data(
        name="token_type_ids", shape=[-1, -1], dtype="int64")
    if task_name == "sts-b":
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="float32")
    else:
        label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")

    return [input_ids, token_type_ids, label]


def reader():
    # Create the tokenizer and dataset
C
Chang Xu 已提交
154 155 156

    tokenizer = AutoTokenizer.from_pretrained(global_config['model_dir'])

C
Chang Xu 已提交
157
    train_ds, dev_ds = load_dataset(
C
Chang Xu 已提交
158 159 160
        global_config['dataset'],
        global_config['task_name'],
        splits=('train', 'dev'))
C
ceci3 已提交
161 162 163 164 165

    trans_func = partial(
        convert_example,
        tokenizer=tokenizer,
        label_list=train_ds.label_list,
C
Chang Xu 已提交
166
        max_seq_length=global_config['max_seq_length'],
C
ceci3 已提交
167 168 169 170 171 172 173 174 175 176
        is_test=True)

    train_ds = train_ds.map(trans_func, lazy=True)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # token_type 
    ): fn(samples)

    train_batch_sampler = paddle.io.BatchSampler(
C
Chang Xu 已提交
177
        train_ds, batch_size=global_config['batch_size'], shuffle=True)
C
ceci3 已提交
178

C
Chang Xu 已提交
179 180
    [input_ids, token_type_ids, labels] = create_data_holder(global_config[
        'task_name'])
C
ceci3 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193
    feed_list_name = []
    train_data_loader = DataLoader(
        dataset=train_ds,
        feed_list=[input_ids, token_type_ids],
        batch_sampler=train_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=False)

    dev_trans_func = partial(
        convert_example,
        tokenizer=tokenizer,
        label_list=train_ds.label_list,
C
Chang Xu 已提交
194
        max_seq_length=global_config['max_seq_length'])
C
ceci3 已提交
195 196 197 198 199 200 201
    dev_batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # token_type 
        Stack(dtype="int64" if train_ds.label_list else "float32")  # label
    ): fn(samples)
    dev_ds = dev_ds.map(dev_trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(
C
Chang Xu 已提交
202
        dev_ds, batch_size=global_config['batch_size'], shuffle=False)
C
ceci3 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
    dev_data_loader = DataLoader(
        dataset=dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=dev_batchify_fn,
        num_workers=0,
        feed_list=[input_ids, token_type_ids, labels],
        return_list=False)

    return train_data_loader, dev_data_loader


def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
    metric.reset()
    for data in eval_dataloader():
        logits = exe.run(compiled_test_program,
                         feed={
                             test_feed_names[0]: data[0]['input_ids'],
                             test_feed_names[1]: data[0]['token_type_ids']
                         },
                         fetch_list=test_fetch_list)
        paddle.disable_static()
C
Chang Xu 已提交
224
        labels_pd = paddle.to_tensor(np.array(data[0]['label']).flatten())
C
ceci3 已提交
225 226 227 228 229 230 231 232
        logits_pd = paddle.to_tensor(logits[0])
        correct = metric.compute(logits_pd, labels_pd)
        metric.update(correct)
        paddle.enable_static()
    res = metric.accumulate()
    return res


233 234 235 236 237
def eval():
    devices = paddle.device.get_device().split(':')[0]
    places = paddle.device._convert_to_place(devices)
    exe = paddle.static.Executor(places)
    val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
C
Chang Xu 已提交
238
        global_config['model_dir'],
239
        exe,
C
Chang Xu 已提交
240 241 242
        model_filename=global_config['model_filename'],
        params_filename=global_config['params_filename'])
    print('Loaded model from: {}'.format(global_config['model_dir']))
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
    metric.reset()
    print('Evaluating...')
    for data in eval_dataloader():
        logits = exe.run(val_program,
                         feed={
                             feed_target_names[0]: data[0]['input_ids'],
                             feed_target_names[1]: data[0]['token_type_ids']
                         },
                         fetch_list=fetch_targets)
        paddle.disable_static()
        labels_pd = paddle.to_tensor(np.array(data[0]['label']).flatten())
        logits_pd = paddle.to_tensor(logits[0])
        correct = metric.compute(logits_pd, labels_pd)
        metric.update(correct)
        paddle.enable_static()
    res = metric.accumulate()
    return res


C
ceci3 已提交
262 263 264
def apply_decay_param_fun(name):
    if name.find("bias") > -1:
        return True
C
ceci3 已提交
265 266
    elif name.find("b_0") > -1:
        return True
C
ceci3 已提交
267 268 269 270 271 272
    elif name.find("norm") > -1:
        return True
    else:
        return False


C
Chang Xu 已提交
273 274
def main():

275
    all_config = load_config(args.config_path)
C
ceci3 已提交
276

C
Chang Xu 已提交
277 278 279 280 281 282
    global global_config
    assert "Global" in all_config, "Key Global not found in config file."
    global_config = all_config["Global"]

    if 'TrainConfig' in all_config:
        all_config['TrainConfig']['optimizer_builder'][
C
ceci3 已提交
283 284
            'apply_decay_param_fun'] = apply_decay_param_fun

C
Chang Xu 已提交
285
    global train_dataloader, eval_dataloader
C
ceci3 已提交
286
    train_dataloader, eval_dataloader = reader()
C
Chang Xu 已提交
287 288 289

    global metric
    metric_class = METRIC_CLASSES[global_config['task_name']]
C
ceci3 已提交
290 291
    metric = metric_class()

292 293 294 295 296
    if args.eval:
        result = eval()
        print('Eval metric:', result)
        sys.exit(0)

C
ceci3 已提交
297
    ac = AutoCompression(
C
Chang Xu 已提交
298 299 300
        model_dir=global_config['model_dir'],
        model_filename=global_config['model_filename'],
        params_filename=global_config['params_filename'],
C
ceci3 已提交
301
        save_dir=args.save_dir,
302
        config=all_config,
C
ceci3 已提交
303
        train_dataloader=train_dataloader,
304 305
        eval_callback=eval_function
        if 'HyperParameterOptimization' not in all_config else eval_dataloader,
C
ceci3 已提交
306
        eval_dataloader=eval_dataloader)
C
ceci3 已提交
307 308

    ac.compress()
309 310 311 312 313
    for file_name in os.listdir(global_config['model_dir']):
        if 'json' in file_name or 'txt' in file_name:
            shutil.copy(
                os.path.join(global_config['model_dir'], file_name),
                args.save_dir)
C
Chang Xu 已提交
314 315 316 317 318 319 320


if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    args = parser.parse_args()
    main()