diff --git a/configs/cls/cls_mv3.yml b/configs/cls/cls_mv3.yml index 124eb48263a90bc06926649ddd77f1ba494a9677..57afab507c03c2a32f1665f908170de05d91143a 100755 --- a/configs/cls/cls_mv3.yml +++ b/configs/cls/cls_mv3.yml @@ -1,21 +1,22 @@ Global: algorithm: CLS - use_gpu: false - epoch_num: 30 + use_gpu: False + epoch_num: 100 log_smooth_window: 20 - print_batch_step: 10 - save_model_dir: output/cls_mb3 + print_batch_step: 100 + save_model_dir: output/cls_mv3 save_epoch_step: 3 - eval_batch_step: 100 - train_batch_size_per_card: 256 - test_batch_size_per_card: 256 - image_shape: [3, 32, 100] - label_list: [0,180] + eval_batch_step: 500 + train_batch_size_per_card: 512 + test_batch_size_per_card: 512 + image_shape: [3, 48, 192] + label_list: ['0','180'] + distort: True reader_yml: ./configs/cls/cls_reader.yml pretrain_weights: - checkpoints: /Users/zhoujun20/Desktop/code/class_model/cls_mb3_ultra_small_0.35/best_accuracy + checkpoints: save_inference_dir: - infer_img: /Users/zhoujun20/Desktop/code/PaddleOCR/doc/imgs_words/ch/word_1.jpg + infer_img: Architecture: function: ppocr.modeling.architectures.cls_model,ClsModel @@ -23,7 +24,7 @@ Architecture: Backbone: function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3 scale: 0.35 - model_name: Ultra_small + model_name: small Head: function: ppocr.modeling.heads.cls_head,ClsHead @@ -38,6 +39,6 @@ Optimizer: beta1: 0.9 beta2: 0.999 decay: - function: piecewise_decay - boundaries: [20,30] - decay_rate: 0.1 + function: cosine_decay + step_each_epoch: 1169 + total_epoch: 100 \ No newline at end of file diff --git a/configs/cls/cls_reader.yml b/configs/cls/cls_reader.yml index 3002fcbdcc6c75c442f8b3e30abc0074a14fd74e..2b1d4c4e75217998f2c489bcd3bfbbb8b8b7f415 100755 --- a/configs/cls/cls_reader.yml +++ b/configs/cls/cls_reader.yml @@ -1,13 +1,13 @@ TrainReader: reader_function: ppocr.data.cls.dataset_traversal,SimpleReader - num_workers: 1 - img_set_dir: / - label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt + num_workers: 8 + img_set_dir: ./train_data/cls + label_file_path: ./train_data/cls/train.txt EvalReader: reader_function: ppocr.data.cls.dataset_traversal,SimpleReader - img_set_dir: / - label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt + img_set_dir: ./train_data/cls + label_file_path: ./train_data/cls/test.txt TestReader: reader_function: ppocr.data.cls.dataset_traversal,SimpleReader diff --git a/deploy/cpp_infer/include/config.h b/deploy/cpp_infer/include/config.h index a5f19c32839a3b3995e690c14ce5bb4c79db161b..27539ea7934dc192e86bca3ea6bfd7999ee229a3 100644 --- a/deploy/cpp_infer/include/config.h +++ b/deploy/cpp_infer/include/config.h @@ -57,6 +57,8 @@ public: this->char_list_file.assign(config_map_["char_list_file"]); + this->use_angle_cls = bool(stoi(config_map_["use_angle_cls"])); + this->cls_model_dir.assign(config_map_["cls_model_dir"]); this->cls_thresh = stod(config_map_["cls_thresh"]); @@ -88,6 +90,8 @@ public: std::string rec_model_dir; + bool use_angle_cls; + std::string char_list_file; std::string cls_model_dir; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 68237170beabb1ecd386821d97c2eefb16435345..a8b99a5960ac3e6238dfea2285ec51c9e80e1749 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -58,7 +58,7 @@ public: void LoadModel(const std::string &model_dir); void Run(std::vector>> boxes, cv::Mat &img, - Classifier &cls); + Classifier *cls); private: std::shared_ptr predictor_; diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 989424d0b58bbf6c307dc07d3e461e93ce0ecc10..e708a6e341e6dd5ba66abe46456e2d74a89e0cb5 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -53,10 +53,15 @@ int main(int argc, char **argv) { config.cpu_math_library_num_threads, config.use_mkldnn, config.use_zero_copy_run, config.max_side_len, config.det_db_thresh, config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize); - Classifier cls(config.cls_model_dir, config.use_gpu, config.gpu_id, - config.gpu_mem, config.cpu_math_library_num_threads, - config.use_mkldnn, config.use_zero_copy_run, - config.cls_thresh); + + Classifier *cls = nullptr; + if (config.use_angle_cls == true) { + cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id, + config.gpu_mem, config.cpu_math_library_num_threads, + config.use_mkldnn, config.use_zero_copy_run, + config.cls_thresh); + } + CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, config.use_mkldnn, config.use_zero_copy_run, diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 0e06b8b37ae3e5937a80fc138945296c29acdfe5..e37994b562cc4bf593332432a990afe4c6697531 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -17,7 +17,7 @@ namespace PaddleOCR { void CRNNRecognizer::Run(std::vector>> boxes, - cv::Mat &img, Classifier &cls) { + cv::Mat &img, Classifier *cls) { cv::Mat srcimg; img.copyTo(srcimg); cv::Mat crop_img; @@ -27,8 +27,9 @@ void CRNNRecognizer::Run(std::vector>> boxes, int index = 0; for (int i = boxes.size() - 1; i >= 0; i--) { crop_img = GetRotateCropImage(srcimg, boxes[i]); - - crop_img = cls.Run(crop_img); + if (cls != nullptr) { + crop_img = cls->Run(crop_img); + } float wh_ratio = float(crop_img.cols) / float(crop_img.rows); diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt index c59e5d55daa3b289210c67ac1d6ae08470218b9e..18360086787f5ee8fc6d32f9d006b5fc3b7b47b9 100644 --- a/deploy/cpp_infer/tools/config.txt +++ b/deploy/cpp_infer/tools/config.txt @@ -4,23 +4,23 @@ gpu_id 0 gpu_mem 4000 cpu_math_library_num_threads 10 use_mkldnn 0 -use_zero_copy_run 1 +use_zero_copy_run 0 # det config max_side_len 960 det_db_thresh 0.3 det_db_box_thresh 0.5 det_db_unclip_ratio 2.0 -det_model_dir ./inference/det_db +det_model_dir ../model/det # cls config -cls_model_dir ./inference/cls +use_angle_cls 1 +cls_model_dir ../model/cls cls_thresh 0.9 # rec config -rec_model_dir ./inference/rec_crnn -char_list_file ../../ppocr/utils/ppocr_keys_v1.txt +rec_model_dir ../model/rec +char_list_file ../model/ppocr_keys_v1.txt # show the detection results -visualize 1 - +visualize 1 \ No newline at end of file diff --git a/doc/doc_ch/angle_class.md b/doc/doc_ch/angle_class.md new file mode 100644 index 0000000000000000000000000000000000000000..e884d5ef48801fb595a422ced95ab5e3b15b627c --- /dev/null +++ b/doc/doc_ch/angle_class.md @@ -0,0 +1,127 @@ +## 文字角度分类 + +### 数据准备 + +请按如下步骤设置数据集: + +训练数据的默认存储路径是 `PaddleOCR/train_data/cls`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录: + +``` +ln -sf /train_data/cls/dataset +``` + +请参考下文组织您的数据。 +- 训练集 + +首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。 + +**注意:** 默认请将图片路径和图片标签用 `\t` 分割,如用其他方式分割将造成训练报错 + +0和180分别表示图片的角度为0度和180度 + +``` +" 图像文件名 图像标注信息 " + +train_data/cls/word_001.jpg 0 +train_data/cls/word_002.jpg 180 +``` + +最终训练集应有如下文件结构: +``` +|-train_data + |-cls + |- cls_gt_train.txt + |- train + |- word_001.png + |- word_002.jpg + |- word_003.jpg + | ... +``` + +- 测试集 + +同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个cls_gt_test.txt,测试集的结构如下所示: + +``` +|-train_data + |-cls + |- 和一个cls_gt_test.txt + |- test + |- word_001.jpg + |- word_002.jpg + |- word_003.jpg + | ... +``` + +### 启动训练 + +PaddleOCR提供了训练脚本、评估脚本和预测脚本。 + +开始训练: + +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +``` +# 设置PYTHONPATH路径 +export PYTHONPATH=$PYTHONPATH:. +# GPU训练 支持单卡,多卡训练,通过CUDA_VISIBLE_DEVICES指定卡号 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +# 启动训练 +python3 tools/train.py -c configs/cls/cls_mv3.yml +``` + +- 数据增强 + +PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`。 + +默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse),随机数据增强(RandAugment)。 + +训练过程中除随机数据增强外每种扰动方式以50%的概率被选择,具体代码实现请参考: +[randaugment.py.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py) +[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) + +*由于OpenCV的兼容性问题,扰动操作暂时只支持linux* + +### 训练 + +PaddleOCR支持训练和评估交替进行, 可以在 `configs/cls/cls_mv3.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/cls_mv3/best_accuracy` 。 + +如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。 + +**注意,预测/评估时的配置文件请务必与训练一致。** + +### 评估 + +评估数据集可以通过`configs/cls/cls_reader.yml` 修改EvalReader中的 `label_file_path` 设置。 + +*注意* 评估时必须确保配置文件中 infer_img 字段为空 +``` +export CUDA_VISIBLE_DEVICES=0 +# GPU 评估, Global.checkpoints 为待测权重 +python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +### 预测 + +* 训练引擎的预测 + +使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。 + +默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重: + +``` +# 预测分类结果 +python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + +预测图片: + +![](../imgs_words/en/word_1.png) + +得到输入图像的预测结果: + +``` +infer_img: doc/imgs_words/en/word_1.png + scores: [[0.93161047 0.06838956]] + label: [0] +``` diff --git a/doc/doc_en/angle_class_en.md b/doc/doc_en/angle_class_en.md new file mode 100644 index 0000000000000000000000000000000000000000..91af20a4cb34330277d8e770459c452614f9b6e0 --- /dev/null +++ b/doc/doc_en/angle_class_en.md @@ -0,0 +1,126 @@ +## TEXT ANGLE CLASSIFICATION + +### DATA PREPARATION + +Please organize the dataset as follows: + +The default storage path for training data is `PaddleOCR/train_data/cls`, if you already have a dataset on your disk, just create a soft link to the dataset directory: + +``` +ln -sf /train_data/cls/dataset +``` + +please refer to the following to organize your data. + +- Training set + +First put the training images in the same folder (train_images), and use a txt file (cls_gt_train.txt) to store the image path and label. + +* Note: by default, the image path and image label are split with `\t`, if you use other methods to split, it will cause training error + +0 and 180 indicate that the angle of the image is 0 degrees and 180 degrees, respectively. + +``` +" Image file name Image annotation " + +train_data/word_001.jpg 0 +train_data/word_002.jpg 180 +``` + +The final training set should have the following file structure: + +``` +|-train_data + |-cls + |- cls_gt_train.txt + |- train + |- word_001.png + |- word_002.jpg + |- word_003.jpg + | ... +``` + +- Test set + +Similar to the training set, the test set also needs to be provided a folder +containing all images (test) and a cls_gt_test.txt. The structure of the test set is as follows: + +``` +|-train_data + |-cls + |- cls_gt_test.txt + |- test + |- word_001.jpg + |- word_002.jpg + |- word_003.jpg + | ... +``` + +### TRAINING + +PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. + +Start training: + +``` +# Set PYTHONPATH path +export PYTHONPATH=$PYTHONPATH:. +# GPU training Support single card and multi-card training, specify the card number through CUDA_VISIBLE_DEVICES +export CUDA_VISIBLE_DEVICES=0,1,2,3 +# Training icdar15 English data +python3 tools/train.py -c configs/cls/cls_mv3.yml +``` + +- Data Augmentation + +PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, please set `distort: true` in the configuration file. + +The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment. + +Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to: +[randaugment.py.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py) +[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) + + +- Training + +PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/cls/cls_mv3.yml` to set the evaluation frequency. By default, it is evaluated every 500 iter and the best acc model is saved under `output/cls_mv3/best_accuracy` during the evaluation process. + +If the evaluation set is large, the test will be time-consuming. It is recommended to reduce the number of evaluations, or evaluate after training. + +**Note that the configuration file for prediction/evaluation must be consistent with the training.** + +### EVALUATION + +The evaluation data set can be modified via `configs/cls/cls_reader.yml` setting of `label_file_path` in EvalReader. + +``` +export CUDA_VISIBLE_DEVICES=0 +# GPU evaluation, Global.checkpoints is the weight to be tested +python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +### PREDICTION + +* Training engine prediction + +Using the model trained by paddleocr, you can quickly get prediction through the following script. + +The default prediction picture is stored in `infer_img`, and the weight is specified via `-o Global.checkpoints`: + +``` +# Predict English results +python3 tools/infer_rec.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg +``` + +Input image: + +![](../imgs_words/en/word_1.png) + +Get the prediction result of the input image: + +``` +infer_img: doc/imgs_words/en/word_1.png + scores: [[0.93161047 0.06838956]] + label: [0] +``` diff --git a/ppocr/data/cls/dataset_traversal.py b/ppocr/data/cls/dataset_traversal.py index fa688f46c62046cc622adcdde1be81522840a47e..c465bf9d3bc22aa794dcea47020b8b851f2dfbf8 100755 --- a/ppocr/data/cls/dataset_traversal.py +++ b/ppocr/data/cls/dataset_traversal.py @@ -14,6 +14,7 @@ import os import sys +import math import random import numpy as np import cv2 @@ -23,7 +24,18 @@ from ppocr.utils.utility import get_image_file_list logger = initial_logger() -from ppocr.data.rec.img_tools import warp, resize_norm_img +from ppocr.data.rec.img_tools import resize_norm_img, warp +from ppocr.data.cls.randaugment import RandAugment + + +def random_crop(img): + img_h, img_w = img.shape[:2] + if img_w > img_h * 4: + w = random.randint(img_h * 2, img_w) + i = random.randint(0, img_w - w) + + img = img[:, i:i + w, :] + return img class SimpleReader(object): @@ -39,7 +51,8 @@ class SimpleReader(object): self.image_shape = params['image_shape'] self.mode = params['mode'] self.infer_img = params['infer_img'] - self.use_distort = False + self.use_distort = params['mode'] == 'train' and params['distort'] + self.randaug = RandAugment() self.label_list = params['label_list'] if "distort" in params: self.use_distort = params['distort'] and params['use_gpu'] @@ -76,6 +89,7 @@ class SimpleReader(object): if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) norm_img = resize_norm_img(img, self.image_shape) + norm_img = norm_img[np.newaxis, :] yield norm_img else: @@ -97,6 +111,8 @@ class SimpleReader(object): for img_id in range(process_id, img_num, self.num_workers): label_infor = label_infor_list[img_id_list[img_id]] substr = label_infor.decode('utf-8').strip("\n").split("\t") + label = self.label_list.index(substr[1]) + img_path = self.img_set_dir + "/" + substr[0] img = cv2.imread(img_path) if img is None: @@ -105,12 +121,14 @@ class SimpleReader(object): if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - label = substr[1] if self.use_distort: + # if random.randint(1, 100)>= 50: + # img = random_crop(img) img = warp(img, 10) + img = self.randaug(img) norm_img = resize_norm_img(img, self.image_shape) norm_img = norm_img[np.newaxis, :] - yield (norm_img, self.label_list.index(int(label))) + yield (norm_img, label) def batch_iter_reader(): batch_outs = [] diff --git a/ppocr/data/cls/randaugment.py b/ppocr/data/cls/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..21345c05be59f6d1c9ae5a8d396ffed2dd9b0ca1 --- /dev/null +++ b/ppocr/data/cls/randaugment.py @@ -0,0 +1,135 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random +import six + + +class RawRandAugment(object): + def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)): + self.num_layers = num_layers + self.magnitude = magnitude + self.max_level = 10 + + abso_level = self.magnitude / self.max_level + self.level_map = { + "shearX": 0.3 * abso_level, + "shearY": 0.3 * abso_level, + "translateX": 150.0 / 331 * abso_level, + "translateY": 150.0 / 331 * abso_level, + "rotate": 30 * abso_level, + "color": 0.9 * abso_level, + "posterize": int(4.0 * abso_level), + "solarize": 256.0 * abso_level, + "contrast": 0.9 * abso_level, + "sharpness": 0.9 * abso_level, + "brightness": 0.9 * abso_level, + "autocontrast": 0, + "equalize": 0, + "invert": 0 + } + + # from https://stackoverflow.com/questions/5252170/ + # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, + Image.new("RGBA", rot.size, (128, ) * 4), + rot).convert(img.mode) + + rnd_ch_op = random.choice + + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "posterize": lambda img, magnitude: + ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: + ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: + ImageEnhance.Contrast(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "sharpness": lambda img, magnitude: + ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "brightness": lambda img, magnitude: + ImageEnhance.Brightness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "autocontrast": lambda img, magnitude: + ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + def __call__(self, img): + avaiable_op_names = list(self.level_map.keys()) + for layer_num in range(self.num_layers): + op_name = np.random.choice(avaiable_op_names) + img = self.func[op_name](img, self.level_map[op_name]) + return img + + +class RandAugment(RawRandAugment): + """ RandAugment wrapper to auto fit different img types """ + + def __init__(self, *args, **kwargs): + if six.PY2: + super(RandAugment, self).__init__(*args, **kwargs) + else: + super().__init__(*args, **kwargs) + + def __call__(self, img): + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + + if six.PY2: + img = super(RandAugment, self).__call__(img) + else: + img = super().__call__(img) + + if isinstance(img, Image.Image): + img = np.asarray(img) + + return img diff --git a/tools/eval_utils/eval_cls_utils.py b/tools/eval_utils/eval_cls_utils.py index 80a131112370d1f707a020824c94daea21340709..9c9b26677dc57b70f7641c26e5ef57ce1d77f1af 100644 --- a/tools/eval_utils/eval_cls_utils.py +++ b/tools/eval_utils/eval_cls_utils.py @@ -16,12 +16,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging import numpy as np -import paddle.fluid as fluid - -__all__ = ['eval_class_run'] +__all__ = ['eval_cls_run'] import logging @@ -52,7 +49,8 @@ def eval_cls_run(exe, eval_info_dict): fetch_list=eval_info_dict['fetch_varname_list'], \ return_numpy=False) softmax_outs = np.array(outs[1]) - + if len(softmax_outs.shape) != 1: + softmax_outs = np.array(outs[0]) acc, acc_num = cal_cls_acc(softmax_outs, label_list) total_acc_num += acc_num total_sample_num += len(label_list) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 54e2dbbba5481e803d29ff16b032fcb57f6446c5..5c54224e6326e83a5d5cde11df1e3047df140953 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -108,7 +108,7 @@ class TextClassifier(object): score = prob_out[rno][label_idx] label = self.label_list[label_idx] cls_res[indices[beg_img_no + rno]] = [label, score] - if label == 180: + if '180' in label and score > 0.9999: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1) return img_list, cls_res, predict_time @@ -130,12 +130,6 @@ def main(args): img_list.append(img) try: img_list, cls_res, predict_time = text_classifier(img_list) - print(cls_res) - from matplotlib import pyplot as plt - for img, angle in zip(img_list, cls_res): - plt.title(str(angle)) - plt.imshow(img) - plt.show() except Exception as e: print(e) exit() diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 555c12b1a929662f436e3a9a031b2e480a837622..bb97c8fcf4ec936309f967ca208e59876b051f17 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -40,7 +40,9 @@ class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) - self.text_classifier = predict_cls.TextClassifier(args) + self.use_angle_cls = args.use_angle_cls + if self.use_angle_cls: + self.text_classifier = predict_cls.TextClassifier(args) def get_rotate_crop_image(self, img, points): ''' @@ -95,10 +97,12 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) - img_rotate_list, angle_list, elapse = self.text_classifier( - img_crop_list) - print("cls num : {}, elapse : {}".format(len(img_rotate_list), elapse)) - rec_res, elapse = self.text_recognizer(img_rotate_list) + if self.use_angle_cls: + img_crop_list, angle_list, elapse = self.text_classifier( + img_crop_list) + print("cls num : {}, elapse : {}".format( + len(img_crop_list), elapse)) + rec_res, elapse = self.text_recognizer(img_crop_list) print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) return dt_boxes, rec_res diff --git a/tools/infer/utility.py b/tools/infer/utility.py index cbbda97b2a60aeba2a592a8d1b5aa1dc294d4067..1aa94f544fdc5d67556d54c0d08644358366e6ee 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -15,6 +15,7 @@ import argparse import os, sys from ppocr.utils.utility import initial_logger + logger = initial_logger() from paddle.fluid.core import PaddleTensor from paddle.fluid.core import AnalysisConfig @@ -31,34 +32,34 @@ def parse_args(): return v.lower() in ("true", "t", "1") parser = argparse.ArgumentParser() - #params for prediction engine + # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) - #params for text detector + # params for text detector parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_max_side_len", type=float, default=960) - #DB parmas + # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) - #EAST parmas + # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - #SAST parmas + # SAST parmas parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_polygon", type=bool, default=False) - #params for text recognizer + # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") @@ -72,13 +73,14 @@ def parse_args(): parser.add_argument("--use_space_char", type=bool, default=True) # params for text classifier + parser.add_argument("--use_angle_cls", type=str2bool, default=True) parser.add_argument("--cls_model_dir", type=str) - parser.add_argument("--cls_image_shape", type=str, default="3, 32, 100") - parser.add_argument("--label_list", type=list, default=[0, 180]) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--cls_batch_num", type=int, default=30) - parser.add_argument("--enable_mkldnn", type=bool, default=False) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) + parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--use_zero_copy_run", type=str2bool, default=False) return parser.parse_args() @@ -112,7 +114,7 @@ def create_predictor(args, mode): if args.enable_mkldnn: config.enable_mkldnn() - #config.enable_memory_optim() + # config.enable_memory_optim() config.disable_glog_info() if args.use_zero_copy_run: diff --git a/tools/infer_cls.py b/tools/infer_cls.py index 443b1e0583376266dc988432c84172be1308adce..1f78cdf930fc506cc716d97e8f93c13b407f48d1 100755 --- a/tools/infer_cls.py +++ b/tools/infer_cls.py @@ -85,9 +85,10 @@ def main(): feed={"image": img}, fetch_list=fetch_varname_list, return_numpy=False) - for k in predict: - k = np.array(k) - print(k) + scores = np.array(predict[0]) + label = np.array(predict[1]) + logger.info('\t scores: {}'.format(scores)) + logger.info('\t label: {}'.format(label)) # save for inference model target_var = [] for key, values in outputs.items():