提交 4760af1c 编写于 作者: L LielinJiang

clean code

上级 c1bd1a2a
......@@ -23,7 +23,7 @@ class PairedDataset(BaseDataset):
cfg.phase) # get the image directory
self.AB_paths = sorted(make_dataset(
self.dir_AB, cfg.max_dataset_size)) # get image paths
# assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
self.transforms = build_transforms(cfg.transforms)
......@@ -53,15 +53,6 @@ class PairedDataset(BaseDataset):
B = AB[:h, w2:, :]
# apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size)
# transform_params = get_params(self.cfg.transform, (w2, h))
# A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
# B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))
# A = A_transform(A)
# B = B_transform(B)
# A, B = self.transforms((A, B))
A, B = self.transforms((A, B))
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
......
......@@ -25,11 +25,6 @@ class Compose(object):
def __call__(self, data):
for f in self.transforms:
try:
# multi-fileds in a sample
# if isinstance(data, Sequence):
# data = f(*data)
# # single field in a sample, call transform directly
# else:
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
......@@ -39,10 +34,6 @@ class Compose(object):
return data
def build_transform(cfg):
pass
def build_transforms(cfg):
transforms = []
......
import sys
import types
import random
import numbers
import warnings
import traceback
import collections
import numpy as np
from paddle.utils import try_import
import paddle.vision.transforms.functional as F
import paddle.vision.transforms.transforms as T
from .builder import TRANSFORMS
......@@ -31,7 +27,6 @@ class Transform():
"""
if args:
for k, v in args.items():
# print(k, v)
if k != "self" and not k.startswith("_"):
setattr(self, k, v)
......@@ -39,7 +34,6 @@ class Transform():
raise NotImplementedError
def __call__(self, inputs):
# print('debug:', type(inputs), type(inputs[0]))
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
......@@ -177,10 +171,6 @@ class RandomHorizontalFlip(Transform):
return img
# import paddle
# paddle.vision.transforms.RandomHorizontalFlip
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
......@@ -271,11 +261,6 @@ class Permute(Transform):
return img
# import paddle
# paddle.vision.transforms.Normalize
# TRANSFORMS.register(T.Normalize)
class Crop():
def __init__(self, pos, size):
self.pos = pos
......
......@@ -35,8 +35,7 @@ class UnpairedDataset(BaseDataset):
btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
# self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1))
# self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1))
self.transform_A = build_transforms(self.cfg.transforms)
self.transform_B = build_transforms(self.cfg.transforms)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册