未验证 提交 35c96a7d 编写于 作者: H haoyuying 提交者: GitHub

reconstruct colorization transform

上级 48363091
...@@ -4,23 +4,18 @@ import paddle.nn as nn ...@@ -4,23 +4,18 @@ import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess import paddlehub.process.transforms as T
if __name__ == '__main__': if __name__ == '__main__':
is_train = True
paddle.disable_static() paddle.disable_static()
model = hub.Module(name='user_guided_colorization') model = hub.Module(name='user_guided_colorization')
transform = Compose([ transform = T.Compose([T.Resize((256, 256), interp='NEAREST'),
Resize((256, 256), interp='NEAREST'), T.RandomPaddingCrop(crop_size=176),
RandomPaddingCrop(crop_size=176), T.RGB2LAB()],
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train),
],
stay_rgb=True, stay_rgb=True,
is_permute=False) is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train') color_set = Colorizedataset(transform=transform, mode='train')
if is_train:
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10) trainer.train(color_set, epochs=101, batch_size=2, eval_dataset=color_set, log_interval=10, save_interval=10)
...@@ -6,4 +6,4 @@ if __name__ == '__main__': ...@@ -6,4 +6,4 @@ if __name__ == '__main__':
paddle.disable_static() paddle.disable_static()
model = hub.Module(name='yolov3_darknet53_pascalvoc', is_train=False) model = hub.Module(name='yolov3_darknet53_pascalvoc', is_train=False)
model.eval() model.eval()
model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON/FILE") model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON")
import paddle
import numpy as np
class ColorizeHint:
"""Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput,
we will obtain the local hints and correspoding mask to guid colorization process.
Args:
percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
Return:
hint(np.ndarray): hint images
mask(np.ndarray): mask images
"""
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data
self.hint = hint
self.mask = mask
N, C, H, W = data.shape
for nn in range(N):
pp = 0
cont_cond = True
while cont_cond:
if self.num_points is None: # draw from geometric
# embed()
cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points
cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met
continue
P = np.random.choice(sample_Ps) # patch size
# sample location
if self.samp == 'normal': # geometric distribution
h = int(np.clip(np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), 0, H - P))
w = int(np.clip(np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), 0, W - P))
else: # uniform distribution
h = np.random.randint(H - P + 1)
w = np.random.randint(W - P + 1)
# add color point
if self.use_avg:
# embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
axis=2,
keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1
# increment counter
pp += 1
mask -= 0.5
return hint, mask
class ColorizePreprocess:
"""Prepare dataset for image Colorization.
Args:
ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not.
Return:
data(dict):The preprocessed data for colorization.
"""
def __init__(self,
ab_thresh: float = 0.,
p: float = 0.,
num_points: int = None,
samp: str = 'normal',
use_avg: bool = True):
self.ab_thresh = ab_thresh
self.p = p
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
self.gethint = ColorizeHint(percent=self.p, num_points=self.num_points, samp=self.samp, use_avg=self.use_avg)
def __call__(self, data_lab):
"""
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task.
Args:
img(np.ndarray|paddle.Tensor): LAB image.
Returns:
data(dict):The preprocessed data for colorization.
"""
if type(data_lab) is not np.ndarray:
data_lab = data_lab.numpy()
data = {}
A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [0], :, :]
data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :]
if np.sum(mask) == 0:
return None
data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number
data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
data['hint_B'] = np.zeros(shape=data['B'].shape)
data['mask_B'] = np.zeros(shape=data['A'].shape)
data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B'])
data['A'] = paddle.to_tensor(data['A'].astype(np.float32))
data['B'] = paddle.to_tensor(data['B'].astype(np.float32))
data['real_B_enc'] = paddle.to_tensor(data['real_B_enc'].astype(np.int64))
data['hint_B'] = paddle.to_tensor(data['hint_B'].astype(np.float32))
data['mask_B'] = paddle.to_tensor(data['mask_B'].astype(np.float32))
return data
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
import os import os
import paddle import paddle
import numpy
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess import paddlehub.process.transforms as T
from paddlehub.module.cv_module import ImageColorizeModule from paddlehub.module.cv_module import ImageColorizeModule
from user_guided_colorization.data_feed import ColorizePreprocess
@moduleinfo( @moduleinfo(
...@@ -32,7 +32,8 @@ from paddlehub.module.cv_module import ImageColorizeModule ...@@ -32,7 +32,8 @@ from paddlehub.module.cv_module import ImageColorizeModule
version="1.0.0", version="1.0.0",
meta=ImageColorizeModule) meta=ImageColorizeModule)
class UserGuidedColorization(nn.Layer): class UserGuidedColorization(nn.Layer):
"""Userguidedcolorization, see https://github.com/haoyuying/colorization-pytorch """
Userguidedcolorization, see https://github.com/haoyuying/colorization-pytorch
Args: Args:
use_tanh (bool): Whether to use tanh as final activation function. use_tanh (bool): Whether to use tanh as final activation function.
...@@ -139,12 +140,7 @@ class UserGuidedColorization(nn.Layer): ...@@ -139,12 +140,7 @@ class UserGuidedColorization(nn.Layer):
1, 1,
1, 1,
), ) ), )
model9 = ( model9 = (nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128))
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
# Conv10 # Conv10
model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), ) model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), )
...@@ -182,30 +178,32 @@ class UserGuidedColorization(nn.Layer): ...@@ -182,30 +178,32 @@ class UserGuidedColorization(nn.Layer):
print("load custom checkpoint success") print("load custom checkpoint success")
else: else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams') checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlehub.bj.bcebos.com/dygraph/image_colorization/user_guided.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)[0] model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict) self.set_dict(model_dict)
print("load pretrained checkpoint success") print("load pretrained checkpoint success")
def transforms(self, images: str, is_train: bool = True) -> callable: def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train: if is_train:
transform = Compose([ transform = T.Compose(
Resize((256, 256), interp='NEAREST'), [T.Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176), T.RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'), T.RGB2LAB()],
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True, stay_rgb=True,
is_permute=False) is_permute=False)
else: else:
transform = Compose([ transform = T.Compose([T.Resize(
Resize((256, 256), interp='NEAREST'), (256, 256), interp='NEAREST'), T.RGB2LAB()],
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True, stay_rgb=True,
is_permute=False) is_permute=False)
return transform(images) return transform(images)
def preprocess(self, inputs: paddle.Tensor, ab_thresh: float = 0., prob: float = 0.):
self.preprocess = ColorizePreprocess(ab_thresh=ab_thresh, p=prob)
return self.preprocess(inputs)
def forward(self, def forward(self,
input_A: paddle.Tensor, input_A: paddle.Tensor,
input_B: paddle.Tensor, input_B: paddle.Tensor,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
import numpy import numpy as np
import paddle import paddle
from paddlehub.process.functional import get_img_file from paddlehub.process.functional import get_img_file
...@@ -26,9 +26,11 @@ from typing import Callable ...@@ -26,9 +26,11 @@ from typing import Callable
class Colorizedataset(paddle.io.Dataset): class Colorizedataset(paddle.io.Dataset):
""" """
Dataset for colorization. Dataset for colorization.
Args: Args:
transform(callmethod) : The method of preprocess images. transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset. mode(str): The mode for preparing dataset.
Returns: Returns:
DataSet: An iterable object for data iterating DataSet: An iterable object for data iterating
""" """
...@@ -44,10 +46,10 @@ class Colorizedataset(paddle.io.Dataset): ...@@ -44,10 +46,10 @@ class Colorizedataset(paddle.io.Dataset):
self.file = os.path.join(DATA_HOME, 'canvas', self.file) self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file) self.data = get_img_file(self.file)
def __getitem__(self, idx: int) -> numpy.ndarray: def __getitem__(self, idx: int) -> np.ndarray:
img_path = self.data[idx] img_path = self.data[idx]
im = self.transform(img_path) im = self.transform(img_path)
return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc'] return im
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
...@@ -111,7 +111,7 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -111,7 +111,7 @@ class ImageColorizeModule(RunModule, ImageServing):
batch_idx(int): The index of batch. batch_idx(int): The index of batch.
Returns: Returns:
results(dict) : The model outputs, such as loss and metrics. results(dict): The model outputs, such as loss and metrics.
''' '''
return self.validation_step(batch, batch_idx) return self.validation_step(batch, batch_idx)
...@@ -126,29 +126,30 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -126,29 +126,30 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns: Returns:
results(dict) : The model outputs, such as metrics. results(dict) : The model outputs, such as metrics.
''' '''
out_class, out_reg = self(batch[0], batch[1], batch[2]) img = self.preprocess(batch[0])
out_class, out_reg = self(img['A'], img['hint_B'], img['mask_B'])
# loss
criterionCE = nn.loss.CrossEntropyLoss() criterionCE = nn.loss.CrossEntropyLoss()
loss_ce = criterionCE(out_class, batch[4][:, 0, :, :]) loss_ce = criterionCE(out_class, img['real_B_enc'][:, 0, :, :])
loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.sum(paddle.abs(img['B'] - out_reg), axis=1, keepdim=True)
loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss_G_L1_reg = paddle.mean(loss_G_L1_reg)
loss = loss_ce + loss_G_L1_reg loss = loss_ce + loss_G_L1_reg
#calculate psnr
visual_ret = OrderedDict() visual_ret = OrderedDict()
psnrs = [] psnrs = []
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB') lab2rgb = T.LAB2RGB()
process = T.ColorPostprocess() process = T.ColorPostprocess()
for i in range(img['A'].numpy().shape[0]):
for i in range(batch[0].numpy().shape[0]): real = lab2rgb(np.concatenate((img['A'].numpy(), img['B'].numpy()), axis=1))[i]
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
visual_ret['real'] = process(real) visual_ret['real'] = process(real)
fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i] fake = lab2rgb(np.concatenate((img['A'].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = process(fake) visual_ret['fake_reg'] = process(fake)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2) mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value) psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs)) psnr = paddle.to_variable(np.array(psnrs))
return {'loss': loss, 'metrics': {'psnr': psnr}} return {'loss': loss, 'metrics': {'psnr': psnr}}
def predict(self, images: str, visualization: bool = True, save_path: str = 'result'): def predict(self, images: str, visualization: bool = True, save_path: str = 'result'):
...@@ -163,23 +164,26 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -163,23 +164,26 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns: Returns:
results(list[dict]) : The prediction result of each input image results(list[dict]) : The prediction result of each input image
''' '''
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB')
lab2rgb = T.LAB2RGB()
process = T.ColorPostprocess() process = T.ColorPostprocess()
resize = T.Resize((256, 256)) resize = T.Resize((256, 256))
visual_ret = OrderedDict()
im = self.transforms(images, is_train=False) im = self.transforms(images, is_train=False)
out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']), im = im[np.newaxis, :, :, :]
paddle.to_variable(im['mask_B'])) im = self.preprocess(im)
result = [] out_class, out_reg = self(im['A'], im['hint_B'], im['mask_B'])
result = []
visual_ret = OrderedDict()
for i in range(im['A'].shape[0]): for i in range(im['A'].shape[0]):
gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i] gray = lab2rgb(np.concatenate((im['A'].numpy(), np.zeros(im['B'].shape)), axis=1))[i]
visual_ret['gray'] = resize(process(gray)) visual_ret['gray'] = resize(process(gray))
hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i] hint = lab2rgb(np.concatenate((im['A'].numpy(), im['hint_B'].numpy()), axis=1))[i]
visual_ret['hint'] = resize(process(hint)) visual_ret['hint'] = resize(process(hint))
real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i] real = lab2rgb(np.concatenate((im['A'].numpy(), im['B'].numpy()), axis=1))[i]
visual_ret['real'] = resize(process(real)) visual_ret['real'] = resize(process(real))
fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i] fake = lab2rgb(np.concatenate((im['A'].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = resize(process(fake)) visual_ret['fake_reg'] = resize(process(fake))
if visualization: if visualization:
...@@ -232,7 +236,8 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -232,7 +236,8 @@ class Yolov3Module(RunModule, ImageServing):
for i, out in enumerate(outputs): for i, out in enumerate(outputs):
anchor_mask = self.anchor_masks[i] anchor_mask = self.anchor_masks[i]
loss = F.yolov3_loss(x=out, loss = F.yolov3_loss(
x=out,
gt_box=gtbox, gt_box=gtbox,
gt_label=gtlabel, gt_label=gtlabel,
gt_score=gtscore, gt_score=gtscore,
...@@ -280,7 +285,8 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -280,7 +285,8 @@ class Yolov3Module(RunModule, ImageServing):
mask_anchors.append((self.anchors[2 * m])) mask_anchors.append((self.anchors[2 * m]))
mask_anchors.append(self.anchors[2 * m + 1]) mask_anchors.append(self.anchors[2 * m + 1])
box, score = F.yolo_box(x=out, box, score = F.yolo_box(
x=out,
img_size=im_shape, img_size=im_shape,
anchors=mask_anchors, anchors=mask_anchors,
class_num=self.class_num, class_num=self.class_num,
...@@ -295,7 +301,8 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -295,7 +301,8 @@ class Yolov3Module(RunModule, ImageServing):
yolo_boxes = paddle.concat(boxes, axis=1) yolo_boxes = paddle.concat(boxes, axis=1)
yolo_scores = paddle.concat(scores, axis=2) yolo_scores = paddle.concat(scores, axis=2)
pred = F.multiclass_nms(bboxes=yolo_boxes, pred = F.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores, scores=yolo_scores,
score_threshold=self.valid_thresh, score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk, nms_top_k=self.nms_topk,
...@@ -309,7 +316,9 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -309,7 +316,9 @@ class Yolov3Module(RunModule, ImageServing):
boxes = bboxes[:, 2:].astype('float32') boxes = bboxes[:, 2:].astype('float32')
if visualization: if visualization:
Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5) if not os.path.exists(save_path):
os.mkdir(save_path)
Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5, save_path)
return boxes, scores, labels return boxes, scores, labels
......
...@@ -185,7 +185,8 @@ def draw_boxes_on_image(image_path: str, ...@@ -185,7 +185,8 @@ def draw_boxes_on_image(image_path: str,
scores: np.ndarray, scores: np.ndarray,
labels: np.ndarray, labels: np.ndarray,
label_names: list, label_names: list,
score_thresh: float = 0.5): score_thresh: float = 0.5,
save_path: str = 'result'):
"""Draw boxes on images.""" """Draw boxes on images."""
image = np.array(Image.open(image_path)) image = np.array(Image.open(image_path))
plt.figure() plt.figure()
...@@ -206,7 +207,8 @@ def draw_boxes_on_image(image_path: str, ...@@ -206,7 +207,8 @@ def draw_boxes_on_image(image_path: str,
x1, y1, x2, y2 = box[0], box[1], box[2], box[3] x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label]) rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label])
ax.add_patch(rect) ax.add_patch(rect)
ax.text(x1, ax.text(
x1,
y1, y1,
'{} {:.4f}'.format(label_names[label], score), '{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom', verticalalignment='bottom',
...@@ -223,8 +225,7 @@ def draw_boxes_on_image(image_path: str, ...@@ -223,8 +225,7 @@ def draw_boxes_on_image(image_path: str,
plt.axis('off') plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0) plt.savefig("{}/{}".format(save_path, image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.cla() plt.cla()
plt.close('all') plt.close('all')
......
...@@ -382,7 +382,7 @@ class RandomDistort: ...@@ -382,7 +382,7 @@ class RandomDistort:
saturation_upper = 1 + self.saturation_range saturation_upper = 1 + self.saturation_range
hue_lower = -self.hue_range hue_lower = -self.hue_range
hue_upper = self.hue_range hue_upper = self.hue_range
ops = [brightness, contrast, saturation, hue] ops = ['brightness', 'contrast', 'saturation', 'hue']
random.shuffle(ops) random.shuffle(ops)
params_dict = { params_dict = {
'brightness': { 'brightness': {
...@@ -421,19 +421,10 @@ class RandomDistort: ...@@ -421,19 +421,10 @@ class RandomDistort:
return im return im
class ConvertColorSpace: class RGB2LAB:
""" """
Convert color space from RGB to LAB or from LAB to RGB. Convert color space from RGB to LAB.
Args:
mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'.
Return:
img(np.ndarray): converted image.
""" """
def __init__(self, mode: str = 'RGB2LAB'):
self.mode = mode
def rgb2xyz(self, rgb: np.ndarray) -> np.ndarray: def rgb2xyz(self, rgb: np.ndarray) -> np.ndarray:
""" """
Convert color space from RGB to XYZ. Convert color space from RGB to XYZ.
...@@ -448,10 +439,10 @@ class ConvertColorSpace: ...@@ -448,10 +439,10 @@ class ConvertColorSpace:
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = np.nan_to_num(rgb) rgb = np.nan_to_num(rgb)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :] x = .412453 * rgb[0, :, :] + .357580 * rgb[1, :, :] + .180423 * rgb[2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :] y = .212671 * rgb[0, :, :] + .715160 * rgb[1, :, :] + .072169 * rgb[2, :, :]
z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] + .950227 * rgb[:, 2, :, :] z = .019334 * rgb[0, :, :] + .119193 * rgb[1, :, :] + .950227 * rgb[2, :, :]
out = np.concatenate((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), axis=1) out = np.concatenate((x[None, :, :], y[None, :, :], z[None, :, :]), axis=0)
return out return out
def xyz2lab(self, xyz: np.ndarray) -> np.ndarray: def xyz2lab(self, xyz: np.ndarray) -> np.ndarray:
...@@ -464,14 +455,14 @@ class ConvertColorSpace: ...@@ -464,14 +455,14 @@ class ConvertColorSpace:
Return: Return:
img(np.ndarray): Converted LAB image. img(np.ndarray): Converted LAB image.
""" """
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None] sc = np.array((0.95047, 1., 1.08883))[:, None, None]
xyz_scale = xyz / sc xyz_scale = xyz / sc
mask = (xyz_scale > .008856).astype(np.float32) mask = (xyz_scale > .008856).astype(np.float32)
xyz_int = np.cbrt(xyz_scale) * mask + (7.787 * xyz_scale + 16. / 116.) * (1 - mask) xyz_int = np.cbrt(xyz_scale) * mask + (7.787 * xyz_scale + 16. / 116.) * (1 - mask)
L = 116. * xyz_int[:, 1, :, :] - 16. L = 116. * xyz_int[1, :, :] - 16.
a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) a = 500. * (xyz_int[0, :, :] - xyz_int[1, :, :])
b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) b = 200. * (xyz_int[1, :, :] - xyz_int[2, :, :])
out = np.concatenate((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), axis=1) out = np.concatenate((L[None, :, :], a[None, :, :], b[None, :, :]), axis=0)
return out return out
def rgb2lab(self, rgb: np.ndarray) -> np.ndarray: def rgb2lab(self, rgb: np.ndarray) -> np.ndarray:
...@@ -485,11 +476,24 @@ class ConvertColorSpace: ...@@ -485,11 +476,24 @@ class ConvertColorSpace:
img(np.ndarray): Converted LAB image. img(np.ndarray): Converted LAB image.
""" """
lab = self.xyz2lab(self.rgb2xyz(rgb)) lab = self.xyz2lab(self.rgb2xyz(rgb))
l_rs = (lab[:, [0], :, :] - 50) / 100 l_rs = (lab[[0], :, :] - 50) / 100
ab_rs = lab[:, 1:, :, :] / 110 ab_rs = lab[1:, :, :] / 110
out = np.concatenate((l_rs, ab_rs), axis=1) out = np.concatenate((l_rs, ab_rs), axis=0)
return out return out
def __call__(self, img: np.ndarray) -> np.ndarray:
img = img / 255
img = np.array(img).transpose(2, 0, 1)
return self.rgb2lab(img)
class LAB2RGB:
"""
Convert color space from LAB to RGB.
"""
def __init__(self, mode: str = 'RGB2LAB'):
self.mode = mode
def xyz2rgb(self, xyz: np.ndarray) -> np.ndarray: def xyz2rgb(self, xyz: np.ndarray) -> np.ndarray:
""" """
Convert color space from XYZ to RGB. Convert color space from XYZ to RGB.
...@@ -551,171 +555,7 @@ class ConvertColorSpace: ...@@ -551,171 +555,7 @@ class ConvertColorSpace:
return out return out
def __call__(self, img: np.ndarray) -> np.ndarray: def __call__(self, img: np.ndarray) -> np.ndarray:
if self.mode == 'RGB2LAB':
img = np.expand_dims(img / 255, 0)
img = np.array(img).transpose(0, 3, 1, 2)
return self.rgb2lab(img)
elif self.mode == 'LAB2RGB':
return self.lab2rgb(img) return self.lab2rgb(img)
else:
raise ValueError('The mode should be RGB2LAB or LAB2RGB')
class ColorizeHint:
"""Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process.
Args:
percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
Return:
hint(np.ndarray): hint images
mask(np.ndarray): mask images
"""
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data
self.hint = hint
self.mask = mask
N, C, H, W = data.shape
for nn in range(N):
pp = 0
cont_cond = True
while cont_cond:
if self.num_points is None: # draw from geometric
# embed()
cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points
cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met
continue
P = np.random.choice(sample_Ps) # patch size
# sample location
if self.samp == 'normal': # geometric distribution
h = int(np.clip(np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), 0, H - P))
w = int(np.clip(np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), 0, W - P))
else: # uniform distribution
h = np.random.randint(H - P + 1)
w = np.random.randint(W - P + 1)
# add color point
if self.use_avg:
# embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
axis=2,
keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1
# increment counter
pp += 1
mask -= 0.5
return hint, mask
class SqueezeAxis:
"""
Squeeze the specific axis when it equal to 1.
Args:
axis(int): Which axis should be squeezed.
"""
def __init__(self, axis: int):
self.axis = axis
def __call__(self, data: dict):
if isinstance(data, dict):
for key in data.keys():
data[key] = np.squeeze(data[key], 0).astype(np.float32)
return data
else:
raise TypeError("Type of data is invalid. Must be Dict or List or tuple, now is {}".format(type(data)))
class ColorizePreprocess:
"""Prepare dataset for image Colorization.
Args:
ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not.
Return:
data(dict):The preprocessed data for colorization.
"""
def __init__(self,
ab_thresh: float = 0.,
p: float = 0.,
num_points: int = None,
samp: str = 'normal',
use_avg: bool = True,
is_train: bool = True):
self.ab_thresh = ab_thresh
self.p = p
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
self.is_train = is_train
self.gethint = ColorizeHint(percent=self.p, num_points=self.num_points, samp=self.samp, use_avg=self.use_avg)
self.squeeze = SqueezeAxis(0)
def __call__(self, data_lab: np.ndarray):
"""
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task.
Args:
img(np.ndarray): LAB image.
Returns:
data(dict):The preprocessed data for colorization.
"""
data = {}
A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [
0,
], :, :]
data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :]
if np.sum(mask) == 0:
return None
data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number
data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
data['hint_B'] = np.zeros(shape=data['B'].shape)
data['mask_B'] = np.zeros(shape=data['A'].shape)
data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B'])
if self.is_train:
data = self.squeeze(data)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
else:
data['A'] = data['A'].astype(np.float32)
data['B'] = data['B'].astype(np.float32)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
data['hint_B'] = data['hint_B'].astype(np.float32)
data['mask_B'] = data['mask_B'].astype(np.float32)
return data
class ColorPostprocess: class ColorPostprocess:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册