未验证 提交 202dce6b 编写于 作者: 郑启航 提交者: GitHub

Integrate with visualdl (#115)

* Integrate with visualdl
* visualdl: add doc and control paramter
* visual dl: move import, add install doc
上级 6c81cc35
......@@ -46,6 +46,12 @@ output_dir
├── epoch002_rec_A.png
└── epoch002_rec_B.png
```
Also, you can add the parameter ```enable_visualdl: true``` in the configuration file, use [PaddlePaddle VisualDL](https://github.com/PaddlePaddle/VisualDL) record the metrics or images generated in the training process, and run the command to monitor the training process:
```
visualdl --logdir output_dir/CycleGANModel-2020-10-29-09-21/
```
#### Recovery of training
The checkpoint of the previous epoch will be saved by default during the training process to facilitate the recovery of training
......
......@@ -59,3 +59,10 @@ If you need to use ppgan to handle video-related tasks, you need to install ffmp
```
conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
```
#### 4.2 Visual DL
If you want to use [PaddlePaddle VisualDL](https://github.com/PaddlePaddle/VisualDL) to monitor the training process, Please install `VisualDL`(For more detail refer [here](./get_started.md)):
```
python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple
```
......@@ -45,6 +45,10 @@ output_dir
├── epoch002_rec_A.png
└── epoch002_rec_B.png
```
同时可以通过在配置文件中添加参数```enable_visualdl: true```使用[飞桨VisualDL](https://github.com/PaddlePaddle/VisualDL)对训练过程产生的指标或生成的图像进行记录,并运行相应命令对训练过程进行实时监控:
```
visualdl --logdir output_dir/CycleGANModel-2020-10-29-09-21/
```
#### 恢复训练
......
......@@ -61,3 +61,11 @@ pip install -v -e . # or "python setup.py develop"
```
conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
```
#### 4.2 Visual DL
如果需要使用[飞桨VisualDL](https://github.com/PaddlePaddle/VisualDL)对训练过程进行可视化监控,请安装`VisualDL`(使用方法请参考[这里](./get_started.md)):
```
python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple
```
......@@ -47,6 +47,10 @@ class Trainer:
self.distributed_data_parallel()
self.logger = logging.getLogger(__name__)
self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl:
import visualdl
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
# base config
self.output_dir = cfg.output_dir
......@@ -54,6 +58,7 @@ class Trainer:
self.start_epoch = 1
self.current_epoch = 1
self.batch_id = 0
self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
......@@ -106,7 +111,7 @@ class Trainer:
if i % self.visual_interval == 0:
self.visual('visual_train')
self.global_steps += 1
step_start_time = time.time()
self.logger.info(
......@@ -165,7 +170,9 @@ class Trainer:
tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.ssim)
self.visual('visual_val', visual_results=visual_results)
self.visual('visual_val',
visual_results=visual_results,
step=self.batch_id)
if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
......@@ -201,7 +208,10 @@ class Trainer:
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
self.visual('visual_test', visual_results=visual_results)
self.visual('visual_test',
visual_results=visual_results,
step=self.batch_id,
is_save_image=True)
if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' %
......@@ -215,6 +225,8 @@ class Trainer:
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
if self.enable_visualdl:
self.vdl_logger.add_scalar(k, v, step=self.global_steps)
if hasattr(self, 'step_time'):
message += 'batch_cost: %.5f sec ' % self.step_time
......@@ -240,23 +252,45 @@ class Trainer:
for optimizer in self.model.optimizers.values():
return optimizer.get_lr()
def visual(self, results_dir, visual_results=None):
def visual(self,
results_dir,
visual_results=None,
step=None,
is_save_image=False):
"""
visual the images, use visualdl or directly write to the directory
Parameters:
results_dir (str) -- directory name which contains saved images
visual_results (dict) -- the results images dict
step (int) -- global steps, used in visualdl
is_save_image (bool) -- weather write to the directory or visualdl
"""
self.model.compute_visuals()
if visual_results is None:
visual_results = self.model.get_current_visuals()
min_max = self.cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
image_num = self.cfg.get('image_num', None)
if (image_num is None) or (not self.enable_visualdl):
image_num = 1
for label, image in visual_results.items():
image_numpy = tensor2img(image, min_max, image_num)
if (not is_save_image) and self.enable_visualdl:
self.vdl_logger.add_image(
results_dir + '/' + label,
image_numpy,
step=step if step else self.global_steps,
dataformats="HWC" if image_num == 1 else "NCHW")
else:
if self.cfg.is_train:
msg = 'epoch%.3d_' % self.current_epoch
else:
msg = ''
makedirs(os.path.join(self.output_dir, results_dir))
min_max = self.cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
for label, image in visual_results.items():
image_numpy = tensor2img(image, min_max)
img_path = os.path.join(self.output_dir, results_dir,
msg + '%s.png' % (label))
save_image(image_numpy, img_path)
......@@ -299,6 +333,7 @@ class Trainer:
state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
self.global_steps = self.steps_per_epoch * state_dicts['epoch']
for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name])
......@@ -311,3 +346,11 @@ class Trainer:
for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name])
def close(self):
"""
when finish the training need close file handler or other.
"""
if self.enable_visualdl:
self.vdl_logger.close()
......@@ -18,6 +18,8 @@ import numpy as np
from PIL import Image
irange = range
def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False):
"""Make a grid of images.
Args:
......@@ -82,35 +84,63 @@ def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False):
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[2]), int(tensor.shape[3])
num_channels = tensor.shape[1]
canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps), dtype=tensor.dtype)
canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps),
dtype=tensor.dtype)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
if k >= nmaps:
break
canvas[:, y * height:(y + 1) * height, x * width:(x + 1) * width] = tensor[k]
canvas[:, y * height:(y + 1) * height,
x * width:(x + 1) * width] = tensor[k]
k = k + 1
return canvas
def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8):
def tensor2img(input_image, min_max=(-1., 1.), image_num=1, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
image_num (int) -- the convert iamge numbers
imtype (type) -- the desired type of the converted numpy array
"""
def processing(im, transpose=True):
""""processing one numpy image.
Parameters:
im (tensor) -- the input image numpy array
"""
if im.shape[0] == 1: # grayscale to RGB
im = np.tile(im, (3, 1, 1))
im = im.clip(min_max[0], min_max[1])
im = (im - min_max[0]) / (min_max[1] - min_max[0])
im = im * 255.0 # scaling
im = np.transpose(im, (1, 2, 0)) if transpose else im # tranpose
return im
if not isinstance(input_image, np.ndarray):
image_numpy = input_image.numpy() # convert it into a numpy array
if len(image_numpy.shape) == 4:
image_numpy = image_numpy[0]
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = image_numpy.clip(min_max[0], min_max[1])
image_numpy = (image_numpy - min_max[0]) / (min_max[1] - min_max[0])
image_numpy = (np.transpose(
image_numpy,
(1, 2, 0))) * 255.0 # post-processing: tranpose and scaling
ndim = image_numpy.ndim
if ndim == 4:
image_numpy = image_numpy[0:image_num]
elif ndim == 3:
# NOTE for eval mode, need add dim
image_numpy = np.expand_dims(image_numpy, 0)
image_num = 1
else:
raise ValueError(
"Image numpy ndim is {} not 3 or 4, Please check data".format(
ndim))
if image_num == 1:
# for one image, log HWC image
image_numpy = processing(image_numpy[0])
else:
# for more image, log NCHW image
image_numpy = np.stack(
[processing(im, transpose=False) for im in image_numpy])
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
......
......@@ -42,8 +42,12 @@ def main(args, cfg):
if args.evaluate_only:
trainer.test()
return
# training, when keyboard interrupt save weights
try:
trainer.train()
except KeyboardInterrupt as e:
trainer.save(trainer.current_epoch)
trainer.close()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册