提交 4760af1c 编写于 作者: L LielinJiang

clean code

上级 c1bd1a2a
...@@ -23,7 +23,7 @@ class PairedDataset(BaseDataset): ...@@ -23,7 +23,7 @@ class PairedDataset(BaseDataset):
cfg.phase) # get the image directory cfg.phase) # get the image directory
self.AB_paths = sorted(make_dataset( self.AB_paths = sorted(make_dataset(
self.dir_AB, cfg.max_dataset_size)) # get image paths self.dir_AB, cfg.max_dataset_size)) # get image paths
# assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size) # crop_size should be smaller than the size of loaded image
self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc
self.transforms = build_transforms(cfg.transforms) self.transforms = build_transforms(cfg.transforms)
...@@ -53,15 +53,6 @@ class PairedDataset(BaseDataset): ...@@ -53,15 +53,6 @@ class PairedDataset(BaseDataset):
B = AB[:h, w2:, :] B = AB[:h, w2:, :]
# apply the same transform to both A and B # apply the same transform to both A and B
# transform_params = get_params(self.opt, A.size)
# transform_params = get_params(self.cfg.transform, (w2, h))
# A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
# B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))
# A = A_transform(A)
# B = B_transform(B)
# A, B = self.transforms((A, B))
A, B = self.transforms((A, B)) A, B = self.transforms((A, B))
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
......
...@@ -25,11 +25,6 @@ class Compose(object): ...@@ -25,11 +25,6 @@ class Compose(object):
def __call__(self, data): def __call__(self, data):
for f in self.transforms: for f in self.transforms:
try: try:
# multi-fileds in a sample
# if isinstance(data, Sequence):
# data = f(*data)
# # single field in a sample, call transform directly
# else:
data = f(data) data = f(data)
except Exception as e: except Exception as e:
stack_info = traceback.format_exc() stack_info = traceback.format_exc()
...@@ -39,10 +34,6 @@ class Compose(object): ...@@ -39,10 +34,6 @@ class Compose(object):
return data return data
def build_transform(cfg):
pass
def build_transforms(cfg): def build_transforms(cfg):
transforms = [] transforms = []
......
import sys import sys
import types
import random import random
import numbers import numbers
import warnings
import traceback
import collections import collections
import numpy as np import numpy as np
from paddle.utils import try_import from paddle.utils import try_import
import paddle.vision.transforms.functional as F import paddle.vision.transforms.functional as F
import paddle.vision.transforms.transforms as T
from .builder import TRANSFORMS from .builder import TRANSFORMS
...@@ -31,7 +27,6 @@ class Transform(): ...@@ -31,7 +27,6 @@ class Transform():
""" """
if args: if args:
for k, v in args.items(): for k, v in args.items():
# print(k, v)
if k != "self" and not k.startswith("_"): if k != "self" and not k.startswith("_"):
setattr(self, k, v) setattr(self, k, v)
...@@ -39,7 +34,6 @@ class Transform(): ...@@ -39,7 +34,6 @@ class Transform():
raise NotImplementedError raise NotImplementedError
def __call__(self, inputs): def __call__(self, inputs):
# print('debug:', type(inputs), type(inputs[0]))
if isinstance(inputs, tuple): if isinstance(inputs, tuple):
inputs = list(inputs) inputs = list(inputs)
if self.keys is not None: if self.keys is not None:
...@@ -177,10 +171,6 @@ class RandomHorizontalFlip(Transform): ...@@ -177,10 +171,6 @@ class RandomHorizontalFlip(Transform):
return img return img
# import paddle
# paddle.vision.transforms.RandomHorizontalFlip
@TRANSFORMS.register() @TRANSFORMS.register()
class PairedRandomHorizontalFlip(RandomHorizontalFlip): class PairedRandomHorizontalFlip(RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None): def __init__(self, prob=0.5, keys=None):
...@@ -271,11 +261,6 @@ class Permute(Transform): ...@@ -271,11 +261,6 @@ class Permute(Transform):
return img return img
# import paddle
# paddle.vision.transforms.Normalize
# TRANSFORMS.register(T.Normalize)
class Crop(): class Crop():
def __init__(self, pos, size): def __init__(self, pos, size):
self.pos = pos self.pos = pos
......
...@@ -35,8 +35,7 @@ class UnpairedDataset(BaseDataset): ...@@ -35,8 +35,7 @@ class UnpairedDataset(BaseDataset):
btoA = self.cfg.direction == 'BtoA' btoA = self.cfg.direction == 'BtoA'
input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image input_nc = self.cfg.output_nc if btoA else self.cfg.input_nc # get the number of channels of input image
output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image output_nc = self.cfg.input_nc if btoA else self.cfg.output_nc # get the number of channels of output image
# self.transform_A = get_transform(self.cfg.transform, grayscale=(input_nc == 1))
# self.transform_B = get_transform(self.cfg.transform, grayscale=(output_nc == 1))
self.transform_A = build_transforms(self.cfg.transforms) self.transform_A = build_transforms(self.cfg.transforms)
self.transform_B = build_transforms(self.cfg.transforms) self.transform_B = build_transforms(self.cfg.transforms)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册