提交 9be17e2a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2668 support weight decay for sparse optimizer

Merge pull request !2668 from wangnan39/support_weight_decay_for_sparse_optimizer
......@@ -164,7 +164,7 @@ class Adam(Optimizer):
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay is not supported.
behavior is currently performed on the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
......
......@@ -73,7 +73,7 @@ class FTRL(Optimizer):
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay is not supported.
behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
......@@ -124,7 +124,7 @@ class FTRL(Optimizer):
linear = self.linear
lr = self.learning_rate
if self.weight_decay > 0.0:
grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
grads = self.scale_grad(grads)
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
......
......@@ -94,8 +94,7 @@ class LazyAdam(Optimizer):
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU, weight decay is
not supported.
continuous development. The sparse behavior is currently performed on the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
......
......@@ -195,12 +195,12 @@ class Optimizer(Cell):
params = self.parameters
if self.is_group:
if self.exec_weight_decay:
gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
params, gradients)
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
params, gradients)
else:
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
return gradients
......@@ -479,10 +479,20 @@ class Optimizer(Cell):
op_add = P.AddN()
op_gather = P.GatherV2()
_apply_decay = C.MultitypeFuncGraph("apply_decay")
@_apply_decay.register("Number", "Bool", "Tensor", "Tuple")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
weight = op_gather(weight, gradient[0], 0)
return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2]
return gradient
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
......
......@@ -60,7 +60,7 @@ class ProximalAdagrad(Optimizer):
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU, weight decay is not supported.
behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
......
......@@ -107,7 +107,7 @@ def test_sparse_adam_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0)
optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
......
......@@ -71,6 +71,6 @@ def test_spares_ftrl_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = FTRL(net.trainable_params(), loss_scale=2.0)
optimizer = FTRL(net.trainable_params(), weight_decay=0.9, loss_scale=2.0)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
......@@ -75,7 +75,7 @@ def test_spares_lazy_adam_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, loss_scale=2.0)
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
......
......@@ -57,7 +57,7 @@ def test_proximal_ada_grad():
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = ProximalAdagrad(net.trainable_params())
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
......@@ -70,6 +70,6 @@ def test_spares_proximal_ada_grad_compile():
net = NetWithSparseGatherV2()
net.set_train()
optimizer = ProximalAdagrad(net.trainable_params(), loss_scale=2.0)
optimizer = ProximalAdagrad(net.trainable_params(), weight_decay=0.9, loss_scale=1024.0)
train_network = TrainOneStepCell(net, optimizer)
_executor.compile(train_network, indices, label)
......@@ -57,7 +57,7 @@ def test_rmsprop_compile():
def test_rmsprop_e():
net = Net()
with pytest.raises(ValueError):
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1)
RMSProp(net.get_parameters(), momentum=-0.1, learning_rate=0.1, weight_decay=0.9)
with pytest.raises(TypeError):
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1)
RMSProp(net.get_parameters(), momentum=1, learning_rate=0.1, weight_decay=0.9)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册