提交 fae95818 编写于 作者: K kinghuin 提交者: Steffy-zxf

regression support ERNIE2 (#156)

* regression support ERNIE2

* modify vdl_log_dir to tb_log_dir

* support ernie2
上级 b4bf6afb
......@@ -34,6 +34,7 @@ parser.add_argument("--max_seq_len", type=int, default=512, help="Number of word
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--use_taskid", type=ast.literal_eval, default=False, help="Whether to use taskid ,if yes to use ernie v2.")
args = parser.parse_args()
# yapf: enable.
......@@ -42,7 +43,10 @@ if __name__ == '__main__':
# Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "sts-b":
dataset = hub.dataset.GLUE("STS-B")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
else:
raise ValueError("%s dataset is not defined" % args.dataset)
......@@ -51,7 +55,8 @@ if __name__ == '__main__':
reader = hub.reader.RegressionReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
......@@ -67,6 +72,9 @@ if __name__ == '__main__':
inputs["input_mask"].name,
]
if args.use_taskid:
feed_list.append(inputs["task_ids"].name)
# Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
......
......@@ -17,3 +17,4 @@ python -u regression.py \
--num_epoch=3 \
--use_pyreader=True \
--use_data_parallel=True \
--use_taskid=False \
......@@ -145,7 +145,7 @@ if __name__ == '__main__':
]
if args.use_taskid:
feed_list = feed_list.append(inputs["task_ids"].name)
feed_list.append(inputs["task_ids"].name)
# Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
......
......@@ -55,7 +55,6 @@ class RunConfig(object):
self._strategy = strategy
self._enable_memory_optim = enable_memory_optim
if checkpoint_dir is None:
now = int(time.time())
time_str = time.strftime("%Y%m%d%H%M%S", time.localtime(now))
self._checkpoint_dir = "ckpt_" + time_str
......
......@@ -140,8 +140,8 @@ class BasicTask(object):
# log item
if not os.path.exists(self.config.checkpoint_dir):
mkdir(self.config.checkpoint_dir)
vdl_log_dir = os.path.join(self.config.checkpoint_dir, "visualization")
self.tb_writer = SummaryWriter(vdl_log_dir)
tb_log_dir = os.path.join(self.config.checkpoint_dir, "visualization")
self.tb_writer = SummaryWriter(tb_log_dir)
# run environment
self._phases = []
......
......@@ -761,7 +761,8 @@ class RegressionReader(BaseReader):
label_map_config=None,
max_seq_len=128,
do_lower_case=True,
random_seed=None):
random_seed=None,
use_task_id=False):
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
......@@ -771,6 +772,10 @@ class RegressionReader(BaseReader):
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.in_tokens = False
self.use_task_id = use_task_id
if self.use_task_id:
self.task_id = 0
np.random.seed(random_seed)
......@@ -811,12 +816,28 @@ class RegressionReader(BaseReader):
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
]
return return_list
def data_generator(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册