提交 fa675f89 编写于 作者: D dyning

updata structure of dygraph

上级 7d09cd19
...@@ -6,29 +6,19 @@ Global: ...@@ -6,29 +6,19 @@ Global:
save_model_dir: ./output/db_mv3/ save_model_dir: ./output/db_mv3/
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 8 eval_batch_step: [4000, 5000]
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints: checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy
save_inference_dir: save_inference_dir:
use_visualdl: True use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt save_res_path: ./output/det_db/predicts_db.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
learning_rate:
lr: 0.001
regularizer:
name: 'L2'
factor: 0
Architecture: Architecture:
type: det model_type: det
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
...@@ -36,7 +26,7 @@ Architecture: ...@@ -36,7 +26,7 @@ Architecture:
scale: 0.5 scale: 0.5
model_name: large model_name: large
Neck: Neck:
name: FPN name: DBFPN
out_channels: 256 out_channels: 256
Head: Head:
name: DBHead name: DBHead
...@@ -49,6 +39,18 @@ Loss: ...@@ -49,6 +39,18 @@ Loss:
alpha: 5 alpha: 5
beta: 10 beta: 10
ohem_ratio: 3 ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
learning_rate:
# name: Cosine
lr: 0.001
# warmup_epoch: 0
regularizer:
name: 'L2'
factor: 0
PostProcess: PostProcess:
name: DBPostProcess name: DBPostProcess
...@@ -61,13 +63,13 @@ Metric: ...@@ -61,13 +63,13 @@ Metric:
name: DetMetric name: DetMetric
main_indicator: hmean main_indicator: hmean
TRAIN: Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./detection/ data_dir: ./train_data/icdar2015/text_localization/
file_list: label_file_list:
- ./detection/train_icdar2015_label.txt # dataset1 - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0] ratio_list: [0.5]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
...@@ -76,10 +78,10 @@ TRAIN: ...@@ -76,10 +78,10 @@ TRAIN:
- IaaAugment: - IaaAugment:
augmenter_args: augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } } - { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [ -10,10 ] } } - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize,'args': { 'size': [ 0.5,3 ] } } - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData: - EastRandomCropData:
size: [ 640,640 ] size: [640, 640]
max_tries: 50 max_tries: 50
keep_ratio: true keep_ratio: true
- MakeBorderMap: - MakeBorderMap:
...@@ -91,41 +93,41 @@ TRAIN: ...@@ -91,41 +93,41 @@ TRAIN:
min_text_size: 8 min_text_size: 8
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ] mean: [0.485, 0.456, 0.406]
std: [ 0.229, 0.224, 0.225 ] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- keepKeys: - KeepKeys:
keep_keys: ['image','threshold_map','threshold_mask','shrink_map','shrink_mask'] # dataloader will return list in this order keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size: 16 batch_size_per_card: 4
num_workers: 8 num_workers: 8
EVAL: Eval:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./detection/ data_dir: ./train_data/icdar2015/text_localization/
file_list: label_file_list:
- ./detection/test_icdar2015_label.txt - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- DetLabelEncode: # Class handling label - DetLabelEncode: # Class handling label
- DetResizeForTest: - DetResizeForTest:
image_shape: [736,1280] image_shape: [736, 1280]
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ] mean: [0.485, 0.456, 0.406]
std: [ 0.229, 0.224, 0.225 ] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- keepKeys: - KeepKeys:
keep_keys: ['image','shape','polys','ignore_tags'] keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size: 1 # must be 1 batch_size_per_card: 1 # must be 1
num_workers: 8 num_workers: 2
\ No newline at end of file \ No newline at end of file
Global: Global:
use_gpu: false use_gpu: true
epoch_num: 500 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/rec/mv3_none_bilstm_ctc/ save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
save_epoch_step: 500 save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 127 eval_batch_step: [0, 1000]
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
...@@ -16,12 +15,14 @@ Global: ...@@ -16,12 +15,14 @@ Global:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process # for data or label process
max_text_length: 80 character_dict_path:
character_dict_path: ppocr/utils/ppocr_keys_v1.txt character_type: en
character_type: 'ch' max_text_length: 25
use_space_char: False loss_type: ctc
infer_mode: False infer_mode: False
use_tps: False # use_space_char: True
# use_tps: False
Optimizer: Optimizer:
...@@ -29,27 +30,26 @@ Optimizer: ...@@ -29,27 +30,26 @@ Optimizer:
beta1: 0.9 beta1: 0.9
beta2: 0.999 beta2: 0.999
learning_rate: learning_rate:
lr: 0.001 lr: 0.0005
regularizer: regularizer:
name: 'L2' name: 'L2'
factor: 0.00001 factor: 0.00001
Architecture: Architecture:
type: rec model_type: rec
algorithm: CRNN algorithm: CRNN
Transform: Transform:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV3
scale: 0.5 scale: 0.5
model_name: small model_name: large
small_stride: [ 1, 2, 2, 2 ]
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: fc encoder_type: rnn
hidden_size: 96 hidden_size: 96
Head: Head:
name: CTC name: CTCHead
fc_decay: 0.00001 fc_decay: 0.0004
Loss: Loss:
name: CTCLoss name: CTCLoss
...@@ -61,46 +61,40 @@ Metric: ...@@ -61,46 +61,40 @@ Metric:
name: RecMetric name: RecMetric
main_indicator: acc main_indicator: acc
TRAIN: Train:
dataset: dataset:
name: SimpleDataSet name: LMDBDateSet
data_dir: ./rec data_dir: ./train_data/data_lmdb_release/training/
file_list:
- ./rec/train.txt # dataset1
ratio_list: [ 0.4,0.6 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- CTCLabelEncode: # Class handling label - CTCLabelEncode: # Class handling label
- RecAug:
- RecResizeImg: - RecResizeImg:
image_shape: [ 3,32,320 ] image_shape: [3, 32, 100]
- keepKeys: - KeepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
batch_size: 256 batch_size_per_card: 256
shuffle: True shuffle: False
drop_last: True drop_last: True
num_workers: 8 num_workers: 8
EVAL: Eval:
dataset: dataset:
name: SimpleDataSet name: LMDBDateSet
data_dir: ./rec data_dir: ./train_data/data_lmdb_release/validation/
file_list:
- ./rec/val.txt
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- CTCLabelEncode: # Class handling label - CTCLabelEncode: # Class handling label
- RecResizeImg: - RecResizeImg:
image_shape: [ 3,32,320 ] image_shape: [3, 32, 100]
- keepKeys: - KeepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size: 256 batch_size_per_card: 256
num_workers: 8 num_workers: 2
Global: Global:
use_gpu: true use_gpu: false
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 1 print_batch_step: 10
save_model_dir: ./output/rec/mv3_none_bilstm_ctc/ save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
save_epoch_step: 500 save_epoch_step: 500
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 1016 eval_batch_step: 127
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True load_static_weights: True
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: True use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process # for data or label process
max_text_length: 80 max_text_length: 80
character_dict_path: ppocr/utils/ppocr_keys_v1.txt character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: 'ch' character_type: 'ch'
use_space_char: True use_space_char: False
infer_mode: False infer_mode: False
use_tps: False use_tps: False
...@@ -29,7 +29,7 @@ Optimizer: ...@@ -29,7 +29,7 @@ Optimizer:
beta1: 0.9 beta1: 0.9
beta2: 0.999 beta2: 0.999
learning_rate: learning_rate:
lr: 0.0005 lr: 0.001
regularizer: regularizer:
name: 'L2' name: 'L2'
factor: 0.00001 factor: 0.00001
...@@ -45,8 +45,8 @@ Architecture: ...@@ -45,8 +45,8 @@ Architecture:
small_stride: [ 1, 2, 2, 2 ] small_stride: [ 1, 2, 2, 2 ]
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: fc
hidden_size: 48 hidden_size: 96
Head: Head:
name: CTC name: CTC
fc_decay: 0.00001 fc_decay: 0.00001
...@@ -63,9 +63,10 @@ Metric: ...@@ -63,9 +63,10 @@ Metric:
TRAIN: TRAIN:
dataset: dataset:
name: LMDBDateSet name: SimpleDataSet
data_dir: ./rec
file_list: file_list:
- ./rec/lmdb/train # dataset1 - ./rec/train.txt # dataset1
ratio_list: [ 0.4,0.6 ] ratio_list: [ 0.4,0.6 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -85,9 +86,10 @@ TRAIN: ...@@ -85,9 +86,10 @@ TRAIN:
EVAL: EVAL:
dataset: dataset:
name: LMDBDateSet name: SimpleDataSet
data_dir: ./rec
file_list: file_list:
- ./rec/lmdb/val - ./rec/val.txt
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
Global:
use_gpu: false
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/mv3_none_none_ctc/
save_epoch_step: 500
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 2000
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: True
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
max_text_length: 25
character_dict_path:
character_type: 'en'
use_space_char: False
infer_mode: False
use_tps: False
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
learning_rate:
lr: 0.0005
regularizer:
name: 'L2'
factor: 0.00001
Architecture:
type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
small_stride: [ 1, 2, 2, 2 ]
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTC
fc_decay: 0.00001
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
TRAIN:
dataset:
name: LMDBDateSet
file_list:
- ./rec/train # dataset1
ratio_list: [ 0.4,0.6 ]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecAug:
- RecResizeImg:
image_shape: [ 3,32,100 ]
- keepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
loader:
batch_size: 256
shuffle: True
drop_last: True
num_workers: 8
EVAL:
dataset:
name: LMDBDateSet
file_list:
- ./rec/val/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [ 3,32,100 ]
- keepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size: 256
num_workers: 8
English | [简体中文](README_cn.md)
## Introduction
Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment.
This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue)
## 1. Prerequisites
You need to install the following basic components first:
a. Docker
b. Graphics driver and CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this)
d. cuDNN 7.6+(GPU)
## 2. Build Image
a. Download PaddleOCR sourcecode
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword)
```
cd docker/cpu
```
c. Build image
```
docker build -t paddleocr:cpu .
```
## 3. Start container
a. CPU version
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU version (base on NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU version (Docker 19.03++)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/)
```
docker logs -f paddle_ocr
```
## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/)
b. Post a service request(sample request in sample_request.txt)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. Get resposne(If the call is successful, the following result will be returned)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
[English](README.md) | 简体中文
## Docker化部署服务
在日常项目应用中,相信大家一般都会希望能通过Docker技术,把PaddleOCR服务打包成一个镜像,以便在Docker或k8s环境里,快速发布上线使用。
本文将提供一些标准化的代码来实现这样的目标。大家通过如下步骤可以把PaddleOCR项目快速发布成可调用的Restful API服务。(目前暂时先实现了基于HubServing模式的部署,后续作者计划增加PaddleServing模式的部署)
## 1.实施前提准备
需要先完成如下基本组件的安装:
a. Docker环境
b. 显卡驱动和CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03以上版本可以跳过此步)
d. cuDNN 7.6+(GPU)
## 2.制作镜像
a.下载PaddleOCR项目代码
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可)
```
cd docker/cpu
```
c.生成镜像
```
docker build -t paddleocr:cpu .
```
## 3.启动Docker容器
a. CPU 版本
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU 版本 (通过NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU 版本 (Docker 19.03以上版本,可以直接用如下命令)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. 检查服务运行情况(出现:Successfully installed ocr_system和Running on http://0.0.0.0:8866/等信息,表示运行成功)
```
docker logs -f paddle_ocr
```
## 4.测试服务
a. 计算待识别图片的Base64编码(如果只是测试一下效果,可以通过免费的在线工具实现,如:http://tool.chinaz.com/tools/imgtobase/)
b. 发送服务请求(可参见sample_request.txt中的值)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"填入图片Base64编码(需要删除'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. 返回结果(如果调用成功,会返回如下结果)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
# Version: 1.0.0
FROM hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev
# PaddleOCR base on Python3.7
RUN pip3.7 install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN python3.7 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3.7 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN git clone https://gitee.com/PaddlePaddle/PaddleOCR
WORKDIR /PaddleOCR
RUN pip3.7 install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN mkdir -p /PaddleOCR/inference
# Download orc detect model(light version). if you want to change normal version, you can change ch_det_mv3_db_infer to ch_det_r50_vd_db_infer, also remember change det_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_det_mv3_db_infer.tar -C /PaddleOCR/inference
# Download orc recognition model(light version). If you want to change normal version, you can change ch_rec_mv3_crnn_infer to ch_rec_r34_vd_crnn_enhance_infer, also remember change rec_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_rec_mv3_crnn_infer.tar -C /PaddleOCR/inference
EXPOSE 8866
CMD ["/bin/bash","-c","export PYTHONPATH=. && hub install deploy/hubserving/ocr_system/ && hub serving start -m ocr_system"]
\ No newline at end of file
# Version: 1.0.0
FROM hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev
# PaddleOCR base on Python3.7
RUN pip3.7 install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN python3.7 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3.7 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN git clone https://gitee.com/PaddlePaddle/PaddleOCR
WORKDIR /home/PaddleOCR
RUN pip3.7 install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN mkdir -p /PaddleOCR/inference
# Download orc detect model(light version). if you want to change normal version, you can change ch_det_mv3_db_infer to ch_det_r50_vd_db_infer, also remember change det_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_det_mv3_db_infer.tar -C /PaddleOCR/inference
# Download orc recognition model(light version). If you want to change normal version, you can change ch_rec_mv3_crnn_infer to ch_rec_r34_vd_crnn_enhance_infer, also remember change rec_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_rec_mv3_crnn_infer.tar -C /PaddleOCR/inference
EXPOSE 8866
CMD ["/bin/bash","-c","export PYTHONPATH=. && hub install deploy/hubserving/ocr_system/ && hub serving start -m ocr_system"]
\ No newline at end of file
English | [简体中文](README_cn.md)
## Introduction
Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment.
This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue)
## 1. Prerequisites
You need to install the following basic components first:
a. Docker
b. Graphics driver and CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this)
d. cuDNN 7.6+(GPU)
## 2. Build Image
a. Download PaddleOCR sourcecode
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword)
```
cd docker/cpu
```
c. Build image
```
docker build -t paddleocr:cpu .
```
## 3. Start container
a. CPU version
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU version (base on NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU version (Docker 19.03++)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/)
```
docker logs -f paddle_ocr
```
## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/)
b. Post a service request(sample request in sample_request.txt)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. Get resposne(If the call is successful, the following result will be returned)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
此差异已折叠。
...@@ -21,104 +21,69 @@ import os ...@@ -21,104 +21,69 @@ import os
import sys import sys
import numpy as np import numpy as np
import paddle import paddle
import signal
import random
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy import copy
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
import paddle.distributed as dist import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
def term_mp(sig_num, frame):
""" kill all child processes
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
def build_dataset(config, global_config): signal.signal(signal.SIGINT, term_mp)
from ppocr.data.dataset import SimpleDataSet, LMDBDateSet signal.signal(signal.SIGTERM, term_mp)
support_dict = ['SimpleDataSet', 'LMDBDateSet']
module_name = config.pop('name') def build_dataloader(config, mode, device):
config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, global_config)
return dataset dataset = eval(module_name)(config, mode)
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
def build_dataloader(config, device, distributed=False, global_config=None): drop_last = loader_config['drop_last']
from ppocr.data.dataset import BatchBalancedDataLoader num_workers = loader_config['num_workers']
config = copy.deepcopy(config) if mode == "Train":
dataset_config = config['dataset'] #Distribute data to multiple cards
batch_sampler = DistributedBatchSampler(
_dataset_list = [] dataset=dataset,
file_list = dataset_config.pop('file_list') batch_size=batch_size,
if len(file_list) == 1: shuffle=False,
ratio_list = [1.0] drop_last=drop_last)
else: else:
ratio_list = dataset_config.pop('ratio_list') #Distribute data to single card
for file in file_list: batch_sampler = BatchSampler(
dataset_config['file_list'] = file dataset=dataset,
_dataset = build_dataset(dataset_config, global_config) batch_size=batch_size,
_dataset_list.append(_dataset) shuffle=False,
data_loader = BatchBalancedDataLoader(_dataset_list, ratio_list, drop_last=drop_last)
distributed, device, config['loader'])
return data_loader, _dataset.info_dict data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
def test_loader(): places=device,
import time num_workers=num_workers,
from tools.program import load_config, ArgsParser return_list=True)
FLAGS = ArgsParser().parse_args() return data_loader
config = load_config(FLAGS.config) #return data_loader, _dataset.info_dict
\ No newline at end of file
place = paddle.CPUPlace()
paddle.disable_static(place)
import time
data_loader, _ = build_dataloader(
config['TRAIN'], place, global_config=config['Global'])
start = time.time()
print(len(data_loader))
for epoch in range(1):
print('epoch {} ****************'.format(epoch))
for i, batch in enumerate(data_loader):
if i > len(data_loader):
break
t = time.time() - start
start = time.time()
print('{}, batch : {} ,time {}'.format(i, len(batch[0]), t))
continue
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import cv2
fig = plt.figure()
# # cv2.imwrite('img.jpg',batch[0].numpy()[0].transpose((1,2,0)))
# # cv2.imwrite('bmap.jpg',batch[1].numpy()[0])
# # cv2.imwrite('bmask.jpg',batch[2].numpy()[0])
# # cv2.imwrite('smap.jpg',batch[3].numpy()[0])
# # cv2.imwrite('smask.jpg',batch[4].numpy()[0])
plt.title('img')
plt.imshow(batch[0].numpy()[0].transpose((1, 2, 0)))
# plt.figure()
# plt.title('bmap')
# plt.imshow(batch[1].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('bmask')
# plt.imshow(batch[2].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('smap')
# plt.imshow(batch[3].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('smask')
# plt.imshow(batch[4].numpy()[0],cmap='Greys')
# plt.show()
# break
if __name__ == '__main__':
test_loader()
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import os
import lmdb
import random
import signal
import paddle
from paddle.io import Dataset, DataLoader, DistributedBatchSampler, BatchSampler
from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
def term_mp(sig_num, frame):
""" kill all child processes
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
class ModeException(Exception):
"""
ModeException
"""
def __init__(self, message='', mode=''):
message += "\nOnly the following 3 modes are supported: " \
"train, valid, test. Given mode is {}".format(mode)
super(ModeException, self).__init__(message)
class SampleNumException(Exception):
"""
SampleNumException
"""
def __init__(self, message='', sample_num=0, batch_size=1):
message += "\nError: The number of the whole data ({}) " \
"is smaller than the batch_size ({}), and drop_last " \
"is turnning on, so nothing will feed in program, " \
"Terminated now. Please reset batch_size to a smaller " \
"number or feed more data!".format(sample_num, batch_size)
super(SampleNumException, self).__init__(message)
def get_file_list(file_list, data_dir, delimiter='\t'):
"""
read label list from file and shuffle the list
Args:
params(dict):
"""
if isinstance(file_list, str):
file_list = [file_list]
data_source_list = []
for file in file_list:
with open(file) as f:
full_lines = [line.strip() for line in f]
for line in full_lines:
try:
img_path, label = line.split(delimiter)
except:
logger = get_logger()
logger.warning('label error in {}'.format(line))
img_path = os.path.join(data_dir, img_path)
data = {'img_path': img_path, 'label': label}
data_source_list.append(data)
return data_source_list
class LMDBDateSet(Dataset):
def __init__(self, config, global_config):
super(LMDBDateSet, self).__init__()
self.data_list = self.load_lmdb_dataset(
config['file_list'], global_config['max_text_length'])
random.shuffle(self.data_list)
self.ops = create_operators(config['transforms'], global_config)
# for rec
character = ''
for op in self.ops:
if hasattr(op, 'character'):
character = getattr(op, 'character')
self.info_dict = {'character': character}
def load_lmdb_dataset(self, data_dir, max_text_length):
self.env = lmdb.open(
data_dir,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not self.env:
print('cannot create lmdb from %s' % (data_dir))
exit(0)
filtered_index_list = []
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
self.nSamples = nSamples
for index in range(self.nSamples):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
if len(label) > max_text_length:
# print(f'The length of the label is longer than max_length: length
# {len(label)}, {label} in dataset {self.root}')
continue
# By default, images containing characters which are not in opt.character are filtered.
# You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
filtered_index_list.append(index)
return filtered_index_list
def print_lmdb_sets_info(self, lmdb_sets):
lmdb_info_strs = []
for dataset_idx in range(len(lmdb_sets)):
tmp_str = " %s:%d," % (lmdb_sets[dataset_idx]['dirpath'],
lmdb_sets[dataset_idx]['num_samples'])
lmdb_info_strs.append(tmp_str)
lmdb_info_strs = ''.join(lmdb_info_strs)
logger = get_logger()
logger.info("DataSummary:" + lmdb_info_strs)
return
def __getitem__(self, idx):
idx = self.data_list[idx]
with self.env.begin(write=False) as txn:
label_key = 'label-%09d'.encode() % idx
label = txn.get(label_key)
if label is not None:
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % idx
imgbuf = txn.get(img_key)
data = {'image': imgbuf, 'label': label}
outs = transform(data, self.ops)
else:
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_list)
class SimpleDataSet(Dataset):
def __init__(self, config, global_config):
super(SimpleDataSet, self).__init__()
delimiter = config.get('delimiter', '\t')
self.data_list = get_file_list(config['file_list'], config['data_dir'],
delimiter)
random.shuffle(self.data_list)
self.ops = create_operators(config['transforms'], global_config)
# for rec
character = ''
for op in self.ops:
if hasattr(op, 'character'):
character = getattr(op, 'character')
self.info_dict = {'character': character}
def __getitem__(self, idx):
data = copy.deepcopy(self.data_list[idx])
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_list)
class BatchBalancedDataLoader(object):
def __init__(self,
dataset_list: list,
ratio_list: list,
distributed,
device,
loader_args: dict):
"""
对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的
:param dataset_list: 数据集列表
:param ratio_list: 比例列表
:param loader_args: dataloader的配置
"""
assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list)
self.dataset_len = 0
self.data_loader_list = []
self.dataloader_iter_list = []
all_batch_size = loader_args.pop('batch_size')
batch_size_list = list(
map(int, [max(1.0, all_batch_size * x) for x in ratio_list]))
remain_num = all_batch_size - sum(batch_size_list)
batch_size_list[np.argmax(ratio_list)] += remain_num
for _dataset, _batch_size in zip(dataset_list, batch_size_list):
if distributed:
batch_sampler_class = DistributedBatchSampler
else:
batch_sampler_class = BatchSampler
batch_sampler = batch_sampler_class(
dataset=_dataset,
batch_size=_batch_size,
shuffle=loader_args['shuffle'],
drop_last=loader_args['drop_last'], )
_data_loader = DataLoader(
dataset=_dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=loader_args['num_workers'],
return_list=True, )
self.data_loader_list.append(_data_loader)
self.dataloader_iter_list.append(iter(_data_loader))
self.dataset_len += len(_dataset)
def __iter__(self):
return self
def __len__(self):
return min([len(x) for x in self.data_loader_list])
def __next__(self):
batch = []
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
try:
_batch_i = next(data_loader_iter)
batch.append(_batch_i)
except StopIteration:
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
_batch_i = next(self.dataloader_iter_list[i])
batch.append(_batch_i)
except ValueError:
pass
if len(batch) > 0:
batch_list = []
batch_item_size = len(batch[0])
for i in range(batch_item_size):
cur_item_list = [batch_i[i] for batch_i in batch]
batch_list.append(paddle.concat(cur_item_list, axis=0))
else:
batch_list = batch[0]
return batch_list
def fill_batch(batch):
"""
2020.09.08: The current paddle version only supports returning data with the same length.
Therefore, fill in the batches with inconsistent lengths.
this method is currently only useful for text detection
"""
keys = list(range(len(batch[0])))
v_max_len_dict = {}
for k in keys:
v_max_len_dict[k] = max([len(item[k]) for item in batch])
for item in batch:
length = []
for k in keys:
v = item[k]
length.append(len(v))
assert isinstance(v, np.ndarray)
if len(v) == v_max_len_dict[k]:
continue
try:
tmp_shape = [v_max_len_dict[k] - len(v)] + list(v[0].shape)
except:
a = 1
tmp_array = np.zeros(tmp_shape, dtype=v[0].dtype)
new_array = np.concatenate([v, tmp_array])
item[k] = new_array
item.append(length)
return batch
...@@ -148,6 +148,8 @@ class CTCLabelEncode(BaseRecLabelEncode): ...@@ -148,6 +148,8 @@ class CTCLabelEncode(BaseRecLabelEncode):
text = self.encode(text) text = self.encode(text)
if text is None: if text is None:
return None return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text)) data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text)) text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text) data['label'] = np.array(text)
......
...@@ -29,7 +29,7 @@ class MakeBorderMap(object): ...@@ -29,7 +29,7 @@ class MakeBorderMap(object):
self.thresh_min = thresh_min self.thresh_min = thresh_min
self.thresh_max = thresh_max self.thresh_max = thresh_max
def __call__(self, data: dict) -> dict: def __call__(self, data):
img = data['image'] img = data['image']
text_polys = data['polys'] text_polys = data['polys']
......
...@@ -99,7 +99,7 @@ class ToCHWImage(object): ...@@ -99,7 +99,7 @@ class ToCHWImage(object):
return data return data
class keepKeys(object): class KeepKeys(object):
def __init__(self, keep_keys, **kwargs): def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys self.keep_keys = keep_keys
......
...@@ -50,16 +50,14 @@ class RecResizeImg(object): ...@@ -50,16 +50,14 @@ class RecResizeImg(object):
image_shape, image_shape,
infer_mode=False, infer_mode=False,
character_type='ch', character_type='ch',
use_tps=False,
**kwargs): **kwargs):
self.image_shape = image_shape self.image_shape = image_shape
self.infer_mode = infer_mode self.infer_mode = infer_mode
self.character_type = character_type self.character_type = character_type
self.use_tps = use_tps
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
if self.infer_mode and self.character_type == "ch" and not self.use_tps: if self.infer_mode and self.character_type == "ch":
norm_img = resize_norm_img_chinese(img, self.image_shape) norm_img = resize_norm_img_chinese(img, self.image_shape)
else: else:
norm_img = resize_norm_img(img, self.image_shape) norm_img = resize_norm_img(img, self.image_shape)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import os
import random
import paddle
from paddle.io import Dataset
import time
import lmdb
import cv2
from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
logger = get_logger()
class LMDBDateSet(Dataset):
def __init__(self, config, mode):
super(LMDBDateSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
logger.info("Initialize indexs of datasets:%s" % data_dir)
self.data_idx_order_list = self.dataset_traversal()
if self.do_shuffle:
np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config)
# # for rec
# character = ''
# for op in self.ops:
# if hasattr(op, 'character'):
# character = getattr(op, 'character')
# self.info_dict = {'character': character}
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
if not dirnames:
env = lmdb.open(
dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
txn = env.begin(write=False)
num_samples = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
"txn":txn, "num_samples":num_samples}
dataset_idx += 1
return lmdb_sets
def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets)
total_sample_num = 0
for lno in range(lmdb_num):
total_sample_num += self.lmdb_sets[lno]['num_samples']
data_idx_order_list = np.zeros((total_sample_num, 2))
beg_idx = 0
for lno in range(lmdb_num):
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
end_idx = beg_idx + tmp_sample_num
data_idx_order_list[beg_idx:end_idx, 0] = lno
data_idx_order_list[beg_idx:end_idx, 1] \
= list(range(tmp_sample_num))
data_idx_order_list[beg_idx:end_idx, 1] += 1
beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list
def get_img_data(self, value):
"""get_img_data"""
if not value:
return None
imgdata = np.frombuffer(value, dtype='uint8')
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
if imgori is None:
return None
return imgori
def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
return imgbuf, label
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info
data = {'image': img, 'label': label}
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return self.data_idx_order_list.shape[0]
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import os
import random
import paddle
from paddle.io import Dataset
import time
from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
logger = get_logger()
class SimpleDataSet(Dataset):
def __init__(self, config, mode):
super(SimpleDataSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
if data_source_num == 1:
ratio_list = [1.0]
else:
ratio_list = dataset_config.pop('ratio_list')
assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1."
assert len(ratio_list) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines_list, data_num_list = self.get_image_info_list(
label_file_list)
self.data_idx_order_list = self.dataset_traversal(
data_num_list, ratio_list, batch_size)
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def get_image_info_list(self, file_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines_list = []
data_num_list = []
for file in file_list:
with open(file, "rb") as f:
lines = f.readlines()
data_lines_list.append(lines)
data_num_list.append(len(lines))
return data_lines_list, data_num_list
def dataset_traversal(self, data_num_list, ratio_list, batch_size):
select_num_list = []
dataset_num = len(data_num_list)
for dno in range(dataset_num):
select_num = round(batch_size * ratio_list[dno])
select_num = max(select_num, 1)
select_num_list.append(select_num)
data_idx_order_list = []
cur_index_sets = [0] * dataset_num
while True:
finish_read_num = 0
for dataset_idx in range(dataset_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
finish_read_num += 1
else:
select_num = select_num_list[dataset_idx]
for sno in range(select_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
break
data_idx_order_list.append((
dataset_idx, cur_index))
cur_index_sets[dataset_idx] += 1
if finish_read_num == dataset_num:
break
return data_idx_order_list
def shuffle_data_random(self):
if self.do_shuffle:
for dno in range(len(self.data_lines_list)):
random.shuffle(self.data_lines_list[dno])
return
def __getitem__(self, idx):
dataset_idx, file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines_list[dataset_idx][file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from .losses import build_loss
__all__ = ['build_model', 'build_loss']
def build_model(config):
from .architectures import Model
config = copy.deepcopy(config)
module_class = Model(config)
return module_class
...@@ -12,5 +12,13 @@ ...@@ -12,5 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .model import Model import copy
__all__ = ['Model']
\ No newline at end of file __all__ = ['build_model']
def build_model(config):
from .base_model import BaseModel
config = copy.deepcopy(config)
module_class = BaseModel(config)
return module_class
\ No newline at end of file
...@@ -15,38 +15,29 @@ from __future__ import absolute_import ...@@ -15,38 +15,29 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os, sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append('/home/zhoujun20/PaddleOCR')
from paddle import nn from paddle import nn
from ppocr.modeling.transform import build_transform
from ppocr.modeling.backbones import build_backbone from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head from ppocr.modeling.heads import build_head
__all__ = ['Model'] __all__ = ['BaseModel']
class BaseModel(nn.Layer):
class Model(nn.Layer):
def __init__(self, config): def __init__(self, config):
""" """
Detection module for OCR. the module for OCR.
args: args:
config (dict): the super parameters for module. config (dict): the super parameters for module.
""" """
super(Model, self).__init__() super(BaseModel, self).__init__()
algorithm = config['algorithm']
self.type = config['type']
self.model_name = '{}_{}'.format(self.type, algorithm)
in_channels = config.get('in_channels', 3) in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom, # build transfrom,
# for rec, transfrom can be TPS,None # for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None, # for det and cls, transfrom shoule to be None,
# if you make model differently, you can use transfrom in det and cls # if you make model differently, you can use transfrom in det and cls
if 'Transform' not in config or config['Transform'] is None: if 'Transform' not in config or config['Transform'] is None:
self.use_transform = False self.use_transform = False
else: else:
...@@ -57,9 +48,9 @@ class Model(nn.Layer): ...@@ -57,9 +48,9 @@ class Model(nn.Layer):
# build backbone, backbone is need for del, rec and cls # build backbone, backbone is need for del, rec and cls
config["Backbone"]['in_channels'] = in_channels config["Backbone"]['in_channels'] = in_channels
self.backbone = build_backbone(config["Backbone"], self.type) self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels in_channels = self.backbone.out_channels
# build neck # build neck
# for rec, neck can be cnn,rnn or reshape(None) # for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on. # for det, neck can be FPN, BIFPN and so on.
...@@ -71,6 +62,7 @@ class Model(nn.Layer): ...@@ -71,6 +62,7 @@ class Model(nn.Layer):
config['Neck']['in_channels'] = in_channels config['Neck']['in_channels'] = in_channels
self.neck = build_neck(config['Neck']) self.neck = build_neck(config['Neck'])
in_channels = self.neck.out_channels in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls # # build head, head is need for det, rec and cls
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
......
...@@ -19,7 +19,6 @@ def build_backbone(config, model_type): ...@@ -19,7 +19,6 @@ def build_backbone(config, model_type):
if model_type == 'det': if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet from .det_resnet_vd import ResNet
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST'] support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec': elif model_type == 'rec':
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
......
...@@ -130,7 +130,6 @@ class MobileNetV3(nn.Layer): ...@@ -130,7 +130,6 @@ class MobileNetV3(nn.Layer):
if_act=True, if_act=True,
act='hard_swish', act='hard_swish',
name='conv_last')) name='conv_last'))
self.stages.append(nn.Sequential(*block_list)) self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages): for i, stage in enumerate(self.stages):
...@@ -275,4 +274,4 @@ class SEModule(nn.Layer): ...@@ -275,4 +274,4 @@ class SEModule(nn.Layer):
outputs = F.relu(outputs) outputs = F.relu(outputs)
outputs = self.conv2(outputs) outputs = self.conv2(outputs)
outputs = F.hard_sigmoid(outputs) outputs = F.hard_sigmoid(outputs)
return inputs * outputs return inputs * outputs
\ No newline at end of file
...@@ -20,8 +20,8 @@ def build_head(config): ...@@ -20,8 +20,8 @@ def build_head(config):
from .det_db_head import DBHead from .det_db_head import DBHead
# rec head # rec head
from .rec_ctc_head import CTC from .rec_ctc_head import CTCHead
support_dict = ['DBHead', 'CTC'] support_dict = ['DBHead', 'CTCHead']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format( assert module_name in support_dict, Exception('head only support {}'.format(
......
...@@ -33,10 +33,9 @@ def get_para_bias_attr(l2_decay, k, name): ...@@ -33,10 +33,9 @@ def get_para_bias_attr(l2_decay, k, name):
regularizer=regularizer, initializer=initializer, name=name + "_b_attr") regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
return [weight_attr, bias_attr] return [weight_attr, bias_attr]
class CTCHead(nn.Layer):
class CTC(nn.Layer): def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
def __init__(self, in_channels, out_channels, fc_decay=1e-5, **kwargs): super(CTCHead, self).__init__()
super(CTC, self).__init__()
weight_attr, bias_attr = get_para_bias_attr( weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc') l2_decay=fc_decay, k=in_channels, name='ctc_fc')
self.fc = nn.Linear( self.fc = nn.Linear(
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
__all__ = ['build_neck'] __all__ = ['build_neck']
def build_neck(config): def build_neck(config):
from .fpn import FPN from .db_fpn import DBFPN
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
support_dict = ['FPN', 'SequenceEncoder'] support_dict = ['DBFPN', 'SequenceEncoder']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
...@@ -22,9 +22,9 @@ import paddle.nn.functional as F ...@@ -22,9 +22,9 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
class FPN(nn.Layer): class DBFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super(FPN, self).__init__() super(DBFPN, self).__init__()
self.out_channels = out_channels self.out_channels = out_channels
weight_attr = paddle.nn.initializer.MSRA(uniform=False) weight_attr = paddle.nn.initializer.MSRA(uniform=False)
......
...@@ -76,8 +76,7 @@ class SequenceEncoder(nn.Layer): ...@@ -76,8 +76,7 @@ class SequenceEncoder(nn.Layer):
'fc': EncoderWithFC, 'fc': EncoderWithFC,
'rnn': EncoderWithRNN 'rnn': EncoderWithRNN
} }
assert encoder_type in support_encoder_dict, '{} must in {}'.format( assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys())
encoder_type, support_encoder_dict.keys())
self.encoder = support_encoder_dict[encoder_type]( self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size) self.encoder_reshape.out_channels, hidden_size)
......
...@@ -50,6 +50,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -50,6 +50,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step3 build optimizer # step3 build optimizer
optim_name = config.pop('name') optim_name = config.pop('name')
# Regularization is invalid. The bug will be fixed in paddle-rc. The param is
# weight_decay.
optim = getattr(optimizer, optim_name)(learning_rate=lr, optim = getattr(optimizer, optim_name)(learning_rate=lr,
regularization=reg, regularization=reg,
**config) **config)
......
...@@ -40,8 +40,8 @@ class Momentum(object): ...@@ -40,8 +40,8 @@ class Momentum(object):
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
parameters=self.weight_decay, parameters=parameters,
weight_decay=parameters) weight_decay=self.weight_decay)
return opt return opt
......
...@@ -24,8 +24,8 @@ __all__ = ['build_post_process'] ...@@ -24,8 +24,8 @@ __all__ = ['build_post_process']
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess from .db_postprocess import DBPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
support_dict = ['DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode'] support_dict = ['DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode']
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -46,7 +46,7 @@ def load_dygraph_pretrain( ...@@ -46,7 +46,7 @@ def load_dygraph_pretrain(
model, model,
logger, logger,
path=None, path=None,
load_static_weights=False, ): load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
...@@ -110,21 +110,20 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -110,21 +110,20 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
load_static_weights = gloabl_config.get('load_static_weights', False) load_static_weights = gloabl_config.get('load_static_weights', False)
if pretrained_model: if not isinstance(pretrained_model, list):
if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model]
pretrained_model = [pretrained_model] if not isinstance(load_static_weights, list):
if not isinstance(load_static_weights, list): load_static_weights = [load_static_weights] * len(
load_static_weights = [load_static_weights] * len( pretrained_model)
pretrained_model) for idx, pretrained in enumerate(pretrained_model):
for idx, pretrained in enumerate(pretrained_model): load_static = load_static_weights[idx]
load_static = load_static_weights[idx] load_dygraph_pretrain(
load_dygraph_pretrain( model,
model, logger,
logger, path=pretrained,
path=pretrained, load_static_weights=load_static)
load_static_weights=load_static) logger.info("load pretrained model from {}".format(
logger.info("load pretrained model from {}".format( pretrained_model))
pretrained_model))
else: else:
logger.info('train from scratch') logger.info('train from scratch')
return best_model_dict return best_model_dict
......
...@@ -28,7 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter ...@@ -28,7 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
def __init__(self): def __init__(self):
...@@ -136,18 +139,18 @@ def check_gpu(use_gpu): ...@@ -136,18 +139,18 @@ def check_gpu(use_gpu):
def train(config, def train(config,
train_dataloader,
valid_dataloader,
device,
model, model,
loss_class, loss_class,
optimizer, optimizer,
lr_scheduler, lr_scheduler,
train_dataloader,
valid_dataloader,
post_process_class, post_process_class,
eval_class, eval_class,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
vdl_writer=None): vdl_writer=None):
global_step = 0
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
...@@ -156,6 +159,7 @@ def train(config, ...@@ -156,6 +159,7 @@ def train(config,
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
global_step = 0
start_eval_step = 0 start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
...@@ -179,14 +183,15 @@ def train(config, ...@@ -179,14 +183,15 @@ def train(config,
start_epoch = 0 start_epoch = 0
for epoch in range(start_epoch, epoch_num): for epoch in range(start_epoch, epoch_num):
if epoch > 0:
train_loader = build_dataloader(config, 'Train', device)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
if idx >= len(train_dataloader): if idx >= len(train_dataloader):
break break
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
lr = optimizer.get_lr() lr = optimizer.get_lr()
t1 = time.time() t1 = time.time()
batch = [paddle.to_variable(x) for x in batch] batch = [paddle.to_tensor(x) for x in batch]
images = batch[0] images = batch[0]
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -199,6 +204,8 @@ def train(config, ...@@ -199,6 +204,8 @@ def train(config,
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
# logger and visualdl # logger and visualdl
stats = {k: v.numpy().mean() for k, v in loss.items()} stats = {k: v.numpy().mean() for k, v in loss.items()}
...@@ -228,8 +235,8 @@ def train(config, ...@@ -228,8 +235,8 @@ def train(config,
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metirc = eval(model, valid_dataloader,
eval_class) post_process_class, eval_class, logger, print_batch_step)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) ['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str) logger.info(cur_metirc_str)
...@@ -291,12 +298,14 @@ def train(config, ...@@ -291,12 +298,14 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class): def eval(model, valid_dataloader,
post_process_class, eval_class,
logger, print_batch_step):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model: ') # pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader): if idx >= len(valid_dataloader):
break break
...@@ -310,11 +319,14 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -310,11 +319,14 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
eval_class(post_result, batch) eval_class(post_result, batch)
pbar.update(1) # pbar.update(1)
total_frame += len(images) total_frame += len(images)
if idx % print_batch_step == 0:
logger.info('tackling images for eval: {}/{}'.format(
idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean # Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric() metirc = eval_class.get_metric()
pbar.close() # pbar.close()
model.train() model.train()
metirc['fps'] = total_frame / total_time metirc['fps'] = total_frame / total_time
return metirc return metirc
...@@ -336,4 +348,25 @@ def preprocess(): ...@@ -336,4 +348,25 @@ def preprocess():
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
return device, config
config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True)
vdl_writer = LogWriter(logdir=vdl_writer_path)
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, vdl_writer
...@@ -31,7 +31,8 @@ paddle.manual_seed(2) ...@@ -31,7 +31,8 @@ paddle.manual_seed(2)
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling import build_model, build_loss from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
...@@ -48,95 +49,76 @@ def main(config, device, logger, vdl_writer): ...@@ -48,95 +49,76 @@ def main(config, device, logger, vdl_writer):
dist.init_parallel_env() dist.init_parallel_env()
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
train_loader, train_info_dict = build_dataloader( train_dataloader = build_dataloader(config, 'Train', device)
config['TRAIN'], device, global_config['distributed'], global_config) if config['Eval']:
if config['EVAL']: valid_dataloader = build_dataloader(config, 'Eval', device)
eval_loader, _ = build_dataloader(config['EVAL'], device, False,
global_config)
else: else:
eval_loader = None valid_dataloader = None
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(
global_config) config['PostProcess'], global_config)
# build model # build model
# for rec algorithm #for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len( char_num = len(getattr(post_process_class, 'character'))
getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
# build loss
loss_class = build_loss(config['Loss'])
# build optim # build optim
optimizer, lr_scheduler = build_optimizer( optimizer, lr_scheduler = build_optimizer(config['Optimizer'],
config['Optimizer'],
epochs=config['Global']['epoch_num'], epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_loader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) parameters=model.parameters())
best_model_dict = init_model(config, model, logger, optimizer)
# build loss
loss_class = build_loss(config['Loss'])
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
# start train # start train
program.train(config, model, loss_class, optimizer, lr_scheduler, program.train(config,
train_loader, eval_loader, post_process_class, eval_class, train_dataloader,
best_model_dict, logger, vdl_writer) valid_dataloader,
device,
model,
def test_reader(config, place, logger, global_config): loss_class,
train_loader, _ = build_dataloader( optimizer,
config['TRAIN'], place, global_config=global_config) lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
vdl_writer)
def test_reader(config, device, logger):
loader = build_dataloader(config, 'Train', device)
# loader = build_dataloader(config, 'Eval', device)
import time import time
starttime = time.time() starttime = time.time()
count = 0 count = 0
try: try:
for data in train_loader: for data in loader():
count += 1 count += 1
if count % 1 == 0: if count % 1 == 0:
batch_time = time.time() - starttime batch_time = time.time() - starttime
starttime = time.time() starttime = time.time()
logger.info("reader: {}, {}, {}".format( logger.info("reader: {}, {}, {}".format(count, len(data), batch_time))
count, len(data[0]), batch_time))
except Exception as e: except Exception as e:
import traceback
traceback.print_exc()
logger.info(e) logger.info(e)
logger.info("finish reader: {}, Success!".format(count)) logger.info("finish reader: {}, Success!".format(count))
def dis_main():
device, config = program.preprocess()
config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config
os.makedirs(config['Global']['save_model_dir'], exist_ok=True)
with open(
os.path.join(config['Global']['save_model_dir'], 'config.yml'),
'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(
log_file='{}/train.log'.format(config['Global']['save_model_dir']))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer = LogWriter(logdir=config['Global']['save_model_dir'])
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger, config['Global'])
if __name__ == '__main__': if __name__ == '__main__':
# main() config, device, logger, vdl_writer = program.preprocess()
# dist.spawn(dis_main, nprocs=2, selelcted_gpus='6,7') main(config, device, logger, vdl_writer)
dis_main() # test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册