提交 7f1a2c0b 编写于 作者: Z zhangxuefei

Fix the tokenize encoding bug and Update cv dataset (add train dev test examples)

上级 f0accf42
......@@ -35,6 +35,10 @@ class ImageClassificationDataset(object):
self.num_labels = 0
self.label_list = []
self.train_examples = []
self.dev_examples = []
self.test_examples = []
def _download_dataset(self, dataset_path, url):
if not os.path.exists(dataset_path):
result, tips, dataset_path = default_downloader.download_file_and_uncompress(
......@@ -47,7 +51,7 @@ class ImageClassificationDataset(object):
exit()
return dataset_path
def _parse_data(self, data_path, shuffle=False):
def _parse_data(self, data_path, shuffle=False, phase=None):
def _base_reader():
data = []
with open(data_path, "r") as file:
......@@ -68,6 +72,13 @@ class ImageClassificationDataset(object):
label = items[-1]
data.append((image_path, items[-1]))
if phase == 'train':
self.train_examples = data
elif phase == 'dev':
self.dev_examples = data
elif phase == 'test':
self.test_examples = data
if shuffle:
np.random.shuffle(data)
......@@ -85,13 +96,22 @@ class ImageClassificationDataset(object):
def train_data(self, shuffle=True):
train_data_path = os.path.join(self.base_path, self.train_list_file)
return self._parse_data(train_data_path, shuffle)
return self._parse_data(train_data_path, shuffle, phase='train')
def test_data(self, shuffle=False):
test_data_path = os.path.join(self.base_path, self.test_list_file)
return self._parse_data(test_data_path, shuffle)
return self._parse_data(test_data_path, shuffle, phase='dev')
def validate_data(self, shuffle=False):
validate_data_path = os.path.join(self.base_path,
self.validate_list_file)
return self._parse_data(validate_data_path, shuffle)
return self._parse_data(validate_data_path, shuffle, phase='test')
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
......@@ -520,7 +520,9 @@ class BasicTask(object):
self.save_checkpoint()
# Final evaluation
if self._base_data_reader.get_dev_examples() != []:
self.eval(phase="dev")
if self._base_data_reader.get_test_examples() != []:
self.eval(phase="test")
self._finetune_end_event(run_states)
......
......@@ -122,3 +122,12 @@ class ImageClassificationReader(object):
yield (image, label)
return paddle.batch(_data_reader, batch_size=batch_size)
def get_train_examples(self):
return self.dataset.train_examples
def get_dev_examples(self):
return self.dataset.dev_examples
def get_test_examples(self):
return self.dataset.test_examples
......@@ -42,7 +42,8 @@ class BaseReader(object):
label_map_config=None,
max_seq_len=512,
do_lower_case=True,
random_seed=None):
random_seed=None,
use_task_id=False):
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
......@@ -52,6 +53,10 @@ class BaseReader(object):
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.in_tokens = False
self.use_task_id = use_task_id
if self.use_task_id:
self.task_id = 0
np.random.seed(random_seed)
......@@ -232,7 +237,6 @@ class BaseReader(object):
phase='train',
shuffle=True,
data=None):
if phase == 'train':
shuffle = True
examples = self.get_train_examples()
......@@ -313,12 +317,25 @@ class ClassifyReader(BaseReader):
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if self.use_task_id:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
]
return return_list
......@@ -355,11 +372,30 @@ class SequenceLabelReader(BaseReader):
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_label_ids, batch_seq_lens
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, padded_label_ids,
batch_seq_lens
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_seq_lens
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_seq_lens
]
return return_list
def _reseg_token_label(self, tokens, tokenizer, phase, labels=None):
......@@ -585,11 +621,27 @@ class MultiLabelClassifyReader(BaseReader):
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
]
return return_list
def _convert_example_to_record(self,
......
......@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import codecs
import collections
import io
import unicodedata
import six
......@@ -71,7 +71,7 @@ def printable_text(text):
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = codecs.open(vocab_file, "r", "UTF-8")
fin = io.open(vocab_file, "r", "UTF-8")
for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册