predict.py 5.2 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
Z
Zeyu Chen 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
22
import ast
Z
Zeyu Chen 已提交
23
import numpy as np
24 25
import os
import time
Z
Zeyu Chen 已提交
26 27 28 29 30 31 32
import paddle
import paddle.fluid as fluid
import paddlehub as hub

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
33
parser.add_argument("--batch_size",     type=int,   default=1, help="Total examples' number in batch for training.")
Z
Zeyu Chen 已提交
34
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
35
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
W
wuzewu 已提交
36
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
K
kinghuin 已提交
37
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset")
Z
Zeyu Chen 已提交
38 39 40 41
args = parser.parse_args()
# yapf: enable.

if __name__ == '__main__':
K
kinghuin 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    dataset = None
    # Download dataset and use ClassifyReader to read dataset
    if args.dataset.lower() == "chnsenticorp":
        dataset = hub.dataset.ChnSentiCorp()
        module = hub.Module(name="ernie")
    elif args.dataset.lower() == "nlpcc_dbqa":
        dataset = hub.dataset.NLPCC_DBQA()
        module = hub.Module(name="ernie")
    elif args.dataset.lower() == "lcqmc":
        dataset = hub.dataset.LCQMC()
        module = hub.Module(name="ernie")
    elif args.dataset.lower() == "mrpc":
        dataset = hub.dataset.GLUE("MRPC")
        module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
    elif args.dataset.lower() == "qqp":
        dataset = hub.dataset.GLUE("QQP")
        module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
    elif args.dataset.lower() == "sst-2":
        dataset = hub.dataset.GLUE("SST-2")
        module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
    elif args.dataset.lower() == "cola":
        dataset = hub.dataset.GLUE("CoLA")
        module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
    elif args.dataset.lower() == "qnli":
        dataset = hub.dataset.GLUE("QNLI")
        module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
    elif args.dataset.lower() == "rte":
        dataset = hub.dataset.GLUE("RTE")
        module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
    elif args.dataset.lower() == "mnli":
        dataset = hub.dataset.GLUE("MNLI")
        module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
    elif args.dataset.lower().startswith("xnli"):
        dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
        module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
    else:
        raise ValueError("%s dataset is not defined" % args.dataset)
Z
Zeyu Chen 已提交
79

K
kinghuin 已提交
80 81
    inputs, outputs, program = module.context(
        trainable=True, max_seq_len=args.max_seq_len)
Z
Zeyu Chen 已提交
82 83 84 85 86
    reader = hub.reader.ClassifyReader(
        dataset=dataset,
        vocab_path=module.get_vocab_path(),
        max_seq_len=args.max_seq_len)

87 88 89 90
    # Construct transfer learning network
    # Use "pooled_output" for classification tasks on an entire sentence.
    # Use "sequence_output" for token-level output.
    pooled_output = outputs["pooled_output"]
Z
Zeyu Chen 已提交
91

92 93 94 95 96 97 98 99
    # Setup feed list for data feeder
    # Must feed all the tensor of ERNIE's module need
    feed_list = [
        inputs["input_ids"].name,
        inputs["position_ids"].name,
        inputs["segment_ids"].name,
        inputs["input_mask"].name,
    ]
Z
Zeyu Chen 已提交
100

101 102
    # Setup runing config for PaddleHub Finetune API
    config = hub.RunConfig(
W
wuzewu 已提交
103 104
        use_data_parallel=False,
        use_pyreader=args.use_pyreader,
105 106 107 108 109
        use_cuda=args.use_gpu,
        batch_size=args.batch_size,
        enable_memory_optim=False,
        checkpoint_dir=args.checkpoint_dir,
        strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
110

111 112 113 114 115 116 117
    # Define a classfication finetune task by PaddleHub's API
    cls_task = hub.TextClassifierTask(
        data_reader=reader,
        feature=pooled_output,
        feed_list=feed_list,
        num_classes=dataset.num_labels,
        config=config)
Z
Zeyu Chen 已提交
118

119
    # Data to be prdicted
K
kinghuin 已提交
120
    data = [[d.text_a, d.text_b] for d in dataset.get_dev_examples()[:3]]
Z
Zeyu Chen 已提交
121

122
    index = 0
123 124
    run_states = cls_task.predict(data=data)
    results = [run_state.run_results for run_state in run_states]
125 126 127 128 129 130
    for batch_result in results:
        # get predict index
        batch_result = np.argmax(batch_result, axis=2)[0]
        for result in batch_result:
            print("%s\tpredict=%s" % (data[index][0], result))
            index += 1