diff --git a/docs/en/extension/VisualDL.md b/docs/en/extension/VisualDL.md
new file mode 100644
index 0000000000000000000000000000000000000000..28650b29d396ad7272eb31e2e47f3b187e4288e9
--- /dev/null
+++ b/docs/en/extension/VisualDL.md
@@ -0,0 +1,43 @@
+# Use VisualDL to visualize the training
+
+## Preface
+VisualDL, a visualization analysis tool of PaddlePaddle, provides a variety of charts to show the trends of parameters, and visualizes model structures, data samples, histograms of tensors, PR curves , ROC curves and high-dimensional data distributions. It enables users to understand the training process and the model structure more clearly and intuitively so as to optimize models efficiently. For more information, please refer to [VisualDL](https://github.com/PaddlePaddle/VisualDL/).
+
+## Use VisualDL in PaddleClas
+Now PaddleClas support use VisualDL to visualize the changes of learning rate, loss, accuracy in training.
+
+### Set config and start training
+You only need to set the `vdl_dir` field in train config:
+
+```yaml
+# confit.txt
+vdl_dir: "./vdl.log"
+```
+
+`vdl_dir`: Specify the directory where VisualDL stores logs.
+
+Then normal start training:
+
+```shell
+python3 tools/train.py -c config.txt
+```
+
+### Start VisualDL
+After starting the training program, you can start the VisualDL service in the new terminal session:
+
+```shell
+ visualdl --logdir ./vdl.log
+```
+
+In the above command, `--logdir` specify the logs directory. VisualDL will traverse and iterate to find the subdirectories of the specified directory to visualize all the experimental results. You can also use the following parameters to set the IP and port number of the VisualDL service:
+
+* `--host`:ip, default is 127.0.0.1
+* `--port`:port, default is 8040
+
+More information about the command,please refer to [VisualDL](https://github.com/PaddlePaddle/VisualDL/blob/develop/README.md#2-launch-panel).
+
+Then you can enter the address `127.0.0.1:8840` and view the training process in the browser:
+
+
+
![](../../images/VisualDL/train_loss.png)
+
diff --git a/docs/en/tutorials/getting_started_en.md b/docs/en/tutorials/getting_started_en.md
index f393f8587a4f748e3add46de583d4b3daeba72c3..1712dc3729106d5591356a5970d4aa9b1f1b744a 100644
--- a/docs/en/tutorials/getting_started_en.md
+++ b/docs/en/tutorials/getting_started_en.md
@@ -38,7 +38,7 @@ Of course, you can also directly modify the configuration file to update the con
epoch:0 train step:13 loss:7.9561 top1:0.0156 top5:0.1094 lr:0.100000 elapse:0.193s
```
-During training, you can view loss changes in real time through `VisualDL`, see [VisualDL](https://github.com/PaddlePaddle/VisualDL) for details.
+During training, you can view loss changes in real time through `VisualDL`, see [VisualDL](../extension/VisualDL.md) for details.
### 1.2 Model finetuning
diff --git a/docs/images/VisualDL/train_loss.png b/docs/images/VisualDL/train_loss.png
new file mode 100644
index 0000000000000000000000000000000000000000..6697707bf56abaed8b1d4f47faad312039c75fbf
Binary files /dev/null and b/docs/images/VisualDL/train_loss.png differ
diff --git a/docs/zh_CN/extension/VisualDL.md b/docs/zh_CN/extension/VisualDL.md
new file mode 100644
index 0000000000000000000000000000000000000000..ea21c5b29e328d23de301b0955cdeb2c4d294852
--- /dev/null
+++ b/docs/zh_CN/extension/VisualDL.md
@@ -0,0 +1,41 @@
+# 使用VisualDL可视化训练过程
+
+## 前言
+VisualDL是飞桨可视化分析工具,以丰富的图表呈现训练参数变化趋势、模型结构、数据样本、高维数据分布等。可帮助用户更清晰直观地理解深度学习模型训练过程及模型结构,进而实现高效的模型优化。更多细节请查看[VisualDL](https://github.com/PaddlePaddle/VisualDL/)。
+
+## 在PaddleClas中使用VisualDL
+现在PaddleClas支持在训练阶段使用VisualDL查看训练过程中学习率(learning rate)、损失值(loss)以及准确率(accuracy)的变化情况。
+
+### 设置config文件并启动训练
+在PaddleClas中使用VisualDL,只需在训练配置文件(config文件)添加如下字段:
+
+```yaml
+# confit.txt
+vdl_dir: "./vdl.log"
+```
+`vdl_dir` 用于指定VisualDL用于保存log信息的目录。
+
+然后正常启动训练即可:
+
+```shell
+python3 tools/train.py -c config.txt
+```
+
+### 启动VisualDL
+在启动训练程序后,可以在新的终端session中启动VisualDL服务:
+
+```shell
+ visualdl --logdir ./vdl.log
+ ```
+
+上述命令中,参数`--logdir`用于指定日志目录,VisualDL将遍历并且迭代寻找指定目录的子目录,将所有实验结果进行可视化。也同样可以使用下述参数设定VisualDL服务的ip及端口号:
+* `--host`:设定IP,默认为127.0.0.1
+* `--port`:设定端口,默认为8040
+
+更多参数信息,请查看[VisualDL](https://github.com/PaddlePaddle/VisualDL/blob/develop/README_CN.md#2-%E5%90%AF%E5%8A%A8%E9%9D%A2%E6%9D%BF)。
+
+在启动VisualDL后,即可在浏览器中查看训练过程,输入地址`127.0.0.1:8840`:
+
+
+
![](../../images/VisualDL/train_loss.png)
+
diff --git a/docs/zh_CN/tutorials/getting_started.md b/docs/zh_CN/tutorials/getting_started.md
index cf953f7ff41b6340008c773cf235f3b251b19220..d7274fc7dead92e7e5c308b1e5912fc78b1c1406 100644
--- a/docs/zh_CN/tutorials/getting_started.md
+++ b/docs/zh_CN/tutorials/getting_started.md
@@ -46,7 +46,7 @@ python tools/train.py \
epoch:0 train step:13 loss:7.9561 top1:0.0156 top5:0.1094 lr:0.100000 elapse:0.193s
```
-训练期间也可以通过VisualDL实时观察loss变化,详见[VisualDL](https://github.com/PaddlePaddle/VisualDL)。
+训练期间也可以通过VisualDL实时观察loss变化,详见[VisualDL](../extension/VisualDL.md)。
### 1.2 模型微调
diff --git a/ppcls/utils/logger.py b/ppcls/utils/logger.py
index 3b032ededcfc6a077211422d6e8f00a379288e79..ece85262446d899a425ac62a0bb1d7a8ff754a50 100644
--- a/ppcls/utils/logger.py
+++ b/ppcls/utils/logger.py
@@ -86,7 +86,7 @@ def scaler(name, value, step, writer):
visualdl --logdir ./scalar --host 0.0.0.0 --port 8830
to preview loss corve in real time.
"""
- writer.add_scalar(name, value, step)
+ writer.add_scalar(tag=name, step=step, value=value)
def advertise():
diff --git a/tools/program.py b/tools/program.py
index 1910475ebe19b0e35e43f8d067abf138150e6eaf..e09789a337305aae47350a82918f7aa932d90028 100644
--- a/tools/program.py
+++ b/tools/program.py
@@ -253,13 +253,17 @@ def create_feeds(batch, use_mix):
return feeds
+total_step = 0
+
+
def run(dataloader,
config,
net,
optimizer=None,
lr_scheduler=None,
epoch=0,
- mode='train'):
+ mode='train',
+ vdl_writer=None):
"""
Feed data to the model and fetch the measures and loss
@@ -314,8 +318,8 @@ def run(dataloader,
optimizer.step()
optimizer.clear_grad()
- metric_list['lr'].update(
- optimizer._global_learning_rate().numpy()[0], batch_size)
+ lr_value = optimizer._global_learning_rate().numpy()[0]
+ metric_list['lr'].update(lr_value, batch_size)
if lr_scheduler is not None:
if lr_scheduler.update_specified:
@@ -333,6 +337,18 @@ def run(dataloader,
metric_list["batch_time"].update(time.time() - tic)
tic = time.time()
+ if vdl_writer and mode == "train":
+ global total_step
+ logger.scaler(
+ name="lr", value=lr_value, step=total_step, writer=vdl_writer)
+ for name, fetch in fetchs.items():
+ logger.scaler(
+ name="train_{}".format(name),
+ value=fetch.numpy()[0],
+ step=total_step,
+ writer=vdl_writer)
+ total_step += 1
+
fetchs_str = ' '.join([
str(metric_list[key].mean)
if "time" in key else str(metric_list[key].value)
@@ -366,7 +382,6 @@ def run(dataloader,
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
else:
end_epoch_str = "END epoch:{:<3d}".format(epoch)
-
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
ips_info))
diff --git a/tools/train.py b/tools/train.py
index 45ac37f9e71a873d0e5bbe1da5f51fd03d8c76e9..0da14bf0ddb85439b0c820950af33055e55265f2 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -83,33 +83,45 @@ def main(args):
last_epoch_id = config.get("last_epoch", -1)
best_top1_acc = 0.0 # best top1 acc record
best_top1_epoch = last_epoch_id
- for epoch_id in range(last_epoch_id + 1, config.epochs):
- net.train()
- # 1. train with train dataset
- program.run(train_dataloader, config, net, optimizer, lr_scheduler,
- epoch_id, 'train')
-
- # 2. validate with validate dataset
- if config.validate and epoch_id % config.valid_interval == 0:
- net.eval()
- with paddle.no_grad():
- top1_acc = program.run(valid_dataloader, config, net, None,
- None, epoch_id, 'valid')
- if top1_acc > best_top1_acc:
- best_top1_acc = top1_acc
- best_top1_epoch = epoch_id
+
+ vdl_writer_path = config.get("vdl_dir", None)
+ vdl_writer = None
+ if vdl_writer_path:
+ from visualdl import LogWriter
+ vdl_writer = LogWriter(vdl_writer_path)
+ # Ensure that the vdl log file can be closed normally
+ try:
+ for epoch_id in range(last_epoch_id + 1, config.epochs):
+ net.train()
+ # 1. train with train dataset
+ program.run(train_dataloader, config, net, optimizer, lr_scheduler,
+ epoch_id, 'train', vdl_writer)
+
+ # 2. validate with validate dataset
+ if config.validate and epoch_id % config.valid_interval == 0:
+ net.eval()
+ with paddle.no_grad():
+ top1_acc = program.run(valid_dataloader, config, net, None,
+ None, epoch_id, 'valid', vdl_writer)
+ if top1_acc > best_top1_acc:
+ best_top1_acc = top1_acc
+ best_top1_epoch = epoch_id
+ model_path = os.path.join(config.model_save_dir,
+ config.ARCHITECTURE["name"])
+ save_model(net, optimizer, model_path, "best_model")
+ message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
+ best_top1_acc, best_top1_epoch)
+ logger.info(message)
+
+ # 3. save the persistable model
+ if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
- save_model(net, optimizer, model_path, "best_model")
- message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
- best_top1_acc, best_top1_epoch)
- logger.info(message)
-
- # 3. save the persistable model
- if epoch_id % config.save_interval == 0:
- model_path = os.path.join(config.model_save_dir,
- config.ARCHITECTURE["name"])
- save_model(net, optimizer, model_path, epoch_id)
+ save_model(net, optimizer, model_path, epoch_id)
+ except Exception as e:
+ logger.error(e)
+ finally:
+ vdl_writer.close() if vdl_writer else None
if __name__ == '__main__':