未验证 提交 08217003 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the data lod in Transformer prediction. (#4571)

上级 48c280d3
......@@ -92,9 +92,14 @@ def do_predict(args):
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{
"name": name,
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
"name":
name,
"shape":
input_descs[name][0],
"dtype":
input_descs[name][1],
"lod_level":
input_descs[name][2] if len(input_descs[name]) > 2 else 0
} for name in input_field_names]
input_field = InputField(input_slots)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册