From 202dce6b79f8361944e0ffc3feb04ac6bf329ae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Wed, 16 Dec 2020 10:29:51 +0800 Subject: [PATCH] Integrate with visualdl (#115) * Integrate with visualdl * visualdl: add doc and control paramter * visual dl: move import, add install doc --- docs/en_US/get_started.md | 6 ++++ docs/en_US/install.md | 9 ++++- docs/zh_CN/get_started.md | 4 +++ docs/zh_CN/install.md | 8 +++++ ppgan/engine/trainer.py | 71 +++++++++++++++++++++++++++++++-------- ppgan/utils/visual.py | 54 ++++++++++++++++++++++------- tools/main.py | 8 +++-- 7 files changed, 131 insertions(+), 29 deletions(-) diff --git a/docs/en_US/get_started.md b/docs/en_US/get_started.md index 474a8fd..5a3fe7c 100644 --- a/docs/en_US/get_started.md +++ b/docs/en_US/get_started.md @@ -46,6 +46,12 @@ output_dir ├── epoch002_rec_A.png └── epoch002_rec_B.png ``` + +Also, you can add the parameter ```enable_visualdl: true``` in the configuration file, use [PaddlePaddle VisualDL](https://github.com/PaddlePaddle/VisualDL) record the metrics or images generated in the training process, and run the command to monitor the training process: +``` +visualdl --logdir output_dir/CycleGANModel-2020-10-29-09-21/ +``` + #### Recovery of training The checkpoint of the previous epoch will be saved by default during the training process to facilitate the recovery of training diff --git a/docs/en_US/install.md b/docs/en_US/install.md index 94de4f6..7b2abc3 100644 --- a/docs/en_US/install.md +++ b/docs/en_US/install.md @@ -26,7 +26,7 @@ Note: command above will install paddle with cuda10.2,if your installed cuda i Visit home page of [paddlepaddle](https://www.paddlepaddle.org.cn/install/quick) for support of other systems, such as Windows10. -### 2. Install paddleGAN +### 2. Install paddleGAN #### 2.1 Install through pip @@ -59,3 +59,10 @@ If you need to use ppgan to handle video-related tasks, you need to install ffmp ``` conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge ``` + +#### 4.2 Visual DL +If you want to use [PaddlePaddle VisualDL](https://github.com/PaddlePaddle/VisualDL) to monitor the training process, Please install `VisualDL`(For more detail refer [here](./get_started.md)): + +``` +python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple +``` diff --git a/docs/zh_CN/get_started.md b/docs/zh_CN/get_started.md index 6692682..69d6ec8 100644 --- a/docs/zh_CN/get_started.md +++ b/docs/zh_CN/get_started.md @@ -45,6 +45,10 @@ output_dir ├── epoch002_rec_A.png └── epoch002_rec_B.png ``` +同时可以通过在配置文件中添加参数```enable_visualdl: true```使用[飞桨VisualDL](https://github.com/PaddlePaddle/VisualDL)对训练过程产生的指标或生成的图像进行记录,并运行相应命令对训练过程进行实时监控: +``` +visualdl --logdir output_dir/CycleGANModel-2020-10-29-09-21/ +``` #### 恢复训练 diff --git a/docs/zh_CN/install.md b/docs/zh_CN/install.md index 70c4205..2321f63 100644 --- a/docs/zh_CN/install.md +++ b/docs/zh_CN/install.md @@ -61,3 +61,11 @@ pip install -v -e . # or "python setup.py develop" ``` conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge ``` + +#### 4.2 Visual DL + +如果需要使用[飞桨VisualDL](https://github.com/PaddlePaddle/VisualDL)对训练过程进行可视化监控,请安装`VisualDL`(使用方法请参考[这里](./get_started.md)): + +``` +python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple +``` diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 375af05..7c8a472 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -47,6 +47,10 @@ class Trainer: self.distributed_data_parallel() self.logger = logging.getLogger(__name__) + self.enable_visualdl = cfg.get('enable_visualdl', False) + if self.enable_visualdl: + import visualdl + self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir) # base config self.output_dir = cfg.output_dir @@ -54,6 +58,7 @@ class Trainer: self.start_epoch = 1 self.current_epoch = 1 self.batch_id = 0 + self.global_steps = 0 self.weight_interval = cfg.snapshot_config.interval self.log_interval = cfg.log_config.interval self.visual_interval = cfg.log_config.visiual_interval @@ -106,7 +111,7 @@ class Trainer: if i % self.visual_interval == 0: self.visual('visual_train') - + self.global_steps += 1 step_start_time = time.time() self.logger.info( @@ -165,7 +170,9 @@ class Trainer: tensor2img(current_visuals['gt'][j], (0., 1.)), **self.cfg.validate.metrics.ssim) - self.visual('visual_val', visual_results=visual_results) + self.visual('visual_val', + visual_results=visual_results, + step=self.batch_id) if i % self.log_interval == 0: self.logger.info('val iter: [%d/%d]' % @@ -201,7 +208,10 @@ class Trainer: name = '%s_%s' % (basename, k) visual_results.update({name: img_tensor[j]}) - self.visual('visual_test', visual_results=visual_results) + self.visual('visual_test', + visual_results=visual_results, + step=self.batch_id, + is_save_image=True) if i % self.log_interval == 0: self.logger.info('Test iter: [%d/%d]' % @@ -215,6 +225,8 @@ class Trainer: for k, v in losses.items(): message += '%s: %.3f ' % (k, v) + if self.enable_visualdl: + self.vdl_logger.add_scalar(k, v, step=self.global_steps) if hasattr(self, 'step_time'): message += 'batch_cost: %.5f sec ' % self.step_time @@ -240,26 +252,48 @@ class Trainer: for optimizer in self.model.optimizers.values(): return optimizer.get_lr() - def visual(self, results_dir, visual_results=None): + def visual(self, + results_dir, + visual_results=None, + step=None, + is_save_image=False): + """ + visual the images, use visualdl or directly write to the directory + + Parameters: + results_dir (str) -- directory name which contains saved images + visual_results (dict) -- the results images dict + step (int) -- global steps, used in visualdl + is_save_image (bool) -- weather write to the directory or visualdl + """ self.model.compute_visuals() if visual_results is None: visual_results = self.model.get_current_visuals() - if self.cfg.is_train: - msg = 'epoch%.3d_' % self.current_epoch - else: - msg = '' - - makedirs(os.path.join(self.output_dir, results_dir)) min_max = self.cfg.get('min_max', None) if min_max is None: min_max = (-1., 1.) + image_num = self.cfg.get('image_num', None) + if (image_num is None) or (not self.enable_visualdl): + image_num = 1 for label, image in visual_results.items(): - image_numpy = tensor2img(image, min_max) - img_path = os.path.join(self.output_dir, results_dir, - msg + '%s.png' % (label)) - save_image(image_numpy, img_path) + image_numpy = tensor2img(image, min_max, image_num) + if (not is_save_image) and self.enable_visualdl: + self.vdl_logger.add_image( + results_dir + '/' + label, + image_numpy, + step=step if step else self.global_steps, + dataformats="HWC" if image_num == 1 else "NCHW") + else: + if self.cfg.is_train: + msg = 'epoch%.3d_' % self.current_epoch + else: + msg = '' + makedirs(os.path.join(self.output_dir, results_dir)) + img_path = os.path.join(self.output_dir, results_dir, + msg + '%s.png' % (label)) + save_image(image_numpy, img_path) def save(self, epoch, name='checkpoint', keep=1): if self.local_rank != 0: @@ -299,6 +333,7 @@ class Trainer: state_dicts = load(checkpoint_path) if state_dicts.get('epoch', None) is not None: self.start_epoch = state_dicts['epoch'] + 1 + self.global_steps = self.steps_per_epoch * state_dicts['epoch'] for net_name, net in self.model.nets.items(): net.set_state_dict(state_dicts[net_name]) @@ -311,3 +346,11 @@ class Trainer: for net_name, net in self.model.nets.items(): net.set_state_dict(state_dicts[net_name]) + + def close(self): + """ + when finish the training need close file handler or other. + + """ + if self.enable_visualdl: + self.vdl_logger.close() diff --git a/ppgan/utils/visual.py b/ppgan/utils/visual.py index ee20739..49d9c30 100644 --- a/ppgan/utils/visual.py +++ b/ppgan/utils/visual.py @@ -18,6 +18,8 @@ import numpy as np from PIL import Image irange = range + + def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False): """Make a grid of images. Args: @@ -82,35 +84,63 @@ def make_grid(tensor, nrow=8, normalize=False, range=None, scale_each=False): ymaps = int(math.ceil(float(nmaps) / xmaps)) height, width = int(tensor.shape[2]), int(tensor.shape[3]) num_channels = tensor.shape[1] - canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps), dtype=tensor.dtype) + canvas = paddle.zeros((num_channels, height * ymaps, width * xmaps), + dtype=tensor.dtype) k = 0 for y in irange(ymaps): for x in irange(xmaps): if k >= nmaps: break - canvas[:, y * height:(y + 1) * height, x * width:(x + 1) * width] = tensor[k] + canvas[:, y * height:(y + 1) * height, + x * width:(x + 1) * width] = tensor[k] k = k + 1 return canvas -def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8): +def tensor2img(input_image, min_max=(-1., 1.), image_num=1, imtype=np.uint8): """"Converts a Tensor array into a numpy image array. Parameters: input_image (tensor) -- the input image tensor array + image_num (int) -- the convert iamge numbers imtype (type) -- the desired type of the converted numpy array """ + def processing(im, transpose=True): + """"processing one numpy image. + + Parameters: + im (tensor) -- the input image numpy array + """ + if im.shape[0] == 1: # grayscale to RGB + im = np.tile(im, (3, 1, 1)) + im = im.clip(min_max[0], min_max[1]) + im = (im - min_max[0]) / (min_max[1] - min_max[0]) + im = im * 255.0 # scaling + im = np.transpose(im, (1, 2, 0)) if transpose else im # tranpose + return im + if not isinstance(input_image, np.ndarray): image_numpy = input_image.numpy() # convert it into a numpy array - if len(image_numpy.shape) == 4: - image_numpy = image_numpy[0] - if image_numpy.shape[0] == 1: # grayscale to RGB - image_numpy = np.tile(image_numpy, (3, 1, 1)) - image_numpy = image_numpy.clip(min_max[0], min_max[1]) - image_numpy = (image_numpy - min_max[0]) / (min_max[1] - min_max[0]) - image_numpy = (np.transpose( - image_numpy, - (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling + ndim = image_numpy.ndim + if ndim == 4: + image_numpy = image_numpy[0:image_num] + elif ndim == 3: + # NOTE for eval mode, need add dim + image_numpy = np.expand_dims(image_numpy, 0) + image_num = 1 + else: + raise ValueError( + "Image numpy ndim is {} not 3 or 4, Please check data".format( + ndim)) + + if image_num == 1: + # for one image, log HWC image + image_numpy = processing(image_numpy[0]) + else: + # for more image, log NCHW image + image_numpy = np.stack( + [processing(im, transpose=False) for im in image_numpy]) + else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype) diff --git a/tools/main.py b/tools/main.py index 68edba4..0918d52 100644 --- a/tools/main.py +++ b/tools/main.py @@ -42,8 +42,12 @@ def main(args, cfg): if args.evaluate_only: trainer.test() return - - trainer.train() + # training, when keyboard interrupt save weights + try: + trainer.train() + except KeyboardInterrupt as e: + trainer.save(trainer.current_epoch) + trainer.close() if __name__ == '__main__': -- GitLab