提交 88dcd908 编写于 作者: L Li Hongzhang

limit summary of exhausting the disk

上级 0aaa2d47
......@@ -108,6 +108,10 @@ class SummaryCollector(Callback):
custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight
lineage page. In the custom data, the key type support str, and the value type support str/int/float.
Default: None, it means there is no custom data.
collect_tensor_freq (Optional[int]): Same as the `collect_freq`, but controls TensorSummary specifically.
Default: None, which means the frequency is auto-calculated just to collect at most 50 steps TensorSummary.
max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk.
Default: None, which means no limit.
Raises:
ValueError: If the parameter value is not expected.
......@@ -145,16 +149,28 @@ class SummaryCollector(Callback):
'histogram_regular': None
}
def __init__(self, summary_dir, collect_freq=10, collect_specified_data=None,
keep_default_action=True, custom_lineage_data=None):
def __init__(self,
summary_dir,
collect_freq=10,
collect_specified_data=None,
keep_default_action=True,
custom_lineage_data=None,
collect_tensor_freq=None,
max_file_size=None):
super(SummaryCollector, self).__init__()
self._summary_dir = self._process_summary_dir(summary_dir)
self._record = None
self._check_collect_freq(collect_freq)
self._check_positive('collect_freq', collect_freq)
self._collect_freq = collect_freq
self._check_positive('collect_tensor_freq', collect_tensor_freq, allow_none=True)
self._collect_tensor_freq = collect_tensor_freq
self._check_positive('max_file_size', max_file_size, allow_none=True)
self._max_file_size = max_file_size
self._check_action(keep_default_action)
self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action)
......@@ -165,16 +181,14 @@ class SummaryCollector(Callback):
self._custom_lineage_data = custom_lineage_data
self._temp_optimizer = None
self._has_saved_train_network = False
self._has_saved_custom_data = False
self._is_parse_loss_success = True
self._first_step = True
self._dataset_sink_mode = True
def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir)
self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size)
self._first_step, self._dataset_sink_mode = True, True
return self
def __exit__(self, *err):
......@@ -198,11 +212,13 @@ class SummaryCollector(Callback):
return summary_dir
@staticmethod
def _check_collect_freq(freq):
"""Check collect freq type and value."""
check_value_type('collect_freq', freq, int)
if freq <= 0:
raise ValueError(f'For `collect_freq` the value should be greater than 0, but got `{freq}`.')
def _check_positive(name, value, allow_none=False):
"""Check if the value to be int type and positive."""
if allow_none:
return
check_value_type(name, value, int)
if value <= 0:
raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.')
@staticmethod
def _check_custom_lineage_data(custom_lineage_data):
......@@ -276,6 +292,9 @@ class SummaryCollector(Callback):
self._collect_graphs(cb_params)
self._collect_dataset_graph(cb_params)
if self._collect_tensor_freq is None:
total_step = cb_params.epoch_num * cb_params.batch_num
self._collect_tensor_freq = max(self._collect_freq, total_step // 50)
if self._custom_lineage_data and not self._has_saved_custom_data:
packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data)
......@@ -287,24 +306,29 @@ class SummaryCollector(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
if cb_params.mode != ModeEnum.TRAIN.value:
return
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)
if cb_params.mode == ModeEnum.TRAIN.value:
if not self._is_collect_this_step(cb_params):
return
self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num
self._collect_at_step_end(cb_params, plugin_filter=None)
self._first_step = False
else:
current = cb_params.cur_epoch_num if self._dataset_sink_mode else cb_params.cur_step_num
if current % self._collect_freq == 0 and current % self._collect_tensor_freq == 0:
self._collect_at_step_end(cb_params, plugin_filter=None)
elif current % self._collect_tensor_freq == 0:
self._collect_at_step_end(cb_params, lambda plugin: plugin == PluginEnum.TENSOR.value)
elif current % self._collect_freq == 0:
self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value)
if not self._has_saved_train_network:
self._collect_graphs(cb_params)
self._collect_input_data(cb_params)
self._collect_metric(cb_params)
self._collect_histogram(cb_params)
def _collect_at_step_end(self, cb_params, plugin_filter):
self._collect_input_data(cb_params)
self._collect_metric(cb_params)
self._collect_histogram(cb_params)
self._record.record(cb_params.cur_step_num, plugin_filter=plugin_filter)
self._first_step = False
self._record.record(cb_params.cur_step_num)
def end(self, run_context):
cb_params = run_context.original_args()
......@@ -331,18 +355,6 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.")
def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True
@staticmethod
def _package_custom_lineage_data(custom_lineage_data):
"""
......@@ -411,7 +423,6 @@ class SummaryCollector(Callback):
if graph_proto is None:
return
self._has_saved_train_network = True
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
def _collect_metric(self, cb_params):
......
......@@ -24,8 +24,8 @@ from ._summary_adapter import package_init_event
class BaseWriter:
"""BaseWriter to be subclass."""
def __init__(self, filepath) -> None:
self._filepath = filepath
def __init__(self, filepath, max_file_size=None) -> None:
self._filepath, self._max_file_size = filepath, max_file_size
self._writer: EventWriter_ = None
def init_writer(self):
......@@ -46,8 +46,15 @@ class BaseWriter:
def write(self, plugin, data):
"""Write data to file."""
if self.writer and disk_usage(self._filepath).free < len(data) * 32:
raise RuntimeError('The disk space may be soon exhausted.')
self.writer.Write(data)
raise RuntimeError(f'The disk space may be soon exhausted by the {type(self).__name__}.')
if self._max_file_size is None:
self.writer.Write(data)
elif self._max_file_size > 0:
self._max_file_size -= len(data)
self.writer.Write(data)
else:
raise RuntimeError(f"The file written by the {type(self).__name__} "
f"has exceeded the specified max file size.")
def flush(self):
"""Flush the writer."""
......
......@@ -51,10 +51,11 @@ class WriterPool(Process):
filelist (str): The mapping from short name to long filename.
"""
def __init__(self, base_dir, **filedict) -> None:
def __init__(self, base_dir, max_file_size, **filedict) -> None:
super().__init__()
self._base_dir, self._filedict = base_dir, filedict
self._queue, self._writers_ = Queue(cpu_count() * 2), None
self._max_file_size = max_file_size
self.start()
def run(self):
......@@ -88,9 +89,9 @@ class WriterPool(Process):
for plugin, filename in self._filedict.items():
filepath = os.path.join(self._base_dir, filename)
if plugin == 'summary':
self._writers_.append(SummaryWriter(filepath))
self._writers_.append(SummaryWriter(filepath, self._max_file_size))
elif plugin == 'lineage':
self._writers_.append(LineageWriter(filepath))
self._writers_.append(LineageWriter(filepath, self._max_file_size))
return self._writers_
def _write(self, plugin, data):
......@@ -98,9 +99,8 @@ class WriterPool(Process):
for writer in self._writers[:]:
try:
writer.write(plugin, data)
except RuntimeError:
logger.warning(f'The disk space may be soon exhausted by this {type(writer).__name__}, '
'so the writer will be closed and not for further writing.')
except RuntimeError as e:
logger.warning(e.args[0])
self._writers.remove(writer)
writer.close()
......
......@@ -75,14 +75,17 @@ class SummaryRecord:
Args:
log_dir (str): The log_dir is a directory location to save the summary.
queue_max_size (int): The capacity of event queue.(reserved). Default: 0.
flush_time (int): Frequency to flush the summaries to disk, the unit is second. Default: 120.
queue_max_size (int): Deprecated. The capacity of event queue.(reserved). Default: 0.
flush_time (int): Deprecated. Frequency to flush the summaries to disk, the unit is second. Default: 120.
file_prefix (str): The prefix of file. Default: "events".
file_suffix (str): The suffix of file. Default: "_MS".
network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk. \
Unlimited by default.
Raises:
TypeError: If `queue_max_size` and `flush_time` is not int, or `file_prefix` and `file_suffix` is not str.
TypeError: If `max_file_size`, `queue_max_size` or `flush_time` is not int, \
or `file_prefix` and `file_suffix` is not str.
RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname.
Examples:
......@@ -103,7 +106,8 @@ class SummaryRecord:
flush_time=120,
file_prefix="events",
file_suffix="_MS",
network=None):
network=None,
max_file_size=None):
self._closed, self._event_writer = False, None
self._mode, self._data_pool = 'train', _dictlist()
......@@ -113,11 +117,18 @@ class SummaryRecord:
self.log_path = _make_directory(log_dir)
if not isinstance(max_file_size, (int, type(None))):
raise TypeError("The 'max_file_size' should be int type.")
if not isinstance(queue_max_size, int) or not isinstance(flush_time, int):
raise TypeError("`queue_max_size` and `flush_time` should be int")
if not isinstance(file_prefix, str) or not isinstance(file_suffix, str):
raise TypeError("`file_prefix` and `file_suffix` should be str.")
if max_file_size is not None and max_file_size < 0:
logger.warning("The 'max_file_size' should be greater than 0.")
max_file_size = None
self.queue_max_size = queue_max_size
if queue_max_size < 0:
# 0 is not limit
......@@ -142,6 +153,7 @@ class SummaryRecord:
raise RuntimeError(ex)
self._event_writer = WriterPool(log_dir,
max_file_size,
summary=self.full_file_name,
lineage=get_event_file_name('events', '_lineage'))
atexit.register(self.close)
......@@ -152,7 +164,7 @@ class SummaryRecord:
raise ValueError('SummaryRecord has been closed.')
return self
def __exit__(self, extype, exvalue, traceback):
def __exit__(self, *err):
"""Exit the context manager."""
self.close()
......@@ -229,13 +241,15 @@ class SummaryRecord:
else:
raise ValueError(f'No such plugin of {repr(plugin)}')
def record(self, step, train_network=None):
def record(self, step, train_network=None, plugin_filter=None):
"""
Record the summary.
Args:
step (int): Represents training step number.
train_network (Cell): The network that called the callback.
plugin_filter (Optional[Callable[[str], bool]]): The filter function, \
which is used to filter out plugins from being written by return False.
Returns:
bool, whether the record process is successful or not.
......@@ -266,7 +280,14 @@ class SummaryRecord:
if self._mode == 'train':
self._add_summary_tensor_data()
self._event_writer.write(self._consume_data_pool(step))
if not plugin_filter:
self._event_writer.write(self._consume_data_pool(step))
else:
filtered = {}
for plugin, datalist in self._consume_data_pool(step).items():
if plugin_filter(plugin):
filtered[plugin] = datalist
self._event_writer.write(filtered)
return True
def _add_summary_tensor_data(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册