提交 06d3c848 编写于 作者: L LielinJiang

reproduce transforms

上级 46376659
...@@ -36,16 +36,17 @@ dataset: ...@@ -36,16 +36,17 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop - name: RandomCrop
no_flip: False output_size: [256, 256]
normalize: - name: RandomHorizontalFlip
mean: prob: 0.5
(127.5, 127.5, 127.5) - name: Permute
std: - name: Normalize
(127.5, 127.5, 127.5) mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/cityscapes/testB dataroot: data/cityscapes/testB
...@@ -55,17 +56,13 @@ dataset: ...@@ -55,17 +56,13 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop - name: Permute
no_flip: True - name: Normalize
normalize: mean: [127.5, 127.5, 127.5]
mean: std: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5)
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -35,16 +35,17 @@ dataset: ...@@ -35,16 +35,17 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop - name: RandomCrop
no_flip: False output_size: [256, 256]
normalize: - name: RandomHorizontalFlip
mean: prob: 0.5
(127.5, 127.5, 127.5) - name: Permute
std: - name: Normalize
(127.5, 127.5, 127.5) mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test: test:
name: SingleDataset name: SingleDataset
dataroot: data/horse2zebra/testA dataroot: data/horse2zebra/testA
...@@ -55,15 +56,13 @@ dataset: ...@@ -55,15 +56,13 @@ dataset:
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transform:
load_size: 256 transform:
crop_size: 256 - name: Resize
preprocess: resize_and_crop size: [256, 256]
no_flip: True - name: Permute
normalize: - name: Normalize
mean: mean: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) std: [127.5, 127.5, 127.5]
std:
(127.5, 127.5, 127.5)
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -33,16 +33,22 @@ dataset: ...@@ -33,16 +33,22 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 0 pool_size: 0
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop keys: [image, image]
no_flip: False - name: PairedRandomCrop
normalize: output_size: [256, 256]
mean: keys: [image, image]
(127.5, 127.5, 127.5) - name: PairedRandomHorizontalFlip
std: prob: 0.5
(127.5, 127.5, 127.5) keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
...@@ -53,16 +59,17 @@ dataset: ...@@ -53,16 +59,17 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: True serial_batches: True
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop keys: [image, image]
no_flip: True - name: Permute
normalize: keys: [image, image]
mean: - name: Normalize
(127.5, 127.5, 127.5) mean: [127.5, 127.5, 127.5]
std: std: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) keys: [image, image]
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -32,16 +32,22 @@ dataset: ...@@ -32,16 +32,22 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 0 pool_size: 0
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop keys: [image, image]
no_flip: False - name: PairedRandomCrop
normalize: output_size: [256, 256]
mean: keys: [image, image]
(127.5, 127.5, 127.5) - name: PairedRandomHorizontalFlip
std: prob: 0.5
(127.5, 127.5, 127.5) keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes/ dataroot: data/cityscapes/
...@@ -52,16 +58,16 @@ dataset: ...@@ -52,16 +58,16 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: True serial_batches: True
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop keys: [image, image]
no_flip: True - name: Permute
normalize: keys: [image, image]
mean: - name: Normalize
(127.5, 127.5, 127.5) mean: [127.5, 127.5, 127.5]
std: std: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) keys: [image, image]
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -32,16 +32,22 @@ dataset: ...@@ -32,16 +32,22 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 0 pool_size: 0
transform: transforms:
load_size: 286 - name: Resize
crop_size: 256 size: [286, 286]
preprocess: resize_and_crop keys: [image, image]
no_flip: False - name: PairedRandomCrop
normalize: output_size: [256, 256]
mean: keys: [image, image]
(127.5, 127.5, 127.5) - name: PairedRandomHorizontalFlip
std: prob: 0.5
(127.5, 127.5, 127.5) keys: [image, image]
- name: Permute
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
test: test:
name: PairedDataset name: PairedDataset
dataroot: data/facades/ dataroot: data/facades/
...@@ -52,16 +58,16 @@ dataset: ...@@ -52,16 +58,16 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: True serial_batches: True
pool_size: 50 pool_size: 50
transform: transforms:
load_size: 256 - name: Resize
crop_size: 256 size: [256, 256]
preprocess: resize_and_crop keys: [image, image]
no_flip: True - name: Permute
normalize: keys: [image, image]
mean: - name: Normalize
(127.5, 127.5, 127.5) mean: [127.5, 127.5, 127.5]
std: std: [127.5, 127.5, 127.5]
(127.5, 127.5, 127.5) keys: [image, image]
optimizer: optimizer:
name: Adam name: Adam
......
...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform ...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_params, get_transform
from .image_folder import make_dataset from .image_folder import make_dataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class PairedDataset(BaseDataset): class PairedDataset(BaseDataset):
"""A dataset class for paired image dataset. """A dataset class for paired image dataset.
""" """
def __init__(self, cfg): def __init__(self, cfg):
"""Initialize this dataset class. """Initialize this dataset class.
...@@ -19,11 +19,14 @@ class PairedDataset(BaseDataset): ...@@ -19,11 +19,14 @@ class PairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags cfg (dict) -- stores all the experiment flags
""" """
BaseDataset.__init__(self, cfg) BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(cfg.dataroot, cfg.phase) # get the image directory self.dir_AB = os.path.join(cfg.dataroot,
self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size)) # get image paths cfg.phase) # get the image directory
assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image self.AB_paths = sorted(make_dataset(
self.dir_AB, cfg.max_dataset_size)) # get image paths
# assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
self.transforms = build_transforms(cfg.transforms)
def __getitem__(self, index): def __getitem__(self, index):
"""Return a data point and its metadata information. """Return a data point and its metadata information.
...@@ -49,27 +52,20 @@ class PairedDataset(BaseDataset): ...@@ -49,27 +52,20 @@ class PairedDataset(BaseDataset):
A = AB[:h, :w2, :] A = AB[:h, :w2, :]
B = AB[:h, w2:, :] B = AB[:h, w2:, :]
# apply the same transform to both A and B # apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size) # transform_params = get_params(self.opt, A.size)
transform_params = get_params(self.cfg.transform, (w2, h)) # transform_params = get_params(self.cfg.transform, (w2, h))
A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1)) # A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1)) # B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))
A = A_transform(A) # A = A_transform(A)
B = B_transform(B) # B = B_transform(B)
# A, B = self.transforms((A, B))
A, B = self.transforms((A, B))
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
def __len__(self): def __len__(self):
"""Return the total number of images in the dataset.""" """Return the total number of images in the dataset."""
return len(self.AB_paths) return len(self.AB_paths)
def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Variable):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
current_paths.append(self.AB_paths[index])
return current_paths
from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute
import copy
import traceback
import paddle
from ...utils.registry import Registry
TRANSFORMS = Registry("TRANSFORMS")
class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
# multi-fileds in a sample
# if isinstance(data, Sequence):
# data = f(*data)
# # single field in a sample, call transform directly
# else:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def build_transform(cfg):
pass
def build_transforms(cfg):
transforms = []
for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))
transforms = Compose(transforms)
return transforms
import sys
import types
import random import random
import numbers
import warnings
import traceback
import collections
import numpy as np
from paddle.utils import try_import
import paddle.vision.transforms.functional as F
import paddle.vision.transforms.transforms as T
class RandomCrop(object): from .builder import TRANSFORMS
def __init__(self, output_size): if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
class Transform():
def _set_attributes(self, args):
"""
Set attributes from the input list of parameters.
Args:
args (list): list of parameters.
"""
if args:
for k, v in args.items():
# print(k, v)
if k != "self" and not k.startswith("_"):
setattr(self, k, v)
def apply_image(self, input):
raise NotImplementedError
def __call__(self, inputs):
# print('debug:', type(inputs), type(inputs[0]))
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key])
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i])
else:
inputs = self.apply_image(inputs)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Resize(Transform):
"""Resize the input Image to the given size.
Args:
size (int|list|tuple): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Interpolation mode of resize. Default: 1.
0 : cv2.INTER_NEAREST
1 : cv2.INTER_LINEAR
2 : cv2.INTER_CUBIC
3 : cv2.INTER_AREA
4 : cv2.INTER_LANCZOS4
5 : cv2.INTER_LINEAR_EXACT
7 : cv2.INTER_MAX
8 : cv2.WARP_FILL_OUTLIERS
16: cv2.WARP_INVERSE_MAP
"""
def __init__(self, size, interpolation=1, keys=None):
super().__init__()
assert isinstance(size, int) or (isinstance(size, Iterable)
and len(size) == 2)
self._set_attributes(locals())
if isinstance(self.size, Iterable):
self.size = tuple(size)
def apply_image(self, img):
return F.resize(img, self.size, self.interpolation)
@TRANSFORMS.register()
class RandomCrop(Transform):
def __init__(self, output_size, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(output_size, int): if isinstance(output_size, int):
self.output_size = (output_size, output_size) self.output_size = (output_size, output_size)
else: else:
...@@ -19,12 +111,171 @@ class RandomCrop(object): ...@@ -19,12 +111,171 @@ class RandomCrop(object):
j = random.randint(0, w - tw) j = random.randint(0, w - tw)
return i, j, th, tw return i, j, th, tw
def __call__(self, img): def apply_image(self, img):
i, j, h, w = self._get_params(img) i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w] cropped_img = img[i:i + h, j:j + w]
return cropped_img return cropped_img
@TRANSFORMS.register()
class PairedRandomCrop(RandomCrop):
def __init__(self, output_size, keys=None):
super().__init__(output_size, keys)
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def apply_image(self, img, crop_prams=None):
if crop_prams is not None:
i, j, h, w = crop_prams
else:
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
if isinstance(inputs, dict):
crop_params = self._get_params(inputs[self.keys[0]])
elif isinstance(inputs, (list, tuple)):
crop_params = self._get_params(inputs[0])
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
crop_params)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i],
crop_params)
else:
crop_params = self._get_params(inputs)
inputs = self.apply_image(inputs, crop_params)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class RandomHorizontalFlip(Transform):
"""Horizontally flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
"""
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=1)
return img
# import paddle
# paddle.vision.transforms.RandomHorizontalFlip
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img, flip):
if flip:
return F.flip(img, code=1)
return img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
flip = np.random.random() < self.prob
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
flip)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i], flip)
else:
inputs = self.apply_image(inputs, flip)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Normalize(Transform):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
"""
def __init__(self, mean=0.0, std=1.0, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(mean, numbers.Number):
mean = [mean, mean, mean]
if isinstance(std, numbers.Number):
std = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def apply_image(self, img):
return (img - self.mean) / self.std
@TRANSFORMS.register()
class Permute(Transform):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
"""
def __init__(self, mode="CHW", to_rgb=True, keys=None):
super().__init__()
self._set_attributes(locals())
assert mode in [
"CHW"
], "Only support 'CHW' mode, but received mode: {}".format(mode)
self.mode = mode
self.to_rgb = to_rgb
def apply_image(self, img):
if self.to_rgb:
img = img[..., ::-1]
if self.mode == "CHW":
return img.transpose((2, 0, 1))
return img
# import paddle
# paddle.vision.transforms.Normalize
# TRANSFORMS.register(T.Normalize)
class Crop(): class Crop():
def __init__(self, pos, size): def __init__(self, pos, size):
self.pos = pos self.pos = pos
...@@ -35,6 +286,6 @@ class Crop(): ...@@ -35,6 +286,6 @@ class Crop():
x, y = self.pos x, y = self.pos
th = tw = self.size th = tw = self.size
if (ow > tw or oh > th): if (ow > tw or oh > th):
return img[y: y + th, x: x + tw] return img[y:y + th, x:x + tw]
return img return img
\ No newline at end of file
...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform ...@@ -5,13 +5,13 @@ from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset from .image_folder import make_dataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class UnpairedDataset(BaseDataset): class UnpairedDataset(BaseDataset):
""" """
""" """
def __init__(self, cfg): def __init__(self, cfg):
"""Initialize this dataset class. """Initialize this dataset class.
...@@ -19,18 +19,26 @@ class UnpairedDataset(BaseDataset): ...@@ -19,18 +19,26 @@ class UnpairedDataset(BaseDataset):
cfg (dict) -- stores all the experiment flags cfg (dict) -- stores all the experiment flags
""" """
BaseDataset.__init__(self, cfg) BaseDataset.__init__(self, cfg)
self.dir_A = os.path.join(cfg.dataroot, cfg.phase + 'A') # create a path '/path/to/data/trainA' self.dir_A = os.path.join(cfg.dataroot, cfg.phase +
self.dir_B = os.path.join(cfg.dataroot, cfg.phase + 'B') # create a path '/path/to/data/trainB' 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(cfg.dataroot, cfg.phase +
'B') # create a path '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(self.dir_A, cfg.max_dataset_size)) # load images from '/path/to/data/trainA' self.A_paths = sorted(make_dataset(
self.B_paths = sorted(make_dataset(self.dir_B, cfg.max_dataset_size)) # load images from '/path/to/data/trainB' self.dir_A,
cfg.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(
self.dir_B,
cfg.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B self.B_size = len(self.B_paths) # get the size of dataset B
btoA = self.cfg.direction == 'BtoA' btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1)) # self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1))
self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1)) # self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1))
self.transform_A = build_transforms(self.cfg.transforms)
self.transform_B = build_transforms(self.cfg.transforms)
self.reset_paths() self.reset_paths()
...@@ -49,10 +57,11 @@ class UnpairedDataset(BaseDataset): ...@@ -49,10 +57,11 @@ class UnpairedDataset(BaseDataset):
A_paths (str) -- image paths A_paths (str) -- image paths
B_paths (str) -- image paths B_paths (str) -- image paths
""" """
A_path = self.A_paths[index % self.A_size] # make sure index is within then range A_path = self.A_paths[
if self.cfg.serial_batches: # make sure index is within then range index % self.A_size] # make sure index is within then range
if self.cfg.serial_batches: # make sure index is within then range
index_B = index % self.B_size index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs. else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1) index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B] B_path = self.B_paths[index_B]
......
...@@ -2,18 +2,9 @@ import paddle ...@@ -2,18 +2,9 @@ import paddle
from ..utils.registry import Registry from ..utils.registry import Registry
MODELS = Registry("MODEL") MODELS = Registry("MODEL")
def build_model(cfg): def build_model(cfg):
# dataset = MODELS.get(cfg.MODEL.name)(cfg.MODEL)
# place = paddle.CUDAPlace(0)
# dataloader = paddle.io.DataLoader(dataset,
# batch_size=1, #opt.batch_size,
# places=place,
# shuffle=True, #not opt.serial_batches,
# num_workers=0)#int(opt.num_threads))
model = MODELS.get(cfg.model.name)(cfg) model = MODELS.get(cfg.model.name)(cfg)
return model return model
# pass
\ No newline at end of file
...@@ -77,8 +77,8 @@ class Pix2PixModel(BaseModel): ...@@ -77,8 +77,8 @@ class Pix2PixModel(BaseModel):
""" """
AtoB = self.opt.dataset.train.direction == 'AtoB' AtoB = self.opt.dataset.train.direction == 'AtoB'
self.real_A = paddle.to_tensor(input['A' if AtoB else 'B']) self.real_A = paddle.to_variable(input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) self.real_B = paddle.to_variable(input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self): def forward(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册