diff --git a/demo/colorization/train.py b/demo/colorization/train.py index a12223adf07ed7c476d93eddbd3bb6ba3b962f69..d1b18208f0e4d89d74ba7818ae1ea30aeaf7019c 100644 --- a/demo/colorization/train.py +++ b/demo/colorization/train.py @@ -4,23 +4,18 @@ import paddle.nn as nn from paddlehub.finetune.trainer import Trainer from paddlehub.datasets.colorizedataset import Colorizedataset -from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess +import paddlehub.process.transforms as T if __name__ == '__main__': - is_train = True + paddle.disable_static() model = hub.Module(name='user_guided_colorization') - transform = Compose([ - Resize((256, 256), interp='NEAREST'), - RandomPaddingCrop(crop_size=176), - ConvertColorSpace(mode='RGB2LAB'), - ColorizePreprocess(ab_thresh=0, is_train=is_train), - ], - stay_rgb=True, - is_permute=False) + transform = T.Compose([T.Resize((256, 256), interp='NEAREST'), + T.RandomPaddingCrop(crop_size=176), + T.RGB2LAB()], + stay_rgb=True, + is_permute=False) color_set = Colorizedataset(transform=transform, mode='train') - if is_train: - model.train() - optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) - trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') - trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10) + optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) + trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') + trainer.train(color_set, epochs=101, batch_size=2, eval_dataset=color_set, log_interval=10, save_interval=10) diff --git a/demo/detection/yolov3_darknet53_pascalvoc/predict.py b/demo/detection/yolov3_darknet53_pascalvoc/predict.py index 70a985f809582bc0f5d96e6442a7dace6cb3bc54..7e9641a5cd61cae9fb9667da74ce80e6b21f1830 100644 --- a/demo/detection/yolov3_darknet53_pascalvoc/predict.py +++ b/demo/detection/yolov3_darknet53_pascalvoc/predict.py @@ -6,4 +6,4 @@ if __name__ == '__main__': paddle.disable_static() model = hub.Module(name='yolov3_darknet53_pascalvoc', is_train=False) model.eval() - model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON/FILE") + model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON") diff --git a/hub_module/modules/image/colorization/user_guided_colorization/data_feed.py b/hub_module/modules/image/colorization/user_guided_colorization/data_feed.py new file mode 100644 index 0000000000000000000000000000000000000000..984cb45701e6939de650b42d0b9f0046f83860bd --- /dev/null +++ b/hub_module/modules/image/colorization/user_guided_colorization/data_feed.py @@ -0,0 +1,133 @@ +import paddle +import numpy as np + + +class ColorizeHint: + """Get hint and mask images for colorization. + + This method is prepared for user guided colorization tasks. Take the original RGB images as imput, + we will obtain the local hints and correspoding mask to guid colorization process. + + Args: + percent(float): Probability for ignoring hint in an iteration. + num_points(int): Number of selected hints in an iteration. + samp(str): Sample method, default is normal. + use_avg(bool): Whether to use mean in selected hint area. + + Return: + hint(np.ndarray): hint images + mask(np.ndarray): mask images + """ + def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True): + self.percent = percent + self.num_points = num_points + self.samp = samp + self.use_avg = use_avg + + def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray): + sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.data = data + self.hint = hint + self.mask = mask + N, C, H, W = data.shape + for nn in range(N): + pp = 0 + cont_cond = True + while cont_cond: + if self.num_points is None: # draw from geometric + # embed() + cont_cond = np.random.rand() > (1 - self.percent) + else: # add certain number of points + cont_cond = pp < self.num_points + if not cont_cond: # skip out of loop if condition not met + continue + P = np.random.choice(sample_Ps) # patch size + # sample location + if self.samp == 'normal': # geometric distribution + h = int(np.clip(np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), 0, H - P)) + w = int(np.clip(np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), 0, W - P)) + else: # uniform distribution + h = np.random.randint(H - P + 1) + w = np.random.randint(W - P + 1) + # add color point + if self.use_avg: + # embed() + hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P], + axis=2, + keepdims=True), + axis=1, + keepdims=True).reshape(1, C, 1, 1) + else: + hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P] + mask[nn, :, h:h + P, w:w + P] = 1 + # increment counter + pp += 1 + + mask -= 0.5 + return hint, mask + + +class ColorizePreprocess: + """Prepare dataset for image Colorization. + + Args: + ab_thresh(float): Thresh value for setting mask value. + p(float): Probability for ignoring hint in an iteration. + num_points(int): Number of selected hints in an iteration. + samp(str): Sample method, default is normal. + use_avg(bool): Whether to use mean in selected hint area. + is_train(bool): Training process or not. + + Return: + data(dict):The preprocessed data for colorization. + + """ + def __init__(self, + ab_thresh: float = 0., + p: float = 0., + num_points: int = None, + samp: str = 'normal', + use_avg: bool = True): + self.ab_thresh = ab_thresh + self.p = p + self.num_points = num_points + self.samp = samp + self.use_avg = use_avg + self.gethint = ColorizeHint(percent=self.p, num_points=self.num_points, samp=self.samp, use_avg=self.use_avg) + + def __call__(self, data_lab): + """ + This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task. + + Args: + img(np.ndarray|paddle.Tensor): LAB image. + + Returns: + data(dict):The preprocessed data for colorization. + """ + if type(data_lab) is not np.ndarray: + data_lab = data_lab.numpy() + data = {} + A = 2 * 110 / 10 + 1 + data['A'] = data_lab[:, [0], :, :] + data['B'] = data_lab[:, 1:, :, :] + if self.ab_thresh > 0: # mask out grayscale images + thresh = 1. * self.ab_thresh / 110 + mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)), + axis=1) + mask = (mask >= thresh) + data['A'] = data['A'][mask, :, :, :] + data['B'] = data['B'][mask, :, :, :] + if np.sum(mask) == 0: + return None + data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number + data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] + data['hint_B'] = np.zeros(shape=data['B'].shape) + data['mask_B'] = np.zeros(shape=data['A'].shape) + data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B']) + data['A'] = paddle.to_tensor(data['A'].astype(np.float32)) + data['B'] = paddle.to_tensor(data['B'].astype(np.float32)) + data['real_B_enc'] = paddle.to_tensor(data['real_B_enc'].astype(np.int64)) + data['hint_B'] = paddle.to_tensor(data['hint_B'].astype(np.float32)) + data['mask_B'] = paddle.to_tensor(data['mask_B'].astype(np.float32)) + return data diff --git a/hub_module/modules/image/colorization/user_guided_colorization/module.py b/hub_module/modules/image/colorization/user_guided_colorization/module.py index ebba99331caba72db1dc5be66161ef90b0683587..ed838431b1e47b373196ea6c09d3e429a9479083 100644 --- a/hub_module/modules/image/colorization/user_guided_colorization/module.py +++ b/hub_module/modules/image/colorization/user_guided_colorization/module.py @@ -15,12 +15,12 @@ import os import paddle -import numpy import paddle.nn as nn from paddle.nn import Conv2d, ConvTranspose2d from paddlehub.module.module import moduleinfo -from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess +import paddlehub.process.transforms as T from paddlehub.module.cv_module import ImageColorizeModule +from user_guided_colorization.data_feed import ColorizePreprocess @moduleinfo( @@ -32,7 +32,8 @@ from paddlehub.module.cv_module import ImageColorizeModule version="1.0.0", meta=ImageColorizeModule) class UserGuidedColorization(nn.Layer): - """Userguidedcolorization, see https://github.com/haoyuying/colorization-pytorch + """ + Userguidedcolorization, see https://github.com/haoyuying/colorization-pytorch Args: use_tanh (bool): Whether to use tanh as final activation function. @@ -139,12 +140,7 @@ class UserGuidedColorization(nn.Layer): 1, 1, ), ) - model9 = ( - nn.ReLU(), - Conv2d(128, 128, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(128), - ) + model9 = (nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128)) # Conv10 model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), ) @@ -182,30 +178,32 @@ class UserGuidedColorization(nn.Layer): print("load custom checkpoint success") else: checkpoint = os.path.join(self.directory, 'user_guided.pdparams') + if not os.path.exists(checkpoint): + os.system('wget https://paddlehub.bj.bcebos.com/dygraph/image_colorization/user_guided.pdparams -O ' + + checkpoint) model_dict = paddle.load(checkpoint)[0] self.set_dict(model_dict) print("load pretrained checkpoint success") def transforms(self, images: str, is_train: bool = True) -> callable: if is_train: - transform = Compose([ - Resize((256, 256), interp='NEAREST'), - RandomPaddingCrop(crop_size=176), - ConvertColorSpace(mode='RGB2LAB'), - ColorizePreprocess(ab_thresh=0, is_train=is_train) - ], - stay_rgb=True, - is_permute=False) + transform = T.Compose( + [T.Resize((256, 256), interp='NEAREST'), + T.RandomPaddingCrop(crop_size=176), + T.RGB2LAB()], + stay_rgb=True, + is_permute=False) else: - transform = Compose([ - Resize((256, 256), interp='NEAREST'), - ConvertColorSpace(mode='RGB2LAB'), - ColorizePreprocess(ab_thresh=0, is_train=is_train) - ], - stay_rgb=True, - is_permute=False) + transform = T.Compose([T.Resize( + (256, 256), interp='NEAREST'), T.RGB2LAB()], + stay_rgb=True, + is_permute=False) return transform(images) + def preprocess(self, inputs: paddle.Tensor, ab_thresh: float = 0., prob: float = 0.): + self.preprocess = ColorizePreprocess(ab_thresh=ab_thresh, p=prob) + return self.preprocess(inputs) + def forward(self, input_A: paddle.Tensor, input_B: paddle.Tensor, diff --git a/paddlehub/datasets/colorizedataset.py b/paddlehub/datasets/colorizedataset.py index fbcdb4a7d60af39a6155147f16ed94941668d0d2..0ee45c303d792c1dc5b8171997c09651af5102e5 100644 --- a/paddlehub/datasets/colorizedataset.py +++ b/paddlehub/datasets/colorizedataset.py @@ -15,7 +15,7 @@ import os -import numpy +import numpy as np import paddle from paddlehub.process.functional import get_img_file @@ -26,9 +26,11 @@ from typing import Callable class Colorizedataset(paddle.io.Dataset): """ Dataset for colorization. + Args: transform(callmethod) : The method of preprocess images. mode(str): The mode for preparing dataset. + Returns: DataSet: An iterable object for data iterating """ @@ -44,10 +46,10 @@ class Colorizedataset(paddle.io.Dataset): self.file = os.path.join(DATA_HOME, 'canvas', self.file) self.data = get_img_file(self.file) - def __getitem__(self, idx: int) -> numpy.ndarray: + def __getitem__(self, idx: int) -> np.ndarray: img_path = self.data[idx] im = self.transform(img_path) - return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc'] + return im def __len__(self): return len(self.data) diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index e8f8bf01438a0d35fb1b1204832136dc595fe8da..388b51b9a1197d650ec6630fea91c5735ae8b8aa 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -111,7 +111,7 @@ class ImageColorizeModule(RunModule, ImageServing): batch_idx(int): The index of batch. Returns: - results(dict) : The model outputs, such as loss and metrics. + results(dict): The model outputs, such as loss and metrics. ''' return self.validation_step(batch, batch_idx) @@ -126,29 +126,30 @@ class ImageColorizeModule(RunModule, ImageServing): Returns: results(dict) : The model outputs, such as metrics. ''' - out_class, out_reg = self(batch[0], batch[1], batch[2]) + img = self.preprocess(batch[0]) + out_class, out_reg = self(img['A'], img['hint_B'], img['mask_B']) + # loss criterionCE = nn.loss.CrossEntropyLoss() - loss_ce = criterionCE(out_class, batch[4][:, 0, :, :]) - loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True) + loss_ce = criterionCE(out_class, img['real_B_enc'][:, 0, :, :]) + loss_G_L1_reg = paddle.sum(paddle.abs(img['B'] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss = loss_ce + loss_G_L1_reg + #calculate psnr visual_ret = OrderedDict() psnrs = [] - lab2rgb = T.ConvertColorSpace(mode='LAB2RGB') + lab2rgb = T.LAB2RGB() process = T.ColorPostprocess() - - for i in range(batch[0].numpy().shape[0]): - real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i] + for i in range(img['A'].numpy().shape[0]): + real = lab2rgb(np.concatenate((img['A'].numpy(), img['B'].numpy()), axis=1))[i] visual_ret['real'] = process(real) - fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i] + fake = lab2rgb(np.concatenate((img['A'].numpy(), out_reg.numpy()), axis=1))[i] visual_ret['fake_reg'] = process(fake) mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2) psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnrs.append(psnr_value) psnr = paddle.to_variable(np.array(psnrs)) - return {'loss': loss, 'metrics': {'psnr': psnr}} def predict(self, images: str, visualization: bool = True, save_path: str = 'result'): @@ -163,23 +164,26 @@ class ImageColorizeModule(RunModule, ImageServing): Returns: results(list[dict]) : The prediction result of each input image ''' - lab2rgb = T.ConvertColorSpace(mode='LAB2RGB') + + lab2rgb = T.LAB2RGB() process = T.ColorPostprocess() resize = T.Resize((256, 256)) - visual_ret = OrderedDict() + im = self.transforms(images, is_train=False) - out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']), - paddle.to_variable(im['mask_B'])) - result = [] + im = im[np.newaxis, :, :, :] + im = self.preprocess(im) + out_class, out_reg = self(im['A'], im['hint_B'], im['mask_B']) + result = [] + visual_ret = OrderedDict() for i in range(im['A'].shape[0]): - gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i] + gray = lab2rgb(np.concatenate((im['A'].numpy(), np.zeros(im['B'].shape)), axis=1))[i] visual_ret['gray'] = resize(process(gray)) - hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i] + hint = lab2rgb(np.concatenate((im['A'].numpy(), im['hint_B'].numpy()), axis=1))[i] visual_ret['hint'] = resize(process(hint)) - real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i] + real = lab2rgb(np.concatenate((im['A'].numpy(), im['B'].numpy()), axis=1))[i] visual_ret['real'] = resize(process(real)) - fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i] + fake = lab2rgb(np.concatenate((im['A'].numpy(), out_reg.numpy()), axis=1))[i] visual_ret['fake_reg'] = resize(process(fake)) if visualization: @@ -232,16 +236,17 @@ class Yolov3Module(RunModule, ImageServing): for i, out in enumerate(outputs): anchor_mask = self.anchor_masks[i] - loss = F.yolov3_loss(x=out, - gt_box=gtbox, - gt_label=gtlabel, - gt_score=gtscore, - anchors=self.anchors, - anchor_mask=anchor_mask, - class_num=self.class_num, - ignore_thresh=self.ignore_thresh, - downsample_ratio=32, - use_label_smooth=False) + loss = F.yolov3_loss( + x=out, + gt_box=gtbox, + gt_label=gtlabel, + gt_score=gtscore, + anchors=self.anchors, + anchor_mask=anchor_mask, + class_num=self.class_num, + ignore_thresh=self.ignore_thresh, + downsample_ratio=32, + use_label_smooth=False) losses.append(paddle.reduce_mean(loss)) self.downsample //= 2 @@ -280,13 +285,14 @@ class Yolov3Module(RunModule, ImageServing): mask_anchors.append((self.anchors[2 * m])) mask_anchors.append(self.anchors[2 * m + 1]) - box, score = F.yolo_box(x=out, - img_size=im_shape, - anchors=mask_anchors, - class_num=self.class_num, - conf_thresh=self.valid_thresh, - downsample_ratio=self.downsample, - name="yolo_box" + str(i)) + box, score = F.yolo_box( + x=out, + img_size=im_shape, + anchors=mask_anchors, + class_num=self.class_num, + conf_thresh=self.valid_thresh, + downsample_ratio=self.downsample, + name="yolo_box" + str(i)) boxes.append(box) scores.append(paddle.transpose(score, perm=[0, 2, 1])) @@ -295,13 +301,14 @@ class Yolov3Module(RunModule, ImageServing): yolo_boxes = paddle.concat(boxes, axis=1) yolo_scores = paddle.concat(scores, axis=2) - pred = F.multiclass_nms(bboxes=yolo_boxes, - scores=yolo_scores, - score_threshold=self.valid_thresh, - nms_top_k=self.nms_topk, - keep_top_k=self.nms_posk, - nms_threshold=self.nms_thresh, - background_label=-1) + pred = F.multiclass_nms( + bboxes=yolo_boxes, + scores=yolo_scores, + score_threshold=self.valid_thresh, + nms_top_k=self.nms_topk, + keep_top_k=self.nms_posk, + nms_threshold=self.nms_thresh, + background_label=-1) bboxes = pred.numpy() labels = bboxes[:, 0].astype('int32') @@ -309,7 +316,9 @@ class Yolov3Module(RunModule, ImageServing): boxes = bboxes[:, 2:].astype('float32') if visualization: - Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5) + if not os.path.exists(save_path): + os.mkdir(save_path) + Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5, save_path) return boxes, scores, labels diff --git a/paddlehub/process/functional.py b/paddlehub/process/functional.py index 3a9f41ae611fdb7b75df2edb4d46290fe368d951..ff15f6ac727731e25872077184cb9df171280f1e 100644 --- a/paddlehub/process/functional.py +++ b/paddlehub/process/functional.py @@ -185,7 +185,8 @@ def draw_boxes_on_image(image_path: str, scores: np.ndarray, labels: np.ndarray, label_names: list, - score_thresh: float = 0.5): + score_thresh: float = 0.5, + save_path: str = 'result'): """Draw boxes on images.""" image = np.array(Image.open(image_path)) plt.figure() @@ -206,25 +207,25 @@ def draw_boxes_on_image(image_path: str, x1, y1, x2, y2 = box[0], box[1], box[2], box[3] rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label]) ax.add_patch(rect) - ax.text(x1, - y1, - '{} {:.4f}'.format(label_names[label], score), - verticalalignment='bottom', - horizontalalignment='left', - bbox={ - 'facecolor': colors[label], - 'alpha': 0.5, - 'pad': 0 - }, - fontsize=8, - color='white') + ax.text( + x1, + y1, + '{} {:.4f}'.format(label_names[label], score), + verticalalignment='bottom', + horizontalalignment='left', + bbox={ + 'facecolor': colors[label], + 'alpha': 0.5, + 'pad': 0 + }, + fontsize=8, + color='white') print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], str(list(map(int, list(box)))), score)) image_name = image_name.replace('jpg', 'png') plt.axis('off') plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) - plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0) - print("Detect result save at ./output/{}\n".format(image_name)) + plt.savefig("{}/{}".format(save_path, image_name), bbox_inches='tight', pad_inches=0.0) plt.cla() plt.close('all') diff --git a/paddlehub/process/transforms.py b/paddlehub/process/transforms.py index 6a49cce46ff1bc7835e5fdcdd441c12099d7f8dc..36217669ab623f3d089903ebd1036e3133558613 100644 --- a/paddlehub/process/transforms.py +++ b/paddlehub/process/transforms.py @@ -382,7 +382,7 @@ class RandomDistort: saturation_upper = 1 + self.saturation_range hue_lower = -self.hue_range hue_upper = self.hue_range - ops = [brightness, contrast, saturation, hue] + ops = ['brightness', 'contrast', 'saturation', 'hue'] random.shuffle(ops) params_dict = { 'brightness': { @@ -421,19 +421,10 @@ class RandomDistort: return im -class ConvertColorSpace: +class RGB2LAB: """ - Convert color space from RGB to LAB or from LAB to RGB. - - Args: - mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'. - - Return: - img(np.ndarray): converted image. + Convert color space from RGB to LAB. """ - def __init__(self, mode: str = 'RGB2LAB'): - self.mode = mode - def rgb2xyz(self, rgb: np.ndarray) -> np.ndarray: """ Convert color space from RGB to XYZ. @@ -448,10 +439,10 @@ class ConvertColorSpace: np.seterr(invalid='ignore') rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) rgb = np.nan_to_num(rgb) - x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :] - y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :] - z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] + .950227 * rgb[:, 2, :, :] - out = np.concatenate((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), axis=1) + x = .412453 * rgb[0, :, :] + .357580 * rgb[1, :, :] + .180423 * rgb[2, :, :] + y = .212671 * rgb[0, :, :] + .715160 * rgb[1, :, :] + .072169 * rgb[2, :, :] + z = .019334 * rgb[0, :, :] + .119193 * rgb[1, :, :] + .950227 * rgb[2, :, :] + out = np.concatenate((x[None, :, :], y[None, :, :], z[None, :, :]), axis=0) return out def xyz2lab(self, xyz: np.ndarray) -> np.ndarray: @@ -464,14 +455,14 @@ class ConvertColorSpace: Return: img(np.ndarray): Converted LAB image. """ - sc = np.array((0.95047, 1., 1.08883))[None, :, None, None] + sc = np.array((0.95047, 1., 1.08883))[:, None, None] xyz_scale = xyz / sc mask = (xyz_scale > .008856).astype(np.float32) xyz_int = np.cbrt(xyz_scale) * mask + (7.787 * xyz_scale + 16. / 116.) * (1 - mask) - L = 116. * xyz_int[:, 1, :, :] - 16. - a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) - b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) - out = np.concatenate((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), axis=1) + L = 116. * xyz_int[1, :, :] - 16. + a = 500. * (xyz_int[0, :, :] - xyz_int[1, :, :]) + b = 200. * (xyz_int[1, :, :] - xyz_int[2, :, :]) + out = np.concatenate((L[None, :, :], a[None, :, :], b[None, :, :]), axis=0) return out def rgb2lab(self, rgb: np.ndarray) -> np.ndarray: @@ -485,11 +476,24 @@ class ConvertColorSpace: img(np.ndarray): Converted LAB image. """ lab = self.xyz2lab(self.rgb2xyz(rgb)) - l_rs = (lab[:, [0], :, :] - 50) / 100 - ab_rs = lab[:, 1:, :, :] / 110 - out = np.concatenate((l_rs, ab_rs), axis=1) + l_rs = (lab[[0], :, :] - 50) / 100 + ab_rs = lab[1:, :, :] / 110 + out = np.concatenate((l_rs, ab_rs), axis=0) return out + def __call__(self, img: np.ndarray) -> np.ndarray: + img = img / 255 + img = np.array(img).transpose(2, 0, 1) + return self.rgb2lab(img) + + +class LAB2RGB: + """ + Convert color space from LAB to RGB. + """ + def __init__(self, mode: str = 'RGB2LAB'): + self.mode = mode + def xyz2rgb(self, xyz: np.ndarray) -> np.ndarray: """ Convert color space from XYZ to RGB. @@ -551,171 +555,7 @@ class ConvertColorSpace: return out def __call__(self, img: np.ndarray) -> np.ndarray: - if self.mode == 'RGB2LAB': - img = np.expand_dims(img / 255, 0) - img = np.array(img).transpose(0, 3, 1, 2) - return self.rgb2lab(img) - elif self.mode == 'LAB2RGB': - return self.lab2rgb(img) - else: - raise ValueError('The mode should be RGB2LAB or LAB2RGB') - - -class ColorizeHint: - """Get hint and mask images for colorization. - - This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process. - - Args: - percent(float): Probability for ignoring hint in an iteration. - num_points(int): Number of selected hints in an iteration. - samp(str): Sample method, default is normal. - use_avg(bool): Whether to use mean in selected hint area. - - Return: - hint(np.ndarray): hint images - mask(np.ndarray): mask images - """ - def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True): - self.percent = percent - self.num_points = num_points - self.samp = samp - self.use_avg = use_avg - - def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray): - sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9] - self.data = data - self.hint = hint - self.mask = mask - N, C, H, W = data.shape - for nn in range(N): - pp = 0 - cont_cond = True - while cont_cond: - if self.num_points is None: # draw from geometric - # embed() - cont_cond = np.random.rand() > (1 - self.percent) - else: # add certain number of points - cont_cond = pp < self.num_points - if not cont_cond: # skip out of loop if condition not met - continue - P = np.random.choice(sample_Ps) # patch size - # sample location - if self.samp == 'normal': # geometric distribution - h = int(np.clip(np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), 0, H - P)) - w = int(np.clip(np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), 0, W - P)) - else: # uniform distribution - h = np.random.randint(H - P + 1) - w = np.random.randint(W - P + 1) - # add color point - if self.use_avg: - # embed() - hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P], - axis=2, - keepdims=True), - axis=1, - keepdims=True).reshape(1, C, 1, 1) - else: - hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P] - mask[nn, :, h:h + P, w:w + P] = 1 - # increment counter - pp += 1 - - mask -= 0.5 - return hint, mask - - -class SqueezeAxis: - """ - Squeeze the specific axis when it equal to 1. - - Args: - axis(int): Which axis should be squeezed. - - """ - def __init__(self, axis: int): - self.axis = axis - - def __call__(self, data: dict): - if isinstance(data, dict): - for key in data.keys(): - data[key] = np.squeeze(data[key], 0).astype(np.float32) - return data - else: - raise TypeError("Type of data is invalid. Must be Dict or List or tuple, now is {}".format(type(data))) - - -class ColorizePreprocess: - """Prepare dataset for image Colorization. - - Args: - ab_thresh(float): Thresh value for setting mask value. - p(float): Probability for ignoring hint in an iteration. - num_points(int): Number of selected hints in an iteration. - samp(str): Sample method, default is normal. - use_avg(bool): Whether to use mean in selected hint area. - is_train(bool): Training process or not. - - Return: - data(dict):The preprocessed data for colorization. - - """ - def __init__(self, - ab_thresh: float = 0., - p: float = 0., - num_points: int = None, - samp: str = 'normal', - use_avg: bool = True, - is_train: bool = True): - self.ab_thresh = ab_thresh - self.p = p - self.num_points = num_points - self.samp = samp - self.use_avg = use_avg - self.is_train = is_train - self.gethint = ColorizeHint(percent=self.p, num_points=self.num_points, samp=self.samp, use_avg=self.use_avg) - self.squeeze = SqueezeAxis(0) - - def __call__(self, data_lab: np.ndarray): - """ - This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task. - - Args: - img(np.ndarray): LAB image. - - Returns: - data(dict):The preprocessed data for colorization. - """ - data = {} - A = 2 * 110 / 10 + 1 - data['A'] = data_lab[:, [ - 0, - ], :, :] - data['B'] = data_lab[:, 1:, :, :] - if self.ab_thresh > 0: # mask out grayscale images - thresh = 1. * self.ab_thresh / 110 - mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)), - axis=1) - mask = (mask >= thresh) - data['A'] = data['A'][mask, :, :, :] - data['B'] = data['B'][mask, :, :, :] - if np.sum(mask) == 0: - return None - data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number - data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] - data['hint_B'] = np.zeros(shape=data['B'].shape) - data['mask_B'] = np.zeros(shape=data['A'].shape) - data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B']) - if self.is_train: - data = self.squeeze(data) - data['real_B_enc'] = data['real_B_enc'].astype(np.int64) - else: - data['A'] = data['A'].astype(np.float32) - data['B'] = data['B'].astype(np.float32) - data['real_B_enc'] = data['real_B_enc'].astype(np.int64) - data['hint_B'] = data['hint_B'].astype(np.float32) - data['mask_B'] = data['mask_B'].astype(np.float32) - return data + return self.lab2rgb(img) class ColorPostprocess: