提交 ade18e13 编写于 作者: D dyning

add score in rec_infer

上级 78d90511
...@@ -48,6 +48,7 @@ class LMDBReader(object): ...@@ -48,6 +48,7 @@ class LMDBReader(object):
elif params['mode'] == "test": elif params['mode'] == "test":
self.batch_size = 1 self.batch_size = 1
self.infer_img = params["infer_img"] self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
......
...@@ -110,7 +110,11 @@ class RecModel(object): ...@@ -110,7 +110,11 @@ class RecModel(object):
return loader, outputs return loader, outputs
elif mode == "export": elif mode == "export":
predict = predicts['predict'] predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict) predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
return loader, {'decoded_out': decoded_out} predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return loader, {'decoded_out': decoded_out, 'predicts': predict}
...@@ -123,6 +123,8 @@ class AttentionPredict(object): ...@@ -123,6 +123,8 @@ class AttentionPredict(object):
full_ids = fluid.layers.fill_constant_batch_size_like( full_ids = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='int64', value=1) input=init_state, shape=[-1, 1], dtype='int64', value=1)
full_scores = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='float32', value=1)
cond = layers.less_than(x=counter, y=array_len) cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
...@@ -171,6 +173,9 @@ class AttentionPredict(object): ...@@ -171,6 +173,9 @@ class AttentionPredict(object):
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1) new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
fluid.layers.assign(new_ids, full_ids) fluid.layers.assign(new_ids, full_ids)
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
fluid.layers.assign(new_scores, full_scores)
layers.increment(x=counter, value=1, in_place=True) layers.increment(x=counter, value=1, in_place=True)
# update the memories # update the memories
...@@ -184,7 +189,7 @@ class AttentionPredict(object): ...@@ -184,7 +189,7 @@ class AttentionPredict(object):
length_cond = layers.less_than(x=counter, y=array_len) length_cond = layers.less_than(x=counter, y=array_len)
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices)) finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
return full_ids return full_ids, full_scores
def __call__(self, inputs, labels=None, mode=None): def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs) encoder_features = self.encoder(inputs)
...@@ -223,10 +228,10 @@ class AttentionPredict(object): ...@@ -223,10 +228,10 @@ class AttentionPredict(object):
decoder_size, char_num) decoder_size, char_num)
_, decoded_out = layers.topk(input=predict, k=1) _, decoded_out = layers.topk(input=predict, k=1)
decoded_out = layers.lod_reset(decoded_out, y=label_out) decoded_out = layers.lod_reset(decoded_out, y=label_out)
predicts = {'predict': predict, 'decoded_out': decoded_out} predicts = {'predict':predict, 'decoded_out':decoded_out}
else: else:
ids = self.gru_attention_infer( ids, predict = self.gru_attention_infer(
decoder_boot, self.max_length, char_num, word_vector_dim, decoder_boot, self.max_length, char_num, word_vector_dim,
encoded_vector, encoded_proj, decoder_size) encoded_vector, encoded_proj, decoder_size)
predicts = {'decoded_out': ids} predicts = {'predict':predict, 'decoded_out':ids}
return predicts return predicts
...@@ -79,34 +79,44 @@ def main(): ...@@ -79,34 +79,44 @@ def main():
blobs = reader_main(config, 'test')() blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img'] infer_img = config['TestReader']['infer_img']
loss_type = config['Global']['loss_type']
infer_list = get_image_file_list(infer_img) infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list) max_img_num = len(infer_list)
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num): for i in range(max_img_num):
print("infer_img:",infer_list[i]) logger.info("infer_img:%s" % infer_list[i])
img = next(blobs) img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
fetch_list=fetch_varname_list, fetch_list=fetch_varname_list,
return_numpy=False) return_numpy=False)
if loss_type == "ctc":
preds = np.array(predict[0]) preds = np.array(predict[0])
if preds.shape[1] == 1:
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_lod = predict[0].lod()[0] preds_lod = predict[0].lod()[0]
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
else: probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention":
preds = np.array(predict[0])
probs = np.array(predict[1])
end_pos = np.where(preds[0, :] == 1)[0] end_pos = np.where(preds[0, :] == 1)[0]
if len(end_pos) <= 1: if len(end_pos) <= 1:
preds_text = preds[0, 1:] preds = preds[0, 1:]
score = np.mean(probs[0, 1:])
else: else:
preds_text = preds[0, 1:end_pos[1]] preds = preds[0, 1:end_pos[1]]
preds_text = preds_text.reshape(-1) score = np.mean(probs[0, 1:end_pos[1]])
preds_text = char_ops.decode(preds_text) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
print("\t index:",preds) print("\t index:", preds)
print("\t word :",preds_text) print("\t word :", preds_text)
print("\t score :", score)
# save for inference model # save for inference model
target_var = [] target_var = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册