sequence_labeling.py 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Z
Zeyu Chen 已提交
14
"""Finetuning on sequence labeling task."""
15

Z
Zeyu Chen 已提交
16 17
import argparse

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
import paddle.fluid as fluid
import paddlehub as hub

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")

args = parser.parse_args()
# yapf: enable.

if __name__ == '__main__':
Z
Zeyu Chen 已提交
34
    # Step1: load Paddlehub ERNIE pretrained model
35
    module = hub.Module(name="ernie")
Z
Zeyu Chen 已提交
36 37
    inputs, outputs, program = module.context(
        trainable=True, max_seq_len=args.max_seq_len)
38

Z
Zeyu Chen 已提交
39
    # Step2: Download dataset and use SequenceLabelReader to read dataset
40
    reader = hub.reader.SequenceLabelReader(
Z
Zeyu Chen 已提交
41
        dataset=hub.dataset.MSRA_NER(),
42 43 44 45 46
        vocab_path=module.get_vocab_path(),
        max_seq_len=args.max_seq_len)

    num_labels = len(reader.get_labels())

Z
Zeyu Chen 已提交
47
    # Step3: construct transfer learning network
48 49 50 51 52 53
    with fluid.program_guard(program):
        label = fluid.layers.data(
            name="label", shape=[args.max_seq_len, 1], dtype='int64')
        seq_len = fluid.layers.data(name="seq_len", shape=[1], dtype='int64')

        # Use "sequence_output" for token-level output.
Z
Zeyu Chen 已提交
54
        sequence_output = outputs["sequence_output"]
55 56

        # Setup feed list for data feeder
Z
Zeyu Chen 已提交
57
        # Must feed all the tensor of ERNIE's module need
Z
Zeyu Chen 已提交
58
        # Compared to classification task, we need add seq_len tensor to feedlist
59
        feed_list = [
Z
Zeyu Chen 已提交
60 61 62
            inputs["input_ids"].name, inputs["position_ids"].name,
            inputs["segment_ids"].name, inputs["input_mask"].name, label.name,
            seq_len
63
        ]
Z
Zeyu Chen 已提交
64
        # Define a sequence labeling finetune task by PaddleHub's API
Z
Zeyu Chen 已提交
65
        seq_label_task = hub.create_seq_labeling_task(
66 67 68 69 70
            feature=sequence_output,
            labels=label,
            seq_len=seq_len,
            num_classes=num_labels)

Z
Zeyu Chen 已提交
71
        # Select a finetune strategy
72
        strategy = hub.AdamWeightDecayStrategy(
Z
Zeyu Chen 已提交
73 74 75 76 77 78 79 80 81 82
            weight_decay=args.weight_decay,
            learning_rate=args.learning_rate,
            warmup_strategy="linear_warmup_decay",
        )

        # Setup runing config for PaddleHub Finetune API
        config = hub.RunConfig(
            use_cuda=True,
            num_epoch=args.num_epoch,
            batch_size=args.batch_size,
83
            checkpoint_dir=args.checkpoint_dir,
Z
Zeyu Chen 已提交
84
            strategy=strategy)
Z
Zeyu Chen 已提交
85

Z
Zeyu Chen 已提交
86
        # Finetune and evaluate model by PaddleHub's API
87 88 89 90 91 92
        # will finish training, evaluation, testing, save model automatically
        hub.finetune_and_eval(
            task=seq_label_task,
            data_reader=reader,
            feed_list=feed_list,
            config=config)