diff --git a/doc/doc_en/algorithm_rec_starnet.md b/doc/doc_en/algorithm_rec_starnet.md new file mode 100644 index 0000000000000000000000000000000000000000..dbb53a9c737c16fa249483fa97b0b49cf25b2137 --- /dev/null +++ b/doc/doc_en/algorithm_rec_starnet.md @@ -0,0 +1,139 @@ +# STAR-Net + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper information: +> [STAR-Net: a spatial attention residue network for scene text recognition.](http://www.bmva.org/bmvc/2016/papers/paper043/paper043.pdf) +> Wei Liu, Chaofeng Chen, Kwan-Yee K. Wong, Zhizhong Su and Junyu Han. +> BMVC, pages 43.1-43.13, 2016 + +Refer to [DTRB](https://arxiv.org/abs/1904.01906) text Recognition Training and Evaluation Process . Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows: + +|Models|Backbone Networks|Avg Accuracy|Configuration Files|Download Links| +| --- | --- | --- | --- | --- | +|StarNet|Resnet34_vd|84.44%|[configs/rec/rec_r34_vd_tps_bilstm_ctc.yml](../../configs/rec/rec_r34_vd_tps_bilstm_ctc.yml)|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| +|StarNet|MobileNetV3|81.42%|[configs/rec/rec_mv3_tps_bilstm_ctc.yml](../../configs/rec/rec_mv3_tps_bilstm_ctc.yml)|[ trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| + + + +## 2. Environment +Please refer to [Operating Environment Preparation](./environment_en.md) to configure the PaddleOCR operating environment, and refer to [Project Clone](./clone_en.md) to clone the project code. + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Training Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. Take the backbone network based on Resnet34_vd as an example: + + +### 3.1 Training +After the data preparation is complete, the training can be started. The training command is as follows: + +```` +#Single card training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml #Multi-card training, specify the card number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c rec_r34_vd_tps_bilstm_ctc.yml + ```` + + +### 3.2 Evaluation + +```` +# GPU evaluation, Global.pretrained_model is the model to be evaluated +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml -o Global.pretrained_model={path/to/weights}/best_accuracy + ```` + + +### 3.3 Prediction + +```` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png + ```` + + +## 4. Inference + + +### 4.1 Python Inference +First, convert the model saved during the STAR-Net text recognition training process into an inference model. Take the model trained on the MJSynth and SynthText text recognition datasets based on the Resnet34_vd backbone network as an example [Model download address]( https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar) , which can be converted using the following command: + +```shell +python3 tools/export_model.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_tps_bilstm_ctc_v2.0_train/best_accuracy Global.save_inference_dir=./inference/rec_starnet + ```` + +STAR-Net text recognition model inference, you can execute the following commands: + +```shell +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt" + ```` + +![](../imgs_words_en/word_336.png) + +The inference results are as follows: + + +```bash +Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073) +``` + +**Attention** Since the above model refers to the [DTRB](https://arxiv.org/abs/1904.01906) text recognition training and evaluation process, it is different from the ultra-lightweight Chinese recognition model training in two aspects: + +- The image resolutions used during training are different. The image resolutions used for training the above models are [3, 32, 100], while for Chinese model training, in order to ensure the recognition effect of long texts, the image resolutions used during training are [ 3, 32, 320]. The default shape parameter of the predictive inference program is the image resolution used for training Chinese, i.e. [3, 32, 320]. Therefore, when inferring the above English model here, it is necessary to set the shape of the recognized image through the parameter rec_image_shape. + +- Character list, the experiment in the DTRB paper is only for 26 lowercase English letters and 10 numbers, a total of 36 characters. All uppercase and lowercase characters are converted to lowercase characters, and characters not listed above are ignored and considered spaces. Therefore, there is no input character dictionary here, but a dictionary is generated by the following command. Therefore, the parameter rec_char_dict_path needs to be set during inference, which is specified as an English dictionary "./ppocr/utils/ic15_dict.txt". + +``` +self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" +dict_character = list(self.character_str) + + + ``` + + +### 4.2 C++ Inference + +After preparing the inference model, refer to the [cpp infer](../../deploy/cpp_infer/) tutorial to operate. + + +### 4.3 Serving + +After preparing the inference model, refer to the [pdserving](../../deploy/pdserving/) tutorial for Serving deployment, including two modes: Python Serving and C++ Serving. + + +### 4.4 More + +The STAR-Net model also supports the following inference deployment methods: + +- Paddle2ONNX Inference: After preparing the inference model, refer to the [paddle2onnx](../../deploy/paddle2onnx/) tutorial. + + +## 5. FAQ + +## Quote + +```bibtex +@inproceedings{liu2016star, + title={STAR-Net: a spatial attention residue network for scene text recognition.}, + author={Liu, Wei and Chen, Chaofeng and Wong, Kwan-Yee K and Su, Zhizhong and Han, Junyu}, + booktitle={BMVC}, + volume={2}, + pages={7}, + year={2016} +} +``` + + diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 48b16db4a0f2c2c901509d691088d3dc4381fabd..3987d645718be83019cc84d99186c632405f5489 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -34,6 +34,7 @@ def init_args(): parser = argparse.ArgumentParser() # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--use_xpu", type=str2bool, default=False) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--min_subgraph_size", type=int, default=15) @@ -286,6 +287,8 @@ def create_predictor(args, mode, logger): config.set_trt_dynamic_shape_info( min_input_shape, max_input_shape, opt_input_shape) + elif args.use_xpu: + config.enable_xpu(10 * 1024 * 1024) else: config.disable_gpu() if hasattr(args, "cpu_threads"): diff --git a/tools/program.py b/tools/program.py index 7c02dc0149f36085ef05ca378b79d27e92d6dd57..aa0d2698cf66c928f87217996c31c042e1c8aa02 100755 --- a/tools/program.py +++ b/tools/program.py @@ -112,20 +112,25 @@ def merge_config(config, opts): return config -def check_gpu(use_gpu): +def check_device(use_gpu, use_xpu=False): """ Log error and exit when set use_gpu=true in paddlepaddle cpu version. """ - err = "Config use_gpu cannot be set as true while you are " \ - "using paddlepaddle cpu version ! \nPlease try: \n" \ - "\t1. Install paddlepaddle-gpu to run model on GPU \n" \ - "\t2. Set use_gpu as false in config file to run " \ + err = "Config {} cannot be set as true while your paddle " \ + "is not compiled with {} ! \nPlease try: \n" \ + "\t1. Install paddlepaddle to run model on {} \n" \ + "\t2. Set {} as false in config file to run " \ "model on CPU" try: + if use_gpu and use_xpu: + print("use_xpu and use_gpu can not both be ture.") if use_gpu and not paddle.is_compiled_with_cuda(): - print(err) + print(err.format("use_gpu", "cuda", "gpu", "use_gpu")) + sys.exit(1) + if use_xpu and not paddle.device.is_compiled_with_xpu(): + print(err.format("use_xpu", "xpu", "xpu", "use_xpu")) sys.exit(1) except Exception as e: pass @@ -547,7 +552,7 @@ def preprocess(is_train=False): # check if set use_gpu=True in paddlepaddle cpu version use_gpu = config['Global']['use_gpu'] - check_gpu(use_gpu) + use_xpu = config['Global'].get('use_xpu', False) # check if set use_xpu=True in paddlepaddle cpu/gpu version use_xpu = False @@ -562,11 +567,13 @@ def preprocess(is_train=False): 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' ] - device = 'cpu' - if use_gpu: - device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_xpu: - device = 'xpu' + device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0)) + else: + device = 'gpu:{}'.format(dist.ParallelEnv() + .dev_id) if use_gpu else 'cpu' + check_device(use_gpu, use_xpu) + device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1