finetune_with_hub.py 4.5 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification tasks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import argparse
import numpy as np

import paddle
import paddle.fluid as fluid
import paddle_hub as hub

import reader.cls as reader
from utils.args import ArgumentGroup, print_arguments
Z
Zeyu Chen 已提交
31
from paddle_hub.finetune.config import FinetuneConfig
Z
Zeyu Chen 已提交
32 33 34 35 36 37 38

# yapf: disable
parser = argparse.ArgumentParser(__doc__)

train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch",             int,    3,       "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate",     float,  5e-5,    "Learning rate used to train with warmup.")
39
train_g.add_arg("lr_scheduler",      str,    "linear_warmup_decay", "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
Z
Zeyu Chen 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
train_g.add_arg("weight_decay",      float,  0.01,    "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float,  0.1,
                "Proportion of training steps to perform linear learning rate warmup for.")

data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir",      str,  None,  "Path to training data.")
data_g.add_arg("vocab_path",    str,  None,  "Vocabulary path.")
data_g.add_arg("max_seq_len",   int,  512,   "Number of words of the longest seqence.")
data_g.add_arg("batch_size",    int,  32,    "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("in_tokens",     bool, False,
              "If set, the batch size will be the maximum number of tokens in one batch. "
              "Otherwise, it will be the maximum number of examples in one batch.")

args = parser.parse_args()
# yapf: enable.

Z
Zeyu Chen 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68
if __name__ == '__main__':
    print_arguments(args)
    config = FinetuneConfig(
        log_interval=10,
        eval_interval=100,
        save_ckpt_interval=200,
        use_cuda=True,
        checkpoint_dir="./bert_cls_ckpt",
        learning_rate=args.learning_rate,
        num_epoch=args.epoch,
        batch_size=args.batch_size,
        max_seq_len=args.max_seq_len,
        weight_decay=args.weight_decay,
Z
Zeyu Chen 已提交
69 70
        finetune_strategy="bert_finetune",
        with_memory_optimization=True,
71
        in_tokens=False,
Z
Zeyu Chen 已提交
72
        optimizer=None,
Z
Zeyu Chen 已提交
73
        warmup_proportion=args.warmup_proportion)
Z
Zeyu Chen 已提交
74

75
    # loading paddlehub BERT
Z
Zeyu Chen 已提交
76 77 78
    module = hub.Module(
        module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module")
    # module = hub.Module(module_dir="./hub_module/ernie-stable.hub_module")
79

Z
Zeyu Chen 已提交
80
    processor = reader.BERTClassifyReader(
Z
Zeyu Chen 已提交
81
        data_dir=args.data_dir,
82
        vocab_path=module.get_vocab_path(),
83
        max_seq_len=args.max_seq_len)
Z
Zeyu Chen 已提交
84 85 86

    num_labels = len(processor.get_labels())

Z
Zeyu Chen 已提交
87 88 89
    # bert's input tensor, output tensor and forward graph
    # If you want to fine-tune the pretrain model parameter, please set
    # trainable to True
Z
Zeyu Chen 已提交
90 91 92
    input_dict, output_dict, train_program = module.context(
        sign_name="pooled_output", trainable=True)

Z
Zeyu Chen 已提交
93
    with fluid.program_guard(train_program):
Z
Zeyu Chen 已提交
94 95 96 97
        label = fluid.layers.data(name="label", shape=[1], dtype='int64')

        pooled_output = output_dict["pooled_output"]

Z
Zeyu Chen 已提交
98 99
        # Setup feed list for data feeder
        # Must feed all the tensor of bert's module need
Z
Zeyu Chen 已提交
100 101 102 103 104
        feed_list = [
            input_dict["src_ids"].name, input_dict["pos_ids"].name,
            input_dict["sent_ids"].name, input_dict["input_mask"].name,
            label.name
        ]
Z
Zeyu Chen 已提交
105
        # Define a classfication finetune task by PaddleHub's API
Z
Zeyu Chen 已提交
106
        cls_task = hub.append_mlp_classifier(
Z
Zeyu Chen 已提交
107 108
            pooled_output, label, num_classes=num_labels)

Z
Zeyu Chen 已提交
109 110
        # Finetune and evaluate by PaddleHub's API
        # will finish training, evaluation, testing, save model automatically
Z
Zeyu Chen 已提交
111 112 113 114 115
        hub.finetune_and_eval(
            task=cls_task,
            data_processor=processor,
            feed_list=feed_list,
            config=config)