未验证 提交 40ce7f4a 编写于 作者: M Manan Goel 提交者: GitHub

Wandb callback (#46918)

上级 436115cf
......@@ -19,6 +19,7 @@ from .hapi.callbacks import VisualDL # noqa: F401
from .hapi.callbacks import LRScheduler # noqa: F401
from .hapi.callbacks import EarlyStopping # noqa: F401
from .hapi.callbacks import ReduceLROnPlateau # noqa: F401
from .hapi.callbacks import WandbCallback # noqa: F401
__all__ = [ # noqa
'Callback',
......@@ -28,4 +29,5 @@ __all__ = [ # noqa
'LRScheduler',
'EarlyStopping',
'ReduceLROnPlateau',
'WandbCallback',
]
......@@ -993,6 +993,179 @@ class VisualDL(Callback):
delattr(self, 'writer')
class WandbCallback(Callback):
"""Track your training and system metrics using `Weights and Biases <https://docs.wandb.ai>`_.
**Installation and set-up**
Install with pip and log in to your W&B account:
.. code-block:: bash
pip install wandb
wandb login
Args:
project(str, optional): Name of the project. Default: uncategorized
entity(str, optional): Name of the team/user creating the run. Default: Logged in user
name(str, optional): Name of the run. Default: randomly generated by wandb
dir(str, optional): Directory in which all the metadata is stored. Default: `wandb`
mode(str, optional): Can be "online", "offline" or "disabled". Default: "online".
job_type(str, optional): the type of run, for grouping runs together. Default: None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
net = paddle.vision.models.LeNet()
model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())
## uncomment following lines to fit model with wandb callback function
# callback = paddle.callbacks.WandbCallback(project='paddle_mnist')
# model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)
"""
def __init__(
self,
project=None,
entity=None,
name=None,
dir=None,
mode=None,
job_type=None,
**kwargs
):
self.wandb = try_import(
"wandb",
"You want to use `wandb` which is not installed yet install it with `pip install wandb`",
)
self.wandb_args = dict(
project=project,
name=name,
entity=entity,
dir=dir,
mode=mode,
job_type=job_type,
)
self._run = None
self.wandb_args.update(**kwargs)
_ = self.run
def _is_write(self):
return ParallelEnv().local_rank == 0
@property
def run(self):
if self._is_write():
if self._run is None:
if self.wandb.run is not None:
warnings.warn(
"There is a wandb run already in progress and newly created instances"
" of `WandbCallback` will reuse this run. If this is not desired"
" , call `wandb.finish()` before instantiating `WandbCallback`."
)
self._run = self.wandb.run
else:
self._run = self.wandb.init(**self.wandb_args)
return self._run
def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
assert self.epochs
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._is_fit = True
self.train_step = 0
if self._is_write():
self.run.define_metric("train/step")
self.run.define_metric("train/*", step_metric="train/step")
self.run.define_metric("epoch")
self.run.define_metric("eval/*", step_metric="epoch")
def on_epoch_begin(self, epoch, logs=None):
self.steps = self.params['steps']
self.epoch = epoch
def _updates(self, logs, mode):
if not self._is_write():
return
metrics = getattr(self, '%s_metrics' % (mode))
current_step = getattr(self, '%s_step' % (mode))
_metrics = dict()
if mode == 'train':
total_step = current_step
_metrics.update({'train/step': total_step})
else:
total_step = self.epoch
_metrics.update({'epoch': total_step})
for k in metrics:
if k in logs:
temp_tag = mode + '/' + k
if isinstance(logs[k], (list, tuple)):
_metrics.update({temp_tag: logs[k][0]})
elif isinstance(logs[k], numbers.Number):
_metrics.update({temp_tag: logs[k]})
else:
continue
self.run.log(_metrics)
def on_train_batch_end(self, step, logs=None):
logs = logs or {}
self.train_step += 1
if self._is_write():
self._updates(logs, 'train')
def on_eval_begin(self, logs=None):
self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0
self.evaled_samples = 0
def on_train_end(self, logs=None):
if self._is_write():
self.run.finish()
def on_eval_end(self, logs=None):
if self._is_write():
self._updates(logs, 'eval')
if (not hasattr(self, '_is_fit')) and hasattr(self, 'run'):
self.run.finish()
delattr(self, 'run')
class ReduceLROnPlateau(Callback):
"""Reduce learning rate when a metric of evaluation has stopped improving.
Models often benefit from reducing the learning rate by a factor
......
......@@ -66,6 +66,7 @@ set_tests_properties(test_vision_models PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_uci_housing PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_imdb PROPERTIES TIMEOUT 300)
set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600)
set_tests_properties(test_callback_wandb PROPERTIES TIMEOUT 60)
if(WITH_COVERAGE)
set_tests_properties(test_hapi_hub PROPERTIES TIMEOUT 300)
endif()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import paddle.vision.transforms as T
from paddle.static import InputSpec
from paddle.vision.datasets import MNIST
import paddle
from paddle.fluid.framework import _test_eager_guard
class MnistDataset(MNIST):
def __len__(self):
return 512
class TestWandbCallbacks(unittest.TestCase):
def setUp(self):
self.save_dir = tempfile.mkdtemp()
def func_wandb_callback(self):
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = MnistDataset(mode='train', transform=transform)
eval_dataset = MnistDataset(mode='test', transform=transform)
net = paddle.vision.models.LeNet()
model = paddle.Model(net, inputs, labels)
optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(
optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(),
)
callback = paddle.callbacks.WandbCallback(
project='random',
dir=self.save_dir,
anonymous='must',
mode='offline',
)
model.fit(
train_dataset, eval_dataset, batch_size=64, callbacks=callback
)
def test_wandb_callback(self):
with _test_eager_guard():
self.func_wandb_callback()
self.func_wandb_callback()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.utils.lazy_import import try_import
class TestUtilsLazyImport(unittest.TestCase):
def setup(self):
pass
def func_test_lazy_import(self):
paddle = try_import('paddle')
self.assertTrue(paddle.__version__ is not None)
with self.assertRaises(ImportError) as context:
paddle2 = try_import('paddle2')
self.assertTrue(
'require additional dependencies that have to be'
in str(context.exception)
)
with self.assertRaises(ImportError) as context:
paddle2 = try_import('paddle2', 'paddle2 is not installed')
self.assertTrue('paddle2 is not installed' in str(context.exception))
def test_lazy_import(self):
self.func_test_lazy_import()
if __name__ == "__main__":
unittest.main()
......@@ -18,7 +18,7 @@ import importlib
__all__ = []
def try_import(module_name):
def try_import(module_name, err_msg=None):
"""Try importing a module, with an informative error message on failure."""
install_name = module_name
......@@ -32,6 +32,7 @@ def try_import(module_name):
mod = importlib.import_module(module_name)
return mod
except ImportError:
if err_msg is None:
err_msg = (
"Failed importing {}. This likely means that some paddle modules "
"require additional dependencies that have to be "
......
......@@ -17,3 +17,4 @@ numpy>=1.20,<1.22; python_version >= "3.7"
autograd==1.4
librosa==0.8.1
parameterized
wandb>=0.13
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册