train_search.py 8.1 KB
Newer Older
B
Bai Yifan 已提交
1 2 3 4 5
import numpy as np
from itertools import izip
import paddle.fluid as fluid
from paddleslim.teachers.bert.reader.cls import *
from paddleslim.nas.darts.search_space import AdaBERTClassifier
B
baiyfbupt 已提交
6 7 8 9
from paddle.fluid.dygraph.base import to_variable
from tqdm import tqdm
import os
import pickle
B
Bai Yifan 已提交
10 11 12 13 14 15

import logging
from paddleslim.common import AvgrageMeter, get_logger
logger = get_logger(__name__, level=logging.INFO)


B
baiyfbupt 已提交
16 17 18 19
def valid_one_epoch(model, valid_loader, epoch, log_freq):
    accs = AvgrageMeter()
    ce_losses = AvgrageMeter()
    model.student.eval()
B
baiyfbupt 已提交
20

B
baiyfbupt 已提交
21 22 23 24 25 26
    step_id = 0
    for valid_data in valid_loader():
        try:
            loss, acc, ce_loss, _, _ = model._layers.loss(valid_data, epoch)
        except:
            loss, acc, ce_loss, _, _ = model.loss(valid_data, epoch)
B
baiyfbupt 已提交
27

B
baiyfbupt 已提交
28 29 30 31 32
        batch_size = valid_data[0].shape[0]
        ce_losses.update(ce_loss.numpy(), batch_size)
        accs.update(acc.numpy(), batch_size)
        step_id += 1
    return ce_losses.avg[0], accs.avg[0]
B
baiyfbupt 已提交
33 34


B
baiyfbupt 已提交
35 36 37
def train_one_epoch(model, train_loader, valid_loader, optimizer,
                    arch_optimizer, epoch, use_data_parallel, log_freq):
    total_losses = AvgrageMeter()
B
Bai Yifan 已提交
38
    accs = AvgrageMeter()
B
baiyfbupt 已提交
39 40 41
    ce_losses = AvgrageMeter()
    kd_losses = AvgrageMeter()
    val_accs = AvgrageMeter()
B
baiyfbupt 已提交
42
    model.student.train()
B
Bai Yifan 已提交
43 44

    step_id = 0
B
baiyfbupt 已提交
45 46
    for train_data, valid_data in izip(train_loader(), valid_loader()):
        batch_size = train_data[0].shape[0]
B
baiyfbupt 已提交
47 48 49
        # make sure arch on every gpu is same, otherwise an error will occurs
        np.random.seed(step_id * 2 * (epoch + 1))
        if use_data_parallel:
B
baiyfbupt 已提交
50 51
            total_loss, acc, ce_loss, kd_loss, _ = model._layers.loss(
                train_data, epoch)
B
baiyfbupt 已提交
52
        else:
B
baiyfbupt 已提交
53 54
            total_loss, acc, ce_loss, kd_loss, _ = model.loss(train_data,
                                                              epoch)
B
Bai Yifan 已提交
55 56

        if use_data_parallel:
B
baiyfbupt 已提交
57 58
            total_loss = model.scale_loss(total_loss)
            total_loss.backward()
B
Bai Yifan 已提交
59 60
            model.apply_collective_grads()
        else:
B
baiyfbupt 已提交
61 62
            total_loss.backward()
        optimizer.minimize(total_loss)
B
Bai Yifan 已提交
63
        model.clear_gradients()
B
baiyfbupt 已提交
64
        total_losses.update(total_loss.numpy(), batch_size)
B
Bai Yifan 已提交
65
        accs.update(acc.numpy(), batch_size)
B
baiyfbupt 已提交
66 67 68
        ce_losses.update(ce_loss.numpy(), batch_size)
        kd_losses.update(kd_loss.numpy(), batch_size)

B
baiyfbupt 已提交
69 70 71
        # make sure arch on every gpu is same, otherwise an error will occurs
        np.random.seed(step_id * 2 * (epoch + 1) + 1)
        if use_data_parallel:
B
baiyfbupt 已提交
72 73
            arch_loss, _, _, _, arch_logits = model._layers.loss(valid_data,
                                                                 epoch)
B
baiyfbupt 已提交
74
        else:
B
baiyfbupt 已提交
75 76 77 78 79 80 81 82 83
            arch_loss, _, _, _, arch_logits = model.loss(valid_data, epoch)

        if use_data_parallel:
            arch_loss = model.scale_loss(arch_loss)
            arch_loss.backward()
            model.apply_collective_grads()
        else:
            arch_loss.backward()
        arch_optimizer.minimize(arch_loss)
B
baiyfbupt 已提交
84
        model.clear_gradients()
B
baiyfbupt 已提交
85 86 87
        probs = fluid.layers.softmax(arch_logits[-1])
        val_acc = fluid.layers.accuracy(input=probs, label=valid_data[4])
        val_accs.update(val_acc.numpy(), batch_size)
B
Bai Yifan 已提交
88 89 90

        if step_id % log_freq == 0:
            logger.info(
B
baiyfbupt 已提交
91
                "Train Epoch {}, Step {}, Lr {:.6f} total_loss {:.6f}; ce_loss {:.6f}, kd_loss {:.6f}, train_acc {:.6f}, search_valid_acc {:.6f};".
B
Bai Yifan 已提交
92
                format(epoch, step_id,
B
baiyfbupt 已提交
93 94 95
                       optimizer.current_step_lr(), total_losses.avg[
                           0], ce_losses.avg[0], kd_losses.avg[0], accs.avg[0],
                       val_accs.avg[0]))
B
Bai Yifan 已提交
96 97 98 99 100

        step_id += 1


def main():
B
baiyfbupt 已提交
101
    # whether use multi-gpus
B
baiyfbupt 已提交
102
    use_data_parallel = False
B
Bai Yifan 已提交
103 104 105 106 107 108
    place = fluid.CUDAPlace(fluid.dygraph.parallel.Env(
    ).dev_id) if use_data_parallel else fluid.CUDAPlace(0)

    BERT_BASE_PATH = "./data/pretrained_models/uncased_L-12_H-768_A-12"
    vocab_path = BERT_BASE_PATH + "/vocab.txt"
    data_dir = "./data/glue_data/MNLI/"
B
baiyfbupt 已提交
109
    teacher_model_dir = "./data/teacher_model/steps_23000"
B
Bai Yifan 已提交
110
    do_lower_case = True
B
baiyfbupt 已提交
111 112 113
    num_samples = 392702
    # augmented dataset nums
    # num_samples = 8016987
B
baiyfbupt 已提交
114
    max_seq_len = 128
B
baiyfbupt 已提交
115
    batch_size = 128
B
Bai Yifan 已提交
116 117 118 119 120
    hidden_size = 768
    emb_size = 768
    max_layer = 8
    epoch = 80
    log_freq = 10
B
baiyfbupt 已提交
121
    device_num = fluid.dygraph.parallel.Env().nranks
B
baiyfbupt 已提交
122 123 124 125 126

    use_fixed_gumbel = False
    train_phase = "search_train"
    val_phase = "search_valid"
    step_per_epoch = int(num_samples * 0.5 / ((batch_size) * device_num))
B
Bai Yifan 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139

    with fluid.dygraph.guard(place):
        model = AdaBERTClassifier(
            3,
            n_layer=max_layer,
            hidden_size=hidden_size,
            emb_size=emb_size,
            teacher_model=teacher_model_dir,
            data_dir=data_dir,
            use_fixed_gumbel=use_fixed_gumbel)

        learning_rate = fluid.dygraph.CosineDecay(2e-2, step_per_epoch, epoch)

B
baiyfbupt 已提交
140 141 142 143 144 145
        model_parameters = []
        for p in model.parameters():
            if (p.name not in [a.name for a in model.arch_parameters()] and
                    p.name not in
                [a.name for a in model.teacher.parameters()]):
                model_parameters.append(p)
B
Bai Yifan 已提交
146 147 148 149 150

        optimizer = fluid.optimizer.MomentumOptimizer(
            learning_rate,
            0.9,
            regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
B
baiyfbupt 已提交
151 152 153 154 155 156 157 158 159
            parameter_list=model_parameters)

        arch_optimizer = fluid.optimizer.Adam(
            3e-4,
            0.5,
            0.999,
            regularization=fluid.regularizer.L2Decay(1e-3),
            parameter_list=model.arch_parameters())

B
baiyfbupt 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
        processor = MnliProcessor(
            data_dir=data_dir,
            vocab_path=vocab_path,
            max_seq_len=max_seq_len,
            do_lower_case=do_lower_case,
            in_tokens=False)

        train_reader = processor.data_generator(
            batch_size=batch_size,
            phase=train_phase,
            epoch=1,
            dev_count=1,
            shuffle=True)
        valid_reader = processor.data_generator(
            batch_size=batch_size,
            phase=val_phase,
            epoch=1,
            dev_count=1,
            shuffle=True)
        dev_reader = processor.data_generator(
            batch_size=batch_size,
            phase="dev",
            epoch=1,
            dev_count=1,
            shuffle=False)
B
baiyfbupt 已提交
185 186 187 188 189 190 191

        if use_data_parallel:
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)
            valid_reader = fluid.contrib.reader.distributed_batch_reader(
                valid_reader)

B
Bai Yifan 已提交
192
        train_loader = fluid.io.DataLoader.from_generator(
B
baiyfbupt 已提交
193
            capacity=128,
B
Bai Yifan 已提交
194 195
            use_double_buffer=True,
            iterable=True,
B
baiyfbupt 已提交
196
            return_list=True)
B
Bai Yifan 已提交
197
        valid_loader = fluid.io.DataLoader.from_generator(
B
baiyfbupt 已提交
198
            capacity=128,
B
baiyfbupt 已提交
199 200
            use_double_buffer=True,
            iterable=True,
B
baiyfbupt 已提交
201
            return_list=True)
B
baiyfbupt 已提交
202
        dev_loader = fluid.io.DataLoader.from_generator(
B
baiyfbupt 已提交
203
            capacity=128,
B
Bai Yifan 已提交
204 205
            use_double_buffer=True,
            iterable=True,
B
baiyfbupt 已提交
206
            return_list=True)
B
baiyfbupt 已提交
207

B
Bai Yifan 已提交
208
        train_loader.set_batch_generator(train_reader, places=place)
B
baiyfbupt 已提交
209 210
        valid_loader.set_batch_generator(valid_reader, places=place)
        dev_loader.set_batch_generator(dev_reader, places=place)
B
Bai Yifan 已提交
211

B
baiyfbupt 已提交
212 213 214
        if use_data_parallel:
            strategy = fluid.dygraph.parallel.prepare_context()
            model = fluid.dygraph.parallel.DataParallel(model, strategy)
B
Bai Yifan 已提交
215 216

        for epoch_id in range(epoch):
B
baiyfbupt 已提交
217 218 219 220
            train_one_epoch(model, train_loader, valid_loader, optimizer,
                            arch_optimizer, epoch_id, use_data_parallel,
                            log_freq)
            loss, acc = valid_one_epoch(model, dev_loader, epoch_id, log_freq)
B
baiyfbupt 已提交
221 222
            logger.info("dev set, ce_loss {:.6f}; acc: {:.6f};".format(loss,
                                                                       acc))
B
baiyfbupt 已提交
223

B
baiyfbupt 已提交
224
            if use_data_parallel:
B
baiyfbupt 已提交
225
                print(model._layers.student._encoder.alphas.numpy())
B
baiyfbupt 已提交
226 227
            else:
                print(model.student._encoder.alphas.numpy())
B
Bai Yifan 已提交
228 229 230 231 232
            print("=" * 100)


if __name__ == '__main__':
    main()