提交 bc362fe9 编写于 作者: K kinghuin 提交者: wuzewu

discard task_id

上级 f2a7e2a2
......@@ -45,8 +45,7 @@ if __name__ == '__main__':
# Setup feed list for data feeder
feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name,
inputs["task_ids"].name
inputs["segment_ids"].name, inputs["input_mask"].name
]
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
......@@ -68,8 +67,7 @@ if __name__ == '__main__':
reader = hub.reader.MultiLabelClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
max_seq_len=args.max_seq_len)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
......
......@@ -50,9 +50,10 @@ if __name__ == '__main__':
# Setup feed list for data feeder
feed_list = [
inputs["input_ids"].name, inputs["position_ids"].name,
inputs["segment_ids"].name, inputs["input_mask"].name,
inputs["task_ids"].name
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
......@@ -74,8 +75,7 @@ if __name__ == '__main__':
reader = hub.reader.MultiLabelClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
max_seq_len=args.max_seq_len)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
......
......@@ -55,8 +55,7 @@ if __name__ == '__main__':
reader = hub.reader.RegressionReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
max_seq_len=args.max_seq_len)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
......@@ -72,9 +71,6 @@ if __name__ == '__main__':
inputs["input_mask"].name,
]
if args.use_taskid:
feed_list.append(inputs["task_ids"].name)
# Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
......
......@@ -51,7 +51,6 @@ if __name__ == '__main__':
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=use_taskid,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
......
......@@ -40,7 +40,6 @@
--max_seq_len: ERNIE/BERT模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数
--use_data_parallel: 是否使用并行计算,默认False。打开该功能依赖nccl库。
--use_pyreader: 是否使用pyreader,默认False。
--use_taskid: 是否使用taskid,taskid是ERNIE 2.0特有的,use_taskid=True表示使用ERNIE 2.0;如果想使用ERNIE 1.0 或者BERT等module,use_taskid应该设置为False。
# 任务相关
--checkpoint_dir: 模型保存路径,PaddleHub会自动保存验证集上表现最好的模型
......@@ -86,8 +85,7 @@ dataset = hub.dataset.ChnSentiCorp()
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=128,
use_task_id=False)
max_seq_len=128)
metrics_choices = ["acc"]
```
......@@ -99,8 +97,6 @@ metrics_choices = ["acc"]
`max_seq_len` 需要与Step1中context接口传入的序列长度保持一致
`use_task_id` 表示是否使用ERNIR 2.0 module
ClassifyReader中的`data_generator`会自动按照模型对应词表对数据进行切词,以迭代器的方式返回ERNIE/BERT所需要的Tensor格式,包括`input_ids``position_ids``segment_id`与序列对应的mask `input_mask`.
**NOTE**: Reader返回tensor的顺序是固定的,默认按照input_ids, position_ids, segment_id, input_mask这一顺序返回。
......
......@@ -130,8 +130,7 @@ if __name__ == '__main__':
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
max_seq_len=args.max_seq_len)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
......@@ -147,9 +146,6 @@ if __name__ == '__main__':
inputs["input_mask"].name,
]
if args.use_taskid:
feed_list.append(inputs["task_ids"].name)
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_data_parallel=False,
......
......@@ -121,27 +121,11 @@ if __name__ == '__main__':
# Start preparing parameters for reader and task accoring to module
# For ernie_v2, it has an addition embedding named task_id
# For ernie_v2_chinese_tiny, it use an addition sentence_piece_vocab to tokenize
if module.name.startswith("ernie_v2"):
use_taskid = True
else:
use_taskid = False
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
check = [inputs["task_ids"].name]
global_block = program.global_block()
for op in global_block.ops:
for input_arg in op.input_arg_names:
for ch in check:
if ch in input_arg:
print(op)
check.append(input_arg)
break
exit(0)
pooled_output = outputs["pooled_output"]
# Setup feed list for data feeder
......@@ -152,8 +136,6 @@ if __name__ == '__main__':
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
if use_taskid:
feed_list.append(inputs["task_ids"].name)
# Finish preparing parameter for reader and task accoring to modul
# Define reader
......@@ -161,7 +143,6 @@ if __name__ == '__main__':
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
use_task_id=use_taskid,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
......
......@@ -606,15 +606,11 @@ class Module(object):
"task_ids"
]
logger.warning(
"%s will exploite task_id, the arguement use_taskid of Reader class must be True."
% self.name)
"For %s, it's no necessary to feed task_ids." % self.name)
else:
feed_list = [
"input_ids", "position_ids", "segment_ids", "input_mask"
]
logger.warning(
"%s has no task_id, the arguement use_taskid of Reader class must be False."
% self.name)
for tensor_name in feed_list:
seq_tensor_shape = [-1, max_seq_len, 1]
logger.info("The shape of input tensor[{}] set to {}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册