提交 026851ed 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!617 modify parallel checkpoint

Merge pull request !617 from changzherui/mod_parall_ckpt
...@@ -282,8 +282,8 @@ The following uses a specific model parameter as an example. The parameter name ...@@ -282,8 +282,8 @@ The following uses a specific model parameter as an example. The parameter name
3. Modify values of model parameters. 3. Modify values of model parameters.
``` ```
new_param.set_parameter_data(tensor_slice) new_param.set_parameter_data(tensor_slice, True)
new_param_moments.set_parameter_data(tensor_slice_moments) new_param_moments.set_parameter_data(tensor_slice_moments, True)
``` ```
- `set_parameter_data`: sets the value of a model parameter. The API parameter type is Tensor or number. - `set_parameter_data`: sets the value of a model parameter. The API parameter type is Tensor or number.
...@@ -389,7 +389,7 @@ User process: ...@@ -389,7 +389,7 @@ User process:
# merge the parallel parameters of the model # merge the parallel parameters of the model
allgather_net = get_allgather_cell() allgather_net = get_allgather_cell()
param_data = allgather_net(param_data) param_data = allgather_net(param_data)
layerwise_param.set_parameter_data(param_data) layerwise_param.set_parameter_data(param_data, True)
# convert param_dict to list type data # convert param_dict to list type data
param_list = [] param_list = []
...@@ -553,7 +553,7 @@ User process: ...@@ -553,7 +553,7 @@ User process:
rank = get_rank() rank = get_rank()
tensor_slice = Tensor(slice_list[rank]) tensor_slice = Tensor(slice_list[rank])
# modify model parameter data values # modify model parameter data values
new_param.set_parameter_data(tensor_slice) new_param.set_parameter_data(tensor_slice, True)
# load the modified parameter data into the network # load the modified parameter data into the network
weight = np.ones([4, 8]).astype(np.float32) weight = np.ones([4, 8]).astype(np.float32)
......
...@@ -151,13 +151,15 @@ metrics = { ...@@ -151,13 +151,15 @@ metrics = {
net = ResNet() net = ResNet()
loss = CrossEntropyLoss() loss = CrossEntropyLoss()
opt = Momentum() opt = Momentum()
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics) model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, callbacks=TimeMonitor())
ds_eval = create_dataset() ds_eval = create_dataset()
output = model.eval(ds_eval) output = model.eval(ds_eval)
``` ```
The `model.eval` method returns a dictionary that contains the metrics and results transferred to the metrics. The `model.eval` method returns a dictionary that contains the metrics and results transferred to the metrics.
The callback function can also be used in the eval process, and the user can call the related API or customize the callback method to achieve the desired function.
You can also define your own metrics class by inheriting the `Metric` base class and rewriting the `clear`, `update`, and `eval` methods. You can also define your own metrics class by inheriting the `Metric` base class and rewriting the `clear`, `update`, and `eval` methods.
The `accuracy` operator is used as an example to describe the internal implementation principle. The `accuracy` operator is used as an example to describe the internal implementation principle.
......
...@@ -286,8 +286,8 @@ param_dict = load_checkpoint("./CKP-Integrated_1-4_32.ckpt") ...@@ -286,8 +286,8 @@ param_dict = load_checkpoint("./CKP-Integrated_1-4_32.ckpt")
3. 修改模型参数数据值。 3. 修改模型参数数据值。
``` ```
new_param.set_parameter_data(tensor_slice) new_param.set_parameter_data(tensor_slice, True)
new_param_moments.set_parameter_data(tensor_slice_moments) new_param_moments.set_parameter_data(tensor_slice_moments, True)
``` ```
- `set_parameter_data`:设置模型参数的值,接口参数类型为Tensor 或number。 - `set_parameter_data`:设置模型参数的值,接口参数类型为Tensor 或number。
...@@ -392,7 +392,7 @@ load_param_into_net(opt, param_dict) ...@@ -392,7 +392,7 @@ load_param_into_net(opt, param_dict)
# merge the parallel parameters of the model # merge the parallel parameters of the model
allgather_net = get_allgather_cell() allgather_net = get_allgather_cell()
param_data = allgather_net(param_data) param_data = allgather_net(param_data)
layerwise_param.set_parameter_data(param_data) layerwise_param.set_parameter_data(param_data, True)
# convert param_dict to list type data # convert param_dict to list type data
param_list = [] param_list = []
...@@ -556,7 +556,7 @@ load_param_into_net(opt, param_dict) ...@@ -556,7 +556,7 @@ load_param_into_net(opt, param_dict)
rank = get_rank() rank = get_rank()
tensor_slice = Tensor(slice_list[rank]) tensor_slice = Tensor(slice_list[rank])
# modify model parameter data values # modify model parameter data values
new_param.set_parameter_data(tensor_slice) new_param.set_parameter_data(tensor_slice, True)
# load the modified parameter data into the network # load the modified parameter data into the network
weight = np.ones([4, 8]).astype(np.float32) weight = np.ones([4, 8]).astype(np.float32)
......
...@@ -133,6 +133,7 @@ epoch: 20 step: 32 loss: 2.298344373703003 ...@@ -133,6 +133,7 @@ epoch: 20 step: 32 loss: 2.298344373703003
同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。 同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。
在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch``step``loss`的值。 在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch``step``loss`的值。
## MindSpore metrics功能介绍 ## MindSpore metrics功能介绍
当训练结束后,可以使用metrics评估训练结果的好坏。 当训练结束后,可以使用metrics评估训练结果的好坏。
...@@ -152,13 +153,15 @@ metrics = { ...@@ -152,13 +153,15 @@ metrics = {
net = ResNet() net = ResNet()
loss = CrossEntropyLoss() loss = CrossEntropyLoss()
opt = Momentum() opt = Momentum()
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics) model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, callbacks=TimeMonitor())
ds_eval = create_dataset() ds_eval = create_dataset()
output = model.eval(ds_eval) output = model.eval(ds_eval)
``` ```
`model.eval`方法会返回一个字典,里面是传入metrics的指标和结果。 `model.eval`方法会返回一个字典,里面是传入metrics的指标和结果。
在eval过程中也可以使用callback功能,用户可以调用相关API或自定义callback方法实现想要的功能。
用户也可以定义自己的`metrics`类,通过继承`Metric`基类,并重写`clear``update``eval`三个方法即可实现。 用户也可以定义自己的`metrics`类,通过继承`Metric`基类,并重写`clear``update``eval`三个方法即可实现。
`accuracy`算子举例说明其内部实现原理: `accuracy`算子举例说明其内部实现原理:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册