finetune.py 13.2 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
Z
Zeyu Chen 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16 17 18 19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Z
Zeyu Chen 已提交
20 21 22
import os
import time

W
wuzewu 已提交
23 24
import paddle
import paddle.fluid as fluid
25
import numpy as np
26
from visualdl import LogWriter
Z
Zeyu Chen 已提交
27

W
wuzewu 已提交
28
from paddlehub.common.logger import logger
29 30
from paddlehub.common.utils import mkdir
from paddlehub.finetune.config import RunConfig
31
from paddlehub.finetune.strategy import AdamWeightDecayStrategy, DefaultStrategy
W
wuzewu 已提交
32
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
Z
Zeyu Chen 已提交
33
from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_label_task
34
import paddlehub as hub
35 36


37 38 39 40 41 42 43 44 45 46 47 48 49 50
def _do_memory_optimization(task, config):
    if config.enable_memory_optim:
        logger.info("Memory optimization start...")
        task_var_name = task.metric_variable_names()
        logger.info(
            "Skip memory optimization on variables: {}".format(task_var_name))
        optimize_time_begin = time.time()
        fluid.memory_optimize(
            input_program=fluid.default_main_program(),
            # skip memory optimization on task metric variables
            skip_opt_set=task_var_name)
        time_used = time.time() - optimize_time_begin
        logger.info("Memory optimization done! Time elapsed %f sec" % time_used)

51 52 53 54
    # lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
    #     program=task.main_program(), batch_size=config.batch_size)
    # logger.info("Theoretical memory usage in training: %.2f - %.2f %s" %
    #             (lower_mem, upper_mem, unit)),
55 56


57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
def _finetune_seq_label_task(task,
                             data_reader,
                             feed_list,
                             config=None,
                             do_eval=False):
    """
    Finetune sequence labeling task, evaluate metric is F1, precision and recall

    """
    main_program = task.main_program()
    startup_program = task.startup_program()
    loss = task.variable("loss")
    seq_len = task.variable("seq_len")

    num_epoch = config.num_epoch
    batch_size = config.batch_size
Z
Zeyu Chen 已提交
73 74
    log_writer = LogWriter(
        os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=1)
75

Z
Zeyu Chen 已提交
76
    place, dev_count = hub.common.get_running_device_info(config)
77 78 79 80 81
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

        # Select strategy
82
        if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
83 84 85 86 87 88 89 90 91 92 93
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
        elif isinstance(config.strategy, hub.DefaultStrategy):
            config.strategy.execute(loss)
        #TODO: add more finetune strategy

        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)

Z
Zeyu Chen 已提交
94
        best_eval_f1 = 0.0
95 96 97
        train_time_used = 0
        logger.info("PaddleHub finetune start")

W
wuzewu 已提交
98 99
        exe.run(fluid.default_startup_program())

Z
Zeyu Chen 已提交
100 101 102 103 104 105 106 107
        # add visualdl scalar
        with log_writer.mode("train") as logw:
            train_loss_scalar = logw.scalar(tag="Loss [train]")
        with log_writer.mode("evaluate") as logw:
            eval_f1_scalar = logw.scalar(tag="F1 [eval]")
            eval_precision_scalar = logw.scalar(tag="Precision [eval]")
            eval_recall_scalar = logw.scalar(tag="Recall [eval]")

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        # Finetune loop
        for epoch in range(current_epoch, num_epoch + 1):
            train_reader = data_reader.data_generator(
                batch_size=batch_size, phase='train')
            num_trained_examples = loss_sum = 0
            for batch in train_reader():
                num_batch_examples = len(batch)
                train_time_begin = time.time()
                loss_v = exe.run(
                    feed=data_feeder.feed(batch), fetch_list=[loss.name])
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                loss_sum += loss_v[0] * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    speed = config.log_interval / train_time_used
                    logger.info("step %d: loss=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, speed))
Z
Zeyu Chen 已提交
129
                    train_loss_scalar.add_record(global_step, avg_loss)
130 131

                    train_time_used = 0
Z
Zeyu Chen 已提交
132 133
                    num_trained_examples = 0
                    loss_sum = 0
134 135 136 137 138 139 140 141 142 143 144

                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore correct dataset training status
                    save_checkpoint(
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)

                if do_eval and global_step % config.eval_interval == 0:
Z
Zeyu Chen 已提交
145
                    f1, precision, recall = evaluate_seq_label_task(
146 147 148
                        task,
                        data_reader,
                        feed_list,
149
                        phase="dev",
150
                        config=config)
Z
Zeyu Chen 已提交
151 152 153 154 155 156 157 158 159 160
                    eval_f1_scalar.add_record(global_step, f1)
                    eval_precision_scalar.add_record(global_step, precision)
                    eval_recall_scalar.add_record(global_step, recall)
                    if f1 > best_eval_f1:
                        best_eval_f1 = f1
                        model_saved_dir = os.path.join(config.checkpoint_dir,
                                                       "best_model")
                        logger.info("best model saved to %s [best F1=%.5f]" %
                                    (model_saved_dir, best_eval_f1))
                        fluid.io.save_persistables(exe, dirname=model_saved_dir)
161 162 163 164 165 166 167 168 169

        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
        save_checkpoint(
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)

Z
Zeyu Chen 已提交
170
        # Final evaluation
171
        if do_eval:
Z
Zeyu Chen 已提交
172
            evaluate_seq_label_task(
173
                task, data_reader, feed_list, phase="dev", config=config)
Z
Zeyu Chen 已提交
174
            evaluate_seq_label_task(
175 176 177 178 179 180
                task, data_reader, feed_list, phase="test", config=config)
        logger.info("PaddleHub finetune finished.")


def _finetune_cls_task(task, data_reader, feed_list, config=None,
                       do_eval=False):
W
wuzewu 已提交
181
    main_program = task.main_program()
Z
Zeyu Chen 已提交
182
    startup_program = task.startup_program()
Z
Zeyu Chen 已提交
183 184
    loss = task.variable("loss")
    accuracy = task.variable("accuracy")
W
wuzewu 已提交
185

186
    num_epoch = config.num_epoch
W
wuzewu 已提交
187
    batch_size = config.batch_size
Z
Zeyu Chen 已提交
188 189
    log_writer = LogWriter(
        os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=1)
W
wuzewu 已提交
190

Z
Zeyu Chen 已提交
191
    place, dev_count = hub.common.get_running_device_info(config)
W
wuzewu 已提交
192 193
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
Z
Zeyu Chen 已提交
194 195
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

Z
Zeyu Chen 已提交
196
        # select strategy
197
        if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
Z
Zeyu Chen 已提交
198 199
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
W
wuzewu 已提交
200
        elif isinstance(config.strategy, hub.DefaultStrategy):
Z
Zeyu Chen 已提交
201
            config.strategy.execute(loss)
Z
Zeyu Chen 已提交
202
        #TODO: add more finetune strategy
W
wuzewu 已提交
203

204 205 206 207
        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)
208 209 210 211

        best_eval_acc = 0.0
        train_time_used = 0
        logger.info("PaddleHub finetune start")
W
wuzewu 已提交
212 213

        # add visualdl scalar
Z
Zeyu Chen 已提交
214
        with log_writer.mode("train") as logw:
Z
Zeyu Chen 已提交
215 216
            train_loss_scalar = logw.scalar(tag="Loss [train]")
            train_acc_scalar = logw.scalar(tag="Accuracy [train]")
Z
Zeyu Chen 已提交
217
        with log_writer.mode("evaluate") as logw:
Z
Zeyu Chen 已提交
218 219
            eval_loss_scalar = logw.scalar(tag="Loss [eval]")
            eval_acc_scalar = logw.scalar(tag="Accuracy [eval]")
W
wuzewu 已提交
220

W
wuzewu 已提交
221 222
        exe.run(fluid.default_startup_program())

223
        # Finetune loop
224
        for epoch in range(current_epoch, num_epoch + 1):
225
            train_reader = data_reader.data_generator(
Z
Zeyu Chen 已提交
226
                batch_size=batch_size, phase='train')
227
            num_trained_examples = acc_sum = loss_sum = 0
W
wuzewu 已提交
228
            for batch in train_reader():
229 230
                num_batch_examples = len(batch)
                train_time_begin = time.time()
W
wuzewu 已提交
231
                loss_v, accuracy_v = exe.run(
W
wuzewu 已提交
232
                    feed=data_feeder.feed(batch),
S
Steffy-zxf 已提交
233 234 235 236
                    fetch_list=[loss.name, accuracy.name],
                    return_numpy=False)
                loss_v = np.array(loss_v)
                accuracy_v = np.array(accuracy_v)
237 238 239 240 241 242 243 244 245 246
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                acc_sum += accuracy_v * num_batch_examples
                loss_sum += loss_v * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    avg_acc = acc_sum / num_trained_examples
Z
Zeyu Chen 已提交
247
                    speed = config.log_interval / train_time_used
248 249
                    logger.info("step %d: loss=%.5f acc=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, avg_acc, speed))
W
wuzewu 已提交
250 251

                    # record visualdl log
252 253 254 255 256 257
                    train_loss_scalar.add_record(global_step, avg_loss)
                    train_acc_scalar.add_record(global_step, avg_acc)

                    train_time_used = 0
                    num_trained_examples = acc_sum = loss_sum = 0

W
wuzewu 已提交
258
                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
259 260
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore dataset training status
W
wuzewu 已提交
261
                    save_checkpoint(
262 263 264 265
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)
W
wuzewu 已提交
266

267
                if do_eval and global_step % config.eval_interval == 0:
268
                    eval_loss, eval_acc, eval_perf = evaluate_cls_task(
W
wuzewu 已提交
269
                        task,
270
                        data_reader,
W
wuzewu 已提交
271
                        feed_list,
272
                        phase="val",
W
wuzewu 已提交
273
                        config=config)
274 275
                    eval_loss_scalar.add_record(global_step, eval_loss)
                    eval_acc_scalar.add_record(global_step, eval_acc)
W
wuzewu 已提交
276 277
                    if eval_acc > best_eval_acc:
                        best_eval_acc = eval_acc
278
                        model_saved_dir = os.path.join(config.checkpoint_dir,
279 280 281 282 283
                                                       "best_model")
                        logger.info(
                            "best model saved to %s [best accuracy=%.5f]" %
                            (model_saved_dir, best_eval_acc))
                        fluid.io.save_persistables(exe, dirname=model_saved_dir)
W
wuzewu 已提交
284

285 286
        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
W
wuzewu 已提交
287
        save_checkpoint(
288 289 290 291
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)
292

Z
Zeyu Chen 已提交
293
        # Final evaluation
294
        if do_eval:
Z
Zeyu Chen 已提交
295 296
            evaluate_cls_task(
                task, data_reader, feed_list, phase="dev", config=config)
297 298
            evaluate_cls_task(
                task, data_reader, feed_list, phase="test", config=config)
299
        logger.info("PaddleHub finetune finished.")
W
wuzewu 已提交
300 301


302
def finetune_and_eval(task, data_reader, feed_list, config=None):
303 304 305 306 307 308
    if config is None:
        config = RunConfig()

    if not os.path.exists(config.checkpoint_dir):
        mkdir(config.checkpoint_dir)

309 310 311
    if task.task_type == "sequence_labeling":
        _finetune_seq_label_task(
            task, data_reader, feed_list, config, do_eval=True)
Z
Zeyu Chen 已提交
312
    elif task.task_type == "image_classification" or task.task_type == "text_classification":
313
        _finetune_cls_task(task, data_reader, feed_list, config, do_eval=True)