From f2dde17623a8b05af83851272f34d7beaafdd20b Mon Sep 17 00:00:00 2001 From: sibo2rr <1415419833@qq.com> Date: Thu, 16 Dec 2021 15:25:23 +0800 Subject: [PATCH] multi scale sampler and dataset --- .../multi_scale/MobileNetV1_multi_scale.yaml | 134 ++++++++++++++++++ ppcls/data/__init__.py | 2 + ppcls/data/dataloader/__init__.py | 2 + ppcls/data/dataloader/multi_scale_dataset.py | 119 ++++++++++++++++ ppcls/data/dataloader/multi_scale_sampler.py | 105 ++++++++++++++ 5 files changed, 362 insertions(+) create mode 100644 ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml create mode 100644 ppcls/data/dataloader/multi_scale_dataset.py create mode 100644 ppcls/data/dataloader/multi_scale_sampler.py diff --git a/ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml b/ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml new file mode 100644 index 00000000..6623e4ce --- /dev/null +++ b/ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml @@ -0,0 +1,134 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + use_dali: True + +# model architecture +Arch: + name: MobileNetV1 + class_num: 100 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Piecewise + learning_rate: 0.1 + decay_epochs: [30, 60, 90] + values: [0.1, 0.01, 0.001, 0.0001] + regularizer: + name: 'L2' + coeff: 0.00003 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: MultiScaleDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: MultiScaleSamplerDDP + scales: [224, 256] + first_bs: 4 + is_training: True + + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index cffac812..50921368 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -28,11 +28,13 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.mix_dataset import MixDataset +from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset # sampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler from ppcls.data.dataloader.pk_sampler import PKSampler from ppcls.data.dataloader.mix_sampler import MixSampler +from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSamplerDDP from ppcls.data import preprocess from ppcls.data.preprocess import transform diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py index 8f819210..bfbeb40f 100644 --- a/ppcls/data/dataloader/__init__.py +++ b/ppcls/data/dataloader/__init__.py @@ -5,5 +5,7 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.mix_dataset import MixDataset +from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset from ppcls.data.dataloader.mix_sampler import MixSampler +from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSamplerDDP from ppcls.data.dataloader.pk_sampler import PKSampler diff --git a/ppcls/data/dataloader/multi_scale_dataset.py b/ppcls/data/dataloader/multi_scale_dataset.py new file mode 100644 index 00000000..9a8809f5 --- /dev/null +++ b/ppcls/data/dataloader/multi_scale_dataset.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import os + +from paddle.io import Dataset +from paddle.vision import transforms +import cv2 +import warnings + +from ppcls.data import preprocess +from ppcls.data.preprocess import transform +from ppcls.data.preprocess.ops.operators import DecodeImage +from ppcls.utils import logger + + +def create_operators(params): + """ + create operators based on the config + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance(params, list), ('operator config should be a list') + ops = [] + for operator in params: + assert isinstance(operator, + dict) and len(operator) == 1, "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + op = getattr(preprocess, op_name)(**param) + ops.append(op) + + return ops + + +class MultiScaleDataset(Dataset): + def __init__( + self, + image_root, + cls_label_path, + transform_ops=None, ): + self._img_root = image_root + self._cls_path = cls_label_path + self.transform_ops = transform_ops + # if transform_ops: + # self._transform_ops = create_operators(transform_ops) + + self.images = [] + self.labels = [] + self._load_anno() + + def _load_anno(self, seed=None): + assert os.path.exists(self._cls_path) + assert os.path.exists(self._img_root) + self.images = [] + self.labels = [] + + with open(self._cls_path) as fd: + lines = fd.readlines() + if seed is not None: + np.random.RandomState(seed).shuffle(lines) + for l in lines: + l = l.strip().split(" ") + self.images.append(os.path.join(self._img_root, l[0])) + self.labels.append(np.int64(l[1])) + assert os.path.exists(self.images[-1]) + + + def __getitem__(self, properties): + # properites is a tuple, contains (width, height, index) + img_width = properties[0] + img_height = properties[1] + index = properties[2] + has_crop = False + if self.transform_ops: + for i in range(len(self.transform_ops)): + op = self.transform_ops[i] + if 'RandCropImage' in op: + warnings.warn("Multi scale dataset will crop image according to the multi scale resolution") + self.transform_ops[i]['RandCropImage'] = {'size': img_width} + has_crop = True + if has_crop == False: + raise RuntimeError("Multi scale dateset requests RandCropImage") + self._transform_ops = create_operators(self.transform_ops) + + try: + with open(self.images[index], 'rb') as f: + img = f.read() + if self._transform_ops: + img = transform(img, self._transform_ops) + img = img.transpose((2, 0, 1)) + return (img, self.labels[index]) + + except Exception as ex: + logger.error("Exception occured when parse line: {} with msg: {}". + format(self.images[index], ex)) + rnd_idx = np.random.randint(self.__len__()) + return self.__getitem__(rnd_idx) + + def __len__(self): + return len(self.images) + + @property + def class_num(self): + return len(set(self.labels)) diff --git a/ppcls/data/dataloader/multi_scale_sampler.py b/ppcls/data/dataloader/multi_scale_sampler.py new file mode 100644 index 00000000..68011378 --- /dev/null +++ b/ppcls/data/dataloader/multi_scale_sampler.py @@ -0,0 +1,105 @@ +from paddle.io import Sampler +import paddle.distributed as dist + +import math +import random +import numpy as np + +from ppcls import data + +class MultiScaleSamplerDDP(Sampler): + def __init__(self, data_source, scales, first_bs, g): + print(scales) + # min. and max. spatial dimensions + self.data_source = data_source + self.n_data_samples = len(self.data_source) + + if isinstance(scales[0], tuple): + width_dims = [i[0] for i in scales] + height_dims = [i[1] for i in scales] + elif isinstance(scales[0], int): + width_dims = scales + height_dims = scales + base_im_w = width_dims[0] + base_im_h = height_dims[0] + base_batch_size = first_bs + + # Get the GPU and node related information + num_replicas =dist.get_world_size() + rank = dist.get_rank() + # adjust the total samples to avoid batch dropping + num_samples_per_replica = int(math.ceil(self.n_data_samples * 1.0 / num_replicas)) + img_indices = [idx for idx in range(self.n_data_samples)] + + self.shuffle = False + if is_training: + # compute the spatial dimensions and corresponding batch size + # ImageNet models down-sample images by a factor of 32. + # Ensure that width and height dimensions are multiples are multiple of 32. + width_dims = [int((w // 32) * 32) for w in width_dims] + height_dims = [int((h // 32) * 32) for h in height_dims] + + img_batch_pairs = list() + base_elements = base_im_w * base_im_h * base_batch_size + for (h, w) in zip(height_dims, width_dims): + batch_size = int(max(1, (base_elements / (h * w)))) + img_batch_pairs.append((h, w, batch_size)) + self.img_batch_pairs = img_batch_pairs + self.shuffle = True + else: + self.img_batch_pairs = [(base_im_h , base_im_w , base_batch_size)] + + self.img_indices = img_indices + self.n_samples_per_replica = num_samples_per_replica + self.epoch = 0 + self.rank = rank + self.num_replicas = num_replicas + + self.batch_list = [] + self.current = 0 + indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas] + while self.current < self.n_samples_per_replica: + curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs) + + end_index = min(self.current + curr_bsz, self.n_samples_per_replica) + + batch_ids = indices_rank_i[self.current:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != curr_bsz: + batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] + self.current += curr_bsz + + if len(batch_ids) > 0: + batch = [curr_h, curr_w, len(batch_ids)] + self.batch_list.append(batch) + self.length = len(self.batch_list) + + def __iter__(self): + if self.shuffle: + random.seed(self.epoch) + random.shuffle(self.img_indices) + random.shuffle(self.img_batch_pairs) + indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas] + else: + indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas] + + start_index = 0 + for batch_tuple in self.batch_list: + curr_h, curr_w, curr_bsz = batch_tuple + end_index = min(start_index + curr_bsz, self.n_samples_per_replica) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != curr_bsz: + batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] + start_index += curr_bsz + + if len(batch_ids) > 0: + batch = [(curr_h, curr_w, b_id) for b_id in batch_ids] + yield batch + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def __len__(self): + return self.length + -- GitLab