提交 f4560fa8 编写于 作者: Z Zeyu Chen

remove extract embedding reader

上级 4be6ad56
...@@ -5,6 +5,6 @@ CKPT_DIR="./ckpt_${DATASET}" ...@@ -5,6 +5,6 @@ CKPT_DIR="./ckpt_${DATASET}"
python -u senta_finetune.py \ python -u senta_finetune.py \
--batch_size=24 \ --batch_size=24 \
--use_gpu=True \ --use_gpu=False \
--checkpoint_dir=${CKPT_DIR} \ --checkpoint_dir=${CKPT_DIR} \
--num_epoch=3 --num_epoch=3
...@@ -36,7 +36,6 @@ class BaseReader(object): ...@@ -36,7 +36,6 @@ class BaseReader(object):
label_map_config=None, label_map_config=None,
max_seq_len=512, max_seq_len=512,
do_lower_case=True, do_lower_case=True,
in_tokens=False,
random_seed=None): random_seed=None):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer( self.tokenizer = tokenization.FullTokenizer(
...@@ -46,7 +45,7 @@ class BaseReader(object): ...@@ -46,7 +45,7 @@ class BaseReader(object):
self.pad_id = self.vocab["[PAD]"] self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"] self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"] self.sep_id = self.vocab["[SEP]"]
self.in_tokens = in_tokens self.in_tokens = False
np.random.seed(random_seed) np.random.seed(random_seed)
...@@ -352,36 +351,6 @@ class SequenceLabelReader(BaseReader): ...@@ -352,36 +351,6 @@ class SequenceLabelReader(BaseReader):
return record return record
class ExtractEmbeddingReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
# padding
padded_token_ids, input_mask, seq_lens = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len)
padded_position_ids = pad_batch_data(
batch_position_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len)
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, seq_lens
]
return return_list
class LACClassifyReader(object): class LACClassifyReader(object):
def __init__(self, dataset, vocab_path): def __init__(self, dataset, vocab_path):
self.dataset = dataset self.dataset = dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册