未验证 提交 b3066812 编写于 作者: X xiaoting 提交者: GitHub

Multi scale (#9837)

* update for multi scale

* update for multi scale

* update for multi scale

* rm notes
上级 ded37403
......@@ -73,7 +73,8 @@ Metric:
Train:
dataset:
name: SimpleDataSet
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
......@@ -90,8 +91,6 @@ Train:
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
......@@ -99,11 +98,18 @@ Train:
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 128
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: 128
batch_size_per_card: *bs
drop_last: true
num_workers: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
......@@ -115,9 +121,13 @@ Eval:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
max_text_length: 100
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
use_space_char: True
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
eval_mode: True
- KeepKeys:
keep_keys:
- image
......@@ -128,5 +138,5 @@ Eval:
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
batch_size_per_card: 1
num_workers: 4
......@@ -72,9 +72,9 @@ Metric:
Train:
dataset:
name: SimpleDataSet
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
......@@ -89,8 +89,6 @@ Train:
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
......@@ -98,6 +96,14 @@ Train:
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs: &bs 128
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: 128
......@@ -114,9 +120,13 @@ Eval:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
max_text_length: 100
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
use_space_char: True
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
eval_mode: True
- KeepKeys:
keep_keys:
- image
......
......@@ -33,10 +33,11 @@ from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.simple_dataset import SimpleDataSet, MultiScaleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTableMaster
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
from ppocr.data.multi_scale_sampler import MultiScaleSampler
__all__ = ['build_dataloader', 'transform', 'create_operators']
......@@ -55,7 +56,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
'LMDBDataSetSR', 'LMDBDataSetTableMaster'
'LMDBDataSetSR', 'LMDBDataSetTableMaster', 'MultiScaleDataSet'
]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
......@@ -76,6 +77,11 @@ def build_dataloader(config, mode, device, logger, seed=None):
if mode == "Train":
# Distribute data to multiple cards
if 'sampler' in config[mode]:
config_sampler = config[mode]['sampler']
sampler_name = config_sampler.pop("name")
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
else:
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
......
......@@ -73,7 +73,7 @@ def create_operators(op_param_list, global_config=None):
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
if global_config is not None and "max_text_length" not in param:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
......
......@@ -219,17 +219,20 @@ class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
eval_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.eval_mode = eval_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_dict_path is not None:
if self.eval_mode or (self.infer_mode and
self.character_dict_path is not None):
norm_img, valid_ratio = resize_norm_img_chinese(img,
self.image_shape)
else:
......
from paddle.io import Sampler
import paddle.distributed as dist
import numpy as np
import random
import math
class MultiScaleSampler(Sampler):
def __init__(self,
data_source,
scales,
first_bs=128,
fix_bs=True,
divided_factor=[8, 16],
is_training=True,
ratio_wh=0.8,
max_w=480.,
seed=None):
"""
multi scale samper
Args:
data_source(dataset)
scales(list): several scales for image resolution
first_bs(int): batch size for the first scale in scales
divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
is_training(boolean): mode
"""
# min. and max. spatial dimensions
self.data_source = data_source
self.data_idx_order_list = np.array(data_source.data_idx_order_list)
self.ds_width = data_source.ds_width
self.seed = data_source.seed
if self.ds_width:
self.wh_ratio = data_source.wh_ratio
self.wh_ratio_sort = data_source.wh_ratio_sort
self.n_data_samples = len(self.data_source)
self.ratio_wh = ratio_wh
self.max_w = max_w
if isinstance(scales[0], list):
width_dims = [i[0] for i in scales]
height_dims = [i[1] for i in scales]
elif isinstance(scales[0], int):
width_dims = scales
height_dims = scales
base_im_w = width_dims[0]
base_im_h = height_dims[0]
base_batch_size = first_bs
# Get the GPU and node related information
num_replicas = dist.get_world_size()
rank = dist.get_rank()
# adjust the total samples to avoid batch dropping
num_samples_per_replica = int(
math.ceil(self.n_data_samples * 1.0 / num_replicas))
img_indices = [idx for idx in range(self.n_data_samples)]
self.shuffle = False
if is_training:
# compute the spatial dimensions and corresponding batch size
# ImageNet models down-sample images by a factor of 32.
# Ensure that width and height dimensions are multiples are multiple of 32.
width_dims = [
int((w // divided_factor[0]) * divided_factor[0])
for w in width_dims
]
height_dims = [
int((h // divided_factor[1]) * divided_factor[1])
for h in height_dims
]
img_batch_pairs = list()
base_elements = base_im_w * base_im_h * base_batch_size
for (h, w) in zip(height_dims, width_dims):
if fix_bs:
batch_size = base_batch_size
else:
batch_size = int(max(1, (base_elements / (h * w))))
img_batch_pairs.append((w, h, batch_size))
self.img_batch_pairs = img_batch_pairs
self.shuffle = True
else:
self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
self.img_indices = img_indices
self.n_samples_per_replica = num_samples_per_replica
self.epoch = 0
self.rank = rank
self.num_replicas = num_replicas
self.batch_list = []
self.current = 0
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
while self.current < self.n_samples_per_replica:
curr_w, curr_h, curr_bsz = random.choice(self.img_batch_pairs)
end_index = min(self.current + curr_bsz, self.n_samples_per_replica)
batch_ids = indices_rank_i[self.current:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
self.current += curr_bsz
if len(batch_ids) > 0:
batch = [curr_w, curr_h, len(batch_ids)]
self.batch_list.append(batch)
self.length = len(self.batch_list)
self.batchs_in_one_epoch = self.iter()
self.batchs_in_one_epoch_id = [
i for i in range(len(self.batchs_in_one_epoch))
]
def __iter__(self):
if self.seed is None:
random.seed(self.epoch)
self.epoch += 1
else:
random.seed(self.seed)
random.shuffle(self.batchs_in_one_epoch_id)
for batch_tuple_id in self.batchs_in_one_epoch_id:
yield self.batchs_in_one_epoch[batch_tuple_id]
def iter(self):
if self.shuffle:
if self.seed is not None:
random.seed(self.seed)
else:
random.seed(self.epoch)
if not self.ds_width:
random.shuffle(self.img_indices)
random.shuffle(self.img_batch_pairs)
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
else:
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
start_index = 0
batchs_in_one_epoch = []
for batch_tuple in self.batch_list:
curr_w, curr_h, curr_bsz = batch_tuple
end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
batch_ids = indices_rank_i[start_index:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
start_index += curr_bsz
if len(batch_ids) > 0:
if self.ds_width:
wh_ratio_current = self.wh_ratio[self.wh_ratio_sort[
batch_ids]]
ratio_current = wh_ratio_current.mean()
ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h
else:
ratio_current = None
batch = [(curr_w, curr_h, b_id, ratio_current)
for b_id in batch_ids]
# yield batch
batchs_in_one_epoch.append(batch)
return batchs_in_one_epoch
def set_epoch(self, epoch: int):
self.epoch = epoch
def __len__(self):
return self.length
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import math
import os
import json
import random
......@@ -163,3 +165,96 @@ class SimpleDataSet(Dataset):
def __len__(self):
return len(self.data_idx_order_list)
class MultiScaleDataSet(SimpleDataSet):
def __init__(self, config, mode, logger, seed=None):
super(MultiScaleDataSet, self).__init__(config, mode, logger, seed)
self.ds_width = config[mode]['dataset'].get('ds_width', False)
if self.ds_width:
self.wh_aware()
def wh_aware(self):
data_line_new = []
wh_ratio = []
for lins in self.data_lines:
data_line_new.append(lins)
lins = lins.decode('utf-8')
name, label, w, h = lins.strip("\n").split(self.delimiter)
wh_ratio.append(float(w) / float(h))
self.data_lines = data_line_new
self.wh_ratio = np.array(wh_ratio)
self.wh_ratio_sort = np.argsort(self.wh_ratio)
self.data_idx_order_list = list(range(len(self.data_lines)))
def resize_norm_img(self, data, imgW, imgH, padding=True):
img = data['image']
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
padding_im[:, :, :resized_w] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
data['image'] = padding_im
data['valid_ratio'] = valid_ratio
return data
def __getitem__(self, properties):
# properites is a tuple, contains (width, height, index)
img_height = properties[1]
idx = properties[2]
if self.ds_width and properties[3] is not None:
wh_ratio = properties[3]
img_width = img_height * (1 if int(round(wh_ratio)) == 0 else
int(round(wh_ratio)))
file_idx = self.wh_ratio_sort[idx]
else:
file_idx = self.data_idx_order_list[idx]
img_width = properties[0]
wh_ratio = None
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops[:-1])
if outs is not None:
outs = self.resize_norm_img(outs, img_width, img_height)
outs = transform(outs, self.ops[-1:])
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, traceback.format_exc()))
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = (idx + 1) % self.__len__()
return self.__getitem__([img_width, img_height, rnd_idx, wh_ratio])
return outs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册