提交 e823f178 编写于 作者: G gaotingquan 提交者: Tingquan Gao

feat: support training image orientation model

add the config of orientation model
add the preprocess op RandomRot90 that can rotate the img and return the orientation
add the CustomLabelDataset that support getting label by preprocess
refactor some preprocess ops to support dict parameter and return dict
上级 59a6cfc3
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 10
eval_during_train: True
eval_interval: 10
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: PPLCNet_x1_0
pretrained: True
use_ssld: True
class_num: 4
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.56
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: CustomLabelDataset
image_root: ./dataset/OrientationDataset/
sample_list_path: ./dataset/OrientationDataset/train_list.txt
label_key: random_rot90_orientation
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- TimmAutoAugment:
prob: 0.0
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.0
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
- RandomRot90:
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 16
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/OrientationDataset/
cls_label_path: ./dataset/OrientationDataset/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./test_img/
batch_size: 1
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 1
Metric:
Train:
- TopkAcc:
topk: [1]
Eval:
- TopkAcc:
topk: [1]
......@@ -31,7 +31,7 @@ from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
# sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
......
......@@ -11,3 +11,4 @@ from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
......@@ -51,8 +51,7 @@ class CommonDataset(Dataset):
label_ratio=False):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._transform_ops = create_operators(transform_ops)
self.images = []
self.labels = []
......@@ -84,4 +83,4 @@ class CommonDataset(Dataset):
@property
def class_num(self):
return len(set(self.labels))
return len(set(self.images))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import CommonDataset
class CustomLabelDataset(CommonDataset):
"""CustomLabelDataset
Args:
image_root (str): image root, path to `ILSVRC2012`
sample_list_path (str): path to the file with samples listed.
transform_ops (list, optional): list of transform op(s). Defaults to None.
label_key (str, optional): Defaults to None.
delimiter (str, optional): delimiter. Defaults to None.
"""
def __init__(self,
image_root,
sample_list_path,
transform_ops=None,
label_key=None,
delimiter=None):
self.delimiter = delimiter
super().__init__(image_root, sample_list_path, transform_ops)
if self._transform_ops is None and label_key is not None:
label_key = None
msg = ""
logger.warning(msg)
self.label_key = label_key
def _load_anno(self, seed=None):
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
self.images = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for line in lines:
line = line.strip()
if self.delimiter is not None:
line = line.split(self.delimiter)[0]
self.images.append(os.path.join(self._img_root, line))
assert os.path.exists(self.images[
-1]), f"path {self.images[-1]} does not exist."
def __getitem__(self, idx):
try:
with open(self.images[idx], 'rb') as f:
img = f.read()
if self._transform_ops:
processed_sample = transform({"img": img}, self._transform_ops)
img = processed_sample["img"].transpose((2, 0, 1))
if self.label_key is not None:
label = processed_sample[self.label_key]
sample = (img, label)
return sample
return (img)
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
......@@ -40,6 +40,7 @@ from ppcls.data.preprocess.ops.operators import ColorJitter
from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
......@@ -101,7 +102,8 @@ class TimmAutoAugment(RawTimmAutoAugment):
super().__init__(*args, **kwargs)
self.prob = prob
def __call__(self, img):
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
......@@ -109,5 +111,9 @@ class TimmAutoAugment(RawTimmAutoAugment):
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
......@@ -161,7 +161,8 @@ class DecodeImage(object):
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
)
def __call__(self, img):
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
if isinstance(img, Image.Image):
assert self.backend == "pil", "invalid input 'img' in DecodeImage"
elif isinstance(img, np.ndarray):
......@@ -188,8 +189,12 @@ class DecodeImage(object):
if self.channel_first:
img = img.transpose((2, 0, 1))
return img
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ResizeImage(object):
......@@ -416,7 +421,8 @@ class RandCropImage(object):
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
def __call__(self, img):
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
size = self.size
scale = self.scale
ratio = self.ratio
......@@ -440,9 +446,13 @@ class RandCropImage(object):
i = random.randint(0, img_w - w)
j = random.randint(0, img_h - h)
img = img[j:j + h, i:i + w, :]
return self._resize_func(img, size)
img = self._resize_func(img[j:j + h, i:i + w, :], size)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class RandCropImageV2(object):
......@@ -547,7 +557,8 @@ class NormalizeImage(object):
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, img):
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
......@@ -567,7 +578,14 @@ class NormalizeImage(object):
(img, pad_zeros), axis=0)
if self.order == 'chw' else np.concatenate(
(img, pad_zeros), axis=2))
return img.astype(self.output_dtype)
img = img.astype(self.output_dtype)
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
class ToCHWImage(object):
......@@ -745,3 +763,24 @@ class Pad(object):
cv2.BORDER_CONSTANT,
value=(self.fill, self.fill, self.fill))
return img
class RandomRot90(object):
"""RandomRot90
"""
def __init__(self):
pass
def __call__(self, ori_data):
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
orientation = random.choice([0, 1, 2, 3])
if orientation:
img = np.rot90(img, orientation)
processed_data = {
**
ori_data,
"img": img,
"random_rot90_orientation": orientation
} if isinstance(ori_data, dict) else img
return processed_data
......@@ -70,9 +70,11 @@ class RandomErasing(object):
self.attempt = attempt
self.get_pixels = Pixels(mode, mean)
def __call__(self, img):
def __call__(self, ori_data):
if random.random() > self.EPSILON:
return img
return ori_data
img = ori_data["img"] if isinstance(ori_data, dict) else ori_data
for _ in range(self.attempt):
if isinstance(img, np.ndarray):
......@@ -105,5 +107,16 @@ class RandomErasing(object):
img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
return img
return img
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
processed_data = {
**
ori_data,
"img": img
} if isinstance(ori_data, dict) else img
return processed_data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册