From 492f2e47b7d32617feb9f99e4c6c9a299a3d5b4b Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Fri, 15 May 2020 18:37:08 +0800 Subject: [PATCH] upgrade visualdl to 2.0.0a2 --- paddlex/cv/models/base.py | 19 +++---------------- setup.py | 2 +- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index f0a52fa..0fd06b6 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -428,9 +428,7 @@ class BaseAPI: if use_vdl: # VisualDL component - log_writer = LogWriter(vdl_logdir, sync_cycle=20) - train_step_component = OrderedDict() - eval_component = OrderedDict() + log_writer = LogWriter(vdl_logdir) thresh = 0.0001 if early_stop: @@ -468,13 +466,7 @@ class BaseAPI: if use_vdl: for k, v in step_metrics.items(): - if k not in train_step_component.keys(): - with log_writer.mode('Each_Step_while_Training' - ) as step_logger: - train_step_component[ - k] = step_logger.scalar( - 'Training: {}'.format(k)) - train_step_component[k].add_record(num_steps, v) + log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps) # 估算剩余时间 avg_step_time = np.mean(time_stat) @@ -535,12 +527,7 @@ class BaseAPI: if isinstance(v, np.ndarray): if v.size > 1: continue - if k not in eval_component: - with log_writer.mode('Each_Epoch_on_Eval_Data' - ) as eval_logger: - eval_component[k] = eval_logger.scalar( - 'Evaluation: {}'.format(k)) - eval_component[k].add_record(i + 1, v) + log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1) self.save_model(save_dir=current_save_dir) time_eval_one_epoch = time.time() - eval_epoch_start_time eval_epoch_start_time = time.time() diff --git a/setup.py b/setup.py index 01e778b..f8d54dc 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ setuptools.setup( setup_requires=['cython', 'numpy', 'sklearn'], install_requires=[ "pycocotools;platform_system!='Windows'", 'pyyaml', 'colorama', 'tqdm', - 'visualdl==1.3.0', 'paddleslim==1.0.1' + 'visualdl==1.3.0', 'paddleslim==1.0.1', 'visualdl==2.0.0a2' ], classifiers=[ "Programming Language :: Python :: 3", -- GitLab