未验证 提交 89dbb63f 编写于 作者: L LielinJiang 提交者: GitHub

Fix some bugs (#140)

* fix some bugs

* update configs
上级 cd642c08
...@@ -67,7 +67,7 @@ dataset: ...@@ -67,7 +67,7 @@ dataset:
batch_size: 1 batch_size: 1
max_size: inf max_size: inf
is_train: False is_train: False
load_pipeline: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: A key: A
- name: LoadImageFromFile - name: LoadImageFromFile
......
...@@ -35,7 +35,7 @@ dataset: ...@@ -35,7 +35,7 @@ dataset:
batch_size: 1 batch_size: 1
is_train: True is_train: True
max_size: inf max_size: inf
load_pipeline: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: A key: A
- name: LoadImageFromFile - name: LoadImageFromFile
...@@ -67,7 +67,7 @@ dataset: ...@@ -67,7 +67,7 @@ dataset:
batch_size: 1 batch_size: 1
max_size: inf max_size: inf
is_train: False is_train: False
load_pipeline: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: A key: A
- name: LoadImageFromFile - name: LoadImageFromFile
......
...@@ -61,7 +61,7 @@ dataset: ...@@ -61,7 +61,7 @@ dataset:
dataroot: data/cityscapes/test dataroot: data/cityscapes/test
num_workers: 4 num_workers: 4
batch_size: 1 batch_size: 1
load_pipeline: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: pair key: pair
- name: SplitPairedImage - name: SplitPairedImage
......
...@@ -61,7 +61,7 @@ dataset: ...@@ -61,7 +61,7 @@ dataset:
dataroot: data/cityscapes/test dataroot: data/cityscapes/test
num_workers: 4 num_workers: 4
batch_size: 1 batch_size: 1
load_pipeline: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: pair key: pair
- name: Transforms - name: Transforms
......
...@@ -61,7 +61,7 @@ dataset: ...@@ -61,7 +61,7 @@ dataset:
dataroot: data/facades/test dataroot: data/facades/test
num_workers: 4 num_workers: 4
batch_size: 1 batch_size: 1
load_pipeline: preprocess:
- name: LoadImageFromFile - name: LoadImageFromFile
key: pair key: pair
- name: Transforms - name: Transforms
......
...@@ -60,7 +60,8 @@ class RealSRPredictor(BasePredictor): ...@@ -60,7 +60,8 @@ class RealSRPredictor(BasePredictor):
img = self.norm(ori_img) img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...]) x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x) with paddle.no_grad():
out = self.model(x)
pred_img = self.denorm(out.numpy()[0]) pred_img = self.denorm(out.numpy()[0])
pred_img = Image.fromarray(pred_img) pred_img = Image.fromarray(pred_img)
......
...@@ -124,6 +124,9 @@ class Trainer: ...@@ -124,6 +124,9 @@ class Trainer:
self.weight_interval = cfg.snapshot_config.interval self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.validate_interval = -1 self.validate_interval = -1
if cfg.get('validate', None) is not None: if cfg.get('validate', None) is not None:
self.validate_interval = cfg.validate.get('interval', -1) self.validate_interval = cfg.validate.get('interval', -1)
...@@ -177,16 +180,12 @@ class Trainer: ...@@ -177,16 +180,12 @@ class Trainer:
self.model.lr_scheduler.step() self.model.lr_scheduler.step()
if self.by_epoch: if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
temp = self.current_epoch
else:
temp = self.current_iter
if self.validate_interval > -1 and temp % self.validate_interval == 0:
self.test() self.test()
if temp % self.weight_interval == 0: if self.current_iter % self.weight_interval == 0:
self.save(temp, 'weight', keep=-1) self.save(self.current_iter, 'weight', keep=-1)
self.save(temp) self.save(self.current_iter)
self.current_iter += 1 self.current_iter += 1
...@@ -335,7 +334,12 @@ class Trainer: ...@@ -335,7 +334,12 @@ class Trainer:
assert name in ['checkpoint', 'weight'] assert name in ['checkpoint', 'weight']
state_dicts = {} state_dicts = {}
save_filename = 'epoch_%s_%s.pdparams' % (epoch, name) if self.by_epoch:
save_filename = 'epoch_%s_%s.pdparams' % (
epoch // self.iters_per_epoch, name)
else:
save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
save_path = os.path.join(self.output_dir, save_filename) save_path = os.path.join(self.output_dir, save_filename)
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
state_dicts[net_name] = net.state_dict() state_dicts[net_name] = net.state_dict()
...@@ -353,9 +357,16 @@ class Trainer: ...@@ -353,9 +357,16 @@ class Trainer:
if keep > 0: if keep > 0:
try: try:
checkpoint_name_to_be_removed = os.path.join( if self.by_epoch:
self.output_dir, checkpoint_name_to_be_removed = os.path.join(
'epoch_%s_%s.pdparams' % (epoch - keep, name)) self.output_dir, 'epoch_%s_%s.pdparams' %
((epoch - keep * self.weight_interval) //
self.iters_per_epoch, name))
else:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir, 'iter_%s_%s.pdparams' %
(epoch - keep * self.weight_interval, name))
if os.path.exists(checkpoint_name_to_be_removed): if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed) os.remove(checkpoint_name_to_be_removed)
...@@ -366,7 +377,7 @@ class Trainer: ...@@ -366,7 +377,7 @@ class Trainer:
state_dicts = load(checkpoint_path) state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None: if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1 self.start_epoch = state_dicts['epoch'] + 1
self.global_steps = self.steps_per_epoch * state_dicts['epoch'] self.global_steps = self.iters_per_epoch * state_dicts['epoch']
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
......
...@@ -17,7 +17,8 @@ import argparse ...@@ -17,7 +17,8 @@ import argparse
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PaddleGAN') parser = argparse.ArgumentParser(description='PaddleGAN')
parser.add_argument('--config-file', parser.add_argument('-c',
'--config-file',
metavar="FILE", metavar="FILE",
help='config file path') help='config file path')
# cuda setting # cuda setting
......
...@@ -26,8 +26,10 @@ def setup(args, cfg): ...@@ -26,8 +26,10 @@ def setup(args, cfg):
cfg.is_train = True cfg.is_train = True
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir, cfg.output_dir = os.path.join(
str(cfg.model.name) + cfg.timestamp) cfg.output_dir,
os.path.splitext(os.path.basename(str(args.config_file)))[0] +
cfg.timestamp)
logger = setup_logger(cfg.output_dir) logger = setup_logger(cfg.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册