提交 fa675f89 编写于 作者: D dyning

updata structure of dygraph

上级 7d09cd19
......@@ -6,29 +6,19 @@ Global:
save_model_dir: ./output/db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 8
eval_batch_step: [4000, 5000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy
save_inference_dir:
use_visualdl: True
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
learning_rate:
lr: 0.001
regularizer:
name: 'L2'
factor: 0
Architecture:
type: det
model_type: det
algorithm: DB
Transform:
Backbone:
......@@ -36,7 +26,7 @@ Architecture:
scale: 0.5
model_name: large
Neck:
name: FPN
name: DBFPN
out_channels: 256
Head:
name: DBHead
......@@ -49,6 +39,18 @@ Loss:
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
learning_rate:
# name: Cosine
lr: 0.001
# warmup_epoch: 0
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DBPostProcess
......@@ -61,13 +63,13 @@ Metric:
name: DetMetric
main_indicator: hmean
TRAIN:
Train:
dataset:
name: SimpleDataSet
data_dir: ./detection/
file_list:
- ./detection/train_icdar2015_label.txt # dataset1
ratio_list: [1.0]
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [0.5]
transforms:
- DecodeImage: # load image
img_mode: BGR
......@@ -76,10 +78,10 @@ TRAIN:
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [ -10,10 ] } }
- { 'type': Resize,'args': { 'size': [ 0.5,3 ] } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [ 640,640 ]
size: [640, 640]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
......@@ -91,41 +93,41 @@ TRAIN:
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- keepKeys:
keep_keys: ['image','threshold_map','threshold_mask','shrink_map','shrink_mask'] # dataloader will return list in this order
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size: 16
batch_size_per_card: 4
num_workers: 8
EVAL:
Eval:
dataset:
name: SimpleDataSet
data_dir: ./detection/
file_list:
- ./detection/test_icdar2015_label.txt
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
image_shape: [736,1280]
image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- keepKeys:
keep_keys: ['image','shape','polys','ignore_tags']
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size: 1 # must be 1
num_workers: 8
\ No newline at end of file
batch_size_per_card: 1 # must be 1
num_workers: 2
\ No newline at end of file
Global:
use_gpu: false
epoch_num: 500
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
save_epoch_step: 500
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 127
eval_batch_step: [0, 1000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: True
pretrained_model:
checkpoints:
......@@ -16,12 +15,14 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
max_text_length: 80
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: 'ch'
use_space_char: False
character_dict_path:
character_type: en
max_text_length: 25
loss_type: ctc
infer_mode: False
use_tps: False
# use_space_char: True
# use_tps: False
Optimizer:
......@@ -29,27 +30,26 @@ Optimizer:
beta1: 0.9
beta2: 0.999
learning_rate:
lr: 0.001
lr: 0.0005
regularizer:
name: 'L2'
factor: 0.00001
Architecture:
type: rec
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride: [ 1, 2, 2, 2 ]
model_name: large
Neck:
name: SequenceEncoder
encoder_type: fc
encoder_type: rnn
hidden_size: 96
Head:
name: CTC
fc_decay: 0.00001
name: CTCHead
fc_decay: 0.0004
Loss:
name: CTCLoss
......@@ -61,46 +61,40 @@ Metric:
name: RecMetric
main_indicator: acc
TRAIN:
Train:
dataset:
name: SimpleDataSet
data_dir: ./rec
file_list:
- ./rec/train.txt # dataset1
ratio_list: [ 0.4,0.6 ]
name: LMDBDateSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecAug:
- RecResizeImg:
image_shape: [ 3,32,320 ]
- keepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
image_shape: [3, 32, 100]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
batch_size: 256
shuffle: True
batch_size_per_card: 256
shuffle: False
drop_last: True
num_workers: 8
EVAL:
Eval:
dataset:
name: SimpleDataSet
data_dir: ./rec
file_list:
- ./rec/val.txt
name: LMDBDateSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [ 3,32,320 ]
- keepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
image_shape: [3, 32, 100]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size: 256
num_workers: 8
batch_size_per_card: 256
num_workers: 2
Global:
use_gpu: true
use_gpu: false
epoch_num: 500
log_smooth_window: 20
print_batch_step: 1
print_batch_step: 10
save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
save_epoch_step: 500
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 1016
eval_batch_step: 127
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: True
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
max_text_length: 80
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: 'ch'
use_space_char: True
use_space_char: False
infer_mode: False
use_tps: False
......@@ -29,7 +29,7 @@ Optimizer:
beta1: 0.9
beta2: 0.999
learning_rate:
lr: 0.0005
lr: 0.001
regularizer:
name: 'L2'
factor: 0.00001
......@@ -45,8 +45,8 @@ Architecture:
small_stride: [ 1, 2, 2, 2 ]
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
encoder_type: fc
hidden_size: 96
Head:
name: CTC
fc_decay: 0.00001
......@@ -63,9 +63,10 @@ Metric:
TRAIN:
dataset:
name: LMDBDateSet
name: SimpleDataSet
data_dir: ./rec
file_list:
- ./rec/lmdb/train # dataset1
- ./rec/train.txt # dataset1
ratio_list: [ 0.4,0.6 ]
transforms:
- DecodeImage: # load image
......@@ -85,9 +86,10 @@ TRAIN:
EVAL:
dataset:
name: LMDBDateSet
name: SimpleDataSet
data_dir: ./rec
file_list:
- ./rec/lmdb/val
- ./rec/val.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
......
Global:
use_gpu: false
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/mv3_none_none_ctc/
save_epoch_step: 500
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: 2000
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: True
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
max_text_length: 25
character_dict_path:
character_type: 'en'
use_space_char: False
infer_mode: False
use_tps: False
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
learning_rate:
lr: 0.0005
regularizer:
name: 'L2'
factor: 0.00001
Architecture:
type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
small_stride: [ 1, 2, 2, 2 ]
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTC
fc_decay: 0.00001
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
TRAIN:
dataset:
name: LMDBDateSet
file_list:
- ./rec/train # dataset1
ratio_list: [ 0.4,0.6 ]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecAug:
- RecResizeImg:
image_shape: [ 3,32,100 ]
- keepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
loader:
batch_size: 256
shuffle: True
drop_last: True
num_workers: 8
EVAL:
dataset:
name: LMDBDateSet
file_list:
- ./rec/val/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [ 3,32,100 ]
- keepKeys:
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size: 256
num_workers: 8
English | [简体中文](README_cn.md)
## Introduction
Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment.
This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue)
## 1. Prerequisites
You need to install the following basic components first:
a. Docker
b. Graphics driver and CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this)
d. cuDNN 7.6+(GPU)
## 2. Build Image
a. Download PaddleOCR sourcecode
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword)
```
cd docker/cpu
```
c. Build image
```
docker build -t paddleocr:cpu .
```
## 3. Start container
a. CPU version
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU version (base on NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU version (Docker 19.03++)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/)
```
docker logs -f paddle_ocr
```
## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/)
b. Post a service request(sample request in sample_request.txt)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. Get resposne(If the call is successful, the following result will be returned)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
[English](README.md) | 简体中文
## Docker化部署服务
在日常项目应用中,相信大家一般都会希望能通过Docker技术,把PaddleOCR服务打包成一个镜像,以便在Docker或k8s环境里,快速发布上线使用。
本文将提供一些标准化的代码来实现这样的目标。大家通过如下步骤可以把PaddleOCR项目快速发布成可调用的Restful API服务。(目前暂时先实现了基于HubServing模式的部署,后续作者计划增加PaddleServing模式的部署)
## 1.实施前提准备
需要先完成如下基本组件的安装:
a. Docker环境
b. 显卡驱动和CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03以上版本可以跳过此步)
d. cuDNN 7.6+(GPU)
## 2.制作镜像
a.下载PaddleOCR项目代码
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可)
```
cd docker/cpu
```
c.生成镜像
```
docker build -t paddleocr:cpu .
```
## 3.启动Docker容器
a. CPU 版本
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU 版本 (通过NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU 版本 (Docker 19.03以上版本,可以直接用如下命令)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. 检查服务运行情况(出现:Successfully installed ocr_system和Running on http://0.0.0.0:8866/等信息,表示运行成功)
```
docker logs -f paddle_ocr
```
## 4.测试服务
a. 计算待识别图片的Base64编码(如果只是测试一下效果,可以通过免费的在线工具实现,如:http://tool.chinaz.com/tools/imgtobase/)
b. 发送服务请求(可参见sample_request.txt中的值)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"填入图片Base64编码(需要删除'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. 返回结果(如果调用成功,会返回如下结果)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
# Version: 1.0.0
FROM hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev
# PaddleOCR base on Python3.7
RUN pip3.7 install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN python3.7 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3.7 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN git clone https://gitee.com/PaddlePaddle/PaddleOCR
WORKDIR /PaddleOCR
RUN pip3.7 install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN mkdir -p /PaddleOCR/inference
# Download orc detect model(light version). if you want to change normal version, you can change ch_det_mv3_db_infer to ch_det_r50_vd_db_infer, also remember change det_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_det_mv3_db_infer.tar -C /PaddleOCR/inference
# Download orc recognition model(light version). If you want to change normal version, you can change ch_rec_mv3_crnn_infer to ch_rec_r34_vd_crnn_enhance_infer, also remember change rec_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_rec_mv3_crnn_infer.tar -C /PaddleOCR/inference
EXPOSE 8866
CMD ["/bin/bash","-c","export PYTHONPATH=. && hub install deploy/hubserving/ocr_system/ && hub serving start -m ocr_system"]
\ No newline at end of file
# Version: 1.0.0
FROM hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev
# PaddleOCR base on Python3.7
RUN pip3.7 install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN python3.7 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip3.7 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN git clone https://gitee.com/PaddlePaddle/PaddleOCR
WORKDIR /home/PaddleOCR
RUN pip3.7 install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN mkdir -p /PaddleOCR/inference
# Download orc detect model(light version). if you want to change normal version, you can change ch_det_mv3_db_infer to ch_det_r50_vd_db_infer, also remember change det_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_det_mv3_db_infer.tar -C /PaddleOCR/inference
# Download orc recognition model(light version). If you want to change normal version, you can change ch_rec_mv3_crnn_infer to ch_rec_r34_vd_crnn_enhance_infer, also remember change rec_model_dir in deploy/hubserving/ocr_system/params.py)
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar /PaddleOCR/inference
RUN tar xf /PaddleOCR/inference/ch_rec_mv3_crnn_infer.tar -C /PaddleOCR/inference
EXPOSE 8866
CMD ["/bin/bash","-c","export PYTHONPATH=. && hub install deploy/hubserving/ocr_system/ && hub serving start -m ocr_system"]
\ No newline at end of file
English | [简体中文](README_cn.md)
## Introduction
Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment.
This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue)
## 1. Prerequisites
You need to install the following basic components first:
a. Docker
b. Graphics driver and CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this)
d. cuDNN 7.6+(GPU)
## 2. Build Image
a. Download PaddleOCR sourcecode
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword)
```
cd docker/cpu
```
c. Build image
```
docker build -t paddleocr:cpu .
```
## 3. Start container
a. CPU version
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU version (base on NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU version (Docker 19.03++)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/)
```
docker logs -f paddle_ocr
```
## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/)
b. Post a service request(sample request in sample_request.txt)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. Get resposne(If the call is successful, the following result will be returned)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
此差异已折叠。
......@@ -21,104 +21,69 @@ import os
import sys
import numpy as np
import paddle
import signal
import random
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
def term_mp(sig_num, frame):
""" kill all child processes
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
def build_dataset(config, global_config):
from ppocr.data.dataset import SimpleDataSet, LMDBDateSet
support_dict = ['SimpleDataSet', 'LMDBDateSet']
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
module_name = config.pop('name')
def build_dataloader(config, mode, device):
config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
dataset = eval(module_name)(config, global_config)
return dataset
def build_dataloader(config, device, distributed=False, global_config=None):
from ppocr.data.dataset import BatchBalancedDataLoader
config = copy.deepcopy(config)
dataset_config = config['dataset']
_dataset_list = []
file_list = dataset_config.pop('file_list')
if len(file_list) == 1:
ratio_list = [1.0]
assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode)
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last']
num_workers = loader_config['num_workers']
if mode == "Train":
#Distribute data to multiple cards
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
else:
ratio_list = dataset_config.pop('ratio_list')
for file in file_list:
dataset_config['file_list'] = file
_dataset = build_dataset(dataset_config, global_config)
_dataset_list.append(_dataset)
data_loader = BatchBalancedDataLoader(_dataset_list, ratio_list,
distributed, device, config['loader'])
return data_loader, _dataset.info_dict
def test_loader():
import time
from tools.program import load_config, ArgsParser
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
place = paddle.CPUPlace()
paddle.disable_static(place)
import time
data_loader, _ = build_dataloader(
config['TRAIN'], place, global_config=config['Global'])
start = time.time()
print(len(data_loader))
for epoch in range(1):
print('epoch {} ****************'.format(epoch))
for i, batch in enumerate(data_loader):
if i > len(data_loader):
break
t = time.time() - start
start = time.time()
print('{}, batch : {} ,time {}'.format(i, len(batch[0]), t))
continue
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import cv2
fig = plt.figure()
# # cv2.imwrite('img.jpg',batch[0].numpy()[0].transpose((1,2,0)))
# # cv2.imwrite('bmap.jpg',batch[1].numpy()[0])
# # cv2.imwrite('bmask.jpg',batch[2].numpy()[0])
# # cv2.imwrite('smap.jpg',batch[3].numpy()[0])
# # cv2.imwrite('smask.jpg',batch[4].numpy()[0])
plt.title('img')
plt.imshow(batch[0].numpy()[0].transpose((1, 2, 0)))
# plt.figure()
# plt.title('bmap')
# plt.imshow(batch[1].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('bmask')
# plt.imshow(batch[2].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('smap')
# plt.imshow(batch[3].numpy()[0],cmap='Greys')
# plt.figure()
# plt.title('smask')
# plt.imshow(batch[4].numpy()[0],cmap='Greys')
# plt.show()
# break
if __name__ == '__main__':
test_loader()
#Distribute data to single card
batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=num_workers,
return_list=True)
return data_loader
#return data_loader, _dataset.info_dict
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import os
import lmdb
import random
import signal
import paddle
from paddle.io import Dataset, DataLoader, DistributedBatchSampler, BatchSampler
from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
def term_mp(sig_num, frame):
""" kill all child processes
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
class ModeException(Exception):
"""
ModeException
"""
def __init__(self, message='', mode=''):
message += "\nOnly the following 3 modes are supported: " \
"train, valid, test. Given mode is {}".format(mode)
super(ModeException, self).__init__(message)
class SampleNumException(Exception):
"""
SampleNumException
"""
def __init__(self, message='', sample_num=0, batch_size=1):
message += "\nError: The number of the whole data ({}) " \
"is smaller than the batch_size ({}), and drop_last " \
"is turnning on, so nothing will feed in program, " \
"Terminated now. Please reset batch_size to a smaller " \
"number or feed more data!".format(sample_num, batch_size)
super(SampleNumException, self).__init__(message)
def get_file_list(file_list, data_dir, delimiter='\t'):
"""
read label list from file and shuffle the list
Args:
params(dict):
"""
if isinstance(file_list, str):
file_list = [file_list]
data_source_list = []
for file in file_list:
with open(file) as f:
full_lines = [line.strip() for line in f]
for line in full_lines:
try:
img_path, label = line.split(delimiter)
except:
logger = get_logger()
logger.warning('label error in {}'.format(line))
img_path = os.path.join(data_dir, img_path)
data = {'img_path': img_path, 'label': label}
data_source_list.append(data)
return data_source_list
class LMDBDateSet(Dataset):
def __init__(self, config, global_config):
super(LMDBDateSet, self).__init__()
self.data_list = self.load_lmdb_dataset(
config['file_list'], global_config['max_text_length'])
random.shuffle(self.data_list)
self.ops = create_operators(config['transforms'], global_config)
# for rec
character = ''
for op in self.ops:
if hasattr(op, 'character'):
character = getattr(op, 'character')
self.info_dict = {'character': character}
def load_lmdb_dataset(self, data_dir, max_text_length):
self.env = lmdb.open(
data_dir,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not self.env:
print('cannot create lmdb from %s' % (data_dir))
exit(0)
filtered_index_list = []
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
self.nSamples = nSamples
for index in range(self.nSamples):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
if len(label) > max_text_length:
# print(f'The length of the label is longer than max_length: length
# {len(label)}, {label} in dataset {self.root}')
continue
# By default, images containing characters which are not in opt.character are filtered.
# You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
filtered_index_list.append(index)
return filtered_index_list
def print_lmdb_sets_info(self, lmdb_sets):
lmdb_info_strs = []
for dataset_idx in range(len(lmdb_sets)):
tmp_str = " %s:%d," % (lmdb_sets[dataset_idx]['dirpath'],
lmdb_sets[dataset_idx]['num_samples'])
lmdb_info_strs.append(tmp_str)
lmdb_info_strs = ''.join(lmdb_info_strs)
logger = get_logger()
logger.info("DataSummary:" + lmdb_info_strs)
return
def __getitem__(self, idx):
idx = self.data_list[idx]
with self.env.begin(write=False) as txn:
label_key = 'label-%09d'.encode() % idx
label = txn.get(label_key)
if label is not None:
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % idx
imgbuf = txn.get(img_key)
data = {'image': imgbuf, 'label': label}
outs = transform(data, self.ops)
else:
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_list)
class SimpleDataSet(Dataset):
def __init__(self, config, global_config):
super(SimpleDataSet, self).__init__()
delimiter = config.get('delimiter', '\t')
self.data_list = get_file_list(config['file_list'], config['data_dir'],
delimiter)
random.shuffle(self.data_list)
self.ops = create_operators(config['transforms'], global_config)
# for rec
character = ''
for op in self.ops:
if hasattr(op, 'character'):
character = getattr(op, 'character')
self.info_dict = {'character': character}
def __getitem__(self, idx):
data = copy.deepcopy(self.data_list[idx])
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_list)
class BatchBalancedDataLoader(object):
def __init__(self,
dataset_list: list,
ratio_list: list,
distributed,
device,
loader_args: dict):
"""
对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的
:param dataset_list: 数据集列表
:param ratio_list: 比例列表
:param loader_args: dataloader的配置
"""
assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list)
self.dataset_len = 0
self.data_loader_list = []
self.dataloader_iter_list = []
all_batch_size = loader_args.pop('batch_size')
batch_size_list = list(
map(int, [max(1.0, all_batch_size * x) for x in ratio_list]))
remain_num = all_batch_size - sum(batch_size_list)
batch_size_list[np.argmax(ratio_list)] += remain_num
for _dataset, _batch_size in zip(dataset_list, batch_size_list):
if distributed:
batch_sampler_class = DistributedBatchSampler
else:
batch_sampler_class = BatchSampler
batch_sampler = batch_sampler_class(
dataset=_dataset,
batch_size=_batch_size,
shuffle=loader_args['shuffle'],
drop_last=loader_args['drop_last'], )
_data_loader = DataLoader(
dataset=_dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=loader_args['num_workers'],
return_list=True, )
self.data_loader_list.append(_data_loader)
self.dataloader_iter_list.append(iter(_data_loader))
self.dataset_len += len(_dataset)
def __iter__(self):
return self
def __len__(self):
return min([len(x) for x in self.data_loader_list])
def __next__(self):
batch = []
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
try:
_batch_i = next(data_loader_iter)
batch.append(_batch_i)
except StopIteration:
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
_batch_i = next(self.dataloader_iter_list[i])
batch.append(_batch_i)
except ValueError:
pass
if len(batch) > 0:
batch_list = []
batch_item_size = len(batch[0])
for i in range(batch_item_size):
cur_item_list = [batch_i[i] for batch_i in batch]
batch_list.append(paddle.concat(cur_item_list, axis=0))
else:
batch_list = batch[0]
return batch_list
def fill_batch(batch):
"""
2020.09.08: The current paddle version only supports returning data with the same length.
Therefore, fill in the batches with inconsistent lengths.
this method is currently only useful for text detection
"""
keys = list(range(len(batch[0])))
v_max_len_dict = {}
for k in keys:
v_max_len_dict[k] = max([len(item[k]) for item in batch])
for item in batch:
length = []
for k in keys:
v = item[k]
length.append(len(v))
assert isinstance(v, np.ndarray)
if len(v) == v_max_len_dict[k]:
continue
try:
tmp_shape = [v_max_len_dict[k] - len(v)] + list(v[0].shape)
except:
a = 1
tmp_array = np.zeros(tmp_shape, dtype=v[0].dtype)
new_array = np.concatenate([v, tmp_array])
item[k] = new_array
item.append(length)
return batch
......@@ -148,6 +148,8 @@ class CTCLabelEncode(BaseRecLabelEncode):
text = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
......
......@@ -29,7 +29,7 @@ class MakeBorderMap(object):
self.thresh_min = thresh_min
self.thresh_max = thresh_max
def __call__(self, data: dict) -> dict:
def __call__(self, data):
img = data['image']
text_polys = data['polys']
......
......@@ -99,7 +99,7 @@ class ToCHWImage(object):
return data
class keepKeys(object):
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
......
......@@ -50,16 +50,14 @@ class RecResizeImg(object):
image_shape,
infer_mode=False,
character_type='ch',
use_tps=False,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_type = character_type
self.use_tps = use_tps
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_type == "ch" and not self.use_tps:
if self.infer_mode and self.character_type == "ch":
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import os
import random
import paddle
from paddle.io import Dataset
import time
import lmdb
import cv2
from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
logger = get_logger()
class LMDBDateSet(Dataset):
def __init__(self, config, mode):
super(LMDBDateSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
logger.info("Initialize indexs of datasets:%s" % data_dir)
self.data_idx_order_list = self.dataset_traversal()
if self.do_shuffle:
np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config)
# # for rec
# character = ''
# for op in self.ops:
# if hasattr(op, 'character'):
# character = getattr(op, 'character')
# self.info_dict = {'character': character}
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
if not dirnames:
env = lmdb.open(
dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
txn = env.begin(write=False)
num_samples = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
"txn":txn, "num_samples":num_samples}
dataset_idx += 1
return lmdb_sets
def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets)
total_sample_num = 0
for lno in range(lmdb_num):
total_sample_num += self.lmdb_sets[lno]['num_samples']
data_idx_order_list = np.zeros((total_sample_num, 2))
beg_idx = 0
for lno in range(lmdb_num):
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
end_idx = beg_idx + tmp_sample_num
data_idx_order_list[beg_idx:end_idx, 0] = lno
data_idx_order_list[beg_idx:end_idx, 1] \
= list(range(tmp_sample_num))
data_idx_order_list[beg_idx:end_idx, 1] += 1
beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list
def get_img_data(self, value):
"""get_img_data"""
if not value:
return None
imgdata = np.frombuffer(value, dtype='uint8')
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
if imgori is None:
return None
return imgori
def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
return imgbuf, label
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info
data = {'image': img, 'label': label}
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return self.data_idx_order_list.shape[0]
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import os
import random
import paddle
from paddle.io import Dataset
import time
from .imaug import transform, create_operators
from ppocr.utils.logging import get_logger
logger = get_logger()
class SimpleDataSet(Dataset):
def __init__(self, config, mode):
super(SimpleDataSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
if data_source_num == 1:
ratio_list = [1.0]
else:
ratio_list = dataset_config.pop('ratio_list')
assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1."
assert len(ratio_list) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines_list, data_num_list = self.get_image_info_list(
label_file_list)
self.data_idx_order_list = self.dataset_traversal(
data_num_list, ratio_list, batch_size)
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def get_image_info_list(self, file_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines_list = []
data_num_list = []
for file in file_list:
with open(file, "rb") as f:
lines = f.readlines()
data_lines_list.append(lines)
data_num_list.append(len(lines))
return data_lines_list, data_num_list
def dataset_traversal(self, data_num_list, ratio_list, batch_size):
select_num_list = []
dataset_num = len(data_num_list)
for dno in range(dataset_num):
select_num = round(batch_size * ratio_list[dno])
select_num = max(select_num, 1)
select_num_list.append(select_num)
data_idx_order_list = []
cur_index_sets = [0] * dataset_num
while True:
finish_read_num = 0
for dataset_idx in range(dataset_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
finish_read_num += 1
else:
select_num = select_num_list[dataset_idx]
for sno in range(select_num):
cur_index = cur_index_sets[dataset_idx]
if cur_index >= data_num_list[dataset_idx]:
break
data_idx_order_list.append((
dataset_idx, cur_index))
cur_index_sets[dataset_idx] += 1
if finish_read_num == dataset_num:
break
return data_idx_order_list
def shuffle_data_random(self):
if self.do_shuffle:
for dno in range(len(self.data_lines_list)):
random.shuffle(self.data_lines_list[dno])
return
def __getitem__(self, idx):
dataset_idx, file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines_list[dataset_idx][file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from .losses import build_loss
__all__ = ['build_model', 'build_loss']
def build_model(config):
from .architectures import Model
config = copy.deepcopy(config)
module_class = Model(config)
return module_class
......@@ -12,5 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .model import Model
__all__ = ['Model']
\ No newline at end of file
import copy
__all__ = ['build_model']
def build_model(config):
from .base_model import BaseModel
config = copy.deepcopy(config)
module_class = BaseModel(config)
return module_class
\ No newline at end of file
......@@ -15,38 +15,29 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append('/home/zhoujun20/PaddleOCR')
from paddle import nn
from ppocr.modeling.transform import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
__all__ = ['Model']
__all__ = ['BaseModel']
class Model(nn.Layer):
class BaseModel(nn.Layer):
def __init__(self, config):
"""
Detection module for OCR.
the module for OCR.
args:
config (dict): the super parameters for module.
"""
super(Model, self).__init__()
algorithm = config['algorithm']
self.type = config['type']
self.model_name = '{}_{}'.format(self.type, algorithm)
super(BaseModel, self).__init__()
in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom,
# for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None,
# if you make model differently, you can use transfrom in det and cls
# if you make model differently, you can use transfrom in det and cls
if 'Transform' not in config or config['Transform'] is None:
self.use_transform = False
else:
......@@ -57,9 +48,9 @@ class Model(nn.Layer):
# build backbone, backbone is need for del, rec and cls
config["Backbone"]['in_channels'] = in_channels
self.backbone = build_backbone(config["Backbone"], self.type)
self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
......@@ -71,6 +62,7 @@ class Model(nn.Layer):
config['Neck']['in_channels'] = in_channels
self.neck = build_neck(config['Neck'])
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
......
......@@ -19,7 +19,6 @@ def build_backbone(config, model_type):
if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec':
from .rec_mobilenet_v3 import MobileNetV3
......
......@@ -130,7 +130,6 @@ class MobileNetV3(nn.Layer):
if_act=True,
act='hard_swish',
name='conv_last'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
......@@ -275,4 +274,4 @@ class SEModule(nn.Layer):
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = F.hard_sigmoid(outputs)
return inputs * outputs
return inputs * outputs
\ No newline at end of file
......@@ -20,8 +20,8 @@ def build_head(config):
from .det_db_head import DBHead
# rec head
from .rec_ctc_head import CTC
support_dict = ['DBHead', 'CTC']
from .rec_ctc_head import CTCHead
support_dict = ['DBHead', 'CTCHead']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
......
......@@ -33,10 +33,9 @@ def get_para_bias_attr(l2_decay, k, name):
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
return [weight_attr, bias_attr]
class CTC(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=1e-5, **kwargs):
super(CTC, self).__init__()
class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
self.fc = nn.Linear(
......
......@@ -14,11 +14,10 @@
__all__ = ['build_neck']
def build_neck(config):
from .fpn import FPN
from .db_fpn import DBFPN
from .rnn import SequenceEncoder
support_dict = ['FPN', 'SequenceEncoder']
support_dict = ['DBFPN', 'SequenceEncoder']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
......@@ -22,9 +22,9 @@ import paddle.nn.functional as F
from paddle import ParamAttr
class FPN(nn.Layer):
class DBFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
super(FPN, self).__init__()
super(DBFPN, self).__init__()
self.out_channels = out_channels
weight_attr = paddle.nn.initializer.MSRA(uniform=False)
......
......@@ -76,8 +76,7 @@ class SequenceEncoder(nn.Layer):
'fc': EncoderWithFC,
'rnn': EncoderWithRNN
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys())
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size)
......
......@@ -50,6 +50,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step3 build optimizer
optim_name = config.pop('name')
# Regularization is invalid. The bug will be fixed in paddle-rc. The param is
# weight_decay.
optim = getattr(optimizer, optim_name)(learning_rate=lr,
regularization=reg,
**config)
......
......@@ -40,8 +40,8 @@ class Momentum(object):
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
parameters=self.weight_decay,
weight_decay=parameters)
parameters=parameters,
weight_decay=self.weight_decay)
return opt
......
......@@ -24,8 +24,8 @@ __all__ = ['build_post_process']
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
support_dict = ['DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode']
config = copy.deepcopy(config)
......
......@@ -46,7 +46,7 @@ def load_dygraph_pretrain(
model,
logger,
path=None,
load_static_weights=False, ):
load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
......@@ -110,21 +110,20 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_static_weights = gloabl_config.get('load_static_weights', False)
if pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(
pretrained_model)
for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx]
load_dygraph_pretrain(
model,
logger,
path=pretrained,
load_static_weights=load_static)
logger.info("load pretrained model from {}".format(
pretrained_model))
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(
pretrained_model)
for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx]
load_dygraph_pretrain(
model,
logger,
path=pretrained,
load_static_weights=load_static)
logger.info("load pretrained model from {}".format(
pretrained_model))
else:
logger.info('train from scratch')
return best_model_dict
......
......@@ -28,7 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
class ArgsParser(ArgumentParser):
def __init__(self):
......@@ -136,18 +139,18 @@ def check_gpu(use_gpu):
def train(config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
train_dataloader,
valid_dataloader,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
vdl_writer=None):
global_step = 0
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
......@@ -156,6 +159,7 @@ def train(config,
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
global_step = 0
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
......@@ -179,14 +183,15 @@ def train(config,
start_epoch = 0
for epoch in range(start_epoch, epoch_num):
if epoch > 0:
train_loader = build_dataloader(config, 'Train', device)
for idx, batch in enumerate(train_dataloader):
if idx >= len(train_dataloader):
break
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
lr = optimizer.get_lr()
t1 = time.time()
batch = [paddle.to_variable(x) for x in batch]
batch = [paddle.to_tensor(x) for x in batch]
images = batch[0]
preds = model(images)
loss = loss_class(preds, batch)
......@@ -199,6 +204,8 @@ def train(config,
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
# logger and visualdl
stats = {k: v.numpy().mean() for k, v in loss.items()}
......@@ -228,8 +235,8 @@ def train(config,
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class)
cur_metirc = eval(model, valid_dataloader,
post_process_class, eval_class, logger, print_batch_step)
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str)
......@@ -291,12 +298,14 @@ def train(config,
return
def eval(model, valid_dataloader, post_process_class, eval_class):
def eval(model, valid_dataloader,
post_process_class, eval_class,
logger, print_batch_step):
model.eval()
with paddle.no_grad():
total_frame = 0.0
total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model: ')
# pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader):
break
......@@ -310,11 +319,14 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
total_time += time.time() - start
# Evaluate the results of the current batch
eval_class(post_result, batch)
pbar.update(1)
# pbar.update(1)
total_frame += len(images)
if idx % print_batch_step == 0:
logger.info('tackling images for eval: {}/{}'.format(
idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric()
pbar.close()
# pbar.close()
model.train()
metirc['fps'] = total_frame / total_time
return metirc
......@@ -336,4 +348,25 @@ def preprocess():
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
return device, config
config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True)
vdl_writer = LogWriter(logdir=vdl_writer_path)
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, vdl_writer
......@@ -31,7 +31,8 @@ paddle.manual_seed(2)
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
from ppocr.modeling import build_model, build_loss
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
......@@ -48,95 +49,76 @@ def main(config, device, logger, vdl_writer):
dist.init_parallel_env()
global_config = config['Global']
# build dataloader
train_loader, train_info_dict = build_dataloader(
config['TRAIN'], device, global_config['distributed'], global_config)
if config['EVAL']:
eval_loader, _ = build_dataloader(config['EVAL'], device, False,
global_config)
train_dataloader = build_dataloader(config, 'Train', device)
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device)
else:
eval_loader = None
valid_dataloader = None
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
post_process_class = build_post_process(
config['PostProcess'], global_config)
# build model
# for rec algorithm
#for rec algorithm
if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character'))
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# build loss
loss_class = build_loss(config['Loss'])
# build optim
optimizer, lr_scheduler = build_optimizer(
config['Optimizer'],
optimizer, lr_scheduler = build_optimizer(config['Optimizer'],
epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_loader),
step_each_epoch=len(train_dataloader),
parameters=model.parameters())
best_model_dict = init_model(config, model, logger, optimizer)
# build loss
loss_class = build_loss(config['Loss'])
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
# start train
program.train(config, model, loss_class, optimizer, lr_scheduler,
train_loader, eval_loader, post_process_class, eval_class,
best_model_dict, logger, vdl_writer)
def test_reader(config, place, logger, global_config):
train_loader, _ = build_dataloader(
config['TRAIN'], place, global_config=global_config)
program.train(config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
vdl_writer)
def test_reader(config, device, logger):
loader = build_dataloader(config, 'Train', device)
# loader = build_dataloader(config, 'Eval', device)
import time
starttime = time.time()
count = 0
try:
for data in train_loader:
for data in loader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
logger.info("reader: {}, {}, {}".format(
count, len(data[0]), batch_time))
logger.info("reader: {}, {}, {}".format(count, len(data), batch_time))
except Exception as e:
import traceback
traceback.print_exc()
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))
def dis_main():
device, config = program.preprocess()
config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config
os.makedirs(config['Global']['save_model_dir'], exist_ok=True)
with open(
os.path.join(config['Global']['save_model_dir'], 'config.yml'),
'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(
log_file='{}/train.log'.format(config['Global']['save_model_dir']))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer = LogWriter(logdir=config['Global']['save_model_dir'])
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger, config['Global'])
if __name__ == '__main__':
# main()
# dist.spawn(dis_main, nprocs=2, selelcted_gpus='6,7')
dis_main()
config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册