dataset.py 3.9 KB
Newer Older
G
gengdongjie 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset preprocessing."""
import os
import math as m
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as c
E
Eric 已提交
22
import mindspore.dataset.vision.c_transforms as vc
G
gengdongjie 已提交
23 24 25 26
from PIL import Image
from src.config import config as cf


Y
yangyongjie 已提交
27
class _CaptchaDataset:
G
gengdongjie 已提交
28 29 30 31 32 33
    """
    create train or evaluation dataset for warpctc

    Args:
        img_root_dir(str): root path of images
        max_captcha_digits(int): max number of digits in images.
Y
yangyongjie 已提交
34
        device_target(str): platform of training, support Ascend and GPU.
G
gengdongjie 已提交
35 36
    """

Y
yangyongjie 已提交
37
    def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'):
G
gengdongjie 已提交
38 39 40 41 42
        if not os.path.exists(img_root_dir):
            raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
        self.img_root_dir = img_root_dir
        self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
        self.max_captcha_digits = max_captcha_digits
Y
yangyongjie 已提交
43 44 45
        self.target = device_target
        self.blank = 10 if self.target == 'Ascend' else 0
        self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names]
G
gengdongjie 已提交
46 47 48 49 50 51 52 53 54 55 56 57

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, item):
        img_name = self.img_names[item]
        im = Image.open(os.path.join(self.img_root_dir, img_name))
        r, g, b = im.split()
        im = Image.merge("RGB", (b, g, r))
        image = np.array(im)
        label_str = os.path.splitext(img_name)[0]
        label_str = label_str[label_str.find('-') + 1:]
Y
yangyongjie 已提交
58 59 60 61 62 63 64 65
        if self.target == 'Ascend':
            label = [int(i) for i in label_str]
            label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
        else:
            label = [int(i) + 1 for i in label_str]
            length = len(label)
            label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
            label.append(length)
G
gengdongjie 已提交
66 67 68 69
        label = np.array(label)
        return image, label


Y
yangyongjie 已提交
70
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
G
gengdongjie 已提交
71 72 73 74 75 76
    """
     create train or evaluation dataset for warpctc

     Args:
        dataset_path(int): dataset path
        batch_size(int): batch size of generated dataset, default is 1
Y
yangyongjie 已提交
77 78 79
        num_shards(int): number of devices
        shard_id(int): rank id
        device_target(str): platform of training, support Ascend and GPU
G
gengdongjie 已提交
80 81
     """

Y
yangyongjie 已提交
82 83 84
    dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
    ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
    ds.set_dataset_size(m.ceil(len(dataset) / num_shards))
G
gengdongjie 已提交
85 86 87 88 89 90 91 92 93 94 95 96
    image_trans = [
        vc.Rescale(1.0 / 255.0, 0.0),
        vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
        vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
        vc.HWC2CHW()
    ]
    label_trans = [
        c.TypeCast(mstype.int32)
    ]
    ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans)
    ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans)

Y
yangyongjie 已提交
97
    ds = ds.batch(batch_size, drop_remainder=True)
G
gengdongjie 已提交
98
    return ds