未验证 提交 89dbb63f 编写于 作者: L LielinJiang 提交者: GitHub

Fix some bugs (#140)

* fix some bugs

* update configs
上级 cd642c08
......@@ -67,7 +67,7 @@ dataset:
batch_size: 1
max_size: inf
is_train: False
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
......
......@@ -35,7 +35,7 @@ dataset:
batch_size: 1
is_train: True
max_size: inf
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
......@@ -67,7 +67,7 @@ dataset:
batch_size: 1
max_size: inf
is_train: False
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: A
- name: LoadImageFromFile
......
......@@ -61,7 +61,7 @@ dataset:
dataroot: data/cityscapes/test
num_workers: 4
batch_size: 1
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: pair
- name: SplitPairedImage
......
......@@ -61,7 +61,7 @@ dataset:
dataroot: data/cityscapes/test
num_workers: 4
batch_size: 1
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: pair
- name: Transforms
......
......@@ -61,7 +61,7 @@ dataset:
dataroot: data/facades/test
num_workers: 4
batch_size: 1
load_pipeline:
preprocess:
- name: LoadImageFromFile
key: pair
- name: Transforms
......
......@@ -60,6 +60,7 @@ class RealSRPredictor(BasePredictor):
img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...])
with paddle.no_grad():
out = self.model(x)
pred_img = self.denorm(out.numpy()[0])
......
......@@ -124,6 +124,9 @@ class Trainer:
self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.validate_interval = -1
if cfg.get('validate', None) is not None:
self.validate_interval = cfg.validate.get('interval', -1)
......@@ -177,16 +180,12 @@ class Trainer:
self.model.lr_scheduler.step()
if self.by_epoch:
temp = self.current_epoch
else:
temp = self.current_iter
if self.validate_interval > -1 and temp % self.validate_interval == 0:
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
self.test()
if temp % self.weight_interval == 0:
self.save(temp, 'weight', keep=-1)
self.save(temp)
if self.current_iter % self.weight_interval == 0:
self.save(self.current_iter, 'weight', keep=-1)
self.save(self.current_iter)
self.current_iter += 1
......@@ -335,7 +334,12 @@ class Trainer:
assert name in ['checkpoint', 'weight']
state_dicts = {}
save_filename = 'epoch_%s_%s.pdparams' % (epoch, name)
if self.by_epoch:
save_filename = 'epoch_%s_%s.pdparams' % (
epoch // self.iters_per_epoch, name)
else:
save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
save_path = os.path.join(self.output_dir, save_filename)
for net_name, net in self.model.nets.items():
state_dicts[net_name] = net.state_dict()
......@@ -353,9 +357,16 @@ class Trainer:
if keep > 0:
try:
if self.by_epoch:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir,
'epoch_%s_%s.pdparams' % (epoch - keep, name))
self.output_dir, 'epoch_%s_%s.pdparams' %
((epoch - keep * self.weight_interval) //
self.iters_per_epoch, name))
else:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir, 'iter_%s_%s.pdparams' %
(epoch - keep * self.weight_interval, name))
if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed)
......@@ -366,7 +377,7 @@ class Trainer:
state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
self.global_steps = self.steps_per_epoch * state_dicts['epoch']
self.global_steps = self.iters_per_epoch * state_dicts['epoch']
for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name])
......
......@@ -17,7 +17,8 @@ import argparse
def parse_args():
parser = argparse.ArgumentParser(description='PaddleGAN')
parser.add_argument('--config-file',
parser.add_argument('-c',
'--config-file',
metavar="FILE",
help='config file path')
# cuda setting
......
......@@ -26,8 +26,10 @@ def setup(args, cfg):
cfg.is_train = True
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir,
str(cfg.model.name) + cfg.timestamp)
cfg.output_dir = os.path.join(
cfg.output_dir,
os.path.splitext(os.path.basename(str(args.config_file)))[0] +
cfg.timestamp)
logger = setup_logger(cfg.output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册