finetune.py 13.2 KB
Newer Older
W
wuzewu 已提交
1
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
Z
Zeyu Chen 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Z
Zeyu Chen 已提交
19 20 21
import os
import time

W
wuzewu 已提交
22 23
import paddle
import paddle.fluid as fluid
24
import numpy as np
25
from visualdl import LogWriter
Z
Zeyu Chen 已提交
26

W
wuzewu 已提交
27
from paddlehub.common.logger import logger
28 29
from paddlehub.common.utils import mkdir
from paddlehub.finetune.config import RunConfig
30
from paddlehub.finetune.strategy import AdamWeightDecayStrategy, DefaultStrategy
W
wuzewu 已提交
31
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
Z
Zeyu Chen 已提交
32
from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_label_task
33
import paddlehub as hub
34 35


36 37 38 39 40 41 42 43 44 45 46 47 48 49
def _do_memory_optimization(task, config):
    if config.enable_memory_optim:
        logger.info("Memory optimization start...")
        task_var_name = task.metric_variable_names()
        logger.info(
            "Skip memory optimization on variables: {}".format(task_var_name))
        optimize_time_begin = time.time()
        fluid.memory_optimize(
            input_program=fluid.default_main_program(),
            # skip memory optimization on task metric variables
            skip_opt_set=task_var_name)
        time_used = time.time() - optimize_time_begin
        logger.info("Memory optimization done! Time elapsed %f sec" % time_used)

50 51 52 53
    # lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
    #     program=task.main_program(), batch_size=config.batch_size)
    # logger.info("Theoretical memory usage in training: %.2f - %.2f %s" %
    #             (lower_mem, upper_mem, unit)),
54 55


56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
def _finetune_seq_label_task(task,
                             data_reader,
                             feed_list,
                             config=None,
                             do_eval=False):
    """
    Finetune sequence labeling task, evaluate metric is F1, precision and recall

    """
    main_program = task.main_program()
    startup_program = task.startup_program()
    loss = task.variable("loss")
    seq_len = task.variable("seq_len")

    num_epoch = config.num_epoch
    batch_size = config.batch_size
Z
Zeyu Chen 已提交
72 73
    log_writer = LogWriter(
        os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=1)
74

Z
Zeyu Chen 已提交
75
    place, dev_count = hub.common.get_running_device_info(config)
76 77 78 79 80
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

        # Select strategy
81
        if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
82 83 84 85 86 87 88 89 90 91 92
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
        elif isinstance(config.strategy, hub.DefaultStrategy):
            config.strategy.execute(loss)
        #TODO: add more finetune strategy

        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)

Z
Zeyu Chen 已提交
93
        best_eval_f1 = 0.0
94 95 96
        train_time_used = 0
        logger.info("PaddleHub finetune start")

W
wuzewu 已提交
97 98
        exe.run(fluid.default_startup_program())

Z
Zeyu Chen 已提交
99 100 101 102 103 104 105 106
        # add visualdl scalar
        with log_writer.mode("train") as logw:
            train_loss_scalar = logw.scalar(tag="Loss [train]")
        with log_writer.mode("evaluate") as logw:
            eval_f1_scalar = logw.scalar(tag="F1 [eval]")
            eval_precision_scalar = logw.scalar(tag="Precision [eval]")
            eval_recall_scalar = logw.scalar(tag="Recall [eval]")

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        # Finetune loop
        for epoch in range(current_epoch, num_epoch + 1):
            train_reader = data_reader.data_generator(
                batch_size=batch_size, phase='train')
            num_trained_examples = loss_sum = 0
            for batch in train_reader():
                num_batch_examples = len(batch)
                train_time_begin = time.time()
                loss_v = exe.run(
                    feed=data_feeder.feed(batch), fetch_list=[loss.name])
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                loss_sum += loss_v[0] * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    speed = config.log_interval / train_time_used
                    logger.info("step %d: loss=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, speed))
Z
Zeyu Chen 已提交
128
                    train_loss_scalar.add_record(global_step, avg_loss)
129 130

                    train_time_used = 0
Z
Zeyu Chen 已提交
131 132
                    num_trained_examples = 0
                    loss_sum = 0
133 134 135 136 137 138 139 140 141 142 143

                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore correct dataset training status
                    save_checkpoint(
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)

                if do_eval and global_step % config.eval_interval == 0:
Z
Zeyu Chen 已提交
144
                    f1, precision, recall = evaluate_seq_label_task(
145 146 147
                        task,
                        data_reader,
                        feed_list,
148
                        phase="dev",
149
                        config=config)
Z
Zeyu Chen 已提交
150 151 152 153 154 155 156 157 158 159
                    eval_f1_scalar.add_record(global_step, f1)
                    eval_precision_scalar.add_record(global_step, precision)
                    eval_recall_scalar.add_record(global_step, recall)
                    if f1 > best_eval_f1:
                        best_eval_f1 = f1
                        model_saved_dir = os.path.join(config.checkpoint_dir,
                                                       "best_model")
                        logger.info("best model saved to %s [best F1=%.5f]" %
                                    (model_saved_dir, best_eval_f1))
                        fluid.io.save_persistables(exe, dirname=model_saved_dir)
160 161 162 163 164 165 166 167 168

        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
        save_checkpoint(
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)

Z
Zeyu Chen 已提交
169
        # Final evaluation
170
        if do_eval:
Z
Zeyu Chen 已提交
171
            evaluate_seq_label_task(
172
                task, data_reader, feed_list, phase="dev", config=config)
Z
Zeyu Chen 已提交
173
            evaluate_seq_label_task(
174 175 176 177 178 179
                task, data_reader, feed_list, phase="test", config=config)
        logger.info("PaddleHub finetune finished.")


def _finetune_cls_task(task, data_reader, feed_list, config=None,
                       do_eval=False):
W
wuzewu 已提交
180
    main_program = task.main_program()
Z
Zeyu Chen 已提交
181
    startup_program = task.startup_program()
Z
Zeyu Chen 已提交
182 183
    loss = task.variable("loss")
    accuracy = task.variable("accuracy")
W
wuzewu 已提交
184

185
    num_epoch = config.num_epoch
W
wuzewu 已提交
186
    batch_size = config.batch_size
Z
Zeyu Chen 已提交
187 188
    log_writer = LogWriter(
        os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=1)
W
wuzewu 已提交
189

Z
Zeyu Chen 已提交
190
    place, dev_count = hub.common.get_running_device_info(config)
W
wuzewu 已提交
191 192
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
Z
Zeyu Chen 已提交
193 194
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

Z
Zeyu Chen 已提交
195
        # select strategy
196
        if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
Z
Zeyu Chen 已提交
197 198
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
W
wuzewu 已提交
199
        elif isinstance(config.strategy, hub.DefaultStrategy):
Z
Zeyu Chen 已提交
200
            config.strategy.execute(loss)
Z
Zeyu Chen 已提交
201
        #TODO: add more finetune strategy
W
wuzewu 已提交
202

203 204 205 206
        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)
207 208 209 210

        best_eval_acc = 0.0
        train_time_used = 0
        logger.info("PaddleHub finetune start")
W
wuzewu 已提交
211 212

        # add visualdl scalar
Z
Zeyu Chen 已提交
213
        with log_writer.mode("train") as logw:
Z
Zeyu Chen 已提交
214 215
            train_loss_scalar = logw.scalar(tag="Loss [train]")
            train_acc_scalar = logw.scalar(tag="Accuracy [train]")
Z
Zeyu Chen 已提交
216
        with log_writer.mode("evaluate") as logw:
Z
Zeyu Chen 已提交
217 218
            eval_loss_scalar = logw.scalar(tag="Loss [eval]")
            eval_acc_scalar = logw.scalar(tag="Accuracy [eval]")
W
wuzewu 已提交
219

W
wuzewu 已提交
220 221
        exe.run(fluid.default_startup_program())

222
        # Finetune loop
223
        for epoch in range(current_epoch, num_epoch + 1):
224
            train_reader = data_reader.data_generator(
Z
Zeyu Chen 已提交
225
                batch_size=batch_size, phase='train')
226
            num_trained_examples = acc_sum = loss_sum = 0
W
wuzewu 已提交
227
            for batch in train_reader():
228 229
                num_batch_examples = len(batch)
                train_time_begin = time.time()
W
wuzewu 已提交
230
                loss_v, accuracy_v = exe.run(
W
wuzewu 已提交
231
                    feed=data_feeder.feed(batch),
S
Steffy-zxf 已提交
232 233 234 235
                    fetch_list=[loss.name, accuracy.name],
                    return_numpy=False)
                loss_v = np.array(loss_v)
                accuracy_v = np.array(accuracy_v)
236 237 238 239 240 241 242 243 244 245
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                acc_sum += accuracy_v * num_batch_examples
                loss_sum += loss_v * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    avg_acc = acc_sum / num_trained_examples
Z
Zeyu Chen 已提交
246
                    speed = config.log_interval / train_time_used
247 248
                    logger.info("step %d: loss=%.5f acc=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, avg_acc, speed))
W
wuzewu 已提交
249 250

                    # record visualdl log
251 252 253 254 255 256
                    train_loss_scalar.add_record(global_step, avg_loss)
                    train_acc_scalar.add_record(global_step, avg_acc)

                    train_time_used = 0
                    num_trained_examples = acc_sum = loss_sum = 0

W
wuzewu 已提交
257
                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
258 259
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore dataset training status
W
wuzewu 已提交
260
                    save_checkpoint(
261 262 263 264
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)
W
wuzewu 已提交
265

266
                if do_eval and global_step % config.eval_interval == 0:
267
                    eval_loss, eval_acc, eval_perf = evaluate_cls_task(
W
wuzewu 已提交
268
                        task,
269
                        data_reader,
W
wuzewu 已提交
270
                        feed_list,
271
                        phase="val",
W
wuzewu 已提交
272
                        config=config)
273 274
                    eval_loss_scalar.add_record(global_step, eval_loss)
                    eval_acc_scalar.add_record(global_step, eval_acc)
W
wuzewu 已提交
275 276
                    if eval_acc > best_eval_acc:
                        best_eval_acc = eval_acc
277
                        model_saved_dir = os.path.join(config.checkpoint_dir,
278 279 280 281 282
                                                       "best_model")
                        logger.info(
                            "best model saved to %s [best accuracy=%.5f]" %
                            (model_saved_dir, best_eval_acc))
                        fluid.io.save_persistables(exe, dirname=model_saved_dir)
W
wuzewu 已提交
283

284 285
        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
W
wuzewu 已提交
286
        save_checkpoint(
287 288 289 290
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)
291

Z
Zeyu Chen 已提交
292
        # Final evaluation
293
        if do_eval:
Z
Zeyu Chen 已提交
294 295
            evaluate_cls_task(
                task, data_reader, feed_list, phase="dev", config=config)
296 297
            evaluate_cls_task(
                task, data_reader, feed_list, phase="test", config=config)
298
        logger.info("PaddleHub finetune finished.")
W
wuzewu 已提交
299 300


301
def finetune_and_eval(task, data_reader, feed_list, config=None):
302 303 304 305 306 307
    if config is None:
        config = RunConfig()

    if not os.path.exists(config.checkpoint_dir):
        mkdir(config.checkpoint_dir)

308 309 310
    if task.task_type == "sequence_labeling":
        _finetune_seq_label_task(
            task, data_reader, feed_list, config, do_eval=True)
Z
Zeyu Chen 已提交
311
    elif task.task_type == "image_classification" or task.task_type == "text_classification":
312
        _finetune_cls_task(task, data_reader, feed_list, config, do_eval=True)