提交 08320527 编写于 作者: C cuicheng01

add clarity_assessment code

上级 b3ab418e
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 10
eval_during_train: True
eval_interval: 1
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: PPLCNet_x1_0
pretrained: True
use_ssld: True
class_num: 2
use_last_conv: False
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.14
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: CustomLabelDataset
image_root: ./dataset/
sample_list_path: ./dataset/ImageNet_OCR_det.txt
label_key: blur_image
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- BlurImage:
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: True
loader:
num_workers: 12
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/blur/
cls_label_path: ./dataset/blur/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./test_img/
batch_size: 1
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 1
Metric:
Train:
- TopkAcc:
topk: [1]
Eval:
- TopkAcc:
topk: [1]
......@@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90
from ppcls.data.preprocess.ops.operators import BlurImage
from .ops.operators import format_data
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
......
......@@ -718,8 +718,8 @@ class Pad(object):
# Process fill color for affine transforms
major_found, minor_found = (int(v)
for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (int(v) for v in
min_pil_version.split('.')[:2])
major_required, minor_required = (
int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and
minor_found < minor_required):
if fill is None:
......@@ -781,3 +781,54 @@ class RandomRot90(object):
if orientation:
img = np.rot90(img, orientation)
return {"img": img, "random_rot90_orientation": orientation}
class BlurImage(object):
"""BlurImage
"""
def __init__(self,
ratio=0.5,
motion_max_ksize=12,
motion_max_angle=45,
gaussian_max_ksize=12):
self.ratio = ratio
self.motion_max_ksize = motion_max_ksize
self.motion_max_angle = motion_max_angle
self.gaussian_max_ksize = gaussian_max_ksize
def _gaussian_blur(self, img, max_ksize=12):
ksize = (np.random.choice(np.arange(5, max_ksize, 2)),
np.random.choice(np.arange(5, max_ksize, 2)))
img = cv2.GaussianBlur(img, ksize, 0)
return img
def _motion_blur(self, img, max_ksize=12, max_angle=45):
degree = np.random.choice(np.arange(5, max_ksize, 2))
angle = np.random.choice(np.arange(-1 * max_angle, max_angle))
M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
motion_blur_kernel = np.diag(np.ones(degree))
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
(degree, degree))
motion_blur_kernel = motion_blur_kernel / degree
blurred = cv2.filter2D(img, -1, motion_blur_kernel)
cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX)
img = np.array(blurred, dtype=np.uint8)
return img
@format_data
def __call__(self, img):
if random.random() > self.ratio:
label = 0
else:
method = random.choice(["gaussian", "motion"])
if method == "gaussian":
img = self._gaussian_blur(img, self.gaussian_max_ksize)
else:
img = self._motion_blur(img, self.motion_max_ksize,
self.motion_max_angle)
label = 1
return {"img": img, "blur_image": label}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册