未验证 提交 50c1302b 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1918 from TingquanGao/dev/multi-scale

Dev/multi scale
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 120
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture
Arch:
name: MobileNetV1
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [30, 60, 90]
values: [0.1, 0.01, 0.001, 0.0001]
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiScaleDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
# support to specify width and height respectively:
# scales: [(160,160), (192,192), (224,224) (288,288) (320,320)]
sampler:
name: MultiScaleSampler
scales: [160, 192, 224, 288, 320]
# first_bs: batch size for the first image resolution in the scales list
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs: 64
divided_factor: 32
is_training: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
......@@ -28,12 +28,15 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
# sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
......
......@@ -5,6 +5,8 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import os
from paddle.io import Dataset
from paddle.vision import transforms
import cv2
import warnings
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
from ppcls.data.preprocess.ops.operators import DecodeImage
from ppcls.utils import logger
from ppcls.data.dataloader.common_dataset import create_operators
class MultiScaleDataset(Dataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None, ):
self._img_root = image_root
self._cls_path = cls_label_path
self.transform_ops = transform_ops
self.images = []
self.labels = []
self._load_anno()
self.has_crop_flag = 1
def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
self.images = []
self.labels = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for l in lines:
l = l.strip().split(" ")
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(np.int64(l[1]))
assert os.path.exists(self.images[-1])
def __getitem__(self, properties):
# properites is a tuple, contains (width, height, index)
img_width = properties[0]
img_height = properties[1]
index = properties[2]
has_crop = False
if self.transform_ops:
for i in range(len(self.transform_ops)):
op = self.transform_ops[i]
resize_op = ['RandCropImage', 'ResizeImage', 'CropImage']
for resize in resize_op:
if resize in op:
if self.has_crop_flag:
logger.warning(
"Multi scale dataset will crop image according to the multi scale resolution"
)
self.transform_ops[i][resize] = {
'size': (img_width, img_height)
}
has_crop = True
self.has_crop_flag = 0
if has_crop == False:
logger.error("Multi scale dateset requests RandCropImage")
raise RuntimeError("Multi scale dateset requests RandCropImage")
self._transform_ops = create_operators(self.transform_ops)
try:
with open(self.images[index], 'rb') as f:
img = f.read()
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[index])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[index], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
from paddle.io import Sampler
import paddle.distributed as dist
import math
import random
import numpy as np
from ppcls import data
class MultiScaleSampler(Sampler):
def __init__(self,
data_source,
scales,
first_bs,
divided_factor=32,
is_training=True,
seed=None):
"""
multi scale samper
Args:
data_source(dataset)
scales(list): several scales for image resolution
first_bs(int): batch size for the first scale in scales
divided_factor(int): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
is_training(boolean): mode
"""
# min. and max. spatial dimensions
self.data_source = data_source
self.n_data_samples = len(self.data_source)
if isinstance(scales[0], tuple):
width_dims = [i[0] for i in scales]
height_dims = [i[1] for i in scales]
elif isinstance(scales[0], int):
width_dims = scales
height_dims = scales
base_im_w = width_dims[0]
base_im_h = height_dims[0]
base_batch_size = first_bs
# Get the GPU and node related information
num_replicas = dist.get_world_size()
rank = dist.get_rank()
# adjust the total samples to avoid batch dropping
num_samples_per_replica = int(
math.ceil(self.n_data_samples * 1.0 / num_replicas))
img_indices = [idx for idx in range(self.n_data_samples)]
self.shuffle = False
if is_training:
# compute the spatial dimensions and corresponding batch size
# ImageNet models down-sample images by a factor of 32.
# Ensure that width and height dimensions are multiples are multiple of 32.
width_dims = [
int((w // divided_factor) * divided_factor) for w in width_dims
]
height_dims = [
int((h // divided_factor) * divided_factor)
for h in height_dims
]
img_batch_pairs = list()
base_elements = base_im_w * base_im_h * base_batch_size
for (h, w) in zip(height_dims, width_dims):
batch_size = int(max(1, (base_elements / (h * w))))
img_batch_pairs.append((w, h, batch_size))
self.img_batch_pairs = img_batch_pairs
self.shuffle = True
else:
self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
self.img_indices = img_indices
self.n_samples_per_replica = num_samples_per_replica
self.epoch = 0
self.rank = rank
self.num_replicas = num_replicas
self.seed = seed
self.batch_list = []
self.current = 0
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
while self.current < self.n_samples_per_replica:
curr_w, curr_h, curr_bsz = random.choice(self.img_batch_pairs)
end_index = min(self.current + curr_bsz,
self.n_samples_per_replica)
batch_ids = indices_rank_i[self.current:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
self.current += curr_bsz
if len(batch_ids) > 0:
batch = [curr_w, curr_h, len(batch_ids)]
self.batch_list.append(batch)
self.length = len(self.batch_list)
def __iter__(self):
if self.shuffle:
if self.seed is not None:
random.seed(self.seed)
else:
random.seed(self.epoch)
random.shuffle(self.img_indices)
random.shuffle(self.img_batch_pairs)
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
else:
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
start_index = 0
for batch_tuple in self.batch_list:
curr_w, curr_h, curr_bsz = batch_tuple
end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
batch_ids = indices_rank_i[start_index:end_index]
n_batch_samples = len(batch_ids)
if n_batch_samples != curr_bsz:
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
start_index += curr_bsz
if len(batch_ids) > 0:
batch = [(curr_w, curr_h, b_id) for b_id in batch_ids]
yield batch
def set_epoch(self, epoch: int):
self.epoch = epoch
def __len__(self):
return self.length
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册