提交 fae95818 编写于 作者: K kinghuin 提交者: Steffy-zxf

regression support ERNIE2 (#156)

* regression support ERNIE2

* modify vdl_log_dir to tb_log_dir

* support ernie2
上级 b4bf6afb
...@@ -34,6 +34,7 @@ parser.add_argument("--max_seq_len", type=int, default=512, help="Number of word ...@@ -34,6 +34,7 @@ parser.add_argument("--max_seq_len", type=int, default=512, help="Number of word
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.") parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--use_taskid", type=ast.literal_eval, default=False, help="Whether to use taskid ,if yes to use ernie v2.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
...@@ -42,7 +43,10 @@ if __name__ == '__main__': ...@@ -42,7 +43,10 @@ if __name__ == '__main__':
# Download dataset and use ClassifyReader to read dataset # Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "sts-b": if args.dataset.lower() == "sts-b":
dataset = hub.dataset.GLUE("STS-B") dataset = hub.dataset.GLUE("STS-B")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12") if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
else: else:
raise ValueError("%s dataset is not defined" % args.dataset) raise ValueError("%s dataset is not defined" % args.dataset)
...@@ -51,7 +55,8 @@ if __name__ == '__main__': ...@@ -51,7 +55,8 @@ if __name__ == '__main__':
reader = hub.reader.RegressionReader( reader = hub.reader.RegressionReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len) max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
# Construct transfer learning network # Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence. # Use "pooled_output" for classification tasks on an entire sentence.
...@@ -67,6 +72,9 @@ if __name__ == '__main__': ...@@ -67,6 +72,9 @@ if __name__ == '__main__':
inputs["input_mask"].name, inputs["input_mask"].name,
] ]
if args.use_taskid:
feed_list.append(inputs["task_ids"].name)
# Select finetune strategy, setup config and finetune # Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy( strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
......
...@@ -17,3 +17,4 @@ python -u regression.py \ ...@@ -17,3 +17,4 @@ python -u regression.py \
--num_epoch=3 \ --num_epoch=3 \
--use_pyreader=True \ --use_pyreader=True \
--use_data_parallel=True \ --use_data_parallel=True \
--use_taskid=False \
...@@ -145,7 +145,7 @@ if __name__ == '__main__': ...@@ -145,7 +145,7 @@ if __name__ == '__main__':
] ]
if args.use_taskid: if args.use_taskid:
feed_list = feed_list.append(inputs["task_ids"].name) feed_list.append(inputs["task_ids"].name)
# Select finetune strategy, setup config and finetune # Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy( strategy = hub.AdamWeightDecayStrategy(
......
...@@ -55,7 +55,6 @@ class RunConfig(object): ...@@ -55,7 +55,6 @@ class RunConfig(object):
self._strategy = strategy self._strategy = strategy
self._enable_memory_optim = enable_memory_optim self._enable_memory_optim = enable_memory_optim
if checkpoint_dir is None: if checkpoint_dir is None:
now = int(time.time()) now = int(time.time())
time_str = time.strftime("%Y%m%d%H%M%S", time.localtime(now)) time_str = time.strftime("%Y%m%d%H%M%S", time.localtime(now))
self._checkpoint_dir = "ckpt_" + time_str self._checkpoint_dir = "ckpt_" + time_str
......
...@@ -140,8 +140,8 @@ class BasicTask(object): ...@@ -140,8 +140,8 @@ class BasicTask(object):
# log item # log item
if not os.path.exists(self.config.checkpoint_dir): if not os.path.exists(self.config.checkpoint_dir):
mkdir(self.config.checkpoint_dir) mkdir(self.config.checkpoint_dir)
vdl_log_dir = os.path.join(self.config.checkpoint_dir, "visualization") tb_log_dir = os.path.join(self.config.checkpoint_dir, "visualization")
self.tb_writer = SummaryWriter(vdl_log_dir) self.tb_writer = SummaryWriter(tb_log_dir)
# run environment # run environment
self._phases = [] self._phases = []
......
...@@ -761,7 +761,8 @@ class RegressionReader(BaseReader): ...@@ -761,7 +761,8 @@ class RegressionReader(BaseReader):
label_map_config=None, label_map_config=None,
max_seq_len=128, max_seq_len=128,
do_lower_case=True, do_lower_case=True,
random_seed=None): random_seed=None,
use_task_id=False):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer( self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case) vocab_file=vocab_path, do_lower_case=do_lower_case)
...@@ -771,6 +772,10 @@ class RegressionReader(BaseReader): ...@@ -771,6 +772,10 @@ class RegressionReader(BaseReader):
self.cls_id = self.vocab["[CLS]"] self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"] self.sep_id = self.vocab["[SEP]"]
self.in_tokens = False self.in_tokens = False
self.use_task_id = use_task_id
if self.use_task_id:
self.task_id = 0
np.random.seed(random_seed) np.random.seed(random_seed)
...@@ -811,12 +816,28 @@ class RegressionReader(BaseReader): ...@@ -811,12 +816,28 @@ class RegressionReader(BaseReader):
padded_token_ids, padded_position_ids, padded_text_type_ids, padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels input_mask, batch_labels
] ]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_labels
]
else: else:
return_list = [ return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids, padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask input_mask
] ]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
]
return return_list return return_list
def data_generator(self, def data_generator(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册