提交 014f4f40 编写于 作者: Z Zeyu Chen

change bert task reader interface

上级 34868288
...@@ -73,11 +73,11 @@ if __name__ == '__main__': ...@@ -73,11 +73,11 @@ if __name__ == '__main__':
warmup_proportion=args.warmup_proportion) warmup_proportion=args.warmup_proportion)
# loading paddlehub BERT # loading paddlehub BERT
# module = hub.Module( module = hub.Module(
# module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module") module_dir="./hub_module/chinese_L-12_H-768_A-12.hub_module")
module = hub.Module(module_dir="./hub_module/ernie-stable.hub_module") # module = hub.Module(module_dir="./hub_module/ernie-stable.hub_module")
processor = reader.ChnsenticorpProcessor( processor = reader.BERTClassifyReader(
data_dir=args.data_dir, data_dir=args.data_dir,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len) max_seq_len=args.max_seq_len)
......
...@@ -109,9 +109,10 @@ class DataProcessor(object): ...@@ -109,9 +109,10 @@ class DataProcessor(object):
def get_num_examples(self, phase): def get_num_examples(self, phase):
"""Get number of examples for train, dev or test.""" """Get number of examples for train, dev or test."""
if phase not in ['train', 'dev', 'test']: if phase not in ['train', 'validate', 'test']:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].") "Unknown phase, which should be in ['train', 'validate, 'test']."
)
return self.num_examples[phase] return self.num_examples[phase]
def get_train_progress(self): def get_train_progress(self):
...@@ -131,9 +132,9 @@ class DataProcessor(object): ...@@ -131,9 +132,9 @@ class DataProcessor(object):
if phase == 'train': if phase == 'train':
examples = self.get_train_examples(self.data_dir) examples = self.get_train_examples(self.data_dir)
self.num_examples['train'] = len(examples) self.num_examples['train'] = len(examples)
elif phase == 'dev': elif phase == 'validate':
examples = self.get_dev_examples(self.data_dir) examples = self.get_dev_examples(self.data_dir)
self.num_examples['dev'] = len(examples) self.num_examples['validate'] = len(examples)
elif phase == 'test': elif phase == 'test':
examples = self.get_test_examples(self.data_dir) examples = self.get_test_examples(self.data_dir)
self.num_examples['test'] = len(examples) self.num_examples['test'] = len(examples)
...@@ -190,7 +191,7 @@ class DataProcessor(object): ...@@ -190,7 +191,7 @@ class DataProcessor(object):
return_input_mask=True, return_input_mask=True,
return_max_len=True, return_max_len=True,
return_num_token=False) return_num_token=False)
yield batch_data yield [batch_data]
return wrapper return wrapper
...@@ -473,6 +474,41 @@ class ChnsenticorpProcessor(DataProcessor): ...@@ -473,6 +474,41 @@ class ChnsenticorpProcessor(DataProcessor):
return examples return examples
class BERTClassifyReader(DataProcessor):
"""Processor for the Chnsenticorp data set."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_single_example_to_unicode(guid, single_example): def convert_single_example_to_unicode(guid, single_example):
text_a = tokenization.convert_to_unicode(single_example[0]) text_a = tokenization.convert_to_unicode(single_example[0])
text_b = tokenization.convert_to_unicode(single_example[1]) text_b = tokenization.convert_to_unicode(single_example[1])
......
export CUDA_VISIBLE_DEVICES=2 export CUDA_VISIBLE_DEVICES=5
DATA_PATH=./chnsenticorp_data DATA_PATH=./chnsenticorp_data
rm -rf $CKPT_PATH rm -rf $CKPT_PATH
python -u finetune_with_hub.py \ python -u finetune_with_hub.py \
--use_cuda true \
--batch_size 32 \ --batch_size 32 \
--in_tokens false \ --in_tokens false \
--data_dir ${DATA_PATH} \ --data_dir ${DATA_PATH} \
--weight_decay 0.01 \ --weight_decay 0.01 \
--warmup_proportion 0.0 \ --warmup_proportion 0.0 \
--validation_steps 50 \
--epoch 3 \ --epoch 3 \
--max_seq_len 128 \ --max_seq_len 128 \
--learning_rate 5e-5 \ --learning_rate 5e-5
--skip_steps 10
...@@ -255,7 +255,6 @@ class Module: ...@@ -255,7 +255,6 @@ class Module:
def get_vocab_path(self): def get_vocab_path(self):
for assets_file in self.assets: for assets_file in self.assets:
print(assets_file)
if "vocab.txt" in assets_file: if "vocab.txt" in assets_file:
return assets_file return assets_file
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册