提交 3d354d76 编写于 作者: W wanyiming

mod_callback

上级 1b63c76c
......@@ -75,9 +75,9 @@ class Callback:
"""
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can leverage this mechanism to init and release resources automatically.
You can use this mechanism to initialize and release resources automatically.
Callback function will execution some operating to the current step or epoch.
Callback function will execute some operations in the current step or epoch.
Examples:
>>> class Print_info(Callback):
......@@ -229,11 +229,11 @@ class RunContext:
"""
Provides information about the model.
Run call being made. Provides information about original request to model function.
callback objects can stop the loop by calling request_stop() of run_context.
Provides information about original request to model function.
Callback objects can stop the loop by calling request_stop() of run_context.
Args:
original_args (dict): Holding the related information of model etc.
original_args (dict): Holding the related information of model.
"""
def __init__(self, original_args):
if not isinstance(original_args, dict):
......@@ -246,13 +246,13 @@ class RunContext:
Get the _original_args object.
Returns:
Dict, a object holding the original arguments of model.
Dict, an object that holds the original arguments of model.
"""
return self._original_args
def request_stop(self):
"""
Sets stop requested during training.
Sets stop requirement during training.
Callbacks can use this function to request stop of iterations.
model.train() checks whether this is called or not.
......
......@@ -70,23 +70,24 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix):
class CheckpointConfig:
"""
The config for model checkpoint.
The configuration of model checkpoint.
Note:
During the training process, if dataset is transmitted through the data channel,
suggest set save_checkpoint_steps be an integer multiple of loop_size.
Otherwise there may be deviation in the timing of saving checkpoint.
It is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size.
Otherwise, the time to save the checkpoint may be biased.
Args:
save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.
Can't be used with save_checkpoint_steps at the same time.
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: 5.
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene.
Default: True. Integrated save function is only supported in automatic parallel scene, not supported
in manual parallel.
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
Raises:
ValueError: If the input_param is None or 0.
......@@ -180,9 +181,9 @@ class ModelCheckpoint(Callback):
It is called to combine with train process and save the model and network parameters after traning.
Args:
prefix (str): Checkpoint files names prefix. Default: "CKP".
directory (str): Folder path into which checkpoint files will be saved. Default: None.
config (CheckpointConfig): Checkpoint strategy config. Default: None.
prefix (str): The prefix name of checkpoint files. Default: "CKP".
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
Raises:
ValueError: If the prefix is invalid.
......
......@@ -27,13 +27,13 @@ class LossMonitor(Callback):
If the loss is NAN or INF, it will terminate training.
Note:
If per_print_times is 0 do not print loss.
If per_print_times is 0, do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
per_print_times (int): Print the loss each every time. Default: 1.
Raises:
ValueError: If print_step is not int or less than zero.
ValueError: If print_step is not an integer or less than zero.
"""
def __init__(self, per_print_times=1):
......
......@@ -62,7 +62,7 @@ class SummaryCollector(Callback):
SummaryCollector can help you to collect some common information.
It can help you to collect loss, learning late, computational graph and so on.
SummaryCollector also persists data collected by the summary operator into a summary file.
SummaryCollector also enables the summary operator to collect data from a summary file.
Note:
1. Multiple SummaryCollector instances in callback list are not allowed.
......@@ -74,51 +74,51 @@ class SummaryCollector(Callback):
If the directory does not exist, it will be created automatically.
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
and the unit is `step`. Default: 10. If a frequency is set, we will collect data
at (current steps % freq) == 0, and the first step will be collected at any time.
when (current steps % freq) equals to 0, and the first step will be collected at any time.
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
It is not recommended to collect data too frequently, which can affect performance.
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
By default, if set to None, all data is collected as the default behavior.
If you want to customize the data collected, you can do so with a dictionary.
Examples,you can set {'collect_metric': False} to control not collecting metrics.
You can customize the collected data with a dictionary.
For example, you can set {'collect_metric': False} to control not collecting metrics.
The data that supports control is shown below.
- collect_metric: Whether to collect training metrics, currently only loss is collected.
The first output will be treated as loss, and it will be averaged.
- collect_metric: Whether to collect training metrics, currently only the loss is collected.
The first output will be treated as the loss and it will be averaged.
Optional: True/False. Default: True.
- collect_graph: Whether to collect computational graph, currently only
- collect_graph: Whether to collect the computational graph. Currently, only
training computational graph is collected. Optional: True/False. Default: True.
- collect_train_lineage: Whether to collect lineage data for the training phase,
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
- collect_eval_lineage: Whether to collect lineage data for the eval phase,
- collect_eval_lineage: Whether to collect lineage data for the evaluation phase,
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
- collect_input_data: Whether to collect dataset for each training. Currently only image data is supported.
Optional: True/False. Default: True.
- collect_dataset_graph: Whether to collect dataset graph for the training phase.
Optional: True/False. Default: True.
- histogram_regular: Collect weight and bias for parameter distribution page display in MindInsight.
- histogram_regular: Collect weight and bias for parameter distribution page and displayed in MindInsight.
This field allows regular strings to control which parameters to collect.
Default: None, it means only the first five parameters are collected.
It is not recommended to collect too many parameters at once, as it can affect performance.
Note that if you collect too many parameters and run out of memory, the training will fail.
keep_default_action (bool): This field affects the collection behavior of the 'collect_specified_data' field.
Optional: True/False, Default: True.
True: means that after specified data is set, non-specified data is collected as the default behavior.
False: means that after specified data is set, only the specified data is collected,
True: it means that after specified data is set, non-specified data is collected as the default behavior.
False: it means that after specified data is set, only the specified data is collected,
and the others are not collected.
custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight
lineage page. In the custom data, the key type support str, and the value type support str/int/float.
Default: None, it means there is no custom data.
collect_tensor_freq (Optional[int]): Same semantic as the `collect_freq`, but controls TensorSummary only.
Because TensorSummary data is too large compared to other summary data, this parameter is used to reduce
its collection. By default, TensorSummary data will be collected at most 20 steps, but not more than how
many steps other summary data will be collected.
lineage page. In the custom data, the type of the key supports str, and the type of value supports str, int
and float. Default: None, it means there is no custom data.
collect_tensor_freq (Optional[int]): The same semantics as the `collect_freq`, but controls TensorSummary only.
Because TensorSummary data is too large to be compared with other summary data, this parameter is used to
reduce its collection. By default, The maximum number of steps for collecting TensorSummary data is 21,
but it will not exceed the number of steps for collecting other summary data.
Default: None, which means to follow the behavior as described above. For example, given `collect_freq=10`,
when the total steps is 600, TensorSummary will be collected 20 steps, while other summary data 61 steps,
but when the total steps is 20, both TensorSummary and other summary will be collected 3 steps.
Also note that when in parallel mode, the total steps will be splitted evenly, which will
affect how many steps TensorSummary will be collected.
max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk.
affect the number of steps TensorSummary will be collected.
max_file_size (Optional[int]): The maximum size in bytes of each file that can be written to the disk.
Default: None, which means no limit. For example, to write not larger than 4GB,
specify `max_file_size=4 * 1024**3`.
......
......@@ -41,7 +41,7 @@ class FixedLossScaleManager(LossScaleManager):
Args:
loss_scale (float): Loss scale. Default: 128.0.
drop_overflow_update (bool): whether to do optimizer if there is overflow. Default: True.
drop_overflow_update (bool): whether to execute optimizer if there is an overflow. Default: True.
Examples:
>>> loss_scale_manager = FixedLossScaleManager()
......@@ -59,7 +59,7 @@ class FixedLossScaleManager(LossScaleManager):
return self._loss_scale
def get_drop_overflow_update(self):
"""Get the flag whether to drop optimizer update when there is overflow happened"""
"""Get the flag whether to drop optimizer update when there is an overflow."""
return self._drop_overflow_update
def update_loss_scale(self, overflow):
......@@ -82,7 +82,7 @@ class DynamicLossScaleManager(LossScaleManager):
Dynamic loss-scale manager.
Args:
init_loss_scale (float): Init loss scale. Default: 2**24.
init_loss_scale (float): Initialize loss scale. Default: 2**24.
scale_factor (int): Coefficient of increase and decrease. Default: 2.
scale_window (int): Maximum continuous normal steps when there is no overflow. Default: 2000.
......@@ -135,7 +135,7 @@ class DynamicLossScaleManager(LossScaleManager):
self.cur_iter += 1
def get_drop_overflow_update(self):
"""Get the flag whether to drop optimizer update when there is overflow happened"""
"""Get the flag whether to drop optimizer update when there is an overflow."""
return True
def get_update_cell(self):
......
......@@ -13,11 +13,11 @@
# limitations under the License.
# ============================================================================
"""
quantization.
Quantization.
User can use quantization aware to train a model. MindSpore supports quantization aware training,
which models quantization errors in both the forward and backward passes using fake-quantization
ops. Note that the entire computation is carried out in floating point. At the end of quantization
operations. Note that the entire computation is carried out in floating point. At the end of quantization
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
"""
......
......@@ -474,8 +474,8 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
def convert_quant_network(network,
bn_fold=False,
freeze_bn=10000,
bn_fold=True,
freeze_bn=10000000,
quant_delay=(0, 0),
num_bits=(8, 8),
per_channel=(False, False),
......@@ -487,21 +487,20 @@ def convert_quant_network(network,
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first
num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False)
symmetric (bool, list or tuple): Quantization algorithm use symmetric or not. If `True` then base on
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: (False, False)
narrow_range (bool, list or tuple): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and
second element represent data flow. Default: (False, False)
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
The first element represents weights and the second element represents data flow. Default: (False, False)
Returns:
Cell, Network which has change to quantization aware training network cell.
......
......@@ -144,10 +144,10 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
Saves checkpoint info to a specified file.
Args:
parameter_list (list): Parameters list, each element is a dict
parameter_list (list): Parameters list, each element is a dictionary
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpt_file_name (str): Checkpoint file name.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
Raises:
RuntimeError: Failed to save the Checkpoint file.
......@@ -270,10 +270,10 @@ def load_param_into_net(net, parameter_dict):
Args:
net (Cell): Cell network.
parameter_dict (dict): Parameter dict.
parameter_dict (dict): Parameter dictionary.
Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
"""
if not isinstance(net, nn.Cell):
logger.error("Failed to combine the net and the parameters.")
......@@ -447,12 +447,12 @@ def _fill_param_into_net(net, parameter_list):
def export(net, *inputs, file_name, file_format='AIR'):
"""
Exports MindSpore predict model to file in specified format.
Export the MindSpore prediction model to a file in the specified format.
Args:
net (Cell): MindSpore network.
inputs (Tensor): Inputs of the `net`.
file_name (str): File name of model to export.
file_name (str): File name of the model to be exported.
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
- AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model.
......@@ -507,7 +507,7 @@ def parse_print(print_file_name):
Loads Print data from a specified file.
Args:
print_file_name (str): The file name of save print data.
print_file_name (str): The file name of saved print data.
Returns:
List, element of list is Tensor.
......
......@@ -64,29 +64,29 @@ class SummaryRecord:
SummaryRecord is used to record the summary data and lineage data.
The API will create a summary file and lineage files lazily in a given directory and writes data to them.
It writes the data to files by executing the 'record' method. In addition to record the data bubbled up from
It writes the data to files by executing the 'record' method. In addition to recording the data bubbled up from
the network by defining the summary operators, SummaryRecord also supports to record extra data which
can be added by calling add_value.
Note:
1. Make sure to close the SummaryRecord at the end, or the process will not exit.
Please see the Example section below on how to properly close with two ways.
2. The SummaryRecord instance can only allow one at a time, otherwise it will cause problems with data writes.
1. Make sure to close the SummaryRecord at the end, otherwise the process will not exit.
Please see the Example section below to learn how to close properly in two ways.
2. Only one SummaryRecord instance is allowed at a time, otherwise it will cause data writing problems.
Args:
log_dir (str): The log_dir is a directory location to save the summary.
queue_max_size (int): Deprecated. The capacity of event queue.(reserved). Default: 0.
flush_time (int): Deprecated. Frequency to flush the summaries to disk, the unit is second. Default: 120.
flush_time (int): Deprecated. Frequency of flush the summary file to disk. The unit is second. Default: 120.
file_prefix (str): The prefix of file. Default: "events".
file_suffix (str): The suffix of file. Default: "_MS".
network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk. \
max_file_size (Optional[int]): The maximum size of each file that can be written to disk (in bytes). \
Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`.
Raises:
TypeError: If `max_file_size`, `queue_max_size` or `flush_time` is not int, \
or `file_prefix` and `file_suffix` is not str.
RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname.
TypeError: If the data type of `max_file_size`, `queue_max_size` or `flush_time` is not int, \
or the data type of `file_prefix` and `file_suffix` is not str.
RuntimeError: If the log_dir is not a normalized absolute path name.
Examples:
>>> # use in with statement to auto close
......@@ -171,10 +171,10 @@ class SummaryRecord:
def set_mode(self, mode):
"""
Set the mode for the recorder to be aware. The mode is set 'train' by default.
Set the mode for the recorder to be aware. The mode is set to 'train' by default.
Args:
mode (str): The mode to set, which should be 'train' or 'eval'.
mode (str): The mode to be set, which should be 'train' or 'eval'.
Raises:
ValueError: When the mode is not recognized.
......@@ -190,29 +190,30 @@ class SummaryRecord:
def add_value(self, plugin, name, value):
"""
Add value to be record later on.
Add value to be recorded later.
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
the name should be the tag name, and the value should be a Tensor.
When the plugin plugin is 'graph', the value should be a GraphProto.
When the plugin is 'graph', the value should be a GraphProto.
When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage',
When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage',
or 'custom_lineage_data', the value should be a proto message.
Args:
plugin (str): The plugin for the value.
name (str): The name for the value.
plugin (str): The value of the plugin.
name (str): The value of the name.
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
The value to store.
- GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'.
- Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'.
- TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'.
- EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'.
- DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'.
- UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'.
- The data type of value should be 'GraphProto' when the plugin is 'graph'.
- The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor'
or 'histogram'.
- The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'.
- The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'.
- The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'.
- The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'.
Raises:
ValueError: When the name is not valid.
......@@ -248,9 +249,9 @@ class SummaryRecord:
Args:
step (int): Represents training step number.
train_network (Cell): The network that called the callback.
train_network (Cell): The network to call the callback.
plugin_filter (Optional[Callable[[str], bool]]): The filter function, \
which is used to filter out plugins from being written by return False.
which is used to filter out plugins from being written by returning False.
Returns:
bool, whether the record process is successful or not.
......@@ -342,7 +343,7 @@ class SummaryRecord:
def close(self):
"""
Flush all events and close summary records. Please use with statement to autoclose.
Flush all events and close summary records. Please use the statement to autoclose.
Examples:
>>> try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册