提交 108dd7a4 编写于 作者: O ougongchang

Make sure record the first step data in SummaryCollector, and catch the...

Make sure record the first step data in SummaryCollector, and catch the ValueError when the loss is not a Scalar.
上级 0478b7d1
......@@ -73,7 +73,7 @@ class SummaryCollector(Callback):
summary_dir (str): The collected data will be persisted to this directory.
If the directory does not exist, it will be created automatically.
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
and the unit is `step`. Default: 10.
and the unit is `step`. Default: 10. The first step will be recorded at any time.
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
It is not recommended to collect data too frequently, which can affect performance.
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
......@@ -142,9 +142,6 @@ class SummaryCollector(Callback):
'histogram_regular': None
}
# _OPTIMIZER_FAILED means find optimizer failed, so we will not collect data about optimizer.
_OPTIMIZER_FAILED = 'Failed'
def __init__(self, summary_dir, collect_freq=10, collect_specified_data=None,
keep_default_action=True, custom_lineage_data=None):
super(SummaryCollector, self).__init__()
......@@ -158,7 +155,8 @@ class SummaryCollector(Callback):
self._check_action(keep_default_action)
self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action)
logger.info(f"For `collect_specified_data` the value after processing is: {self._collect_specified_data}.")
msg = f"For 'collect_specified_data' the value after processing is: {self._collect_specified_data}."
logger.info(msg)
self._check_custom_lineage_data(custom_lineage_data)
self._custom_lineage_data = custom_lineage_data
......@@ -167,6 +165,7 @@ class SummaryCollector(Callback):
self._has_saved_train_network = False
self._has_saved_custom_data = False
self._is_parse_loss_success = True
self._first_step = True
def __enter__(self):
self._record = SummaryRecord(log_dir=self._summary_dir)
......@@ -228,7 +227,7 @@ class SummaryCollector(Callback):
if specified_data is None:
if action:
return self._DEFAULT_SPECIFIED_DATA
return None
return dict()
check_value_type('collect_specified_data', specified_data, [dict, type(None)])
......@@ -282,9 +281,13 @@ class SummaryCollector(Callback):
cb_params = run_context.original_args()
if cb_params.mode == ModeEnum.TRAIN.value:
if cb_params.cur_step_num % self._collect_freq:
# Make sure the first step data is recorded
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
return
self._first_step = False
if not self._has_saved_train_network:
self._collect_graphs(cb_params)
......@@ -357,7 +360,8 @@ class SummaryCollector(Callback):
input_data = input_data[0]
try:
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data)
except ValueError:
except ValueError as ex:
logger.warning(str(ex))
self._collect_specified_data['collect_input_data'] = False
return
......@@ -395,7 +399,12 @@ class SummaryCollector(Callback):
loss = self._get_loss(cb_params)
if loss is None:
return
self._record.add_value(PluginEnum.SCALAR.value, 'loss/auto', loss)
try:
self._record.add_value(PluginEnum.SCALAR.value, 'loss/auto', loss)
except ValueError as exc:
logger.warning(str(exc))
self._collect_specified_data['collect_metric'] = False
def _get_loss(self, cb_params):
"""
......@@ -446,7 +455,9 @@ class SummaryCollector(Callback):
Returns:
Union[Optimizer, None], if parse optimizer success, will return a optimizer, else return None.
"""
if self._optimizer == self._OPTIMIZER_FAILED:
# 'optimizer_failed' means find optimizer failed, so we will not collect data about optimizer.
optimizer_failed = 'Failed'
if self._optimizer == optimizer_failed:
return None
if self._optimizer is not None:
......@@ -458,9 +469,11 @@ class SummaryCollector(Callback):
optimizer = self._parse_optimizer_by_network(network)
if optimizer is None or not isinstance(optimizer, Optimizer):
logger.warning("Can not find optimizer in network, or the optimizer does not inherit Mindpore's optimizer, "
"so we will not collect data about optimizer in SummaryCollector.")
optimizer = self._OPTIMIZER_FAILED
logger.warning("Can not find optimizer in network, or the optimizer does not inherit MindSpore's "
"optimizer, so we will not collect data about optimizer in SummaryCollector.")
optimizer = None
self._optimizer = optimizer if optimizer is not None else optimizer_failed
return optimizer
......@@ -469,6 +482,8 @@ class SummaryCollector(Callback):
"""Parse optimizer from network, if parse success will return a optimizer, else return None."""
optimizer = None
for _, cell in network.cells_and_names():
if isinstance(cell, Optimizer):
return cell
try:
optimizer = getattr(cell, 'optimizer')
except AttributeError:
......@@ -489,11 +504,11 @@ class SummaryCollector(Callback):
if 'histogram_regular' not in self._collect_specified_data:
return
self._optimizer = self._get_optimizer(cb_params)
if self._optimizer is None:
optimizer = self._get_optimizer(cb_params)
if optimizer is None:
return
parameters = self._optimizer.parameters
parameters = optimizer.parameters
regular = self._collect_specified_data.get('histogram_regular')
if regular is not None:
for parameter in parameters:
......@@ -538,7 +553,7 @@ class SummaryCollector(Callback):
train_lineage[LineageMetadata.loss] = None
optimizer = self._get_optimizer(cb_params)
learning_rate = self._get_learning_rate(optimizer)
learning_rate = self._get_learning_rate(optimizer) if optimizer is not None else None
if learning_rate is not None:
train_lineage[LineageMetadata.learning_rate] = list(np.atleast_1d(learning_rate.asnumpy()))[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册