finetune.py 11.7 KB
Newer Older
W
wuzewu 已提交
1
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
Z
Zeyu Chen 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Z
Zeyu Chen 已提交
19 20 21
import os
import time

W
wuzewu 已提交
22 23
import paddle
import paddle.fluid as fluid
24
import numpy as np
25
from visualdl import LogWriter
Z
Zeyu Chen 已提交
26

W
wuzewu 已提交
27
from paddlehub.common.logger import logger
28
from paddlehub.finetune.strategy import AdamWeightDecayStrategy, DefaultStrategy
W
wuzewu 已提交
29
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
Z
Zeyu Chen 已提交
30
from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_labeling_task
31
import paddlehub as hub
32 33


34 35 36 37 38 39 40 41 42 43 44 45 46 47
def _do_memory_optimization(task, config):
    if config.enable_memory_optim:
        logger.info("Memory optimization start...")
        task_var_name = task.metric_variable_names()
        logger.info(
            "Skip memory optimization on variables: {}".format(task_var_name))
        optimize_time_begin = time.time()
        fluid.memory_optimize(
            input_program=fluid.default_main_program(),
            # skip memory optimization on task metric variables
            skip_opt_set=task_var_name)
        time_used = time.time() - optimize_time_begin
        logger.info("Memory optimization done! Time elapsed %f sec" % time_used)

48 49 50 51
    # lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
    #     program=task.main_program(), batch_size=config.batch_size)
    # logger.info("Theoretical memory usage in training: %.2f - %.2f %s" %
    #             (lower_mem, upper_mem, unit)),
52 53


54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
def _finetune_seq_label_task(task,
                             data_reader,
                             feed_list,
                             config=None,
                             do_eval=False):
    """
    Finetune sequence labeling task, evaluate metric is F1, precision and recall

    """
    main_program = task.main_program()
    startup_program = task.startup_program()
    loss = task.variable("loss")
    seq_len = task.variable("seq_len")

    num_epoch = config.num_epoch
    batch_size = config.batch_size

Z
Zeyu Chen 已提交
71
    place, dev_count = hub.common.get_running_device_info(config)
72 73 74 75 76
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

        # Select strategy
77
        if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
78 79 80 81 82 83 84 85 86 87 88 89 90 91
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
        elif isinstance(config.strategy, hub.DefaultStrategy):
            config.strategy.execute(loss)
        #TODO: add more finetune strategy

        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)

        train_time_used = 0
        logger.info("PaddleHub finetune start")

W
wuzewu 已提交
92 93
        exe.run(fluid.default_startup_program())

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        # Finetune loop
        for epoch in range(current_epoch, num_epoch + 1):
            train_reader = data_reader.data_generator(
                batch_size=batch_size, phase='train')
            num_trained_examples = loss_sum = 0
            for batch in train_reader():
                num_batch_examples = len(batch)
                train_time_begin = time.time()
                loss_v = exe.run(
                    feed=data_feeder.feed(batch), fetch_list=[loss.name])
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                loss_sum += loss_v[0] * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    speed = config.log_interval / train_time_used
                    logger.info("step %d: loss=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, speed))

                    train_time_used = 0
                    num_trained_examples = loss_sum = 0

                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore correct dataset training status
                    save_checkpoint(
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)

                if do_eval and global_step % config.eval_interval == 0:
129
                    evaluate_seq_labeling_task(
130 131 132
                        task,
                        data_reader,
                        feed_list,
133
                        phase="test",
134
                        config=config)
135
                    evaluate_seq_labeling_task(
136 137 138
                        task,
                        data_reader,
                        feed_list,
139
                        phase="dev",
140 141 142 143 144 145 146 147 148 149 150
                        config=config)

        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
        save_checkpoint(
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)

        if do_eval:
151
            evaluate_seq_labeling_task(
152
                task, data_reader, feed_list, phase="dev", config=config)
153
            evaluate_seq_labeling_task(
154 155 156 157 158 159
                task, data_reader, feed_list, phase="test", config=config)
        logger.info("PaddleHub finetune finished.")


def _finetune_cls_task(task, data_reader, feed_list, config=None,
                       do_eval=False):
W
wuzewu 已提交
160
    main_program = task.main_program()
Z
Zeyu Chen 已提交
161
    startup_program = task.startup_program()
Z
Zeyu Chen 已提交
162 163
    loss = task.variable("loss")
    accuracy = task.variable("accuracy")
W
wuzewu 已提交
164

165
    num_epoch = config.num_epoch
W
wuzewu 已提交
166
    batch_size = config.batch_size
W
wuzewu 已提交
167
    log_writter = LogWriter(
168
        os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=10)
W
wuzewu 已提交
169

Z
Zeyu Chen 已提交
170
    place, dev_count = hub.common.get_running_device_info(config)
W
wuzewu 已提交
171 172
    with fluid.program_guard(main_program, startup_program):
        exe = fluid.Executor(place=place)
Z
Zeyu Chen 已提交
173 174
        data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)

Z
Zeyu Chen 已提交
175
        # select strategy
176
        if isinstance(config.strategy, hub.AdamWeightDecayStrategy):
Z
Zeyu Chen 已提交
177 178
            scheduled_lr = config.strategy.execute(loss, main_program,
                                                   data_reader, config)
W
wuzewu 已提交
179
        elif isinstance(config.strategy, hub.DefaultStrategy):
Z
Zeyu Chen 已提交
180
            config.strategy.execute(loss)
Z
Zeyu Chen 已提交
181
        #TODO: add more finetune strategy
W
wuzewu 已提交
182

183 184 185 186
        _do_memory_optimization(task, config)

        # Try to restore model training checkpoint
        current_epoch, global_step = load_checkpoint(config.checkpoint_dir, exe)
187 188 189 190

        best_eval_acc = 0.0
        train_time_used = 0
        logger.info("PaddleHub finetune start")
W
wuzewu 已提交
191 192 193 194 195 196 197 198 199

        # add visualdl scalar
        with log_writter.mode("train") as logw:
            train_loss_scalar = logw.scalar(tag="loss[train]")
            train_acc_scalar = logw.scalar(tag="accuracy[train]")
        with log_writter.mode("evaluate") as logw:
            eval_loss_scalar = logw.scalar(tag="loss[evaluate]")
            eval_acc_scalar = logw.scalar(tag="accuracy[evaluate]")

W
wuzewu 已提交
200 201
        exe.run(fluid.default_startup_program())

202
        # Finetune loop
203
        for epoch in range(current_epoch, num_epoch + 1):
204
            train_reader = data_reader.data_generator(
Z
Zeyu Chen 已提交
205
                batch_size=batch_size, phase='train')
206
            num_trained_examples = acc_sum = loss_sum = 0
W
wuzewu 已提交
207
            for batch in train_reader():
208 209
                num_batch_examples = len(batch)
                train_time_begin = time.time()
W
wuzewu 已提交
210
                loss_v, accuracy_v = exe.run(
W
wuzewu 已提交
211
                    feed=data_feeder.feed(batch),
W
wuzewu 已提交
212
                    fetch_list=[loss.name, accuracy.name])
213 214 215 216 217 218 219 220 221 222
                train_time_used += time.time() - train_time_begin
                global_step += 1
                num_trained_examples += num_batch_examples
                acc_sum += accuracy_v * num_batch_examples
                loss_sum += loss_v * num_batch_examples

                # log fintune status
                if global_step % config.log_interval == 0:
                    avg_loss = loss_sum / num_trained_examples
                    avg_acc = acc_sum / num_trained_examples
Z
Zeyu Chen 已提交
223
                    speed = config.log_interval / train_time_used
224 225
                    logger.info("step %d: loss=%.5f acc=%.5f [step/sec: %.2f]" %
                                (global_step, avg_loss, avg_acc, speed))
W
wuzewu 已提交
226 227

                    # record visualdl log
228 229 230 231 232 233
                    train_loss_scalar.add_record(global_step, avg_loss)
                    train_acc_scalar.add_record(global_step, avg_acc)

                    train_time_used = 0
                    num_trained_examples = acc_sum = loss_sum = 0

W
wuzewu 已提交
234
                if config.save_ckpt_interval and global_step % config.save_ckpt_interval == 0:
235 236
                    # NOTE: current saved checkpoint machanism is not completed,
                    # it can't restore dataset training status
W
wuzewu 已提交
237
                    save_checkpoint(
238 239 240 241
                        checkpoint_dir=config.checkpoint_dir,
                        current_epoch=epoch,
                        global_step=global_step,
                        exe=exe)
W
wuzewu 已提交
242

243
                if do_eval and global_step % config.eval_interval == 0:
244
                    eval_loss, eval_acc, eval_perf = evaluate_cls_task(
W
wuzewu 已提交
245
                        task,
246
                        data_reader,
W
wuzewu 已提交
247
                        feed_list,
248
                        phase="val",
W
wuzewu 已提交
249
                        config=config)
250 251
                    eval_loss_scalar.add_record(global_step, eval_loss)
                    eval_acc_scalar.add_record(global_step, eval_acc)
W
wuzewu 已提交
252 253
                    if eval_acc > best_eval_acc:
                        best_eval_acc = eval_acc
254
                        model_saved_dir = os.path.join(config.checkpoint_dir,
255 256 257 258 259
                                                       "best_model")
                        logger.info(
                            "best model saved to %s [best accuracy=%.5f]" %
                            (model_saved_dir, best_eval_acc))
                        fluid.io.save_persistables(exe, dirname=model_saved_dir)
W
wuzewu 已提交
260

261 262
        # NOTE: current saved checkpoint machanism is not completed, it can't
        # resotre dataset training status
W
wuzewu 已提交
263
        save_checkpoint(
264 265 266 267
            checkpoint_dir=config.checkpoint_dir,
            current_epoch=num_epoch + 1,
            global_step=global_step,
            exe=exe)
268 269

        if do_eval:
270 271
            evaluate_cls_task(
                task, data_reader, feed_list, phase="test", config=config)
272
        logger.info("PaddleHub finetune finished.")
W
wuzewu 已提交
273 274


275
def finetune_and_eval(task, data_reader, feed_list, config=None):
276 277 278
    if task.task_type == "sequence_labeling":
        _finetune_seq_label_task(
            task, data_reader, feed_list, config, do_eval=True)
279
    # if it's image_classification and text classificaiton
280 281
    else:
        _finetune_cls_task(task, data_reader, feed_list, config, do_eval=True)
W
wuzewu 已提交
282 283


284
def finetune(task, data_reader, feed_list, config=None):
285
    _finetune_cls_task(task, data_reader, feed_list, config, do_eval=False)