You need to sign in or sign up before continuing.
未验证 提交 64cde5d1 编写于 作者: L LiuHao 提交者: GitHub

Update run_ernie_classifier.py (#4790)

上级 eb7eb9cd
...@@ -38,16 +38,16 @@ from utils import init_checkpoint ...@@ -38,16 +38,16 @@ from utils import init_checkpoint
def ernie_pyreader(args, pyreader_name): def ernie_pyreader(args, pyreader_name):
src_ids = fluid.data( src_ids = fluid.layers.data(
name="src_ids", shape=[None, args.max_seq_len, 1], dtype="int64") name="src_ids", shape=[None, args.max_seq_len, 1], dtype="int64")
sent_ids = fluid.data( sent_ids = fluid.layers.data(
name="sent_ids", shape=[None, args.max_seq_len, 1], dtype="int64") name="sent_ids", shape=[None, args.max_seq_len, 1], dtype="int64")
pos_ids = fluid.data( pos_ids = fluid.layers.data(
name="pos_ids", shape=[None, args.max_seq_len, 1], dtype="int64") name="pos_ids", shape=[None, args.max_seq_len, 1], dtype="int64")
input_mask = fluid.data( input_mask = fluid.layers.data(
name="input_mask", shape=[None, args.max_seq_len, 1], dtype="float32") name="input_mask", shape=[None, args.max_seq_len, 1], dtype="float32")
labels = fluid.data(name="labels", shape=[None, 1], dtype="int64") labels = fluid.layers.data(name="labels", shape=[None, 1], dtype="int64")
seq_lens = fluid.data(name="seq_lens", shape=[None], dtype="int64") seq_lens = fluid.layers.data(name="seq_lens", shape=[None], dtype="int64")
pyreader = fluid.io.DataLoader.from_generator( pyreader = fluid.io.DataLoader.from_generator(
feed_list=[src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens], feed_list=[src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册