未验证 提交 202dce6b 编写于 作者: 郑启航 提交者: GitHub

Integrate with visualdl (#115)

* Integrate with visualdl
* visualdl: add doc and control paramter
* visual dl: move import, add install doc
上级 6c81cc35
...@@ -46,6 +46,12 @@ output_dir ...@@ -46,6 +46,12 @@ output_dir
├── epoch002_rec_A.png ├── epoch002_rec_A.png
└── epoch002_rec_B.png └── epoch002_rec_B.png
``` ```
Also, you can add the parameter ```enable_visualdl: true``` in the configuration file, use [PaddlePaddle VisualDL](https://github.com/PaddlePaddle/VisualDL) record the metrics or images generated in the training process, and run the command to monitor the training process:
```
visualdl --logdir output_dir/CycleGANModel-2020-10-29-09-21/
```
#### Recovery of training #### Recovery of training
The checkpoint of the previous epoch will be saved by default during the training process to facilitate the recovery of training The checkpoint of the previous epoch will be saved by default during the training process to facilitate the recovery of training
......
...@@ -59,3 +59,10 @@ If you need to use ppgan to handle video-related tasks, you need to install ffmp ...@@ -59,3 +59,10 @@ If you need to use ppgan to handle video-related tasks, you need to install ffmp
``` ```
conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
``` ```
#### 4.2 Visual DL
If you want to use [PaddlePaddle VisualDL](https://github.com/PaddlePaddle/VisualDL) to monitor the training process, Please install `VisualDL`(For more detail refer [here](./get_started.md)):
```
python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple
```
...@@ -45,6 +45,10 @@ output_dir ...@@ -45,6 +45,10 @@ output_dir
├── epoch002_rec_A.png ├── epoch002_rec_A.png
└── epoch002_rec_B.png └── epoch002_rec_B.png
``` ```
同时可以通过在配置文件中添加参数```enable_visualdl: true```使用[飞桨VisualDL](https://github.com/PaddlePaddle/VisualDL)对训练过程产生的指标或生成的图像进行记录,并运行相应命令对训练过程进行实时监控:
```
visualdl --logdir output_dir/CycleGANModel-2020-10-29-09-21/
```
#### 恢复训练 #### 恢复训练
......
...@@ -61,3 +61,11 @@ pip install -v -e . # or "python setup.py develop" ...@@ -61,3 +61,11 @@ pip install -v -e . # or "python setup.py develop"
``` ```
conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
``` ```
#### 4.2 Visual DL
如果需要使用[飞桨VisualDL](https://github.com/PaddlePaddle/VisualDL)对训练过程进行可视化监控,请安装`VisualDL`(使用方法请参考[这里](./get_started.md)):
```
python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple
```
...@@ -47,6 +47,10 @@ class Trainer: ...@@ -47,6 +47,10 @@ class Trainer:
self.distributed_data_parallel() self.distributed_data_parallel()
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl:
import visualdl
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
# base config # base config
self.output_dir = cfg.output_dir self.output_dir = cfg.output_dir
...@@ -54,6 +58,7 @@ class Trainer: ...@@ -54,6 +58,7 @@ class Trainer:
self.start_epoch = 1 self.start_epoch = 1
self.current_epoch = 1 self.current_epoch = 1
self.batch_id = 0 self.batch_id = 0
self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval self.visual_interval = cfg.log_config.visiual_interval
...@@ -106,7 +111,7 @@ class Trainer: ...@@ -106,7 +111,7 @@ class Trainer:
if i % self.visual_interval == 0: if i % self.visual_interval == 0:
self.visual('visual_train') self.visual('visual_train')
self.global_steps += 1
step_start_time = time.time() step_start_time = time.time()
self.logger.info( self.logger.info(
...@@ -165,7 +170,9 @@ class Trainer: ...@@ -165,7 +170,9 @@ class Trainer:
tensor2img(current_visuals['gt'][j], (0., 1.)), tensor2img(current_visuals['gt'][j], (0., 1.)),
**self.cfg.validate.metrics.ssim) **self.cfg.validate.metrics.ssim)
self.visual('visual_val', visual_results=visual_results) self.visual('visual_val',
visual_results=visual_results,
step=self.batch_id)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' % self.logger.info('val iter: [%d/%d]' %
...@@ -201,7 +208,10 @@ class Trainer: ...@@ -201,7 +208,10 @@ class Trainer:
name = '%s_%s' % (basename, k) name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]}) visual_results.update({name: img_tensor[j]})
self.visual('visual_test', visual_results=visual_results) self.visual('visual_test',
visual_results=visual_results,
step=self.batch_id,
is_save_image=True)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % self.logger.info('Test iter: [%d/%d]' %
...@@ -215,6 +225,8 @@ class Trainer: ...@@ -215,6 +225,8 @@ class Trainer:
for k, v in losses.items(): for k, v in losses.items():
message += '%s: %.3f ' % (k, v) message += '%s: %.3f ' % (k, v)
if self.enable_visualdl:
self.vdl_logger.add_scalar(k, v, step=self.global_steps)
if hasattr(self, 'step_time'): if hasattr(self, 'step_time'):
message += 'batch_cost: %.5f sec ' % self.step_time message += 'batch_cost: %.5f sec ' % self.step_time
...@@ -240,23 +252,45 @@ class Trainer: ...@@ -240,23 +252,45 @@ class Trainer:
for optimizer in self.model.optimizers.values(): for optimizer in self.model.optimizers.values():
return optimizer.get_lr() return optimizer.get_lr()
def visual(self, results_dir, visual_results=None): def visual(self,
results_dir,
visual_results=None,
step=None,
is_save_image=False):
"""
visual the images, use visualdl or directly write to the directory
Parameters:
results_dir (str) -- directory name which contains saved images
visual_results (dict) -- the results images dict
step (int) -- global steps, used in visualdl
is_save_image (bool) -- weather write to the directory or visualdl
"""
self.model.compute_visuals() self.model.compute_visuals()
if visual_results is None: if visual_results is None:
visual_results = self.model.get_current_visuals() visual_results = self.model.get_current_visuals()
min_max = self.cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
image_num = self.cfg.get('image_num', None)
if (image_num is None) or (not self.enable_visualdl):
image_num = 1
for label, image in visual_results.items():
image_numpy = tensor2img(image, min_max, image_num)
if (not is_save_image) and self.enable_visualdl:
self.vdl_logger.add_image(
results_dir + '/' + label,
image_numpy,
step=step if step else self.global_steps,
dataformats="HWC" if image_num == 1 else "NCHW")
else:
if self.cfg.is_train: if self.cfg.is_train:
msg = 'epoch%.3d_' % self.current_epoch msg = 'epoch%.3d_' % self.current_epoch
else: else:
msg = '' msg = ''
makedirs(os.path.join(self.output_dir, results_dir)) makedirs(os.path.join(self.output_dir, results_dir))
min_max = self.cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
for label, image in visual_results.items():
image_numpy = tensor2img(image, min_max)
img_path = os.path.join(self.output_dir, results_dir, img_path = os.path.join(self.output_dir, results_dir,
msg + '%s.png' % (label)) msg + '%s.png' % (label))
save_image(image_numpy, img_path) save_image(image_numpy, img_path)
...@@ -299,6 +333,7 @@ class Trainer: ...@@ -299,6 +333,7 @@ class Trainer:
state_dicts = load(checkpoint_path) state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None: if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1 self.start_epoch = state_dicts['epoch'] + 1
self.global_steps = self.steps_per_epoch * state_dicts['epoch']
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
...@@ -311,3 +346,11 @@ class Trainer: ...@@ -311,3 +346,11 @@ class Trainer:
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
def close(self):
"""
when finish the training need close file handler or other.
"""
if self.enable_visualdl:
self.vdl_logger.close()
...@@ -18,6 +18,8 @@ import numpy as np ...@@ -18,6 +18,8 @@ import numpy as np
from PIL import Image from PIL import Image
irange = range irange = range
def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False): def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False):
"""Make a grid of images. """Make a grid of images.
Args: Args:
...@@ -82,35 +84,63 @@ def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False): ...@@ -82,35 +84,63 @@ def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False):
ymaps = int(math.ceil(float(nmaps) / xmaps)) ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[2]), int(tensor.shape[3]) height, width = int(tensor.shape[2]), int(tensor.shape[3])
num_channels = tensor.shape[1] num_channels = tensor.shape[1]
canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps), dtype=tensor.dtype) canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps),
dtype=tensor.dtype)
k = 0 k = 0
for y in irange(ymaps): for y in irange(ymaps):
for x in irange(xmaps): for x in irange(xmaps):
if k >= nmaps: if k >= nmaps:
break break
canvas[:, y * height:(y + 1) * height, x * width:(x + 1) * width] = tensor[k] canvas[:, y * height:(y + 1) * height,
x * width:(x + 1) * width] = tensor[k]
k = k + 1 k = k + 1
return canvas return canvas
def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8): def tensor2img(input_image, min_max=(-1., 1.), image_num=1, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array. """"Converts a Tensor array into a numpy image array.
Parameters: Parameters:
input_image (tensor) -- the input image tensor array input_image (tensor) -- the input image tensor array
image_num (int) -- the convert iamge numbers
imtype (type) -- the desired type of the converted numpy array imtype (type) -- the desired type of the converted numpy array
""" """
def processing(im, transpose=True):
""""processing one numpy image.
Parameters:
im (tensor) -- the input image numpy array
"""
if im.shape[0] == 1: # grayscale to RGB
im = np.tile(im, (3, 1, 1))
im = im.clip(min_max[0], min_max[1])
im = (im - min_max[0]) / (min_max[1] - min_max[0])
im = im * 255.0 # scaling
im = np.transpose(im, (1, 2, 0)) if transpose else im # tranpose
return im
if not isinstance(input_image, np.ndarray): if not isinstance(input_image, np.ndarray):
image_numpy = input_image.numpy() # convert it into a numpy array image_numpy = input_image.numpy() # convert it into a numpy array
if len(image_numpy.shape) == 4: ndim = image_numpy.ndim
image_numpy = image_numpy[0] if ndim == 4:
if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = image_numpy[0:image_num]
image_numpy = np.tile(image_numpy, (3, 1, 1)) elif ndim == 3:
image_numpy = image_numpy.clip(min_max[0], min_max[1]) # NOTE for eval mode, need add dim
image_numpy = (image_numpy - min_max[0]) / (min_max[1] - min_max[0]) image_numpy = np.expand_dims(image_numpy, 0)
image_numpy = (np.transpose( image_num = 1
image_numpy, else:
(1, 2, 0))) * 255.0 # post-processing: tranpose and scaling raise ValueError(
"Image numpy ndim is {} not 3 or 4, Please check data".format(
ndim))
if image_num == 1:
# for one image, log HWC image
image_numpy = processing(image_numpy[0])
else:
# for more image, log NCHW image
image_numpy = np.stack(
[processing(im, transpose=False) for im in image_numpy])
else: # if it is a numpy array, do nothing else: # if it is a numpy array, do nothing
image_numpy = input_image image_numpy = input_image
return image_numpy.astype(imtype) return image_numpy.astype(imtype)
......
...@@ -42,8 +42,12 @@ def main(args, cfg): ...@@ -42,8 +42,12 @@ def main(args, cfg):
if args.evaluate_only: if args.evaluate_only:
trainer.test() trainer.test()
return return
# training, when keyboard interrupt save weights
try:
trainer.train() trainer.train()
except KeyboardInterrupt as e:
trainer.save(trainer.current_epoch)
trainer.close()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册