提交 12344396 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1174 Make optimizer parameter same as gradient

Merge pull request !1174 from ghzl/fix-beg-group-parameters
......@@ -141,7 +141,7 @@ class DistributedGradReducer(Cell):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.weights = optimizer.parameters
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
>>> self.sens = sens
......
......@@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell
from ...common import Tensor, ParameterTuple
from ...common import Tensor
from ...common.parameter import Parameter
from ...ops import functional as F
from ...ops import composite as C
......@@ -201,7 +201,7 @@ class TrainOneStepWithLossScaleCell(Cell):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册