未验证 提交 96096612 编写于 作者: H haoyuying 提交者: GitHub

Colorize (#897)

上级 5889f7cf
...@@ -2,9 +2,8 @@ import paddle ...@@ -2,9 +2,8 @@ import paddle
import paddlehub as hub import paddlehub as hub
import paddle.nn as nn import paddle.nn as nn
if __name__ == '__main__': if __name__ == '__main__':
paddle.disable_static() paddle.disable_static()
model = hub.Module(directory='user_guided_colorization') model = hub.Module(name='user_guided_colorization')
model.eval() model.eval()
result = model.predict(images='sea.jpg') result = model.predict(images='house.png')
\ No newline at end of file
...@@ -3,18 +3,24 @@ import paddlehub as hub ...@@ -3,18 +3,24 @@ import paddlehub as hub
import paddle.nn as nn import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
if __name__ == '__main__': if __name__ == '__main__':
is_train = True is_train = True
paddle.disable_static() paddle.disable_static()
model = hub.Module(directory='user_guided_colorization') model = hub.Module(name='user_guided_colorization')
transform = Compose([Resize((256,256),interp="RANDOM"),RandomPaddingCrop(crop_size=176), ConvertColorSpace(mode='RGB2LAB'), ColorizePreprocess(ab_thresh=0, p=1)], stay_rgb=True) transform = Compose([
color_set = Colorizedataset(transform=transform, mode=is_train) Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train),
],
stay_rgb=True,
is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train')
if is_train: if is_train:
model.train() model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=3, batch_size=1, eval_dataset=color_set, save_interval=1) trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10)
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle import paddle
import numpy import numpy
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
from paddlehub.module.cv_module import ImageColorizeModule from paddlehub.module.cv_module import ImageColorizeModule
...@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer): ...@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer):
if load_checkpoint is not None: if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0] model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict) self.set_dict(model_dict)
print("load pretrained model success") print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transforms(self, images: str, is_train: bool = True) -> callable: def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train: if is_train:
transform = Compose([ transform = Compose([
Resize((256, 256), interp="RANDOM"), Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176), RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'), ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train) ColorizePreprocess(ab_thresh=0, is_train=is_train)
], ],
stay_rgb=True) stay_rgb=True,
is_permute=False)
else: else:
transform = Compose([ transform = Compose([
Resize((256, 256), interp="RANDOM"), Resize((256, 256), interp='NEAREST'),
ConvertColorSpace(mode='RGB2LAB'), ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train) ColorizePreprocess(ab_thresh=0, is_train=is_train)
], ],
stay_rgb=True) stay_rgb=True,
is_permute=False)
return transform(images) return transform(images)
def forward(self, def forward(self,
......
...@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file ...@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
from typing import Callable from typing import Callable
class Colorizedataset(paddle.io.Dataset): class Colorizedataset(paddle.io.Dataset):
""" """
Dataset for colorization. Dataset for colorization.
...@@ -34,14 +35,12 @@ class Colorizedataset(paddle.io.Dataset): ...@@ -34,14 +35,12 @@ class Colorizedataset(paddle.io.Dataset):
def __init__(self, transform: Callable, mode: str = 'train'): def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode self.mode = mode
self.transform = transform self.transform = transform
if self.mode == 'train': if self.mode == 'train':
self.file = 'train' self.file = 'train'
elif self.mode == 'test': elif self.mode == 'test':
self.file = 'test' self.file = 'test'
else:
self.file = 'validation'
self.file = os.path.join(DATA_HOME, 'canvas', self.file) self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file) self.data = get_img_file(self.file)
...@@ -51,4 +50,4 @@ class Colorizedataset(paddle.io.Dataset): ...@@ -51,4 +50,4 @@ class Colorizedataset(paddle.io.Dataset):
return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc'] return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc']
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
\ No newline at end of file
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
from typing import List from typing import List
from collections import OrderedDict from collections import OrderedDict
import cv2
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -27,6 +28,7 @@ from PIL import Image ...@@ -27,6 +28,7 @@ from PIL import Image
from paddlehub.module.module import serving, RunModule from paddlehub.module.module import serving, RunModule
from paddlehub.utils.utils import base64_to_cv2 from paddlehub.utils.utils import base64_to_cv2
from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize
from paddlehub.process.functional import subtract_imagenet_mean_batch, gram_matrix
class ImageServing(object): class ImageServing(object):
...@@ -192,3 +194,87 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -192,3 +194,87 @@ class ImageColorizeModule(RunModule, ImageServing):
psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnr_value = 20 * np.log10(255. / np.sqrt(mse))
result.append(visual_ret) result.append(visual_ret)
return result return result
class StyleTransferModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
mse_loss = nn.MSELoss()
N, C, H, W = batch[0].shape
batch[1] = batch[1][0].unsqueeze(0)
self.setTarget(batch[1])
y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = subtract_imagenet_mean_batch(y)
xc = subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y)
features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1])
gram_style = [gram_matrix(y) for y in features_style]
style_loss = 0.
for m in range(len(features_y)):
gram_y = gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :])
loss = content_loss + style_loss
return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}}
def predict(self, origin_path: str, style_path: str, visualization: bool = True, save_path: str = 'result'):
'''
Colorize images
Args:
origin_path(str): Content image path .
style_path(str): Style image path.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0)
style = style.unsqueeze(0)
self.setTarget(style)
output = self(content)
output = paddle.clip(output[0].transpose((1, 2, 0)), 0, 255).numpy()
if visualization:
output = output.astype(np.uint8)
style_name = "style_" + str(time.time()) + ".png"
if not os.path.exists(save_path):
os.mkdir(save_path)
path = os.path.join(save_path, style_name)
cv2.imwrite(path, output)
return output
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import cv2 import cv2
import paddle
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
...@@ -114,7 +115,25 @@ def get_img_file(dir_name: str) -> list: ...@@ -114,7 +115,25 @@ def get_img_file(dir_name: str) -> list:
if not is_image_file(filename): if not is_image_file(filename):
continue continue
img_path = os.path.join(parent, filename) img_path = os.path.join(parent, filename)
print(img_path)
images.append(img_path) images.append(img_path)
images.sort() images.sort()
return images return images
\ No newline at end of file
def subtract_imagenet_mean_batch(batch: paddle.Tensor) -> paddle.Tensor:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean = np.zeros(shape=batch.shape, dtype='float32')
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
mean = paddle.to_tensor(mean)
return batch - mean
def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
"""Get gram matrix"""
b, ch, h, w = data.shape
features = data.reshape((b, ch, w * h))
features_t = features.transpose((0, 2, 1))
gram = features.bmm(features_t) / (ch * h * w)
return gram
...@@ -24,15 +24,16 @@ from paddlehub.process.functional import * ...@@ -24,15 +24,16 @@ from paddlehub.process.functional import *
class Compose: class Compose:
def __init__(self, transforms, to_rgb=True, stay_rgb=False): def __init__(self, transforms, to_rgb=True, stay_rgb=False, is_permute=True):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
if len(transforms) < 1: if len(transforms) < 1:
raise ValueError('The length of transforms ' + \ raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.stay_rgb = stay_rgb self.stay_rgb = stay_rgb
self.is_permute = is_permute
def __call__(self, im): def __call__(self, im):
if isinstance(im, str): if isinstance(im, str):
...@@ -45,15 +46,16 @@ class Compose: ...@@ -45,15 +46,16 @@ class Compose:
for op in self.transforms: for op in self.transforms:
im = op(im) im = op(im)
if not self.stay_rgb: if not self.stay_rgb:
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
if self.is_permute:
im = permute(im) im = permute(im)
return im return im
class RandomHorizontalFlip: class RandomHorizontalFlip:
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -239,8 +241,13 @@ class RandomPaddingCrop: ...@@ -239,8 +241,13 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder( im = cv2.copyMakeBorder(im,
im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=self.im_padding_value) 0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
if crop_height > 0 and crop_width > 0: if crop_height > 0 and crop_width > 0:
h_off = np.random.randint(img_height - crop_height + 1) h_off = np.random.randint(img_height - crop_height + 1)
...@@ -295,13 +302,12 @@ class RandomRotation: ...@@ -295,13 +302,12 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy r[1, 2] += (nh / 2) - cy
dsize = (nw, nh) dsize = (nw, nh)
im = cv2.warpAffine( im = cv2.warpAffine(im,
im, r,
r, dsize=dsize,
dsize=dsize, flags=cv2.INTER_LINEAR,
flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,
borderMode=cv2.BORDER_CONSTANT, borderValue=self.im_padding_value)
borderValue=self.im_padding_value)
return im return im
...@@ -403,14 +409,14 @@ class RandomDistort: ...@@ -403,14 +409,14 @@ class RandomDistort:
return im return im
class ConvertColorSpace: class ConvertColorSpace:
""" """
Convert color space from RGB to LAB or from LAB to RGB. Convert color space from RGB to LAB or from LAB to RGB.
Args: Args:
mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'. mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'.
Return: Return:
img(np.ndarray): converted image. img(np.ndarray): converted image.
""" """
...@@ -429,7 +435,7 @@ class ConvertColorSpace: ...@@ -429,7 +435,7 @@ class ConvertColorSpace:
""" """
mask = (rgb > 0.04045) mask = (rgb > 0.04045)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
rgb = (((rgb + .055) / 1.055) ** 2.4) * mask + rgb / 12.92 * (1 - mask) rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = np.nan_to_num(rgb) rgb = np.nan_to_num(rgb)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :] x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :] y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
...@@ -490,7 +496,7 @@ class ConvertColorSpace: ...@@ -490,7 +496,7 @@ class ConvertColorSpace:
rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs
mask = (rgb > .0031308).astype(np.float32) mask = (rgb > .0031308).astype(np.float32)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
out = (1.055 * (rgb ** (1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask) out = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
out = np.nan_to_num(out) out = np.nan_to_num(out)
return out return out
...@@ -511,7 +517,7 @@ class ConvertColorSpace: ...@@ -511,7 +517,7 @@ class ConvertColorSpace:
out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1) out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1)
mask = (out > .2068966).astype(np.float32) mask = (out > .2068966).astype(np.float32)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
out = (out ** 3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask) out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
out = np.nan_to_num(out) out = np.nan_to_num(out)
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None] sc = np.array((0.95047, 1., 1.08883))[None, :, None, None]
out = out * sc out = out * sc
...@@ -546,27 +552,27 @@ class ConvertColorSpace: ...@@ -546,27 +552,27 @@ class ConvertColorSpace:
class ColorizeHint: class ColorizeHint:
"""Get hint and mask images for colorization. """Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process. This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process.
Args: Args:
percent(float): Probability for ignoring hint in an iteration. percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration. num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal. samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area. use_avg(bool): Whether to use mean in selected hint area.
Return: Return:
hint(np.ndarray): hint images hint(np.ndarray): hint images
mask(np.ndarray): mask images mask(np.ndarray): mask images
""" """
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True): def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent self.percent = percent
self.num_points = num_points self.num_points = num_points
self.samp = samp self.samp = samp
self.use_avg = use_avg self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray): def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9, ] sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data self.data = data
self.hint = hint self.hint = hint
self.mask = mask self.mask = mask
...@@ -577,7 +583,7 @@ class ColorizeHint: ...@@ -577,7 +583,7 @@ class ColorizeHint:
while cont_cond: while cont_cond:
if self.num_points is None: # draw from geometric if self.num_points is None: # draw from geometric
# embed() # embed()
cont_cond = np.random.rand() < (1 - self.percent) cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points else: # add certain number of points
cont_cond = pp < self.num_points cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met if not cont_cond: # skip out of loop if condition not met
...@@ -593,9 +599,11 @@ class ColorizeHint: ...@@ -593,9 +599,11 @@ class ColorizeHint:
# add color point # add color point
if self.use_avg: if self.use_avg:
# embed() # embed()
hint[nn, :, h:h + P, w:w + P] = np.mean( hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True), axis=1, keepdims=True).reshape( axis=2,
1, C, 1, 1) keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else: else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P] hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1 mask[nn, :, h:h + P, w:w + P] = 1
...@@ -609,10 +617,10 @@ class ColorizeHint: ...@@ -609,10 +617,10 @@ class ColorizeHint:
class SqueezeAxis: class SqueezeAxis:
""" """
Squeeze the specific axis when it equal to 1. Squeeze the specific axis when it equal to 1.
Args: Args:
axis(int): Which axis should be squeezed. axis(int): Which axis should be squeezed.
""" """
def __init__(self, axis: int): def __init__(self, axis: int):
self.axis = axis self.axis = axis
...@@ -628,7 +636,7 @@ class SqueezeAxis: ...@@ -628,7 +636,7 @@ class SqueezeAxis:
class ColorizePreprocess: class ColorizePreprocess:
"""Prepare dataset for image Colorization. """Prepare dataset for image Colorization.
Args: Args:
ab_thresh(float): Thresh value for setting mask value. ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration. p(float): Probability for ignoring hint in an iteration.
...@@ -636,13 +644,14 @@ class ColorizePreprocess: ...@@ -636,13 +644,14 @@ class ColorizePreprocess:
samp(str): Sample method, default is normal. samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area. use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not. is_train(bool): Training process or not.
Return: Return:
data(dict):The preprocessed data for colorization. data(dict):The preprocessed data for colorization.
""" """
def __init__(self, ab_thresh: float = 0., def __init__(self,
p: float = .125, ab_thresh: float = 0.,
p: float = 0.,
num_points: int = None, num_points: int = None,
samp: str = 'normal', samp: str = 'normal',
use_avg: bool = True, use_avg: bool = True,
...@@ -668,11 +677,14 @@ class ColorizePreprocess: ...@@ -668,11 +677,14 @@ class ColorizePreprocess:
""" """
data = {} data = {}
A = 2 * 110 / 10 + 1 A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [0, ], :, :] data['A'] = data_lab[:, [
0,
], :, :]
data['B'] = data_lab[:, 1:, :, :] data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110 thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),axis=1) mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh) mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :] data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :] data['B'] = data['B'][mask, :, :, :]
...@@ -698,10 +710,10 @@ class ColorizePreprocess: ...@@ -698,10 +710,10 @@ class ColorizePreprocess:
class ColorPostprocess: class ColorPostprocess:
""" """
Transform images from [0, 1] to [0, 255] Transform images from [0, 1] to [0, 255]
Args: Args:
type(type): Type of Image value. type(type): Type of Image value.
Return: Return:
img(np.ndarray): Image in range of 0-255. img(np.ndarray): Image in range of 0-255.
""" """
...@@ -713,3 +725,41 @@ class ColorPostprocess: ...@@ -713,3 +725,41 @@ class ColorPostprocess:
img = np.clip(img, 0, 1) * 255 img = np.clip(img, 0, 1) * 255
img = img.astype(self.type) img = img.astype(self.type)
return img return img
class CenterCrop:
"""
Crop the middle part of the image to the specified size.
Args:
crop_size(int): Crop size.
Return:
img(np.ndarray): Croped image.
"""
def __init__(self, crop_size: int):
self.crop_size = crop_size
def __call__(self, img: np.ndarray):
img_width, img_height, chanel = img.shape
crop_top = int((img_height - self.crop_size) / 2.)
crop_left = int((img_width - self.crop_size) / 2.)
return img[crop_left:crop_left + self.crop_size, crop_top:crop_top + self.crop_size, :]
class SetType:
"""
Set image type.
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Transformed image.
"""
def __init__(self, datatype: type = 'float32'):
self.type = datatype
def __call__(self, img: np.ndarray):
img = img.astype(self.type)
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册