diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b6a299ba4009ba73afc9673fe008b97ca139c57b..1584bc76a9dd8ddff9d05a8cb693bcbd2e09fcde 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,11 +1,10 @@
-repos:
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
- rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
+ sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: a11d9314b22d8f8c7556443875b731ef05965464
+ sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
@@ -16,7 +15,7 @@ repos:
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
- rev: v1.0.1
+ sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
diff --git a/configs/rec/rec_d28_can.yml b/configs/rec/rec_d28_can.yml
index aeaccb6b0d245cc4be479d17597d0029647db574..9fe936ae1111e08abd3734c67d29b7d39fe4e9e3 100644
--- a/configs/rec/rec_d28_can.yml
+++ b/configs/rec/rec_d28_can.yml
@@ -5,14 +5,14 @@ Global:
print_batch_step: 10
save_model_dir: ./output/rec/can/
save_epoch_step: 1
- # evaluation is run every 1105 iterations
+ # evaluation is run every 1105 iterations (1 epoch)(batch_size = 8)
eval_batch_step: [0, 1105]
cal_metric_during_train: True
- pretrained_model: ./output/rec/can/CAN
- checkpoints: ./output/rec/can/CAN
- save_inference_dir: ./inference/rec_d28_can/
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
use_visualdl: False
- infer_img: doc/imgs_hme/hme_01.jpeg
+ infer_img: doc/datasets/crohme_demo/hme_00.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
max_text_length: 36
@@ -75,7 +75,7 @@ Metric:
Train:
dataset:
- name: HMERDataSet
+ name: PGDataSet
data_dir: ./train_data/CROHME/training/images/
transforms:
- DecodeImage:
@@ -83,19 +83,22 @@ Train:
- GrayImageChannelFormat:
normalize: True
inverse: True
+ - SeqLabelEncode:
+ character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
+ lower: False
- KeepKeys:
keep_keys: ['image', 'label']
- label_file_list: ["./train_data/CROHME/training/labels.json"]
+ label_file_list: ["./train_data/CROHME/training/labels.txt"]
loader:
shuffle: True
- batch_size_per_card: 2
- drop_last: True
- num_workers: 1
+ batch_size_per_card: 8
+ drop_last: False
+ num_workers: 4
collate_fn: DyMaskCollator
Eval:
dataset:
- name: HMERDataSet
+ name: PGDataSet
data_dir: ./train_data/CROHME/evaluation/images/
transforms:
- DecodeImage:
@@ -103,9 +106,12 @@ Eval:
- GrayImageChannelFormat:
normalize: True
inverse: True
+ - SeqLabelEncode:
+ character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
+ lower: False
- KeepKeys:
keep_keys: ['image', 'label']
- label_file_list: ["./train_data/CROHME/evaluation/labels.json"]
+ label_file_list: ["./train_data/CROHME/evaluation/labels.txt"]
loader:
shuffle: False
drop_last: False
diff --git a/doc/imgs_hme/hme_00.jpg b/doc/datasets/crohme_demo/hme_00.jpg
similarity index 100%
rename from doc/imgs_hme/hme_00.jpg
rename to doc/datasets/crohme_demo/hme_00.jpg
diff --git a/doc/imgs_hme/hme_01.jpg b/doc/datasets/crohme_demo/hme_01.jpg
similarity index 100%
rename from doc/imgs_hme/hme_01.jpg
rename to doc/datasets/crohme_demo/hme_01.jpg
diff --git a/doc/imgs_hme/hme_02.jpg b/doc/datasets/crohme_demo/hme_02.jpg
similarity index 100%
rename from doc/imgs_hme/hme_02.jpg
rename to doc/datasets/crohme_demo/hme_02.jpg
diff --git a/doc/doc_ch/algorithm_rec_can.md b/doc/doc_ch/algorithm_rec_can.md
index 9585dae0cf9ecc215b4ff2f6d656345418e197d4..8a012b490458b2d804d2ab69953a3c6dff25347b 100644
--- a/doc/doc_ch/algorithm_rec_can.md
+++ b/doc/doc_ch/algorithm_rec_can.md
@@ -1,4 +1,4 @@
-# 手写数学公式识别算法-ABINet
+# 手写数学公式识别算法-CAN
- [1. 算法简介](#1)
- [2. 环境配置](#2)
@@ -27,7 +27,7 @@
|模型 |骨干网络|配置文件|ExpRate|下载链接|
| ----- | ----- | ----- | ----- | ----- |
-|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar)|
+|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[训练模型](https://paddleocr.bj.bcebos.com/contribution/can_train.tar)|
## 2. 环境配置
@@ -60,16 +60,21 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs
python3 tools/train.py -c configs/rec/rec_d28_can.yml
-o Train.dataset.transforms.GrayImageChannelFormat.inverse=False
```
+- 默认每训练1个epoch(1105次iteration)进行1次评估,若您更改训练的batch_size,或更换数据集,请在训练时作出如下修改
+```
+python3 tools/train.py -c configs/rec/rec_d28_can.yml
+-o Global.eval_batch_step=[0, {length_of_dataset//batch_size}]
+```
#
### 3.2 评估
-可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
+可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/can_train.tar),使用如下命令进行评估:
```shell
-# 注意将pretrained_model的路径设置为本地路径。
-python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy
+# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN
```
@@ -78,9 +83,9 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy
+python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/datasets/crohme_demo/hme_00.jpg' Global.pretrained_model=./rec_d28_can_train/CAN
-# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_hme/'。
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/crohme_demo/'。
```
@@ -89,17 +94,16 @@ python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.a
### 4.1 Python推理
-首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_d28_can_train.tar) ),可以使用如下命令进行转换:
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/can_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
+python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
# 目前的静态图模型默认的输出长度最大为36,如果您需要预测更长的序列,请在导出模型时指定其输出序列为合适的值,例如 Architecture.Head.max_text_length=72
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
-- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应ABINet的`infer_shape`。
转换成功后,在目录下有三个文件:
```
@@ -112,18 +116,18 @@ python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_infe
执行如下命令进行模型推理:
```shell
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
+python3 tools/infer/predict_rec.py --image_dir="./doc/datasets/crohme_demo/hme_00.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
-# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_hme/'。
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/datasets/crohme_demo/'。
# 如果您需要在白底黑字的图片上进行预测,请设置 --rec_image_inverse=False
```
-![测试图片样例](../imgs_hme/hme_00.jpg)
+![测试图片样例](../datasets/crohme_demo/hme_00.jpg)
执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
```shell
-Predicts of ./doc/imgs_hme/hme_03.jpg:['x _ { k } x x _ { k } + y _ { k } y x _ { k }', []]
+Predicts of ./doc/imgs_hme/hme_00.jpg:['x _ { k } x x _ { k } + y _ { k } y x _ { k }', []]
```
diff --git a/doc/doc_en/algorithm_rec_can_en.md b/doc/doc_en/algorithm_rec_can_en.md
index 4d7a64f994e75c0ca3b2f6238465e3ad418b89a5..da6c9c6096fa7170b108012165b7c69862671e1a 100644
--- a/doc/doc_en/algorithm_rec_can_en.md
+++ b/doc/doc_en/algorithm_rec_can_en.md
@@ -25,7 +25,7 @@ Using CROHME handwrittem mathematical expression recognition datasets for traini
|Model|Backbone|config|exprate|Download link|
| --- | --- | --- | --- | --- |
-|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|coming soon|
+|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72|[trained model](https://paddleocr.bj.bcebos.com/contribution/can_train.tar)|
## 2. Environment
@@ -53,14 +53,14 @@ Evaluation:
```
# GPU evaluation
-python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/best_accuracy
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_d28_can.yml -o Global.pretrained_model=./rec_d28_can_train/CAN
```
Prediction:
```
# The configuration file used for prediction must match the training
-python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/imgs_hme/hme_01.jpg' Global.pretrained_model=./rec_d28_can_train/best_accuracy
+python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.attdecoder.is_train=False Global.infer_img='./doc/crohme_demo/hme_00.jpg' Global.pretrained_model=./rec_d28_can_train/CAN
```
@@ -68,16 +68,20 @@ python3 tools/infer_rec.py -c configs/rec/rec_d28_can.yml -o Architecture.Head.a
### 4.1 Python Inference
-First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
+First, the model saved during the CAN handwritten mathematical expression recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_d28_can.yml -o Global.save_inference_dir=./inference/rec_d28_can/ Architecture.Head.attdecoder.is_train=False
+
+# The default output max length of the model is 36. If you need to predict a longer sequence, please specify its output sequence as an appropriate value when exporting the model, as: Architecture.Head.max_ text_ length=72
```
-For RobustScanner text recognition model inference, the following commands can be executed:
+For CAN handwritten mathematical expression recognition model inference, the following commands can be executed:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_hme/hme_01.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_image_shape="1, 100, 100" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
+python3 tools/infer/predict_rec.py --image_dir="./doc/crohme_demo/hme_00.jpg" --rec_algorithm="CAN" --rec_batch_num=1 --rec_model_dir="./inference/rec_d28_can/" --rec_char_dict_path="./ppocr/utils/dict/latex_symbol_dict.txt"
+
+# If you need to predict on a picture with black characters on a white background, please set: -- rec_ image_ inverse=False
```
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 1f3de63de72f44a7daf8d641b2d5c6bc5568df8f..b602a346dbe4b0d45af287f25f05ead0f62daf44 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -37,7 +37,6 @@ from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
-from ppocr.data.hmer_dataset import HMERDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -56,7 +55,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
- 'LMDBDataSetSR', 'HMERDataSet'
+ 'LMDBDataSetSR'
]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
diff --git a/ppocr/data/collate_fn.py b/ppocr/data/collate_fn.py
index fec1e895ff033cd991ed3a927ad7e158989bab45..067b2158aca183c68c3a09999483c059bb10eb14 100644
--- a/ppocr/data/collate_fn.py
+++ b/ppocr/data/collate_fn.py
@@ -95,8 +95,8 @@ class DyMaskCollator(object):
1] > max_height else max_height
max_width = item[0].shape[2] if item[0].shape[
2] > max_width else max_width
- max_length = item[1].shape[0] if item[1].shape[
- 0] > max_length else max_length
+ max_length = len(item[1]) if len(item[
+ 1]) > max_length else max_length
proper_items.append(item)
images, image_masks = np.zeros(
@@ -111,7 +111,7 @@ class DyMaskCollator(object):
_, h, w = proper_items[i][0].shape
images[i][:, :h, :w] = proper_items[i][0]
image_masks[i][:, :h, :w] = 1
- l = proper_items[i][1].shape[0]
+ l = len(proper_items[i][1])
labels[i][:l] = proper_items[i][1]
label_masks[i][:l] = 1
diff --git a/ppocr/data/hmer_dataset.py b/ppocr/data/hmer_dataset.py
deleted file mode 100644
index d5d92f264b2604d68929bf08e8514c1a87b9198b..0000000000000000000000000000000000000000
--- a/ppocr/data/hmer_dataset.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os, json, random, traceback
-import numpy as np
-
-from PIL import Image
-from paddle.io import Dataset
-
-from .imaug import transform, create_operators
-
-
-class HMERDataSet(Dataset):
- def __init__(self, config, mode, logger, seed=None):
- super(HMERDataSet, self).__init__()
-
- self.logger = logger
- self.seed = seed
- self.mode = mode
-
- global_config = config['Global']
- dataset_config = config[mode]['dataset']
- self.data_dir = config[mode]['dataset']['data_dir']
-
- label_file_list = dataset_config['label_file_list']
- data_source_num = len(label_file_list)
- ratio_list = dataset_config.get("ratio_list", [1.0])
-
- self.data_lines, self.labels = self.get_image_info_list(label_file_list,
- ratio_list)
- self.data_idx_order_list = list(range(len(self.data_lines)))
- if self.mode == "train" and self.do_shuffle:
- self.shuffle_data_random()
-
- if isinstance(ratio_list, (float, int)):
- ratio_list = [float(ratio_list)] * int(data_source_num)
-
- assert len(
- ratio_list
- ) == data_source_num, "The length of ratio_list should be the same as the file_list."
-
- self.ops = create_operators(dataset_config['transforms'], global_config)
- self.need_reset = True in [x < 1 for x in ratio_list]
-
- def get_image_info_list(self, file_list, ratio_list):
- if isinstance(file_list, str):
- file_list = [file_list]
- labels = {}
- for idx, file in enumerate(file_list):
- with open(file, "r") as f:
- lines = json.load(f)
- labels.update(lines)
- data_lines = [name for name in labels.keys()]
- return data_lines, labels
-
- def shuffle_data_random(self):
- random.seed(self.seed)
- random.shuffle(self.data_lines)
- return
-
- def __len__(self):
- return len(self.data_idx_order_list)
-
- def __getitem__(self, idx):
- file_idx = self.data_idx_order_list[idx]
- data_name = self.data_lines[file_idx]
- try:
- file_name = data_name + '.jpg'
- img_path = os.path.join(self.data_dir, file_name)
- if not os.path.exists(img_path):
- raise Exception("{} does not exist!".format(img_path))
- with open(img_path, 'rb') as f:
- img = f.read()
-
- label = self.labels.get(data_name).split()
- label = np.array([int(item) for item in label])
-
- data = {'image': img, 'label': label}
- outs = transform(data, self.ops)
- except:
- self.logger.error(
- "When parsing line {}, error happened with msg: {}".format(
- file_name, traceback.format_exc()))
- outs = None
- if outs is None:
- # during evaluation, we should fix the idx to get same results for many times of evaluation.
- rnd_idx = np.random.randint(self.__len__())
- return self.__getitem__(rnd_idx)
- return outs
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 2a2ac2decd1abf4daf2c5325a8f69fc26f4fc0ef..ae916b2ee7003d9b96933f042c4c756f9105be40 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1476,4 +1476,33 @@ class CTLabelEncode(object):
data['polys'] = boxes
data['texts'] = txts
- return data
\ No newline at end of file
+ return data
+
+
+class SeqLabelEncode(BaseRecLabelEncode):
+ def __init__(self,
+ character_dict_path,
+ max_text_length=100,
+ use_space_char=False,
+ lower=True,
+ **kwargs):
+ super(SeqLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char, lower)
+
+ def encode(self, text_seq):
+ text_seq_encoded = []
+ for text in text_seq:
+ if text not in self.character:
+ continue
+ text_seq_encoded.append(self.dict.get(text))
+ if len(text_seq_encoded) == 0:
+ return None
+ return text_seq_encoded
+
+ def __call__(self, data):
+ label = data['label']
+ if isinstance(label, str):
+ label = label.strip().split()
+ label.append(self.end_str)
+ data['label'] = self.encode(label)
+ return data
diff --git a/test_tipc/configs/rec_d28_can/rec_d28_can.yml b/test_tipc/configs/rec_d28_can/rec_d28_can.yml
index aeaccb6b0d245cc4be479d17597d0029647db574..ac7b07712621d40aa33e2f0dd53c907b21d2e5f9 100644
--- a/test_tipc/configs/rec_d28_can/rec_d28_can.yml
+++ b/test_tipc/configs/rec_d28_can/rec_d28_can.yml
@@ -5,14 +5,14 @@ Global:
print_batch_step: 10
save_model_dir: ./output/rec/can/
save_epoch_step: 1
- # evaluation is run every 1105 iterations
+ # evaluation is run every 1105 iterations (1 epoch)(batch_size = 8)
eval_batch_step: [0, 1105]
cal_metric_during_train: True
- pretrained_model: ./output/rec/can/CAN
- checkpoints: ./output/rec/can/CAN
- save_inference_dir: ./inference/rec_d28_can/
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
use_visualdl: False
- infer_img: doc/imgs_hme/hme_01.jpeg
+ infer_img: doc/datasets/crohme_demo/hme_00.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
max_text_length: 36
@@ -75,37 +75,43 @@ Metric:
Train:
dataset:
- name: HMERDataSet
- data_dir: ./train_data/CROHME/training/images/
+ name: PGDataSet
+ data_dir: ./train_data/CROHME_lite/training/images/
transforms:
- DecodeImage:
channel_first: False
- GrayImageChannelFormat:
normalize: True
inverse: True
+ - SeqLabelEncode:
+ character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
+ lower: False
- KeepKeys:
keep_keys: ['image', 'label']
- label_file_list: ["./train_data/CROHME/training/labels.json"]
+ label_file_list: ["./train_data/CROHME_lite/training/labels.txt"]
loader:
shuffle: True
- batch_size_per_card: 2
- drop_last: True
- num_workers: 1
+ batch_size_per_card: 8
+ drop_last: False
+ num_workers: 4
collate_fn: DyMaskCollator
Eval:
dataset:
- name: HMERDataSet
- data_dir: ./train_data/CROHME/evaluation/images/
+ name: PGDataSet
+ data_dir: ./train_data/CROHME_lite/evaluation/images/
transforms:
- DecodeImage:
channel_first: False
- GrayImageChannelFormat:
normalize: True
inverse: True
+ - SeqLabelEncode:
+ character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
+ lower: False
- KeepKeys:
keep_keys: ['image', 'label']
- label_file_list: ["./train_data/CROHME/evaluation/labels.json"]
+ label_file_list: ["./train_data/CROHME_lite/evaluation/labels.txt"]
loader:
shuffle: False
drop_last: False
diff --git a/test_tipc/configs/rec_d28_can/train_infer_python.txt b/test_tipc/configs/rec_d28_can/train_infer_python.txt
index be50c59805b2f1b8a0baa5a61b9d5cd4a21b68df..731d327cd085b41a6bade9b7092dda7b2de9d9f9 100644
--- a/test_tipc/configs/rec_d28_can/train_infer_python.txt
+++ b/test_tipc/configs/rec_d28_can/train_infer_python.txt
@@ -1,6 +1,6 @@
===========================train_params===========================
model_name:rec_d28_can
-python:python
+python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
@@ -9,7 +9,7 @@ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=8
Global.pretrained_model:null
train_model_name:latest
-train_infer_img_dir:./doc/imgs_hme
+train_infer_img_dir:./doc/datasets/crohme_demo
null:null
##
trainer:norm_train
@@ -37,15 +37,15 @@ export2:null
train_model:./inference/rec_d28_can_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_d28_can/rec_d28_can.yml -o
infer_quant:False
-inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/latex_symbol_dict.txt --rec_image_shape="1,100,100" --rec_algorithm="CAN"
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/latex_symbol_dict.txt --rec_algorithm="CAN"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
---rec_model_dir:./output/
---image_dir:./doc/imgs_hme
+--rec_model_dir:
+--image_dir:./doc/datasets/crohme_demo
--save_log_path:./test/output/
--benchmark:True
null:null
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 5ca426e28cb6c2d499e5bfeca7e997e6175b0ce5..4aab17019f39edd15288b51564177b4b4f4b36c4 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -257,6 +257,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../
fi
+ if [ ${model_name} == "rec_d28_can" ]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/can_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf can_train.tar && cd ../
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/CROHME_lite.tar --no-check-certificate
+ cd ./train_data/ && tar xf CROHME_lite.tar && cd ../
+
+ fi
if [ ${model_name} == "layoutxlm_ser" ]; then
${python_name} -m pip install -r ppstructure/kie/requirements.txt
${python_name} -m pip install opencv-python -U
diff --git a/test_tipc/readme.md b/test_tipc/readme.md
index 1442ee1c86a7c1319446a0eb22c08287e1ce689a..9f02c2e3084585618cb1424b6858d16b79494d9b 100644
--- a/test_tipc/readme.md
+++ b/test_tipc/readme.md
@@ -44,6 +44,7 @@
| SAST |det_r50_vd_sast_totaltext_v2.0 | 检测 | 支持 | 多机多卡
混合精度 | - | - |
| Rosetta|rec_mv3_none_none_ctc_v2.0 | 识别 | 支持 | 多机多卡
混合精度 | - | - |
| Rosetta|rec_r34_vd_none_none_ctc_v2.0 | 识别 | 支持 | 多机多卡
混合精度 | - | - |
+| CAN |rec_d28_can | 识别 | 支持 | 多机多卡
混合精度 | - | - |
| CRNN |rec_mv3_none_bilstm_ctc_v2.0 | 识别 | 支持 | 多机多卡
混合精度 | - | - |
| CRNN |rec_r34_vd_none_bilstm_ctc_v2.0| 识别 | 支持 | 多机多卡
混合精度 | - | - |
| StarNet|rec_mv3_tps_bilstm_ctc_v2.0 | 识别 | 支持 | 多机多卡
混合精度 | - | - |
diff --git a/tools/program.py b/tools/program.py
index c491247a64697774cca73e209b714f7560c5fcdd..a0594e950d969c39eb1cb363435897c5f219f0e4 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -544,7 +544,7 @@ def eval(model,
elif model_type in ['sr']:
eval_class(preds, batch_numpy)
elif model_type in ['can']:
- eval_class(preds[0], batch_numpy[2:], epoch_reset=False)
+ eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
else:
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)