提交 e5e64166 编写于 作者: Z Zeyu Chen

add ernie classification prediction

上级 086f5f2c
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import argparse
import numpy as np
import paddle
import paddle.fluid as fluid
import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
# loading Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie")
input_dict, output_dict, program = module.context(
max_seq_len=args.max_seq_len)
# Sentence classification dataset reader
dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64')
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_outputs" for token-level output.
pooled_output = output_dict["pooled_output"]
# Setup feed list for data feeder
# Must feed all the tensor of ERNIE's module need
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_classification_task(
feature=pooled_output, label=label, num_classes=dataset.num_labels)
# classificatin probability tensor
probs = cls_task.variable("probs")
# load best model checkpoint
fluid.io.load_persistables(exe, args.checkpoint_dir)
feed_list = [
input_dict["input_ids"].name, input_dict["position_ids"].name,
input_dict["segment_ids"].name, input_dict["input_mask"].name,
label.name
]
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
test_reader = reader.data_generator(phase='test', shuffle=False)
test_examples = dataset.get_test_examples()
for index, batch in enumerate(test_reader()):
probs_v = exe.run(
feed=data_feeder.feed(batch), fetch_list=[probs.name])
print(test_examples[index], probs_v[0][0])
......@@ -6,10 +6,9 @@ module = hub.Module(name="ernie")
inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
# Step2
dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.ClassifyReader(
dataset=hub.dataset.ChnSentiCorp(),
vocab_path=module.get_vocab_path(),
max_seq_len=128)
dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=128)
# Step3
with fluid.program_guard(program):
......@@ -18,7 +17,7 @@ with fluid.program_guard(program):
pooled_output = outputs["pooled_output"]
cls_task = hub.create_text_classification_task(
feature=pooled_output, label=label, num_classes=reader.get_num_labels())
feature=pooled_output, label=label, num_classes=dataset.num_labels)
# Step4
strategy = hub.AdamWeightDecayStrategy(
......
......@@ -37,11 +37,11 @@ if __name__ == '__main__':
trainable=True, max_seq_len=args.max_seq_len)
# Step2: Download dataset and use ClassifyReader to read dataset
dataset = hub.dataset.NLPCC_DBQA()
reader = hub.reader.ClassifyReader(
dataset=hub.dataset.NLPCC_DBQA(),
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
num_labels = len(reader.get_labels())
# Step3: construct transfer learning network
with fluid.program_guard(program):
......@@ -59,7 +59,7 @@ if __name__ == '__main__':
]
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_classification_task(
pooled_output, label, num_classes=num_labels)
pooled_output, label, num_classes=dataset.num_labels)
# Step4: Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
......
......@@ -37,11 +37,11 @@ if __name__ == '__main__':
trainable=True, max_seq_len=args.max_seq_len)
# Step2: Download dataset and use ClassifyReader to read dataset
dataset = hub.dataset.LCQMC()
reader = hub.reader.ClassifyReader(
dataset=hub.dataset.LCQMC(),
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
num_labels = len(reader.get_labels())
# Step3: construct transfer learning network
with fluid.program_guard(program):
......@@ -59,7 +59,7 @@ if __name__ == '__main__':
]
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_classification_task(
pooled_output, label, num_classes=num_labels)
pooled_output, label, num_classes=dataset.num_labels)
# Step4: Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
......
export CUDA_VISIBLE_DEVICES=1
CKPT_DIR="./ckpt_sentiment_cls/best_model"
python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
......@@ -37,8 +37,9 @@ if __name__ == '__main__':
trainable=True, max_seq_len=args.max_seq_len)
# Step2: Download dataset and use ClassifyReader to read dataset
dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.ClassifyReader(
dataset=hub.dataset.ChnSentiCorp(),
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
......@@ -58,7 +59,9 @@ if __name__ == '__main__':
]
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_classification_task(
pooled_output, label, num_classes=reader.get_num_labels())
feature=pooled_output,
label=label,
num_classes=dataset.num_labels())
# Step4: Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
......
......@@ -70,6 +70,13 @@ class ChnSentiCorp(HubDataset):
def get_labels(self):
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
......
......@@ -40,6 +40,13 @@ class InputExample(object):
self.text_b = text_b
self.label = label
def __str__(self):
if self.text_b is None:
return "text={}\tlabel={}".format(self.text_a, self.label)
else:
return "text_a={}\ttext_b{},label={}".format(
self.text_a, self.text_b, label)
class HubDataset(object):
def get_train_examples(self):
......@@ -56,3 +63,6 @@ class HubDataset(object):
def get_labels(self):
raise NotImplementedError()
def num_labels(self):
raise NotImplementedError()
......@@ -66,6 +66,13 @@ class LCQMC(HubDataset):
"""See base class."""
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
......
......@@ -79,6 +79,13 @@ class MSRA_NER(HubDataset):
def get_labels(self):
return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def get_label_map(self):
return self.label_map
......
......@@ -72,6 +72,13 @@ class NLPCC_DBQA(HubDataset):
"""See base class."""
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
......
......@@ -80,9 +80,6 @@ class BaseReader(object):
"""Gets the list of labels for this data set."""
return self.dataset.get_labels()
def get_num_labels(self):
return len(self.dataset.get_labels())
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_example, self.current_epoch
......@@ -211,7 +208,7 @@ class BaseReader(object):
)
return self.num_examples[phase]
def data_generator(self, batch_size, phase='train', shuffle=True):
def data_generator(self, batch_size=1, phase='train', shuffle=True):
if phase == 'train':
examples = self.get_train_examples()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册