未验证 提交 490b85cf 编写于 作者: W whs 提交者: GitHub

Fix eval.py in ocr ctc demo. (#1970)

上级 6137a6b0

运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
......@@ -156,12 +155,13 @@ env CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --parallel=True
通过以下命令调用评估脚本用指定数据集对模型进行评估:
```
env CUDA_VISIBLE_DEVICE=0 python eval.py \
env CUDA_VISIBLE_DEVICES=0 python eval.py \
--model_path="./models/model_0" \
--input_images_dir="./eval_data/images/" \
--input_images_list="./eval_data/eval_list\" \
--input_images_list="./eval_data/eval_list"
```
执行`python train.py --help`可查看参数详细说明。
......@@ -170,7 +170,7 @@ env CUDA_VISIBLE_DEVICE=0 python eval.py \
从标准输入读取一张图片的路径,并对齐进行预测:
```
env CUDA_VISIBLE_DEVICE=0 python infer.py \
env CUDA_VISIBLE_DEVICES=0 python infer.py \
--model_path="models/model_00044_15000"
```
......@@ -193,7 +193,7 @@ result: [2067 2067 8187 8477 5027 7191 2431 1462]
从文件中批量读取图片路径,并对其进行预测:
```
env CUDA_VISIBLE_DEVICE=0 python infer.py \
env CUDA_VISIBLE_DEVICES=0 python infer.py \
--model_path="models/model_00044_15000" \
--input_images_list="data/test.list"
```
......@@ -204,3 +204,5 @@ env CUDA_VISIBLE_DEVICE=0 python infer.py \
|- |:-: |
|[ocr_ctc_params](https://paddle-ocr-models.bj.bcebos.com/ocr_ctc.zip) | 22.3% |
|[ocr_attention_params](https://paddle-ocr-models.bj.bcebos.com/ocr_attention.zip) | 15.8%|
>在本文示例中,均可通过修改`CUDA_VISIBLE_DEVICES`改变当前任务使用的显卡号。
......@@ -10,6 +10,11 @@ from os import path
from paddle.dataset.image import load_image
import paddle
try:
input = raw_input
except NameError:
pass
SOS = 0
EOS = 1
NUM_CLASSES = 95
......@@ -175,7 +180,7 @@ class DataGenerator(object):
yield img, label
else:
while True:
img_path = raw_input("Please input the path of image: ")
img_path = input("Please input the path of image: ")
img = Image.open(img_path).convert('L')
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
......
......@@ -31,7 +31,8 @@ def evaluate(args):
num_classes = data_reader.num_classes()
data_shape = data_reader.data_shape()
# define network
evaluator, cost = eval(data_shape, num_classes)
evaluator, cost = eval(
data_shape, num_classes, use_cudnn=True if args.use_gpu else False)
# data reader
test_reader = data_reader.test(
......@@ -62,8 +63,8 @@ def evaluate(args):
count += 1
exe.run(fluid.default_main_program(), feed=get_feeder_data(data, place))
avg_distance, avg_seq_error = evaluator.eval(exe)
print("Read %d samples; avg_distance: %s; avg_seq_error: %s" % (
count, avg_distance, avg_seq_error))
print("Read %d samples; avg_distance: %s; avg_seq_error: %s" %
(count, avg_distance, avg_seq_error))
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册