finetune.py 11.6 KB
Newer Older
W
wuzewu 已提交
1
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
Z
Zeyu Chen 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Z
Zeyu Chen 已提交
19 20 21
import os
import time

W
wuzewu 已提交
22 23
import paddle
import paddle.fluid as fluid
Z
Zeyu Chen 已提交
24
import paddlehub as hub
25
import numpy as np
Z
Zeyu Chen 已提交
26

W
wuzewu 已提交
27 28 29
from paddlehub.common.logger import logger
from paddlehub.finetune.strategy import BERTFinetuneStrategy, DefaultStrategy
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
Z
Zeyu Chen 已提交
30
from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_labeling_task
W
wuzewu 已提交
31
from visualdl import LogWriter
32 33


34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
def _do_memory_optimization(task, config):
    if config.enable_memory_optim:
        logger.info("Memory optimization start...")
        task_var_name = task.metric_variable_names()
        logger.info(
            "Skip memory optimization on variables: {}".format(task_var_name))
        optimize_time_begin = time.time()
        fluid.memory_optimize(
            input_program=fluid.default_main_program(),
            # skip memory optimization on task metric variables
            skip_opt_set=task_var_name)
        time_used = time.time() - optimize_time_begin
        logger.info("Memory optimization done! Time elapsed %f sec" % time_used)

    lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
        program=fluid.default_main_program(), batch_size=config.batch_size)
50
    logger.info("Theoretical memory usage in training: %.2f - %.2f %s" %
51 52 53
                (lower_mem, upper_mem, unit)),


54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
def _finetune_seq_label_task(task,
                             data_reader,
                             feed_list,
                             config=None,
                             do_eval=False):
    """
    Finetune sequence labeling task, evaluate metric is F1, precision and recall

    """
    main_program = task.main_program()
    startup_program = task.startup_program()
    loss = task.variable("loss")
    seq_len = task.variable("seq_len")

    num_epoch = config.num_epoch
    batch_size = config.batch_size

Z
Zeyu Chen 已提交
71
    place, dev_count = hub.common.get_running_device_info(config)
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

        # Select strategy
        if isinstance(config.strategy, hub.BERTFinetuneStrategy):
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
        elif isinstance(config.strategy, hub.DefaultStrategy):
            config.strategy.execute(loss)
        #TODO: add more finetune strategy

        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)

        train_time_used = 0
        logger.info("PaddleHub finetune start")

        # Finetune loop
        for epoch in range(current_epoch, num_epoch + 1):
            train_reader = data_reader.data_generator(
                batch_size=batch_size, phase='train')
            num_trained_examples = loss_sum = 0
            for batch in train_reader():
                num_batch_examples = len(batch)
                train_time_begin = time.time()
                loss_v = exe.run(
                    feed=data_feeder.feed(batch), fetch_list=[loss.name])
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                loss_sum += loss_v[0] * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    speed = config.log_interval / train_time_used
                    logger.info("step %d: loss=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, speed))

                    train_time_used = 0
                    num_trained_examples = loss_sum = 0

                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore correct dataset training status
                    save_checkpoint(
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)

                if do_eval and global_step % config.eval_interval == 0:
127
                    evaluate_seq_label_task(
128 129 130
                        task,
                        data_reader,
                        feed_list,
131
                        phase="test",
132
                        config=config)
133
                    evaluate_seq_label_task(
134 135 136
                        task,
                        data_reader,
                        feed_list,
137
                        phase="dev",
138 139 140 141 142 143 144 145 146 147 148
                        config=config)

        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
        save_checkpoint(
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)

        if do_eval:
149 150 151
            evaluate_seq_label_task(
                task, data_reader, feed_list, phase="dev", config=config)
            evaluate_seq_label_task(
152 153 154 155 156 157
                task, data_reader, feed_list, phase="test", config=config)
        logger.info("PaddleHub finetune finished.")


def _finetune_cls_task(task, data_reader, feed_list, config=None,
                       do_eval=False):
W
wuzewu 已提交
158
    main_program = task.main_program()
Z
Zeyu Chen 已提交
159
    startup_program = task.startup_program()
Z
Zeyu Chen 已提交
160 161
    loss = task.variable("loss")
    accuracy = task.variable("accuracy")
W
wuzewu 已提交
162

163
    num_epoch = config.num_epoch
W
wuzewu 已提交
164
    batch_size = config.batch_size
W
wuzewu 已提交
165
    log_writter = LogWriter(
166
        os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=10)
W
wuzewu 已提交
167

Z
Zeyu Chen 已提交
168
    place, dev_count = hub.common.get_running_device_info(config)
W
wuzewu 已提交
169 170
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
Z
Zeyu Chen 已提交
171 172
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

Z
Zeyu Chen 已提交
173 174 175 176
        # select strategy
        if isinstance(config.strategy, hub.BERTFinetuneStrategy):
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
W
wuzewu 已提交
177
        elif isinstance(config.strategy, hub.DefaultStrategy):
Z
Zeyu Chen 已提交
178
            config.strategy.execute(loss)
Z
Zeyu Chen 已提交
179
        #TODO: add more finetune strategy
W
wuzewu 已提交
180

181 182 183 184
        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)
185 186 187 188

        best_eval_acc = 0.0
        train_time_used = 0
        logger.info("PaddleHub finetune start")
W
wuzewu 已提交
189 190 191 192 193 194 195 196 197

        # add visualdl scalar
        with log_writter.mode("train") as logw:
            train_loss_scalar = logw.scalar(tag="loss[train]")
            train_acc_scalar = logw.scalar(tag="accuracy[train]")
        with log_writter.mode("evaluate") as logw:
            eval_loss_scalar = logw.scalar(tag="loss[evaluate]")
            eval_acc_scalar = logw.scalar(tag="accuracy[evaluate]")

198
        # Finetune loop
199
        for epoch in range(current_epoch, num_epoch + 1):
200
            train_reader = data_reader.data_generator(
Z
Zeyu Chen 已提交
201
                batch_size=batch_size, phase='train')
202
            num_trained_examples = acc_sum = loss_sum = 0
W
wuzewu 已提交
203
            for batch in train_reader():
204 205
                num_batch_examples = len(batch)
                train_time_begin = time.time()
W
wuzewu 已提交
206
                loss_v, accuracy_v = exe.run(
W
wuzewu 已提交
207
                    feed=data_feeder.feed(batch),
W
wuzewu 已提交
208
                    fetch_list=[loss.name, accuracy.name])
209 210 211 212 213 214 215 216 217 218
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                acc_sum += accuracy_v * num_batch_examples
                loss_sum += loss_v * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    avg_acc = acc_sum / num_trained_examples
Z
Zeyu Chen 已提交
219
                    speed = config.log_interval / train_time_used
220 221
                    logger.info("step %d: loss=%.5f acc=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, avg_acc, speed))
W
wuzewu 已提交
222 223

                    # record visualdl log
224 225 226 227 228 229
                    train_loss_scalar.add_record(global_step, avg_loss)
                    train_acc_scalar.add_record(global_step, avg_acc)

                    train_time_used = 0
                    num_trained_examples = acc_sum = loss_sum = 0

W
wuzewu 已提交
230
                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
231 232
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore dataset training status
W
wuzewu 已提交
233
                    save_checkpoint(
234 235 236 237
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)
W
wuzewu 已提交
238

239
                if do_eval and global_step % config.eval_interval == 0:
240
                    eval_loss, eval_acc, eval_perf = evaluate_cls_task(
W
wuzewu 已提交
241
                        task,
242
                        data_reader,
W
wuzewu 已提交
243
                        feed_list,
244
                        phase="val",
W
wuzewu 已提交
245
                        config=config)
246 247
                    eval_loss_scalar.add_record(global_step, eval_loss)
                    eval_acc_scalar.add_record(global_step, eval_acc)
W
wuzewu 已提交
248 249
                    if eval_acc > best_eval_acc:
                        best_eval_acc = eval_acc
250
                        model_saved_dir = os.path.join(config.checkpoint_dir,
251 252 253 254 255
                                                       "best_model")
                        logger.info(
                            "best model saved to %s [best accuracy=%.5f]" %
                            (model_saved_dir, best_eval_acc))
                        fluid.io.save_persistables(exe, dirname=model_saved_dir)
W
wuzewu 已提交
256

257 258
        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
W
wuzewu 已提交
259
        save_checkpoint(
260 261 262 263
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)
264 265

        if do_eval:
266 267
            evaluate_cls_task(
                task, data_reader, feed_list, phase="test", config=config)
268
        logger.info("PaddleHub finetune finished.")
W
wuzewu 已提交
269 270


271
def finetune_and_eval(task, data_reader, feed_list, config=None):
272 273 274
    if task.task_type == "sequence_labeling":
        _finetune_seq_label_task(
            task, data_reader, feed_list, config, do_eval=True)
275
    # if it's image_classification and text classificaiton
276 277
    else:
        _finetune_cls_task(task, data_reader, feed_list, config, do_eval=True)
W
wuzewu 已提交
278 279


280
def finetune(task, data_reader, feed_list, config=None):
281
    _finetune_cls_task(task, data_reader, feed_list, config, do_eval=False)