未验证 提交 4cac91eb 编写于 作者: D dyning 提交者: GitHub

Merge pull request #132 from tink2123/add_rec_score

Add rec score
......@@ -36,6 +36,9 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
#### 2.inference模型下载
*windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下*
#### (1)超轻量级中文OCR模型下载
```
mkdir inference && cd inference
......@@ -63,6 +66,9 @@ cd ..
# 设置PYTHONPATH环境变量
export PYTHONPATH=.
# windows下设置环境变量
SET PYTHONPATH=.
# 预测image_dir指定的单张图像
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_mv3_db/" --rec_model_dir="./inference/ch_rec_mv3_crnn/"
......
......@@ -10,4 +10,3 @@ EvalReader:
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
infer_img: ./infer_img
......@@ -18,6 +18,8 @@ Global:
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -11,4 +11,3 @@ EvalReader:
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
......@@ -11,4 +11,3 @@ EvalReader:
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
......@@ -17,6 +17,8 @@ Global:
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -17,6 +17,7 @@ Global:
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -17,6 +17,7 @@ Global:
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -13,10 +13,13 @@ Global:
max_text_length: 25
character_type: en
loss_type: attention
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -13,10 +13,12 @@ Global:
max_text_length: 25
character_type: en
loss_type: ctc
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
......
......@@ -17,6 +17,8 @@ Global:
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -17,6 +17,7 @@ Global:
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -17,6 +17,8 @@ Global:
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -17,6 +17,8 @@ Global:
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
......@@ -46,6 +46,9 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/Res
```
**启动训练**
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
python3 tools/train.py -c configs/det/det_mv3_db.yml
```
......
......@@ -165,6 +165,16 @@ STAR-Net文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
### 3.基于Attention损失的识别模型推理
基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE"
RARE 文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
```
![](imgs_words_en/word_336.png)
执行命令后,上面图像的识别结果如下:
......
......@@ -8,6 +8,8 @@ PaddleOCR 工作环境
建议使用我们提供的docker运行PaddleOCR,有关docker使用请参考[链接](https://docs.docker.com/get-started/)
*如您希望使用 mac 或 windows直接运行预测代码,可以从第2步开始执行。*
1. (建议)准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。
```
# 切换到工作目录下
......@@ -54,6 +56,10 @@ python3 -m pip install paddlepaddle-gpu==1.7.2.post97 -i https://pypi.tuna.tsing
如果您的机器安装的是CUDA10,请运行以下命令安装
python3 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple
如果您的机器是CPU,请运行以下命令安装
python3 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
```
......
......@@ -41,6 +41,8 @@ PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt
# 测试集标签
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
```
最终训练集应有如下文件结构:
......@@ -111,6 +113,8 @@ tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar
开始训练:
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
# 设置PYTHONPATH路径
export PYTHONPATH=$PYTHONPATH:.
......@@ -168,10 +172,11 @@ Global:
评估数据集可以通过 `configs/rec/rec_icdar15_reader.yml` 修改EvalReader中的 `label_file_path` 设置。
*注意* 评估时必须确保配置文件中 infer_img 字段为空
```
export CUDA_VISIBLE_DEVICES=0
# GPU 评估, Global.checkpoints 为待测权重
python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
python3 tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
### 预测
......@@ -184,7 +189,7 @@ python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkp
```
# 预测英文结果
python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg
python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
预测图片:
......
......@@ -61,8 +61,6 @@ class TrainReader(object):
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
......
......@@ -17,6 +17,8 @@ import cv2
import numpy as np
import json
import sys
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from .data_augment import AugmentData
from .random_crop_data import RandomCropData
......@@ -100,6 +102,7 @@ class DBProcessTrain(object):
img_path, gt_label = self.convert_label_infor(label_infor)
imgvalue = cv2.imread(img_path)
if imgvalue is None:
logger.info("{} does not exist!".format(img_path))
return None
data = self.make_data_dict(imgvalue, gt_label)
data = AugmentData(data)
......
......@@ -41,13 +41,18 @@ class LMDBReader(object):
self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length']
self.mode = params['mode']
self.drop_last = False
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == "eval":
self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test":
self.batch_size = 1
self.infer_img = params["infer_img"]
self.drop_last = False
self.infer_img = params['infer_img']
def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {}
dataset_idx = 0
......@@ -100,13 +105,18 @@ class LMDBReader(object):
process_id = 0
def sample_iter_reader():
if self.mode == 'test':
if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape)
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
......@@ -126,9 +136,13 @@ class LMDBReader(object):
if sample_info is None:
continue
img, label = sample_info
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
outs = process_image(
img=img,
image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None:
continue
yield outs
......@@ -136,6 +150,7 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
......@@ -143,10 +158,11 @@ class LMDBReader(object):
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test':
if self.infer_img is None:
return batch_iter_reader
return sample_iter_reader
......@@ -165,26 +181,34 @@ class SimpleReader(object):
self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length']
self.mode = params['mode']
self.infer_img = params['infer_img']
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == 'eval':
self.batch_size = params['test_batch_size_per_card']
self.drop_last = True
else:
self.batch_size = 1
self.infer_img = params['infer_img']
self.batch_size = params['test_batch_size_per_card']
self.drop_last = False
def __call__(self, process_id):
if self.mode != 'train':
process_id = 0
def sample_iter_reader():
if self.mode == 'test':
if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape)
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img
else:
with open(self.label_file_path, "rb") as fin:
......@@ -192,7 +216,7 @@ class SimpleReader(object):
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
if sys.platform=="win32":
if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.")
self.num_workers = 1
......@@ -204,7 +228,7 @@ class SimpleReader(object):
if img is None:
logger.info("{} does not exist!".format(img_path))
continue
if img.shape[-1]==1 or len(list(img.shape))==2:
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1]
......@@ -222,9 +246,10 @@ class SimpleReader(object):
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test':
if self.infer_img is None:
return batch_iter_reader
return sample_iter_reader
......@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape):
return padding_im
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 0
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def get_img_data(value):
"""get_img_data"""
if not value:
......@@ -66,8 +92,13 @@ def process_image(img,
label=None,
char_ops=None,
loss_type=None,
max_text_length=None):
max_text_length=None,
tps=None,
infer_mode=False):
if not infer_mode or char_ops.character_type == "en" or tps != None:
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :]
if label is not None:
char_num = char_ops.get_char_num()
......
......@@ -30,6 +30,8 @@ class RecModel(object):
global_params = params['Global']
char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num
self.char_type = global_params['character_type']
self.infer_img = global_params['infer_img']
if "TPS" in params:
tps_params = deepcopy(params["TPS"])
tps_params.update(global_params)
......@@ -60,8 +62,8 @@ class RecModel(object):
def create_feed(self, mode):
image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention":
label_in = fluid.data(
name='label_in',
......@@ -86,6 +88,16 @@ class RecModel(object):
use_double_buffer=True,
iterable=False)
else:
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set img_shape to be the same , it may affect the inference effect"
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None
loader = None
return image, labels, loader
......@@ -110,7 +122,11 @@ class RecModel(object):
return loader, outputs
elif mode == "export":
predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else:
return loader, {'decoded_out': decoded_out}
predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return loader, {'decoded_out': decoded_out, 'predicts': predict}
......@@ -123,6 +123,8 @@ class AttentionPredict(object):
full_ids = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='int64', value=1)
full_scores = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='float32', value=1)
cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond)
......@@ -171,6 +173,9 @@ class AttentionPredict(object):
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
fluid.layers.assign(new_ids, full_ids)
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
fluid.layers.assign(new_scores, full_scores)
layers.increment(x=counter, value=1, in_place=True)
# update the memories
......@@ -184,7 +189,7 @@ class AttentionPredict(object):
length_cond = layers.less_than(x=counter, y=array_len)
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
return full_ids
return full_ids, full_scores
def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs)
......@@ -223,10 +228,10 @@ class AttentionPredict(object):
decoder_size, char_num)
_, decoded_out = layers.topk(input=predict, k=1)
decoded_out = layers.lod_reset(decoded_out, y=label_out)
predicts = {'predict': predict, 'decoded_out': decoded_out}
predicts = {'predict':predict, 'decoded_out':decoded_out}
else:
ids = self.gru_attention_infer(
ids, predict = self.gru_attention_infer(
decoder_boot, self.max_length, char_num, word_vector_dim,
encoded_vector, encoded_proj, decoder_size)
predicts = {'decoded_out': ids}
predicts = {'predict':predict, 'decoded_out':ids}
return predicts
......@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_sample_num = 0
total_acc_num = 0
total_batch_num = 0
if mode == "test":
if mode == "eval":
is_remove_duplicate = False
else:
is_remove_duplicate = True
......@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict):
total_correct_number = 0
eval_data_acc_info = {}
for eval_data in eval_data_list:
config['EvalReader']['lmdb_sets_dir'] = \
config['TestReader']['lmdb_sets_dir'] = \
eval_data_dir + "/" + eval_data
eval_reader = reader_main(config=config, mode="eval")
eval_reader = reader_main(config=config, mode="test")
eval_info_dict['reader'] = eval_reader
metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
metrics = eval_rec_run(exe, config, eval_info_dict, "test")
total_evaluation_data_number += metrics['total_sample_num']
total_correct_number += metrics['total_acc_num']
eval_data_acc_info[eval_data] = metrics
......
......@@ -32,10 +32,16 @@ class TextRecognizer(object):
self.rec_image_shape = image_shape
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path
if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc'
else:
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio):
......@@ -80,13 +86,14 @@ class TextRecognizer(object):
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0]
predict_batch = self.output_tensors[1].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0]
elapse = time.time() - starttime
predict_time += elapse
starttime = time.time()
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
......@@ -100,6 +107,22 @@ class TextRecognizer(object):
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score])
return rec_res, predict_time
......@@ -116,7 +139,17 @@ if __name__ == "__main__":
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" %
......
......@@ -21,6 +21,7 @@ import time
import multiprocessing
import numpy as np
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
......@@ -54,6 +55,7 @@ def main():
program.merge_config(FLAGS.opt)
logger.info(config)
char_ops = CharacterOps(config['Global'])
loss_type = config['Global']['loss_type']
config['Global']['char_ops'] = char_ops
# check if set use_gpu=True in paddlepaddle cpu version
......@@ -78,35 +80,44 @@ def main():
init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img']
infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
print("infer_img:",infer_list[i])
print("infer_img:%s" % infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog,
feed={"image": img},
fetch_list=fetch_varname_list,
return_numpy=False)
if loss_type == "ctc":
preds = np.array(predict[0])
if preds.shape[1] == 1:
preds = preds.reshape(-1)
preds_lod = predict[0].lod()[0]
preds_text = char_ops.decode(preds)
else:
probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention":
preds = np.array(predict[0])
probs = np.array(predict[1])
end_pos = np.where(preds[0, :] == 1)[0]
if len(end_pos) <= 1:
preds_text = preds[0, 1:]
preds = preds[0, 1:]
score = np.mean(probs[0, 1:])
else:
preds_text = preds[0, 1:end_pos[1]]
preds_text = preds_text.reshape(-1)
preds_text = char_ops.decode(preds_text)
preds = preds[0, 1:end_pos[1]]
score = np.mean(probs[0, 1:end_pos[1]])
preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
print("\t index:",preds)
print("\t word :",preds_text)
print("\t index:", preds)
print("\t word :", preds_text)
print("\t score :", score)
# save for inference model
target_var = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册