提交 e823f178 编写于 作者: G gaotingquan 提交者: Tingquan Gao

feat: support training image orientation model

add the config of orientation model
add the preprocess op RandomRot90 that can rotate the img and return the orientation
add the CustomLabelDataset that support getting label by preprocess
refactor some preprocess ops to support dict parameter and return dict
上级 59a6cfc3
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 10
eval_during_train: True
eval_interval: 10
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: PPLCNet_x1_0
pretrained: True
use_ssld: True
class_num: 4
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.56
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: CustomLabelDataset
image_root: ./dataset/OrientationDataset/
sample_list_path: ./dataset/OrientationDataset/train_list.txt
label_key: random_rot90_orientation
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- TimmAutoAugment:
prob: 0.0
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.0
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
- RandomRot90:
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 16
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/OrientationDataset/
cls_label_path: ./dataset/OrientationDataset/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./test_img/
batch_size: 1
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 1
Metric:
Train:
- TopkAcc:
topk: [1]
Eval:
- TopkAcc:
topk: [1]
...@@ -31,7 +31,7 @@ from ppcls.data.dataloader.mix_dataset import MixDataset ...@@ -31,7 +31,7 @@ from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
# sampler # sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
......
...@@ -11,3 +11,4 @@ from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler ...@@ -11,3 +11,4 @@ from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data.dataloader.pk_sampler import PKSampler from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDataset from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
...@@ -51,7 +51,6 @@ class CommonDataset(Dataset): ...@@ -51,7 +51,6 @@ class CommonDataset(Dataset):
label_ratio=False): label_ratio=False):
self._img_root = image_root self._img_root = image_root
self._cls_path = cls_label_path self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops) self._transform_ops = create_operators(transform_ops)
self.images = [] self.images = []
...@@ -84,4 +83,4 @@ class CommonDataset(Dataset): ...@@ -84,4 +83,4 @@ class CommonDataset(Dataset):
@property @property
def class_num(self): def class_num(self):
return len(set(self.labels)) return len(set(self.images))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import CommonDataset
class CustomLabelDataset(CommonDataset):
"""CustomLabelDataset
Args:
image_root (str): image root, path to `ILSVRC2012`
sample_list_path (str): path to the file with samples listed.
transform_ops (list, optional): list of transform op(s). Defaults to None.
label_key (str, optional): Defaults to None.
delimiter (str, optional): delimiter. Defaults to None.
"""
def __init__(self,
image_root,
sample_list_path,
transform_ops=None,
label_key=None,
delimiter=None):
self.delimiter = delimiter
super().__init__(image_root, sample_list_path, transform_ops)
if self._transform_ops is None and label_key is not None:
label_key = None
msg = ""
logger.warning(msg)
self.label_key = label_key
def _load_anno(self, seed=None):
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
self.images = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for line in lines:
line = line.strip()
if self.delimiter is not None:
line = line.split(self.delimiter)[0]
self.images.append(os.path.join(self._img_root, line))
assert os.path.exists(self.images[
-1]), f"path {self.images[-1]} does not exist."
def __getitem__(self, idx):
try:
with open(self.images[idx], 'rb') as f:
img = f.read()
if self._transform_ops:
processed_sample = transform({"img": img}, self._transform_ops)
img = processed_sample["img"].transpose((2, 0, 1))
if self.label_key is not None:
label = processed_sample[self.label_key]
sample = (img, label)
return sample
return (img)
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
...@@ -40,6 +40,7 @@ from ppcls.data.preprocess.ops.operators import ColorJitter ...@@ -40,6 +40,7 @@ from ppcls.data.preprocess.ops.operators import ColorJitter
from ppcls.data.preprocess.ops.operators import RandomCropImage from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
...@@ -101,7 +102,8 @@ class TimmAutoAugment(RawTimmAutoAugment): ...@@ -101,7 +102,8 @@ class TimmAutoAugment(RawTimmAutoAugment):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.prob = prob self.prob = prob
def __call__(self, img): def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
img = Image.fromarray(img) img = Image.fromarray(img)
...@@ -109,5 +111,9 @@ class TimmAutoAugment(RawTimmAutoAugment): ...@@ -109,5 +111,9 @@ class TimmAutoAugment(RawTimmAutoAugment):
img = super().__call__(img) img = super().__call__(img)
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.asarray(img) img = np.asarray(img)
processed_data = {
return img **
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
...@@ -161,7 +161,8 @@ class DecodeImage(object): ...@@ -161,7 +161,8 @@ class DecodeImage(object):
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}." f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
) )
def __call__(self, img): def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
assert self.backend == "pil", "invalid input 'img' in DecodeImage" assert self.backend == "pil", "invalid input 'img' in DecodeImage"
elif isinstance(img, np.ndarray): elif isinstance(img, np.ndarray):
...@@ -188,8 +189,12 @@ class DecodeImage(object): ...@@ -188,8 +189,12 @@ class DecodeImage(object):
if self.channel_first: if self.channel_first:
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
processed_data = {
return img **
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ResizeImage(object): class ResizeImage(object):
...@@ -416,7 +421,8 @@ class RandCropImage(object): ...@@ -416,7 +421,8 @@ class RandCropImage(object):
self._resize_func = UnifiedResize( self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend) interpolation=interpolation, backend=backend)
def __call__(self, img): def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
size = self.size size = self.size
scale = self.scale scale = self.scale
ratio = self.ratio ratio = self.ratio
...@@ -440,9 +446,13 @@ class RandCropImage(object): ...@@ -440,9 +446,13 @@ class RandCropImage(object):
i = random.randint(0, img_w - w) i = random.randint(0, img_w - w)
j = random.randint(0, img_h - h) j = random.randint(0, img_h - h)
img = img[j:j + h, i:i + w, :] img = self._resize_func(img[j:j + h, i:i + w, :], size)
processed_data = {
return self._resize_func(img, size) **
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class RandCropImageV2(object): class RandCropImageV2(object):
...@@ -547,7 +557,8 @@ class NormalizeImage(object): ...@@ -547,7 +557,8 @@ class NormalizeImage(object):
self.mean = np.array(mean).reshape(shape).astype('float32') self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32') self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, img): def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
from PIL import Image from PIL import Image
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.array(img) img = np.array(img)
...@@ -567,7 +578,14 @@ class NormalizeImage(object): ...@@ -567,7 +578,14 @@ class NormalizeImage(object):
(img, pad_zeros), axis=0) (img, pad_zeros), axis=0)
if self.order == 'chw' else np.concatenate( if self.order == 'chw' else np.concatenate(
(img, pad_zeros), axis=2)) (img, pad_zeros), axis=2))
return img.astype(self.output_dtype)
img = img.astype(self.output_dtype)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ToCHWImage(object): class ToCHWImage(object):
...@@ -745,3 +763,24 @@ class Pad(object): ...@@ -745,3 +763,24 @@ class Pad(object):
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=(self.fill, self.fill, self.fill)) value=(self.fill, self.fill, self.fill))
return img return img
class RandomRot90(object):
"""RandomRot90
"""
def __init__(self):
pass
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
orientation = random.choice([0, 1, 2, 3])
if orientation:
img = np.rot90(img, orientation)
processed_data = {
**
ori_data,
"img": img,
"random_rot90_orientation": orientation
} if isinstance(ori_data, dict) else img
return processed_data
...@@ -70,9 +70,11 @@ class RandomErasing(object): ...@@ -70,9 +70,11 @@ class RandomErasing(object):
self.attempt = attempt self.attempt = attempt
self.get_pixels = Pixels(mode, mean) self.get_pixels = Pixels(mode, mean)
def __call__(self, img): def __call__(self, ori_data):
if random.random() > self.EPSILON: if random.random() > self.EPSILON:
return img return ori_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
for _ in range(self.attempt): for _ in range(self.attempt):
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
...@@ -105,5 +107,16 @@ class RandomErasing(object): ...@@ -105,5 +107,16 @@ class RandomErasing(object):
img[0, x1:x1 + h, y1:y1 + w] = pixels[0] img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
else: else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0] img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
return img processed_data = {
return img **
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册