未验证 提交 d923f7c8 编写于 作者: M Michael Wyatt 提交者: GitHub

Refactor/Pydantify monitoring config (#2640)

* pydantify monitoring configs

---------
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 86477538
......@@ -3,46 +3,84 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from pydantic import BaseModel
from .constants import *
from pydantic import root_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
class MonitorConfig(BaseModel):
class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
def get_monitor_config(param_dict):
monitor_dict = {
key: param_dict.get(key,
{})
for key in ("tensorboard",
"wandb",
"csv_monitor")
}
return DeepSpeedMonitorConfig(**monitor_dict)
class TensorBoardConfig(MonitorConfig):
enabled: bool = TENSORBOARD_ENABLED_DEFAULT
output_path: str = TENSORBOARD_OUTPUT_PATH_DEFAULT
job_name: str = TENSORBOARD_JOB_NAME_DEFAULT
class TensorBoardConfig(DeepSpeedConfigModel):
"""Sets parameters for TensorBoard monitor."""
enabled: bool = False
""" Whether logging to Tensorboard is enabled. Requires `tensorboard` package is installed. """
class WandbConfig(MonitorConfig):
enabled: bool = WANDB_ENABLED_DEFAULT
group: str = WANDB_GROUP_NAME_DEFAULT
team: str = WANDB_TEAM_NAME_DEFAULT
project: str = WANDB_PROJECT_NAME_DEFAULT
output_path: str = ""
"""
Path to where the Tensorboard logs will be written. If not provided, the
output path is set under the training script’s launching path.
"""
job_name: str = "DeepSpeedJobName"
""" Name for the current job. This will become a new directory inside `output_path`. """
class CSVConfig(MonitorConfig):
enabled: bool = CSV_MONITOR_ENABLED_DEFAULT
output_path: str = CSV_MONITOR_OUTPUT_PATH_DEFAULT
job_name: str = CSV_MONITOR_JOB_NAME_DEFAULT
class WandbConfig(DeepSpeedConfigModel):
"""Sets parameters for WandB monitor."""
class DeepSpeedMonitorConfig:
def __init__(self, ds_config):
self.tensorboard_enabled = 'tensorboard' in ds_config
self.wandb_enabled = 'wandb' in ds_config
self.csv_monitor_enabled = 'csv_monitor' in ds_config
enabled: bool = False
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """
if self.tensorboard_enabled:
self.tensorboard_config = TensorBoardConfig(**ds_config['tensorboard'])
if self.wandb_enabled:
self.wandb_config = WandbConfig(**ds_config['wandb'])
if self.csv_monitor_enabled:
self.csv_monitor_config = CSVConfig(**ds_config['csv_monitor'])
group: str = None
""" Name for the WandB group. This can be used to group together runs. """
team: str = None
""" Name for the WandB team. """
project: str = "deepspeed"
""" Name for the WandB project. """
class CSVConfig(DeepSpeedConfigModel):
"""Sets parameters for CSV monitor."""
enabled: bool = False
""" Whether logging to local CSV files is enabled. """
output_path: str = ""
"""
Path to where the csv files will be written. If not provided, the output
path is set under the training script’s launching path.
"""
job_name: str = "DeepSpeedJobName"
""" Name for the current job. This will become a new directory inside `output_path`. """
class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
"""Sets parameters for various monitoring methods."""
tensorboard: TensorBoardConfig = {}
""" TensorBoard monitor, requires `tensorboard` package is installed. """
wandb: WandbConfig = {}
""" WandB monitor, requires `wandb` package is installed. """
csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """
@root_validator
def check_enabled(cls, values):
values["enabled"] = False
if (values.get("tensorboard").enabled or values.get("wandb").enabled
or values.get("csv_monitor").enabled):
values["enabled"] = True
return values
#########################################
# Tensorboard
#########################################
# Tensorboard. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
TENSORBOARD_FORMAT = '''
Tensorboard can be specified as:
"tensorboard": {
"enabled": true,
"output_path": "/home/myname/foo",
"job_name": "model_lr2e-5_epoch3_seed2_seq64"
}
'''
TENSORBOARD = "tensorboard"
# Tensorboard enable signal
TENSORBOARD_ENABLED = "enabled"
TENSORBOARD_ENABLED_DEFAULT = False
# Tensorboard output path
TENSORBOARD_OUTPUT_PATH = "output_path"
TENSORBOARD_OUTPUT_PATH_DEFAULT = ""
# Tensorboard job name
TENSORBOARD_JOB_NAME = "job_name"
TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"
#########################################
# Wandb
#########################################
# Wandb. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
WANDB_FORMAT = '''
Wandb can be specified as:
"wandb": {
"enabled": true,
"team_name": "deepspeed"
"project_name": "zero"
"group_name": "zero: stage 3",
}
'''
WANDB = "wandb"
# Wandb enable signal
WANDB_ENABLED = "enabled"
WANDB_ENABLED_DEFAULT = False
# Wandb team
WANDB_TEAM_NAME = "team"
WANDB_TEAM_NAME_DEFAULT = None
# Wandb project
WANDB_PROJECT_NAME = "project"
WANDB_PROJECT_NAME_DEFAULT = "deepspeed"
# Wandb group
WANDB_GROUP_NAME = "group"
WANDB_GROUP_NAME_DEFAULT = None
#########################################
# csv monitor
#########################################
# Basic CSV monitor. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
CSV_FORMAT = '''
The basic csv monitor can be specified as:
"csv_monitor": {
"enabled": true,
"output_path": "/home/myname/foo",
"job_name": "model_lr2e-5_epoch3_seed2_seq64"
}
'''
CSV_MONITOR = "csv_monitor"
# csv monitor enable signal
CSV_MONITOR_ENABLED = "enabled"
CSV_MONITOR_ENABLED_DEFAULT = False
# csv monitor output path
CSV_MONITOR_OUTPUT_PATH = "output_path"
CSV_MONITOR_OUTPUT_PATH_DEFAULT = ""
# csv_monitor job name
CSV_MONITOR_JOB_NAME = "job_name"
CSV_MONITOR_JOB_NAME_DEFAULT = "DeepSpeedJobName"
......@@ -5,12 +5,12 @@ import deepspeed.comm as dist
class csvMonitor(Monitor):
def __init__(self, monitor_config):
super().__init__(monitor_config)
def __init__(self, csv_config):
super().__init__(csv_config)
self.filenames = []
self.enabled = monitor_config.csv_monitor_config.enabled
self.output_path = monitor_config.csv_monitor_config.output_path
self.job_name = monitor_config.csv_monitor_config.job_name
self.enabled = csv_config.enabled
self.output_path = csv_config.output_path
self.job_name = csv_config.job_name
self.log_dir = self.setup_log_dir()
def setup_log_dir(self, base=os.path.join(os.path.expanduser("~"), "csv_monitor")):
......
......@@ -27,15 +27,15 @@ class MonitorMaster(Monitor):
self.tb_monitor = None
self.wandb_monitor = None
self.csv_monitor = None
self.enabled = monitor_config.tensorboard_enabled or monitor_config.csv_monitor_enabled or monitor_config.wandb_enabled
self.enabled = monitor_config.enabled
if dist.get_rank() == 0:
if monitor_config.tensorboard_enabled:
self.tb_monitor = TensorBoardMonitor(monitor_config)
if monitor_config.wandb_enabled:
self.wandb_monitor = WandbMonitor(monitor_config)
if monitor_config.csv_monitor_enabled:
self.csv_monitor = csvMonitor(monitor_config)
if monitor_config.tensorboard.enabled:
self.tb_monitor = TensorBoardMonitor(monitor_config.tensorboard)
if monitor_config.wandb.enabled:
self.wandb_monitor = WandbMonitor(monitor_config.csv_monitor)
if monitor_config.csv_monitor.enabled:
self.csv_monitor = csvMonitor(monitor_config.wandb)
def write_events(self, event_list):
if dist.get_rank() == 0:
......
......@@ -6,14 +6,14 @@ import deepspeed.comm as dist
class TensorBoardMonitor(Monitor):
def __init__(self, monitor_config):
super().__init__(monitor_config)
def __init__(self, tensorboard_config):
super().__init__(tensorboard_config)
check_tb_availability()
self.summary_writer = None
self.enabled = monitor_config.tensorboard_config.enabled
self.output_path = monitor_config.tensorboard_config.output_path
self.job_name = monitor_config.tensorboard_config.job_name
self.enabled = tensorboard_config.enabled
self.output_path = tensorboard_config.output_path
self.job_name = tensorboard_config.job_name
if self.enabled and dist.get_rank() == 0:
self.get_summary_writer()
......
......@@ -5,15 +5,15 @@ import deepspeed.comm as dist
class WandbMonitor(Monitor):
def __init__(self, monitor_config):
super().__init__(monitor_config)
def __init__(self, wandb_config):
super().__init__(wandb_config)
check_wandb_availability()
import wandb
self.enabled = monitor_config.wandb_config.enabled
self.group = monitor_config.wandb_config.group
self.team = monitor_config.wandb_config.team
self.project = monitor_config.wandb_config.project
self.enabled = wandb_config.enabled
self.group = wandb_config.group
self.team = wandb_config.team
self.project = wandb_config.project
if self.enabled and dist.get_rank() == 0:
wandb.init(project=self.project, group=self.group, entity=self.team)
......
......@@ -25,7 +25,7 @@ from .config_utils import (
from .zero.config import get_zero_config, ZeroStageEnum
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import DeepSpeedMonitorConfig
from ..monitor.config import get_monitor_config
from deepspeed import comm as dist
......@@ -829,7 +829,7 @@ class DeepSpeedConfig(object):
param_dict)
self.comms_config = DeepSpeedCommsConfig(param_dict)
self.monitor_config = DeepSpeedMonitorConfig(param_dict)
self.monitor_config = get_monitor_config(param_dict)
self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
......
......@@ -98,6 +98,13 @@ Memory Usage
memory
Monitoring
----------
.. toctree::
:maxdepth: 2
monitor
Indices and tables
------------------
......
Monitoring
==========
Deepspeed’s Monitor module can log training details into a
Tensorboard-compatible file, to WandB, or to simple CSV files. Below is an
overview of what DeepSpeed will log automatically.
.. csv-table:: Automatically Logged Data
:header: "Field", "Description", "Condition"
:widths: 20, 20, 10
`Train/Samples/train_loss`,The training loss.,None
`Train/Samples/lr`,The learning rate during training.,None
`Train/Samples/loss_scale`,The loss scale when training using `fp16`.,`fp16` must be enabled.
`Train/Eigenvalues/ModelBlockParam_{i}`,Eigen values per param block.,`eigenvalue` must be enabled.
`Train/Samples/elapsed_time_ms_forward`,The global duration of the forward pass.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward`,The global duration of the forward pass.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward_inner`,The backward time that does not include the the gradient reduction time. Only in cases where the gradient reduction is not overlapped, if it is overlapped then the inner time should be about the same as the entire backward time.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward_allreduce`,The global duration of the allreduce operation.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_step`,The optimizer step time,`flops_profiler.enabled` or `wall_clock_breakdown`.
TensorBoard
-----------
.. _TensorBoardConfig:
.. autopydantic_model:: deepspeed.monitor.config.TensorBoardConfig
WandB
-----
.. _WandbConfig:
.. autopydantic_model:: deepspeed.monitor.config.WandbConfig
CSV Monitor
-----------
.. _CSVConfig:
.. autopydantic_model:: deepspeed.monitor.config.CSVConfig
from deepspeed.monitor.constants import *
from deepspeed.monitor.tensorboard import TensorBoardMonitor
from deepspeed.monitor.wandb import WandbMonitor
from deepspeed.monitor.csv_monitor import csvMonitor
from deepspeed.monitor.config import DeepSpeedMonitorConfig
from unit.common import DistributedTest
from deepspeed.runtime.config import DeepSpeedConfig
......@@ -21,7 +20,7 @@ class TestTensorBoard(DistributedTest):
}
}
ds_config = DeepSpeedConfig(config_dict)
tb_monitor = TensorBoardMonitor(ds_config.monitor_config)
tb_monitor = TensorBoardMonitor(ds_config.monitor_config.tensorboard)
assert tb_monitor.enabled == True
assert tb_monitor.output_path == "test_output/ds_logs/"
assert tb_monitor.job_name == "test"
......@@ -29,10 +28,11 @@ class TestTensorBoard(DistributedTest):
def test_empty_tensorboard(self):
config_dict = {"train_batch_size": 2, "tensorboard": {}}
ds_config = DeepSpeedConfig(config_dict)
tb_monitor = TensorBoardMonitor(ds_config.monitor_config)
assert tb_monitor.enabled == TENSORBOARD_ENABLED_DEFAULT
assert tb_monitor.output_path == TENSORBOARD_OUTPUT_PATH_DEFAULT
assert tb_monitor.job_name == TENSORBOARD_JOB_NAME_DEFAULT
tb_monitor = TensorBoardMonitor(ds_config.monitor_config.tensorboard)
defaults = DeepSpeedMonitorConfig().tensorboard
assert tb_monitor.enabled == defaults.enabled
assert tb_monitor.output_path == defaults.output_path
assert tb_monitor.job_name == defaults.job_name
class TestWandB(DistributedTest):
......@@ -49,7 +49,7 @@ class TestWandB(DistributedTest):
}
}
ds_config = DeepSpeedConfig(config_dict)
wandb_monitor = WandbMonitor(ds_config.monitor_config)
wandb_monitor = WandbMonitor(ds_config.monitor_config.wandb)
assert wandb_monitor.enabled == False
assert wandb_monitor.group == "my_group"
assert wandb_monitor.team == "my_team"
......@@ -58,11 +58,12 @@ class TestWandB(DistributedTest):
def test_empty_wandb(self):
config_dict = {"train_batch_size": 2, "wandb": {}}
ds_config = DeepSpeedConfig(config_dict)
wandb_monitor = WandbMonitor(ds_config.monitor_config)
assert wandb_monitor.enabled == WANDB_ENABLED_DEFAULT
assert wandb_monitor.group == WANDB_GROUP_NAME_DEFAULT
assert wandb_monitor.team == WANDB_TEAM_NAME_DEFAULT
assert wandb_monitor.project == WANDB_PROJECT_NAME_DEFAULT
wandb_monitor = WandbMonitor(ds_config.monitor_config.wandb)
defaults = DeepSpeedMonitorConfig().wandb
assert wandb_monitor.enabled == defaults.enabled
assert wandb_monitor.group == defaults.group
assert wandb_monitor.team == defaults.team
assert wandb_monitor.project == defaults.project
class TestCSVMonitor(DistributedTest):
......@@ -78,7 +79,7 @@ class TestCSVMonitor(DistributedTest):
}
}
ds_config = DeepSpeedConfig(config_dict)
csv_monitor = csvMonitor(ds_config.monitor_config)
csv_monitor = csvMonitor(ds_config.monitor_config.csv_monitor)
assert csv_monitor.enabled == True
assert csv_monitor.output_path == "test_output/ds_logs/"
assert csv_monitor.job_name == "test"
......@@ -86,7 +87,8 @@ class TestCSVMonitor(DistributedTest):
def test_empty_csv_monitor(self):
config_dict = {"train_batch_size": 2, "csv_monitor": {}}
ds_config = DeepSpeedConfig(config_dict)
csv_monitor = csvMonitor(ds_config.monitor_config)
assert csv_monitor.enabled == CSV_MONITOR_ENABLED_DEFAULT
assert csv_monitor.output_path == CSV_MONITOR_OUTPUT_PATH_DEFAULT
assert csv_monitor.job_name == CSV_MONITOR_JOB_NAME_DEFAULT
csv_monitor = csvMonitor(ds_config.monitor_config.csv_monitor)
defaults = DeepSpeedMonitorConfig().csv_monitor
assert csv_monitor.enabled == defaults.enabled
assert csv_monitor.output_path == defaults.output_path
assert csv_monitor.job_name == defaults.job_name
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册