diff --git a/README.md b/README.md
index 259ccb5aa02352ca2a2b81bf81d858cec2b47081..7ff799db308c72530b825bafc53973f707be389b 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,8 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
PaddleOCR support a variety of cutting-edge algorithms related to OCR, and developed industrial featured models/solution [PP-OCR](./doc/doc_en/ppocr_introduction_en.md) and [PP-Structure](./ppstructure/README.md) on this basis, and get through the whole process of data production, model training, compression, inference and deployment.
+PaddleOCR also supports metric and model logging during training to [VisualDL](https://www.paddlepaddle.org.cn/documentation/docs/en/guides/03_VisualDL/visualdl_usage_en.html) and [Weights & Biases](https://docs.wandb.ai/).
+
![](./doc/features_en.png)
> It is recommended to start with the “quick experience” in the document tutorial
diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
index 773a3649d8378cb39373b5b90837f17f9ecba335..e7cbae59a14af73639e1a74a14021b9b2ef60057 100644
--- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
+++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
@@ -129,7 +129,7 @@ Loss:
key: head_out
multi_head: True
- DistillationSARLoss:
- weight: 0.5
+ weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
multi_head: True
diff --git a/doc/doc_ch/models_list.md b/doc/doc_ch/models_list.md
index 05f78cf59d049c220bebeab36e5b78949cd8c1da..14b2543cdc86886a6597f96295708a77018f0fc1 100644
--- a/doc/doc_ch/models_list.md
+++ b/doc/doc_ch/models_list.md
@@ -119,6 +119,7 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训
| devanagari_PP-OCRv3_rec | ppocr/utils/dict/devanagari_dict.txt |梵文字母 | [devanagari_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/multi_language/devanagari_PP-OCRv3_rec.yml) |9.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_PP-OCRv3_rec_train.tar) |
+
更多支持语种请参考: [多语言模型](./multi_languages.md)
diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md
index 68c2b5f0c14f0c9b09d854f5a8b33ca86cc4bdf7..d467a7f918ed57eb80754483715f3671fd2552c7 100644
--- a/doc/doc_en/config_en.md
+++ b/doc/doc_en/config_en.md
@@ -36,6 +36,7 @@ Take rec_chinese_lite_train_v2.0.yml as an example
| pretrained_model | Set the path of the pre-trained model | ./pretrain_models/CRNN/best_accuracy | \ |
| checkpoints | set model parameter path | None | Used to load parameters after interruption to continue training|
| use_visualdl | Set whether to enable visualdl for visual log display | False | [Tutorial](https://www.paddlepaddle.org.cn/paddle/visualdl) |
+| use_wandb | Set whether to enable W&B for visual log display | False | [Documentation](https://docs.wandb.ai/)
| infer_img | Set inference image path or folder path | ./infer_img | \||
| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | If the character_dict_path is None, model can only recognize number and lower letters |
| max_text_length | Set the maximum length of text | 25 | \ |
@@ -66,7 +67,7 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
| model_type | Network Type | rec | Currently support`rec`,`det`,`cls` |
| algorithm | Model name | CRNN | See [algorithm_overview](./algorithm_overview_en.md) for the support list |
-| **Transform** | Set the transformation method | - | Currently only recognition algorithms are supported, see [ppocr/modeling/transforms](../../ppocr/modeling/transforms) for details |
+| **Transform** | Set the transformation method | - | Currently only recognition algorithms are supported, see [ppocr/modeling/transform](../../ppocr/modeling/transforms) for details |
| name | Transformation class name | TPS | Currently supports `TPS` |
| num_fiducial | Number of TPS control points | 20 | Ten on the top and bottom |
| loc_lr | Localization network learning rate | 0.1 | \ |
@@ -130,6 +131,17 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck
| drop_last | Whether to discard the last incomplete mini-batch because the number of samples in the data set cannot be divisible by batch_size | True | \ |
| num_workers | The number of sub-processes used to load data, if it is 0, the sub-process is not started, and the data is loaded in the main process | 8 | \ |
+### Weights & Biases ([W&B](../../ppocr/utils/loggers/wandb_logger.py))
+| Parameter | Use | Defaults | Note |
+| :---------------------: | :---------------------: | :--------------: | :--------------------: |
+| project | Project to which the run is to be logged | uncategorized | \
+| name | Alias/Name of the run | Randomly generated by wandb | \
+| id | ID of the run | Randomly generated by wandb | \
+| entity | User or team to which the run is being logged | The logged in user | \
+| save_dir | local directory in which all the models and other data is saved | wandb | \
+| config | model configuration | None | \
+
+
## 3. Multilingual Config File Generation
@@ -233,4 +245,4 @@ For more supported languages, please refer to : [Multi-language model](https://g
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
-* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
+* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
\ No newline at end of file
diff --git a/doc/doc_en/logging_en.md b/doc/doc_en/logging_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..d00ab8bd561c1bb7e489642298e74180e0c66886
--- /dev/null
+++ b/doc/doc_en/logging_en.md
@@ -0,0 +1,61 @@
+## Logging metrics and models
+
+PaddleOCR comes with two metric logging tools integrated directly into the training API: [VisualDL](https://readthedocs.org/projects/visualdl/) and [Weights & Biases](https://docs.wandb.ai/).
+
+### VisualDL
+VisualDL is a visualization analysis tool of PaddlePaddle. The integration allows all training metrics to be logged to a VisualDL dashboard. To use it, add the following line to the `Global` section of the config yaml file -
+
+```
+Global:
+ use_visualdl: True
+```
+
+To see the visualizations run the following command in your terminal
+
+```shell
+visualdl --logdir
+```
+
+Now open `localhost:8040` in your browser of choice!
+
+### Weights & Biases
+W&B is a MLOps tool that can be used for experiment tracking, dataset/model versioning, visualizing results and collaborating with colleagues. A W&B logger is integrated directly into PaddleOCR and to use it, first you need to install the `wandb` sdk and login to your wandb account.
+
+```shell
+pip install wandb
+wandb login
+```
+
+If you do not have a wandb account, you can make one [here](https://wandb.ai/site).
+
+To visualize and track your model training add the following flag to your config yaml file under the `Global` section -
+
+```
+Global:
+ use_wandb: True
+```
+
+To add more arguments to the `WandbLogger` listed [here](./config_en.md) add the header `wandb` to the yaml file and add the arguments under it -
+
+```
+wandb:
+ project: my_project
+ entity: my_team
+```
+
+These config variables from the yaml file are used to instantiate the `WandbLogger` object with the project name, entity name (the logged in user by default), directory to store metadata (`./wandb` by default) and more. During the training process, the `log_metrics` function is called to log training and evaluation metrics at the training and evaluation steps respectively from the rank 0 process only.
+
+At every model saving step, the WandbLogger, logs the model using the `log_model` function along with relavant metadata and tags showing the epoch in which the model is saved, the model is best or not and so on.
+
+All the logging mentioned above is integrated into the `program.train` function and will generate dashboards like this -
+
+![W&B Dashboard](../imgs_en/wandb_metrics.png)
+
+![W&B Models](../imgs_en/wandb_models.png)
+
+For more advanced usage to log images, audios, videos or any other form of data, you can use `WandbLogger().run.log`. More examples on how to log different kinds of data are available [here](https://docs.wandb.ai/examples).
+
+To view the dashboard, the link to the dashboard is printed to the console at the beginning and end of every training job and you can also access it by logging into your W&B account on your browser.
+
+### Using Multiple Loggers
+Both VisualDL and W&B can also be used simultaneously by just setting both the aforementioned flags to True.
\ No newline at end of file
diff --git a/doc/imgs_en/wandb_metrics.png b/doc/imgs_en/wandb_metrics.png
new file mode 100644
index 0000000000000000000000000000000000000000..45f0041ae4d3819c2bf9c9fababcceb3ff20a115
Binary files /dev/null and b/doc/imgs_en/wandb_metrics.png differ
diff --git a/doc/imgs_en/wandb_models.png b/doc/imgs_en/wandb_models.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9a7042bd59fa16179bd8a1f1e0eb49031300e4f
Binary files /dev/null and b/doc/imgs_en/wandb_models.png differ
diff --git a/ppocr/utils/loggers/__init__.py b/ppocr/utils/loggers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1e92f734e84b7e0278f8e7940ef3baf137c159e
--- /dev/null
+++ b/ppocr/utils/loggers/__init__.py
@@ -0,0 +1,3 @@
+from .vdl_logger import VDLLogger
+from .wandb_logger import WandbLogger
+from .loggers import Loggers
diff --git a/ppocr/utils/loggers/base_logger.py b/ppocr/utils/loggers/base_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a7fc3593ba8e69fdd5bed386c7ae4ff0d459988
--- /dev/null
+++ b/ppocr/utils/loggers/base_logger.py
@@ -0,0 +1,15 @@
+import os
+from abc import ABC, abstractmethod
+
+class BaseLogger(ABC):
+ def __init__(self, save_dir):
+ self.save_dir = save_dir
+ os.makedirs(self.save_dir, exist_ok=True)
+
+ @abstractmethod
+ def log_metrics(self, metrics, prefix=None):
+ pass
+
+ @abstractmethod
+ def close(self):
+ pass
\ No newline at end of file
diff --git a/ppocr/utils/loggers/loggers.py b/ppocr/utils/loggers/loggers.py
new file mode 100644
index 0000000000000000000000000000000000000000..260146620811c8e72da66e9f2c7bbcbaef90b90d
--- /dev/null
+++ b/ppocr/utils/loggers/loggers.py
@@ -0,0 +1,18 @@
+from .wandb_logger import WandbLogger
+
+class Loggers(object):
+ def __init__(self, loggers):
+ super().__init__()
+ self.loggers = loggers
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ for logger in self.loggers:
+ logger.log_metrics(metrics, prefix=prefix, step=step)
+
+ def log_model(self, is_best, prefix, metadata=None):
+ for logger in self.loggers:
+ logger.log_model(is_best=is_best, prefix=prefix, metadata=metadata)
+
+ def close(self):
+ for logger in self.loggers:
+ logger.close()
\ No newline at end of file
diff --git a/ppocr/utils/loggers/vdl_logger.py b/ppocr/utils/loggers/vdl_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..c345f93235b239873f0ddcd49c8b1b8966877a03
--- /dev/null
+++ b/ppocr/utils/loggers/vdl_logger.py
@@ -0,0 +1,21 @@
+from .base_logger import BaseLogger
+from visualdl import LogWriter
+
+class VDLLogger(BaseLogger):
+ def __init__(self, save_dir):
+ super().__init__(save_dir)
+ self.vdl_writer = LogWriter(logdir=save_dir)
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ if not prefix:
+ prefix = ""
+ updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
+
+ for k, v in updated_metrics.items():
+ self.vdl_writer.add_scalar(k, v, step)
+
+ def log_model(self, is_best, prefix, metadata=None):
+ pass
+
+ def close(self):
+ self.vdl_writer.close()
\ No newline at end of file
diff --git a/ppocr/utils/loggers/wandb_logger.py b/ppocr/utils/loggers/wandb_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9c6711696569e825638e0a27394071020b29cb5
--- /dev/null
+++ b/ppocr/utils/loggers/wandb_logger.py
@@ -0,0 +1,78 @@
+import os
+from .base_logger import BaseLogger
+
+class WandbLogger(BaseLogger):
+ def __init__(self,
+ project=None,
+ name=None,
+ id=None,
+ entity=None,
+ save_dir=None,
+ config=None,
+ **kwargs):
+ try:
+ import wandb
+ self.wandb = wandb
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "Please install wandb using `pip install wandb`"
+ )
+
+ self.project = project
+ self.name = name
+ self.id = id
+ self.save_dir = save_dir
+ self.config = config
+ self.kwargs = kwargs
+ self.entity = entity
+ self._run = None
+ self._wandb_init = dict(
+ project=self.project,
+ name=self.name,
+ id=self.id,
+ entity=self.entity,
+ dir=self.save_dir,
+ resume="allow"
+ )
+ self._wandb_init.update(**kwargs)
+
+ _ = self.run
+
+ if self.config:
+ self.run.config.update(self.config)
+
+ @property
+ def run(self):
+ if self._run is None:
+ if self.wandb.run is not None:
+ logger.info(
+ "There is a wandb run already in progress "
+ "and newly created instances of `WandbLogger` will reuse"
+ " this run. If this is not desired, call `wandb.finish()`"
+ "before instantiating `WandbLogger`."
+ )
+ self._run = self.wandb.run
+ else:
+ self._run = self.wandb.init(**self._wandb_init)
+ return self._run
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ if not prefix:
+ prefix = ""
+ updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
+
+ self.run.log(updated_metrics, step=step)
+
+ def log_model(self, is_best, prefix, metadata=None):
+ model_path = os.path.join(self.save_dir, prefix + '.pdparams')
+ artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
+ artifact.add_file(model_path, name="model_ckpt.pdparams")
+
+ aliases = [prefix]
+ if is_best:
+ aliases.append("best")
+
+ self.run.log_artifact(artifact, aliases=aliases)
+
+ def close(self):
+ self.run.finish()
\ No newline at end of file
diff --git a/tools/program.py b/tools/program.py
index 90fd309ae9e1ae23723d8e67c62a905e79a073d3..7c02dc0149f36085ef05ca378b79d27e92d6dd57 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict, AverageMeter
from ppocr.utils.logging import get_logger
+from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
from ppocr.utils import profiler
from ppocr.data import build_dataloader
@@ -161,7 +162,7 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
- vdl_writer=None,
+ log_writer=None,
scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
@@ -300,10 +301,8 @@ def train(config,
stats['lr'] = lr
train_stats.update(stats)
- if vdl_writer is not None and dist.get_rank() == 0:
- for k, v in train_stats.get().items():
- vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
- vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
+ if log_writer is not None and dist.get_rank() == 0:
+ log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or
@@ -349,11 +348,9 @@ def train(config,
logger.info(cur_metric_str)
# logger metric
- if vdl_writer is not None:
- for k, v in cur_metric.items():
- if isinstance(v, (float, int)):
- vdl_writer.add_scalar('EVAL/{}'.format(k),
- cur_metric[k], global_step)
+ if log_writer is not None:
+ log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
+
if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
best_model_dict.update(cur_metric)
@@ -374,10 +371,12 @@ def train(config,
]))
logger.info(best_str)
# logger best metric
- if vdl_writer is not None:
- vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
- best_model_dict[main_indicator],
- global_step)
+ if log_writer is not None:
+ log_writer.log_metrics(metrics={
+ "best_{}".format(main_indicator): best_model_dict[main_indicator]
+ }, prefix="EVAL", step=global_step)
+
+ log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
reader_start = time.time()
if dist.get_rank() == 0:
@@ -392,6 +391,10 @@ def train(config,
best_model_dict=best_model_dict,
epoch=epoch,
global_step=global_step)
+
+ if log_writer is not None:
+ log_writer.log_model(is_best=False, prefix="latest")
+
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model(
model,
@@ -404,11 +407,14 @@ def train(config,
best_model_dict=best_model_dict,
epoch=epoch,
global_step=global_step)
+ if log_writer is not None:
+ log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
+
best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str)
- if dist.get_rank() == 0 and vdl_writer is not None:
- vdl_writer.close()
+ if dist.get_rank() == 0 and log_writer is not None:
+ log_writer.close()
return
@@ -565,15 +571,32 @@ def preprocess(is_train=False):
config['Global']['distributed'] = dist.get_world_size() != 1
- if config['Global']['use_visualdl'] and dist.get_rank() == 0:
- from visualdl import LogWriter
+ loggers = []
+
+ if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
- os.makedirs(vdl_writer_path, exist_ok=True)
- vdl_writer = LogWriter(logdir=vdl_writer_path)
+ log_writer = VDLLogger(save_model_dir)
+ loggers.append(log_writer)
+ if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
+ save_dir = config['Global']['save_model_dir']
+ wandb_writer_path = "{}/wandb".format(save_dir)
+ if "wandb" in config:
+ wandb_params = config['wandb']
+ else:
+ wandb_params = dict()
+ wandb_params.update({'save_dir': save_model_dir})
+ log_writer = WandbLogger(**wandb_params, config=config)
+ loggers.append(log_writer)
else:
- vdl_writer = None
+ log_writer = None
print_dict(config, logger)
+
+ if loggers:
+ log_writer = Loggers(loggers)
+ else:
+ log_writer = None
+
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
- return config, device, logger, vdl_writer
+ return config, device, logger, log_writer