未验证 提交 1039a9ac 编写于 作者: Z Zeyu Chen 提交者: GitHub

Merge pull request #19 from Steffy-zxf/update-sequence-labeling-predict

Update sequence-labeling predict
......@@ -53,29 +53,25 @@ if __name__ == '__main__':
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.program_guard(program):
label = fluid.layers.data(
name="label", shape=[args.max_seq_len, 1], dtype='int64')
seq_len = fluid.layers.data(name="seq_len", shape=[1], dtype='int64')
# Use "sequence_outputs" for token-level output.
sequence_output = output_dict["sequence_output"]
# Define a classfication finetune task by PaddleHub's API
seq_label_task = hub.create_seq_label_task(
feature=sequence_output,
num_classes=dataset.num_labels,
max_seq_len=args.max_seq_len)
# Setup feed list for data feeder
# Must feed all the tensor of ERNIE's module need
# Compared to classification task, we need add seq_len tensor to feedlist
feed_list = [
input_dict["input_ids"].name, input_dict["position_ids"].name,
input_dict["segment_ids"].name, input_dict["input_mask"].name,
label.name, seq_len
seq_label_task.variable('label').name,
seq_label_task.variable('seq_len').name
]
# Define a classfication finetune task by PaddleHub's API
seq_label_task = hub.create_seq_label_task(
feature=sequence_output,
labels=label,
num_classes=dataset.num_labels,
seq_len=seq_len)
fetch_list = [
seq_label_task.variable("labels").name,
seq_label_task.variable("infers").name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册