提交 08320527 编写于 作者: C cuicheng01

add clarity_assessment code

上级 b3ab418e
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 10
eval_during_train: True
eval_interval: 1
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: PPLCNet_x1_0
pretrained: True
use_ssld: True
class_num: 2
use_last_conv: False
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.14
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: CustomLabelDataset
image_root: ./dataset/
sample_list_path: ./dataset/ImageNet_OCR_det.txt
label_key: blur_image
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- BlurImage:
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: True
loader:
num_workers: 12
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/blur/
cls_label_path: ./dataset/blur/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./test_img/
batch_size: 1
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 1
Metric:
Train:
- TopkAcc:
topk: [1]
Eval:
- TopkAcc:
topk: [1]
...@@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage ...@@ -41,6 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90 from ppcls.data.preprocess.ops.operators import RandomRot90
from ppcls.data.preprocess.ops.operators import BlurImage
from .ops.operators import format_data from .ops.operators import format_data
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
......
...@@ -718,8 +718,8 @@ class Pad(object): ...@@ -718,8 +718,8 @@ class Pad(object):
# Process fill color for affine transforms # Process fill color for affine transforms
major_found, minor_found = (int(v) major_found, minor_found = (int(v)
for v in PILLOW_VERSION.split('.')[:2]) for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (int(v) for v in major_required, minor_required = (
min_pil_version.split('.')[:2]) int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and if major_found < major_required or (major_found == major_required and
minor_found < minor_required): minor_found < minor_required):
if fill is None: if fill is None:
...@@ -781,3 +781,54 @@ class RandomRot90(object): ...@@ -781,3 +781,54 @@ class RandomRot90(object):
if orientation: if orientation:
img = np.rot90(img, orientation) img = np.rot90(img, orientation)
return {"img": img, "random_rot90_orientation": orientation} return {"img": img, "random_rot90_orientation": orientation}
class BlurImage(object):
"""BlurImage
"""
def __init__(self,
ratio=0.5,
motion_max_ksize=12,
motion_max_angle=45,
gaussian_max_ksize=12):
self.ratio = ratio
self.motion_max_ksize = motion_max_ksize
self.motion_max_angle = motion_max_angle
self.gaussian_max_ksize = gaussian_max_ksize
def _gaussian_blur(self, img, max_ksize=12):
ksize = (np.random.choice(np.arange(5, max_ksize, 2)),
np.random.choice(np.arange(5, max_ksize, 2)))
img = cv2.GaussianBlur(img, ksize, 0)
return img
def _motion_blur(self, img, max_ksize=12, max_angle=45):
degree = np.random.choice(np.arange(5, max_ksize, 2))
angle = np.random.choice(np.arange(-1 * max_angle, max_angle))
M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1)
motion_blur_kernel = np.diag(np.ones(degree))
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
(degree, degree))
motion_blur_kernel = motion_blur_kernel / degree
blurred = cv2.filter2D(img, -1, motion_blur_kernel)
cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX)
img = np.array(blurred, dtype=np.uint8)
return img
@format_data
def __call__(self, img):
if random.random() > self.ratio:
label = 0
else:
method = random.choice(["gaussian", "motion"])
if method == "gaussian":
img = self._gaussian_blur(img, self.gaussian_max_ksize)
else:
img = self._motion_blur(img, self.motion_max_ksize,
self.motion_max_angle)
label = 1
return {"img": img, "blur_image": label}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册