提交 014f4f40 编写于 作者: Z Zeyu Chen

change bert task reader interface

上级 34868288
......@@ -73,11 +73,11 @@ if __name__ == '__main__':
warmup_proportion=args.warmup_proportion)
# loading paddlehub BERT
# module = hub.Module(
# module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module")
module = hub.Module(module_dir="./hub_module/ernie-stable.hub_module")
module = hub.Module(
module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module")
# module = hub.Module(module_dir="./hub_module/ernie-stable.hub_module")
processor = reader.ChnsenticorpProcessor(
processor = reader.BERTClassifyReader(
data_dir=args.data_dir,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
......
......@@ -109,9 +109,10 @@ class DataProcessor(object):
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'dev', 'test']:
if phase not in ['train', 'validate', 'test']:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
"Unknown phase, which should be in ['train', 'validate, 'test']."
)
return self.num_examples[phase]
def get_train_progress(self):
......@@ -131,9 +132,9 @@ class DataProcessor(object):
if phase == 'train':
examples = self.get_train_examples(self.data_dir)
self.num_examples['train'] = len(examples)
elif phase == 'dev':
elif phase == 'validate':
examples = self.get_dev_examples(self.data_dir)
self.num_examples['dev'] = len(examples)
self.num_examples['validate'] = len(examples)
elif phase == 'test':
examples = self.get_test_examples(self.data_dir)
self.num_examples['test'] = len(examples)
......@@ -190,7 +191,7 @@ class DataProcessor(object):
return_input_mask=True,
return_max_len=True,
return_num_token=False)
yield batch_data
yield [batch_data]
return wrapper
......@@ -473,6 +474,41 @@ class ChnsenticorpProcessor(DataProcessor):
return examples
class BERTClassifyReader(DataProcessor):
"""Processor for the Chnsenticorp data set."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_single_example_to_unicode(guid, single_example):
text_a = tokenization.convert_to_unicode(single_example[0])
text_b = tokenization.convert_to_unicode(single_example[1])
......
export CUDA_VISIBLE_DEVICES=2
export CUDA_VISIBLE_DEVICES=5
DATA_PATH=./chnsenticorp_data
rm -rf $CKPT_PATH
python -u finetune_with_hub.py \
--use_cuda true \
--batch_size 32 \
--in_tokens false \
--data_dir ${DATA_PATH} \
--weight_decay 0.01 \
--warmup_proportion 0.0 \
--validation_steps 50 \
--epoch 3 \
--max_seq_len 128 \
--learning_rate 5e-5 \
--skip_steps 10
--learning_rate 5e-5
......@@ -255,7 +255,6 @@ class Module:
def get_vocab_path(self):
for assets_file in self.assets:
print(assets_file)
if "vocab.txt" in assets_file:
return assets_file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册