未验证 提交 0ed92b2d 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #172 from tianxin1860/develop

 fix pyreader return empty array as None in paddle 1.5
...@@ -9,7 +9,6 @@ import paddle.fluid as fluid ...@@ -9,7 +9,6 @@ import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer from paddle.fluid.initializer import NormalInitializer
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from bilm import elmo_encoder from bilm import elmo_encoder
import ipdb
def lex_net(args, word_dict_len, label_dict_len): def lex_net(args, word_dict_len, label_dict_len):
......
...@@ -177,6 +177,8 @@ def evaluate(exe, test_program, test_pyreader, graph_vars, eval_phase): ...@@ -177,6 +177,8 @@ def evaluate(exe, test_program, test_pyreader, graph_vars, eval_phase):
total_acc += np.sum(np_acc * np_num_seqs) total_acc += np.sum(np_acc * np_num_seqs)
total_num_seqs += np.sum(np_num_seqs) total_num_seqs += np.sum(np_num_seqs)
labels.extend(np_labels.reshape((-1)).tolist()) labels.extend(np_labels.reshape((-1)).tolist())
if np_qids is None:
np_qids = np.array([])
qids.extend(np_qids.reshape(-1).tolist()) qids.extend(np_qids.reshape(-1).tolist())
scores.extend(np_probs[:, 1].reshape(-1).tolist()) scores.extend(np_probs[:, 1].reshape(-1).tolist())
np_preds = np.argmax(np_probs, axis=1).astype(np.float32) np_preds = np.argmax(np_probs, axis=1).astype(np.float32)
......
...@@ -238,7 +238,7 @@ class ClassifyReader(BaseReader): ...@@ -238,7 +238,7 @@ class ClassifyReader(BaseReader):
batch_labels = [record.label_id for record in batch_records] batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1]) batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
if batch_records[0].qid: if batch_records[0].qid is not None:
batch_qids = [record.qid for record in batch_records] batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册