提交 352c6faf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!18 enable use float type learning rate in lars optimizer

Merge pull request !18 from gziyan/master
......@@ -13,12 +13,14 @@
# limitations under the License.
# ============================================================================
"""lars optimizer"""
from typing import Iterable
from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.nn.cell import Cell
from .optimizer import grad_scale
......@@ -111,7 +113,8 @@ class LARS(Cell):
self.gather = None
self.global_step = None
self.axis = None
if not isinstance(self.learning_rate, float):
if isinstance(self.learning_rate.default_input, Iterable) or \
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
self.dynamic_lr = True
self.assignadd = P.AssignAdd()
self.gather = P.GatherV2()
......@@ -124,7 +127,7 @@ class LARS(Cell):
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1))
else:
lr = F.scalar_to_array(self.learning_rate)
lr = self.learning_rate
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
......
......@@ -46,7 +46,7 @@ class Net(nn.Cell):
return x
def test_lars():
def test_lars_multi_step_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
......@@ -61,3 +61,20 @@ def test_lars():
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_lars_float_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
lr = 0.1
SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册