module.py 22.6 KB
Newer Older
S
Steffy-zxf 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# -*- coding:utf-8 -*-
import argparse
import ast
import copy
import math
import os
import time

from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving
from PIL import Image
import cv2
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub

S
Steffy-zxf 已提交
18 19
from chinese_ocr_db_crnn_mobile.character import CharacterOps
from chinese_ocr_db_crnn_mobile.utils import base64_to_cv2, draw_ocr, get_image_ext, sorted_boxes
S
Steffy-zxf 已提交
20 21 22


@moduleinfo(
S
Steffy-zxf 已提交
23
    name="chinese_ocr_db_crnn_mobile",
S
Steffy-zxf 已提交
24
    version="1.1.0",
S
Steffy-zxf 已提交
25
    summary=
S
Steffy-zxf 已提交
26 27
    "The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions \
        based on the differentiable_binarization_chn module. Then it classifies the text angle and recognizes the chinese texts. ",
S
Steffy-zxf 已提交
28 29 30
    author="paddle-dev",
    author_email="paddle-dev@baidu.com",
    type="cv/text_recognition")
S
Steffy-zxf 已提交
31
class ChineseOCRDBCRNN(hub.Module):
S
Steffy-zxf 已提交
32
    def _initialize(self, text_detector_module=None, enable_mkldnn=False):
S
Steffy-zxf 已提交
33 34 35 36 37 38 39 40
        """
        initialize with the necessary elements
        """
        self.character_dict_path = os.path.join(self.directory, 'assets',
                                                'ppocr_keys_v1.txt')
        char_ops_params = {
            'character_type': 'ch',
            'character_dict_path': self.character_dict_path,
S
Steffy-zxf 已提交
41 42 43
            'loss_type': 'ctc',
            'max_text_length': 25,
            'use_space_char': True
S
Steffy-zxf 已提交
44 45 46 47 48
        }
        self.char_ops = CharacterOps(char_ops_params)
        self.rec_image_shape = [3, 32, 320]
        self._text_detector_module = text_detector_module
        self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
S
Steffy-zxf 已提交
49
        self.enable_mkldnn = enable_mkldnn
S
Steffy-zxf 已提交
50

S
Steffy-zxf 已提交
51 52 53 54 55 56 57 58 59 60
        self.rec_pretrained_model_path = os.path.join(
            self.directory, 'inference_model', 'character_rec')
        self.cls_pretrained_model_path = os.path.join(
            self.directory, 'inference_model', 'angle_cls')
        self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config(
            self.rec_pretrained_model_path)
        self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
            self.cls_pretrained_model_path)

    def _set_config(self, pretrained_model_path):
S
Steffy-zxf 已提交
61
        """
S
Steffy-zxf 已提交
62
        predictor config path
S
Steffy-zxf 已提交
63
        """
S
Steffy-zxf 已提交
64 65
        model_file_path = os.path.join(pretrained_model_path, 'model')
        params_file_path = os.path.join(pretrained_model_path, 'params')
S
Steffy-zxf 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78

        config = AnalysisConfig(model_file_path, params_file_path)
        try:
            _places = os.environ["CUDA_VISIBLE_DEVICES"]
            int(_places[0])
            use_gpu = True
        except:
            use_gpu = False

        if use_gpu:
            config.enable_use_gpu(8000, 0)
        else:
            config.disable_gpu()
S
Steffy-zxf 已提交
79
            if self.enable_mkldnn:
S
Steffy-zxf 已提交
80 81
                # cache 10 different shapes for mkldnn to avoid memory leak
                config.set_mkldnn_cache_capacity(10)
S
Steffy-zxf 已提交
82
                config.enable_mkldnn()
S
Steffy-zxf 已提交
83 84 85 86

        config.disable_glog_info()
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
        config.switch_use_feed_fetch_ops(False)
S
Steffy-zxf 已提交
87 88 89 90 91 92 93

        predictor = create_paddle_predictor(config)

        input_names = predictor.get_input_names()
        input_tensor = predictor.get_input_tensor(input_names[0])
        output_names = predictor.get_output_names()
        output_tensors = []
S
Steffy-zxf 已提交
94
        for output_name in output_names:
S
Steffy-zxf 已提交
95 96 97 98
            output_tensor = predictor.get_output_tensor(output_name)
            output_tensors.append(output_tensor)

        return predictor, input_tensor, output_tensors
S
Steffy-zxf 已提交
99 100 101 102 103 104 105 106

    @property
    def text_detector_module(self):
        """
        text detect module
        """
        if not self._text_detector_module:
            self._text_detector_module = hub.Module(
S
Steffy-zxf 已提交
107 108
                name='chinese_text_detection_db_mobile',
                enable_mkldnn=self.enable_mkldnn,
S
Steffy-zxf 已提交
109
                version='1.0.3')
S
Steffy-zxf 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        return self._text_detector_module

    def read_images(self, paths=[]):
        images = []
        for img_path in paths:
            assert os.path.isfile(
                img_path), "The {} isn't a valid file.".format(img_path)
            img = cv2.imread(img_path)
            if img is None:
                logger.info("error in loading image:{}".format(img_path))
                continue
            images.append(img)
        return images

    def get_rotate_crop_image(self, img, points):
S
Steffy-zxf 已提交
125
        '''
S
Steffy-zxf 已提交
126 127 128 129 130 131 132 133
        img_height, img_width = img.shape[0:2]
        left = int(np.min(points[:, 0]))
        right = int(np.max(points[:, 0]))
        top = int(np.min(points[:, 1]))
        bottom = int(np.max(points[:, 1]))
        img_crop = img[top:bottom, left:right, :].copy()
        points[:, 0] = points[:, 0] - left
        points[:, 1] = points[:, 1] - top
S
Steffy-zxf 已提交
134 135 136 137 138 139 140 141 142 143 144 145
        '''
        img_crop_width = int(
            max(
                np.linalg.norm(points[0] - points[1]),
                np.linalg.norm(points[2] - points[3])))
        img_crop_height = int(
            max(
                np.linalg.norm(points[0] - points[3]),
                np.linalg.norm(points[1] - points[2])))
        pts_std = np.float32([[0, 0], [img_crop_width, 0],
                              [img_crop_width, img_crop_height],
                              [0, img_crop_height]])
S
Steffy-zxf 已提交
146 147
        M = cv2.getPerspectiveTransform(points, pts_std)
        dst_img = cv2.warpPerspective(
S
Steffy-zxf 已提交
148
            img,
S
Steffy-zxf 已提交
149
            M, (img_crop_width, img_crop_height),
S
Steffy-zxf 已提交
150 151
            borderMode=cv2.BORDER_REPLICATE,
            flags=cv2.INTER_CUBIC)
S
Steffy-zxf 已提交
152 153 154 155 156
        dst_img_height, dst_img_width = dst_img.shape[0:2]
        if dst_img_height * 1.0 / dst_img_width >= 1.5:
            dst_img = np.rot90(dst_img)
        return dst_img

S
Steffy-zxf 已提交
157
    def resize_norm_img_rec(self, img, max_wh_ratio):
S
Steffy-zxf 已提交
158
        imgC, imgH, imgW = self.rec_image_shape
S
Steffy-zxf 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        assert imgC == img.shape[2]
        imgW = int((32 * max_wh_ratio))
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def resize_norm_img_cls(self, img):
        cls_image_shape = [3, 48, 192]
        imgC, imgH, imgW = cls_image_shape
S
Steffy-zxf 已提交
179 180 181 182 183 184 185 186 187
        h = img.shape[0]
        w = img.shape[1]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
S
Steffy-zxf 已提交
188 189 190 191 192
        if cls_image_shape[0] == 1:
            resized_image = resized_image / 255
            resized_image = resized_image[np.newaxis, :]
        else:
            resized_image = resized_image.transpose((2, 0, 1)) / 255
S
Steffy-zxf 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def recognize_text(self,
                       images=[],
                       paths=[],
                       use_gpu=False,
                       output_dir='ocr_result',
                       visualization=False,
                       box_thresh=0.5,
                       text_thresh=0.5):
        """
        Get the chinese texts in the predicted images.
        Args:
            images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
            paths (list[str]): The paths of images. If paths not images
            use_gpu (bool): Whether to use gpu.
            batch_size(int): the program deals once with one
            output_dir (str): The directory to store output images.
            visualization (bool): Whether to save image or not.
            box_thresh(float): the threshold of the detected text box's confidence
            text_thresh(float): the threshold of the recognize chinese texts' confidence
        Returns:
            res (list): The result of chinese texts and save path of images.
        """
        if use_gpu:
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
            except:
                raise RuntimeError(
                    "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
                )

        self.use_gpu = use_gpu

        if images != [] and isinstance(images, list) and paths == []:
            predicted_data = images
        elif images == [] and isinstance(paths, list) and paths != []:
            predicted_data = self.read_images(paths)
        else:
            raise TypeError("The input data is inconsistent with expectations.")

        assert predicted_data != [], "There is not any image to be predicted. Please check the input data."

        detection_results = self.text_detector_module.detect_text(
            images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh)
S
Steffy-zxf 已提交
243

S
Steffy-zxf 已提交
244 245 246 247
        boxes = [
            np.array(item['data']).astype(np.float32)
            for item in detection_results
        ]
S
Steffy-zxf 已提交
248 249 250 251
        all_results = []
        for index, img_boxes in enumerate(boxes):
            original_image = predicted_data[index].copy()
            result = {'save_path': ''}
S
Steffy-zxf 已提交
252
            if img_boxes.size == 0:
S
Steffy-zxf 已提交
253 254 255 256 257 258 259 260 261
                result['data'] = []
            else:
                img_crop_list = []
                boxes = sorted_boxes(img_boxes)
                for num_box in range(len(boxes)):
                    tmp_box = copy.deepcopy(boxes[num_box])
                    img_crop = self.get_rotate_crop_image(
                        original_image, tmp_box)
                    img_crop_list.append(img_crop)
S
Steffy-zxf 已提交
262
                img_crop_list, angle_list = self._classify_text(img_crop_list)
S
Steffy-zxf 已提交
263
                rec_results = self._recognize_text(img_crop_list)
S
Steffy-zxf 已提交
264

S
Steffy-zxf 已提交
265 266 267 268 269 270
                # if the recognized text confidence score is lower than text_thresh, then drop it
                rec_res_final = []
                for index, res in enumerate(rec_results):
                    text, score = res
                    if score >= text_thresh:
                        rec_res_final.append({
S
Steffy-zxf 已提交
271 272 273 274 275 276
                            'text':
                            text,
                            'confidence':
                            float(score),
                            'text_box_position':
                            boxes[index].astype(np.int).tolist()
S
Steffy-zxf 已提交
277 278 279 280 281 282 283 284 285 286 287
                        })
                result['data'] = rec_res_final

                if visualization and result['data']:
                    result['save_path'] = self.save_result_image(
                        original_image, boxes, rec_results, output_dir,
                        text_thresh)
            all_results.append(result)

        return all_results

S
Steffy-zxf 已提交
288 289 290 291 292 293 294 295 296
    @serving
    def serving_method(self, images, **kwargs):
        """
        Run as a service.
        """
        images_decode = [base64_to_cv2(image) for image in images]
        results = self.recognize_text(images_decode, **kwargs)
        return results

S
Steffy-zxf 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
    def save_result_image(self,
                          original_image,
                          detection_boxes,
                          rec_results,
                          output_dir='ocr_result',
                          text_thresh=0.5):
        image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
        txts = [item[0] for item in rec_results]
        scores = [item[1] for item in rec_results]
        draw_img = draw_ocr(
            image,
            detection_boxes,
            txts,
            scores,
            font_file=self.font_file,
            draw_txt=True,
            drop_score=text_thresh)

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        ext = get_image_ext(original_image)
        saved_name = 'ndarray_{}{}'.format(time.time(), ext)
        save_file_path = os.path.join(output_dir, saved_name)
        cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
        return save_file_path

S
Steffy-zxf 已提交
323 324 325 326 327 328 329 330 331 332 333
    def _classify_text(self, image_list):
        img_list = copy.deepcopy(image_list)
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the cls process
        indices = np.argsort(np.array(width_list))

        cls_res = [['', 0.0]] * img_num
S
Steffy-zxf 已提交
334 335 336 337 338 339
        batch_num = 30
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
            max_wh_ratio = 0
            for ino in range(beg_img_no, end_img_no):
S
Steffy-zxf 已提交
340 341
                h, w = img_list[indices[ino]].shape[0:2]
                wh_ratio = w * 1.0 / h
S
Steffy-zxf 已提交
342 343
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
S
Steffy-zxf 已提交
344
                norm_img = self.resize_norm_img_cls(img_list[indices[ino]])
S
Steffy-zxf 已提交
345 346 347 348 349
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()

S
Steffy-zxf 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
            self.cls_input_tensor.copy_from_cpu(norm_img_batch)
            self.cls_predictor.zero_copy_run()

            prob_out = self.cls_output_tensors[0].copy_to_cpu()
            label_out = self.cls_output_tensors[1].copy_to_cpu()
            if len(label_out.shape) != 1:
                prob_out, label_out = label_out, prob_out
            label_list = ['0', '180']
            for rno in range(len(label_out)):
                label_idx = label_out[rno]
                score = prob_out[rno][label_idx]
                label = label_list[label_idx]
                cls_res[indices[beg_img_no + rno]] = [label, score]
                if '180' in label and score > 0.9999:
                    img_list[indices[beg_img_no + rno]] = cv2.rotate(
                        img_list[indices[beg_img_no + rno]], 1)
        return img_list, cls_res

    def _recognize_text(self, img_list):
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the recognition process
        indices = np.argsort(np.array(width_list))

        rec_res = [['', 0.0]] * img_num
        batch_num = 30
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
            max_wh_ratio = 0
            for ino in range(beg_img_no, end_img_no):
                h, w = img_list[indices[ino]].shape[0:2]
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
                norm_img = self.resize_norm_img_rec(img_list[indices[ino]],
                                                    max_wh_ratio)
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)

            norm_img_batch = np.concatenate(norm_img_batch, axis=0)
            norm_img_batch = norm_img_batch.copy()

            self.rec_input_tensor.copy_from_cpu(norm_img_batch)
            self.rec_predictor.zero_copy_run()

            rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu()
            rec_idx_lod = self.rec_output_tensors[0].lod()[0]
            predict_batch = self.rec_output_tensors[1].copy_to_cpu()
            predict_lod = self.rec_output_tensors[1].lod()[0]
S
Steffy-zxf 已提交
403 404 405 406 407 408 409 410 411 412 413
            for rno in range(len(rec_idx_lod) - 1):
                beg = rec_idx_lod[rno]
                end = rec_idx_lod[rno + 1]
                rec_idx_tmp = rec_idx_batch[beg:end, 0]
                preds_text = self.char_ops.decode(rec_idx_tmp)
                beg = predict_lod[rno]
                end = predict_lod[rno + 1]
                probs = predict_batch[beg:end, :]
                ind = np.argmax(probs, axis=1)
                blank = probs.shape[1]
                valid_ind = np.where(ind != (blank - 1))[0]
S
Steffy-zxf 已提交
414 415
                if len(valid_ind) == 0:
                    continue
S
Steffy-zxf 已提交
416
                score = np.mean(probs[valid_ind, ind[valid_ind]])
S
Steffy-zxf 已提交
417 418
                # rec_res.append([preds_text, score])
                rec_res[indices[beg_img_no + rno]] = [preds_text, score]
S
Steffy-zxf 已提交
419

S
Steffy-zxf 已提交
420
            return rec_res
S
Steffy-zxf 已提交
421 422 423 424 425 426 427

    def save_inference_model(self,
                             dirname,
                             model_filename=None,
                             params_filename=None,
                             combined=True):
        detector_dir = os.path.join(dirname, 'text_detector')
S
Steffy-zxf 已提交
428
        classifier_dir = os.path.join(dirname, 'angle_classifier')
S
Steffy-zxf 已提交
429 430 431
        recognizer_dir = os.path.join(dirname, 'text_recognizer')
        self._save_detector_model(detector_dir, model_filename, params_filename,
                                  combined)
S
Steffy-zxf 已提交
432 433
        self._save_classifier_model(classifier_dir, model_filename,
                                    params_filename, combined)
S
Steffy-zxf 已提交
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
        self._save_recognizer_model(recognizer_dir, model_filename,
                                    params_filename, combined)
        logger.info("The inference model has been saved in the path {}".format(
            os.path.realpath(dirname)))

    def _save_detector_model(self,
                             dirname,
                             model_filename=None,
                             params_filename=None,
                             combined=True):
        self.text_detector_module.save_inference_model(
            dirname, model_filename, params_filename, combined)

    def _save_recognizer_model(self,
                               dirname,
                               model_filename=None,
                               params_filename=None,
                               combined=True):
        if combined:
            model_filename = "__model__" if not model_filename else model_filename
            params_filename = "__params__" if not params_filename else params_filename
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)

S
Steffy-zxf 已提交
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
        model_file_path = os.path.join(self.rec_pretrained_model_path, 'model')
        params_file_path = os.path.join(self.rec_pretrained_model_path,
                                        'params')
        program, feeded_var_names, target_vars = fluid.io.load_inference_model(
            dirname=self.rec_pretrained_model_path,
            model_filename=model_file_path,
            params_filename=params_file_path,
            executor=exe)

        fluid.io.save_inference_model(
            dirname=dirname,
            main_program=program,
            executor=exe,
            feeded_var_names=feeded_var_names,
            target_vars=target_vars,
            model_filename=model_filename,
            params_filename=params_filename)

    def _save_classifier_model(self,
                               dirname,
                               model_filename=None,
                               params_filename=None,
                               combined=True):
        if combined:
            model_filename = "__model__" if not model_filename else model_filename
            params_filename = "__params__" if not params_filename else params_filename
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)

        model_file_path = os.path.join(self.cls_pretrained_model_path, 'model')
        params_file_path = os.path.join(self.cls_pretrained_model_path,
                                        'params')
S
Steffy-zxf 已提交
490
        program, feeded_var_names, target_vars = fluid.io.load_inference_model(
S
Steffy-zxf 已提交
491
            dirname=self.cls_pretrained_model_path,
S
Steffy-zxf 已提交
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
            model_filename=model_file_path,
            params_filename=params_file_path,
            executor=exe)

        fluid.io.save_inference_model(
            dirname=dirname,
            main_program=program,
            executor=exe,
            feeded_var_names=feeded_var_names,
            target_vars=target_vars,
            model_filename=model_filename,
            params_filename=params_filename)

    @runnable
    def run_cmd(self, argvs):
        """
        Run as a command
        """
        self.parser = argparse.ArgumentParser(
S
Steffy-zxf 已提交
511 512
            description="Run the %s module." % self.name,
            prog='hub run %s' % self.name,
S
Steffy-zxf 已提交
513 514 515 516 517 518 519 520 521 522 523 524 525 526
            usage='%(prog)s',
            add_help=True)

        self.arg_input_group = self.parser.add_argument_group(
            title="Input options", description="Input data. Required")
        self.arg_config_group = self.parser.add_argument_group(
            title="Config options",
            description=
            "Run configuration for controlling module behavior, not required.")

        self.add_module_config_arg()
        self.add_module_input_arg()

        args = self.parser.parse_args(argvs)
S
Steffy-zxf 已提交
527
        results = self.recognize_text(
S
Steffy-zxf 已提交
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
            paths=[args.input_path],
            use_gpu=args.use_gpu,
            output_dir=args.output_dir,
            visualization=args.visualization)
        return results

    def add_module_config_arg(self):
        """
        Add the command config options
        """
        self.arg_config_group.add_argument(
            '--use_gpu',
            type=ast.literal_eval,
            default=False,
            help="whether use GPU or not")
        self.arg_config_group.add_argument(
            '--output_dir',
            type=str,
            default='ocr_result',
            help="The directory to save output images.")
        self.arg_config_group.add_argument(
            '--visualization',
            type=ast.literal_eval,
            default=False,
            help="whether to save output as images.")

    def add_module_input_arg(self):
        """
        Add the command input options
        """
        self.arg_input_group.add_argument(
            '--input_path', type=str, default=None, help="diretory to image")


if __name__ == '__main__':
S
Steffy-zxf 已提交
563
    ocr = ChineseOCRDBCRNN()
S
Steffy-zxf 已提交
564
    image_path = [
S
Steffy-zxf 已提交
565
        '/mnt/zhangxuefei/PaddleOCR/doc/imgs/2.jpg',
S
Steffy-zxf 已提交
566 567
        '/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
        '/mnt/zhangxuefei/PaddleOCR/doc/imgs/test_image.jpg'
S
Steffy-zxf 已提交
568 569 570 571
    ]
    res = ocr.recognize_text(paths=image_path, visualization=True)
    ocr.save_inference_model('save')
    print(res)