未验证 提交 8832a3fa 编写于 作者: T Tingquan Gao 提交者: GitHub

Support Visual DL (#650)

* Support Visual DL

* Fix VDL

* Add doc of VDL, test=document_fix

* Add the en doc of VDL, test=document_fix
Co-authored-by: Nlittletomatodonkey <2120160898@bit.edu.cn>
上级 1b920e8a
# Use VisualDL to visualize the training
## Preface
VisualDL, a visualization analysis tool of PaddlePaddle, provides a variety of charts to show the trends of parameters, and visualizes model structures, data samples, histograms of tensors, PR curves , ROC curves and high-dimensional data distributions. It enables users to understand the training process and the model structure more clearly and intuitively so as to optimize models efficiently. For more information, please refer to [VisualDL](https://github.com/PaddlePaddle/VisualDL/).
## Use VisualDL in PaddleClas
Now PaddleClas support use VisualDL to visualize the changes of learning rate, loss, accuracy in training.
### Set config and start training
You only need to set the `vdl_dir` field in train config:
```yaml
# confit.txt
vdl_dir: "./vdl.log"
```
`vdl_dir`: Specify the directory where VisualDL stores logs.
Then normal start training:
```shell
python3 tools/train.py -c config.txt
```
### Start VisualDL
After starting the training program, you can start the VisualDL service in the new terminal session:
```shell
visualdl --logdir ./vdl.log
```
In the above command, `--logdir` specify the logs directory. VisualDL will traverse and iterate to find the subdirectories of the specified directory to visualize all the experimental results. You can also use the following parameters to set the IP and port number of the VisualDL service:
* `--host`:ip, default is 127.0.0.1
* `--port`:port, default is 8040
More information about the command,please refer to [VisualDL](https://github.com/PaddlePaddle/VisualDL/blob/develop/README.md#2-launch-panel).
Then you can enter the address `127.0.0.1:8840` and view the training process in the browser:
<div align="center">
<img src="../../images/VisualDL/train_loss.png" width="400">
</div>
......@@ -38,7 +38,7 @@ Of course, you can also directly modify the configuration file to update the con
epoch:0 train step:13 loss:7.9561 top1:0.0156 top5:0.1094 lr:0.100000 elapse:0.193s
```
During training, you can view loss changes in real time through `VisualDL`, see [VisualDL](https://github.com/PaddlePaddle/VisualDL) for details.
During training, you can view loss changes in real time through `VisualDL`, see [VisualDL](../extension/VisualDL.md) for details.
### 1.2 Model finetuning
......
# 使用VisualDL可视化训练过程
## 前言
VisualDL是飞桨可视化分析工具,以丰富的图表呈现训练参数变化趋势、模型结构、数据样本、高维数据分布等。可帮助用户更清晰直观地理解深度学习模型训练过程及模型结构,进而实现高效的模型优化。更多细节请查看[VisualDL](https://github.com/PaddlePaddle/VisualDL/)
## 在PaddleClas中使用VisualDL
现在PaddleClas支持在训练阶段使用VisualDL查看训练过程中学习率(learning rate)、损失值(loss)以及准确率(accuracy)的变化情况。
### 设置config文件并启动训练
在PaddleClas中使用VisualDL,只需在训练配置文件(config文件)添加如下字段:
```yaml
# confit.txt
vdl_dir: "./vdl.log"
```
`vdl_dir` 用于指定VisualDL用于保存log信息的目录。
然后正常启动训练即可:
```shell
python3 tools/train.py -c config.txt
```
### 启动VisualDL
在启动训练程序后,可以在新的终端session中启动VisualDL服务:
```shell
visualdl --logdir ./vdl.log
```
上述命令中,参数`--logdir`用于指定日志目录,VisualDL将遍历并且迭代寻找指定目录的子目录,将所有实验结果进行可视化。也同样可以使用下述参数设定VisualDL服务的ip及端口号:
* `--host`:设定IP,默认为127.0.0.1
* `--port`:设定端口,默认为8040
更多参数信息,请查看[VisualDL](https://github.com/PaddlePaddle/VisualDL/blob/develop/README_CN.md#2-%E5%90%AF%E5%8A%A8%E9%9D%A2%E6%9D%BF)
在启动VisualDL后,即可在浏览器中查看训练过程,输入地址`127.0.0.1:8840`
<div align="center">
<img src="../../images/VisualDL/train_loss.png" width="400">
</div>
......@@ -46,7 +46,7 @@ python tools/train.py \
epoch:0 train step:13 loss:7.9561 top1:0.0156 top5:0.1094 lr:0.100000 elapse:0.193s
```
训练期间也可以通过VisualDL实时观察loss变化,详见[VisualDL](https://github.com/PaddlePaddle/VisualDL)
训练期间也可以通过VisualDL实时观察loss变化,详见[VisualDL](../extension/VisualDL.md)
### 1.2 模型微调
......
......@@ -86,7 +86,7 @@ def scaler(name, value, step, writer):
visualdl --logdir ./scalar --host 0.0.0.0 --port 8830
to preview loss corve in real time.
"""
writer.add_scalar(name, value, step)
writer.add_scalar(tag=name, step=step, value=value)
def advertise():
......
......@@ -253,13 +253,17 @@ def create_feeds(batch, use_mix):
return feeds
total_step = 0
def run(dataloader,
config,
net,
optimizer=None,
lr_scheduler=None,
epoch=0,
mode='train'):
mode='train',
vdl_writer=None):
"""
Feed data to the model and fetch the measures and loss
......@@ -314,8 +318,8 @@ def run(dataloader,
optimizer.step()
optimizer.clear_grad()
metric_list['lr'].update(
optimizer._global_learning_rate().numpy()[0], batch_size)
lr_value = optimizer._global_learning_rate().numpy()[0]
metric_list['lr'].update(lr_value, batch_size)
if lr_scheduler is not None:
if lr_scheduler.update_specified:
......@@ -333,6 +337,18 @@ def run(dataloader,
metric_list["batch_time"].update(time.time() - tic)
tic = time.time()
if vdl_writer and mode == "train":
global total_step
logger.scaler(
name="lr", value=lr_value, step=total_step, writer=vdl_writer)
for name, fetch in fetchs.items():
logger.scaler(
name="train_{}".format(name),
value=fetch.numpy()[0],
step=total_step,
writer=vdl_writer)
total_step += 1
fetchs_str = ' '.join([
str(metric_list[key].mean)
if "time" in key else str(metric_list[key].value)
......@@ -366,7 +382,6 @@ def run(dataloader,
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
else:
end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
ips_info))
......
......@@ -83,33 +83,45 @@ def main(args):
last_epoch_id = config.get("last_epoch", -1)
best_top1_acc = 0.0 # best top1 acc record
best_top1_epoch = last_epoch_id
for epoch_id in range(last_epoch_id + 1, config.epochs):
net.train()
# 1. train with train dataset
program.run(train_dataloader, config, net, optimizer, lr_scheduler,
epoch_id, 'train')
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config, net, None,
None, epoch_id, 'valid')
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
vdl_writer_path = config.get("vdl_dir", None)
vdl_writer = None
if vdl_writer_path:
from visualdl import LogWriter
vdl_writer = LogWriter(vdl_writer_path)
# Ensure that the vdl log file can be closed normally
try:
for epoch_id in range(last_epoch_id + 1, config.epochs):
net.train()
# 1. train with train dataset
program.run(train_dataloader, config, net, optimizer, lr_scheduler,
epoch_id, 'train', vdl_writer)
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config, net, None,
None, epoch_id, 'valid', vdl_writer)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, "best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info(message)
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, "best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info(message)
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, epoch_id)
save_model(net, optimizer, model_path, epoch_id)
except Exception as e:
logger.error(e)
finally:
vdl_writer.close() if vdl_writer else None
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册