未验证 提交 35c96a7d 编写于 作者: H haoyuying 提交者: GitHub

reconstruct colorization transform

上级 48363091
......@@ -4,23 +4,18 @@ import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
import paddlehub.process.transforms as T
if __name__ == '__main__':
is_train = True
paddle.disable_static()
model = hub.Module(name='user_guided_colorization')
transform = Compose([
Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train),
],
stay_rgb=True,
is_permute=False)
transform = T.Compose([T.Resize((256, 256), interp='NEAREST'),
T.RandomPaddingCrop(crop_size=176),
T.RGB2LAB()],
stay_rgb=True,
is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train')
if is_train:
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10)
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=101, batch_size=2, eval_dataset=color_set, log_interval=10, save_interval=10)
......@@ -6,4 +6,4 @@ if __name__ == '__main__':
paddle.disable_static()
model = hub.Module(name='yolov3_darknet53_pascalvoc', is_train=False)
model.eval()
model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON/FILE")
model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON")
import paddle
import numpy as np
class ColorizeHint:
"""Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput,
we will obtain the local hints and correspoding mask to guid colorization process.
Args:
percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
Return:
hint(np.ndarray): hint images
mask(np.ndarray): mask images
"""
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data
self.hint = hint
self.mask = mask
N, C, H, W = data.shape
for nn in range(N):
pp = 0
cont_cond = True
while cont_cond:
if self.num_points is None: # draw from geometric
# embed()
cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points
cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met
continue
P = np.random.choice(sample_Ps) # patch size
# sample location
if self.samp == 'normal': # geometric distribution
h = int(np.clip(np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), 0, H - P))
w = int(np.clip(np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), 0, W - P))
else: # uniform distribution
h = np.random.randint(H - P + 1)
w = np.random.randint(W - P + 1)
# add color point
if self.use_avg:
# embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
axis=2,
keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1
# increment counter
pp += 1
mask -= 0.5
return hint, mask
class ColorizePreprocess:
"""Prepare dataset for image Colorization.
Args:
ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not.
Return:
data(dict):The preprocessed data for colorization.
"""
def __init__(self,
ab_thresh: float = 0.,
p: float = 0.,
num_points: int = None,
samp: str = 'normal',
use_avg: bool = True):
self.ab_thresh = ab_thresh
self.p = p
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
self.gethint = ColorizeHint(percent=self.p, num_points=self.num_points, samp=self.samp, use_avg=self.use_avg)
def __call__(self, data_lab):
"""
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task.
Args:
img(np.ndarray|paddle.Tensor): LAB image.
Returns:
data(dict):The preprocessed data for colorization.
"""
if type(data_lab) is not np.ndarray:
data_lab = data_lab.numpy()
data = {}
A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [0], :, :]
data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :]
if np.sum(mask) == 0:
return None
data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number
data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
data['hint_B'] = np.zeros(shape=data['B'].shape)
data['mask_B'] = np.zeros(shape=data['A'].shape)
data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B'])
data['A'] = paddle.to_tensor(data['A'].astype(np.float32))
data['B'] = paddle.to_tensor(data['B'].astype(np.float32))
data['real_B_enc'] = paddle.to_tensor(data['real_B_enc'].astype(np.int64))
data['hint_B'] = paddle.to_tensor(data['hint_B'].astype(np.float32))
data['mask_B'] = paddle.to_tensor(data['mask_B'].astype(np.float32))
return data
......@@ -15,12 +15,12 @@
import os
import paddle
import numpy
import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
import paddlehub.process.transforms as T
from paddlehub.module.cv_module import ImageColorizeModule
from user_guided_colorization.data_feed import ColorizePreprocess
@moduleinfo(
......@@ -32,7 +32,8 @@ from paddlehub.module.cv_module import ImageColorizeModule
version="1.0.0",
meta=ImageColorizeModule)
class UserGuidedColorization(nn.Layer):
"""Userguidedcolorization, see https://github.com/haoyuying/colorization-pytorch
"""
Userguidedcolorization, see https://github.com/haoyuying/colorization-pytorch
Args:
use_tanh (bool): Whether to use tanh as final activation function.
......@@ -139,12 +140,7 @@ class UserGuidedColorization(nn.Layer):
1,
1,
), )
model9 = (
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
model9 = (nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128))
# Conv10
model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), )
......@@ -182,30 +178,32 @@ class UserGuidedColorization(nn.Layer):
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlehub.bj.bcebos.com/dygraph/image_colorization/user_guided.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train:
transform = Compose([
Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True,
is_permute=False)
transform = T.Compose(
[T.Resize((256, 256), interp='NEAREST'),
T.RandomPaddingCrop(crop_size=176),
T.RGB2LAB()],
stay_rgb=True,
is_permute=False)
else:
transform = Compose([
Resize((256, 256), interp='NEAREST'),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True,
is_permute=False)
transform = T.Compose([T.Resize(
(256, 256), interp='NEAREST'), T.RGB2LAB()],
stay_rgb=True,
is_permute=False)
return transform(images)
def preprocess(self, inputs: paddle.Tensor, ab_thresh: float = 0., prob: float = 0.):
self.preprocess = ColorizePreprocess(ab_thresh=ab_thresh, p=prob)
return self.preprocess(inputs)
def forward(self,
input_A: paddle.Tensor,
input_B: paddle.Tensor,
......
......@@ -15,7 +15,7 @@
import os
import numpy
import numpy as np
import paddle
from paddlehub.process.functional import get_img_file
......@@ -26,9 +26,11 @@ from typing import Callable
class Colorizedataset(paddle.io.Dataset):
"""
Dataset for colorization.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
......@@ -44,10 +46,10 @@ class Colorizedataset(paddle.io.Dataset):
self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file)
def __getitem__(self, idx: int) -> numpy.ndarray:
def __getitem__(self, idx: int) -> np.ndarray:
img_path = self.data[idx]
im = self.transform(img_path)
return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc']
return im
def __len__(self):
return len(self.data)
......@@ -111,7 +111,7 @@ class ImageColorizeModule(RunModule, ImageServing):
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
results(dict): The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx)
......@@ -126,29 +126,30 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns:
results(dict) : The model outputs, such as metrics.
'''
out_class, out_reg = self(batch[0], batch[1], batch[2])
img = self.preprocess(batch[0])
out_class, out_reg = self(img['A'], img['hint_B'], img['mask_B'])
# loss
criterionCE = nn.loss.CrossEntropyLoss()
loss_ce = criterionCE(out_class, batch[4][:, 0, :, :])
loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True)
loss_ce = criterionCE(out_class, img['real_B_enc'][:, 0, :, :])
loss_G_L1_reg = paddle.sum(paddle.abs(img['B'] - out_reg), axis=1, keepdim=True)
loss_G_L1_reg = paddle.mean(loss_G_L1_reg)
loss = loss_ce + loss_G_L1_reg
#calculate psnr
visual_ret = OrderedDict()
psnrs = []
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB')
lab2rgb = T.LAB2RGB()
process = T.ColorPostprocess()
for i in range(batch[0].numpy().shape[0]):
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
for i in range(img['A'].numpy().shape[0]):
real = lab2rgb(np.concatenate((img['A'].numpy(), img['B'].numpy()), axis=1))[i]
visual_ret['real'] = process(real)
fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i]
fake = lab2rgb(np.concatenate((img['A'].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = process(fake)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs))
return {'loss': loss, 'metrics': {'psnr': psnr}}
def predict(self, images: str, visualization: bool = True, save_path: str = 'result'):
......@@ -163,23 +164,26 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns:
results(list[dict]) : The prediction result of each input image
'''
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB')
lab2rgb = T.LAB2RGB()
process = T.ColorPostprocess()
resize = T.Resize((256, 256))
visual_ret = OrderedDict()
im = self.transforms(images, is_train=False)
out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']),
paddle.to_variable(im['mask_B']))
result = []
im = im[np.newaxis, :, :, :]
im = self.preprocess(im)
out_class, out_reg = self(im['A'], im['hint_B'], im['mask_B'])
result = []
visual_ret = OrderedDict()
for i in range(im['A'].shape[0]):
gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i]
gray = lab2rgb(np.concatenate((im['A'].numpy(), np.zeros(im['B'].shape)), axis=1))[i]
visual_ret['gray'] = resize(process(gray))
hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i]
hint = lab2rgb(np.concatenate((im['A'].numpy(), im['hint_B'].numpy()), axis=1))[i]
visual_ret['hint'] = resize(process(hint))
real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i]
real = lab2rgb(np.concatenate((im['A'].numpy(), im['B'].numpy()), axis=1))[i]
visual_ret['real'] = resize(process(real))
fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i]
fake = lab2rgb(np.concatenate((im['A'].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = resize(process(fake))
if visualization:
......@@ -232,16 +236,17 @@ class Yolov3Module(RunModule, ImageServing):
for i, out in enumerate(outputs):
anchor_mask = self.anchor_masks[i]
loss = F.yolov3_loss(x=out,
gt_box=gtbox,
gt_label=gtlabel,
gt_score=gtscore,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.class_num,
ignore_thresh=self.ignore_thresh,
downsample_ratio=32,
use_label_smooth=False)
loss = F.yolov3_loss(
x=out,
gt_box=gtbox,
gt_label=gtlabel,
gt_score=gtscore,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.class_num,
ignore_thresh=self.ignore_thresh,
downsample_ratio=32,
use_label_smooth=False)
losses.append(paddle.reduce_mean(loss))
self.downsample //= 2
......@@ -280,13 +285,14 @@ class Yolov3Module(RunModule, ImageServing):
mask_anchors.append((self.anchors[2 * m]))
mask_anchors.append(self.anchors[2 * m + 1])
box, score = F.yolo_box(x=out,
img_size=im_shape,
anchors=mask_anchors,
class_num=self.class_num,
conf_thresh=self.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box" + str(i))
box, score = F.yolo_box(
x=out,
img_size=im_shape,
anchors=mask_anchors,
class_num=self.class_num,
conf_thresh=self.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box" + str(i))
boxes.append(box)
scores.append(paddle.transpose(score, perm=[0, 2, 1]))
......@@ -295,13 +301,14 @@ class Yolov3Module(RunModule, ImageServing):
yolo_boxes = paddle.concat(boxes, axis=1)
yolo_scores = paddle.concat(scores, axis=2)
pred = F.multiclass_nms(bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
pred = F.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
bboxes = pred.numpy()
labels = bboxes[:, 0].astype('int32')
......@@ -309,7 +316,9 @@ class Yolov3Module(RunModule, ImageServing):
boxes = bboxes[:, 2:].astype('float32')
if visualization:
Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5)
if not os.path.exists(save_path):
os.mkdir(save_path)
Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5, save_path)
return boxes, scores, labels
......
......@@ -185,7 +185,8 @@ def draw_boxes_on_image(image_path: str,
scores: np.ndarray,
labels: np.ndarray,
label_names: list,
score_thresh: float = 0.5):
score_thresh: float = 0.5,
save_path: str = 'result'):
"""Draw boxes on images."""
image = np.array(Image.open(image_path))
plt.figure()
......@@ -206,25 +207,25 @@ def draw_boxes_on_image(image_path: str,
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label])
ax.add_patch(rect)
ax.text(x1,
y1,
'{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom',
horizontalalignment='left',
bbox={
'facecolor': colors[label],
'alpha': 0.5,
'pad': 0
},
fontsize=8,
color='white')
ax.text(
x1,
y1,
'{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom',
horizontalalignment='left',
bbox={
'facecolor': colors[label],
'alpha': 0.5,
'pad': 0
},
fontsize=8,
color='white')
print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], str(list(map(int, list(box)))), score))
image_name = image_name.replace('jpg', 'png')
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.savefig("{}/{}".format(save_path, image_name), bbox_inches='tight', pad_inches=0.0)
plt.cla()
plt.close('all')
......
......@@ -382,7 +382,7 @@ class RandomDistort:
saturation_upper = 1 + self.saturation_range
hue_lower = -self.hue_range
hue_upper = self.hue_range
ops = [brightness, contrast, saturation, hue]
ops = ['brightness', 'contrast', 'saturation', 'hue']
random.shuffle(ops)
params_dict = {
'brightness': {
......@@ -421,19 +421,10 @@ class RandomDistort:
return im
class ConvertColorSpace:
class RGB2LAB:
"""
Convert color space from RGB to LAB or from LAB to RGB.
Args:
mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'.
Return:
img(np.ndarray): converted image.
Convert color space from RGB to LAB.
"""
def __init__(self, mode: str = 'RGB2LAB'):
self.mode = mode
def rgb2xyz(self, rgb: np.ndarray) -> np.ndarray:
"""
Convert color space from RGB to XYZ.
......@@ -448,10 +439,10 @@ class ConvertColorSpace:
np.seterr(invalid='ignore')
rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = np.nan_to_num(rgb)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] + .950227 * rgb[:, 2, :, :]
out = np.concatenate((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), axis=1)
x = .412453 * rgb[0, :, :] + .357580 * rgb[1, :, :] + .180423 * rgb[2, :, :]
y = .212671 * rgb[0, :, :] + .715160 * rgb[1, :, :] + .072169 * rgb[2, :, :]
z = .019334 * rgb[0, :, :] + .119193 * rgb[1, :, :] + .950227 * rgb[2, :, :]
out = np.concatenate((x[None, :, :], y[None, :, :], z[None, :, :]), axis=0)
return out
def xyz2lab(self, xyz: np.ndarray) -> np.ndarray:
......@@ -464,14 +455,14 @@ class ConvertColorSpace:
Return:
img(np.ndarray): Converted LAB image.
"""
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None]
sc = np.array((0.95047, 1., 1.08883))[:, None, None]
xyz_scale = xyz / sc
mask = (xyz_scale > .008856).astype(np.float32)
xyz_int = np.cbrt(xyz_scale) * mask + (7.787 * xyz_scale + 16. / 116.) * (1 - mask)
L = 116. * xyz_int[:, 1, :, :] - 16.
a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :])
b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :])
out = np.concatenate((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), axis=1)
L = 116. * xyz_int[1, :, :] - 16.
a = 500. * (xyz_int[0, :, :] - xyz_int[1, :, :])
b = 200. * (xyz_int[1, :, :] - xyz_int[2, :, :])
out = np.concatenate((L[None, :, :], a[None, :, :], b[None, :, :]), axis=0)
return out
def rgb2lab(self, rgb: np.ndarray) -> np.ndarray:
......@@ -485,11 +476,24 @@ class ConvertColorSpace:
img(np.ndarray): Converted LAB image.
"""
lab = self.xyz2lab(self.rgb2xyz(rgb))
l_rs = (lab[:, [0], :, :] - 50) / 100
ab_rs = lab[:, 1:, :, :] / 110
out = np.concatenate((l_rs, ab_rs), axis=1)
l_rs = (lab[[0], :, :] - 50) / 100
ab_rs = lab[1:, :, :] / 110
out = np.concatenate((l_rs, ab_rs), axis=0)
return out
def __call__(self, img: np.ndarray) -> np.ndarray:
img = img / 255
img = np.array(img).transpose(2, 0, 1)
return self.rgb2lab(img)
class LAB2RGB:
"""
Convert color space from LAB to RGB.
"""
def __init__(self, mode: str = 'RGB2LAB'):
self.mode = mode
def xyz2rgb(self, xyz: np.ndarray) -> np.ndarray:
"""
Convert color space from XYZ to RGB.
......@@ -551,171 +555,7 @@ class ConvertColorSpace:
return out
def __call__(self, img: np.ndarray) -> np.ndarray:
if self.mode == 'RGB2LAB':
img = np.expand_dims(img / 255, 0)
img = np.array(img).transpose(0, 3, 1, 2)
return self.rgb2lab(img)
elif self.mode == 'LAB2RGB':
return self.lab2rgb(img)
else:
raise ValueError('The mode should be RGB2LAB or LAB2RGB')
class ColorizeHint:
"""Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process.
Args:
percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
Return:
hint(np.ndarray): hint images
mask(np.ndarray): mask images
"""
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data
self.hint = hint
self.mask = mask
N, C, H, W = data.shape
for nn in range(N):
pp = 0
cont_cond = True
while cont_cond:
if self.num_points is None: # draw from geometric
# embed()
cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points
cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met
continue
P = np.random.choice(sample_Ps) # patch size
# sample location
if self.samp == 'normal': # geometric distribution
h = int(np.clip(np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), 0, H - P))
w = int(np.clip(np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), 0, W - P))
else: # uniform distribution
h = np.random.randint(H - P + 1)
w = np.random.randint(W - P + 1)
# add color point
if self.use_avg:
# embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
axis=2,
keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1
# increment counter
pp += 1
mask -= 0.5
return hint, mask
class SqueezeAxis:
"""
Squeeze the specific axis when it equal to 1.
Args:
axis(int): Which axis should be squeezed.
"""
def __init__(self, axis: int):
self.axis = axis
def __call__(self, data: dict):
if isinstance(data, dict):
for key in data.keys():
data[key] = np.squeeze(data[key], 0).astype(np.float32)
return data
else:
raise TypeError("Type of data is invalid. Must be Dict or List or tuple, now is {}".format(type(data)))
class ColorizePreprocess:
"""Prepare dataset for image Colorization.
Args:
ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not.
Return:
data(dict):The preprocessed data for colorization.
"""
def __init__(self,
ab_thresh: float = 0.,
p: float = 0.,
num_points: int = None,
samp: str = 'normal',
use_avg: bool = True,
is_train: bool = True):
self.ab_thresh = ab_thresh
self.p = p
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
self.is_train = is_train
self.gethint = ColorizeHint(percent=self.p, num_points=self.num_points, samp=self.samp, use_avg=self.use_avg)
self.squeeze = SqueezeAxis(0)
def __call__(self, data_lab: np.ndarray):
"""
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task.
Args:
img(np.ndarray): LAB image.
Returns:
data(dict):The preprocessed data for colorization.
"""
data = {}
A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [
0,
], :, :]
data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :]
if np.sum(mask) == 0:
return None
data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number
data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
data['hint_B'] = np.zeros(shape=data['B'].shape)
data['mask_B'] = np.zeros(shape=data['A'].shape)
data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B'])
if self.is_train:
data = self.squeeze(data)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
else:
data['A'] = data['A'].astype(np.float32)
data['B'] = data['B'].astype(np.float32)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
data['hint_B'] = data['hint_B'].astype(np.float32)
data['mask_B'] = data['mask_B'].astype(np.float32)
return data
return self.lab2rgb(img)
class ColorPostprocess:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册