提交 352c6faf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!18 enable use float type learning rate in lars optimizer

Merge pull request !18 from gziyan/master
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""lars optimizer""" """lars optimizer"""
from typing import Iterable
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from .optimizer import grad_scale from .optimizer import grad_scale
...@@ -111,7 +113,8 @@ class LARS(Cell): ...@@ -111,7 +113,8 @@ class LARS(Cell):
self.gather = None self.gather = None
self.global_step = None self.global_step = None
self.axis = None self.axis = None
if not isinstance(self.learning_rate, float): if isinstance(self.learning_rate.default_input, Iterable) or \
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
self.dynamic_lr = True self.dynamic_lr = True
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()
self.gather = P.GatherV2() self.gather = P.GatherV2()
...@@ -124,7 +127,7 @@ class LARS(Cell): ...@@ -124,7 +127,7 @@ class LARS(Cell):
lr = self.gather(self.learning_rate, self.global_step, self.axis) lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1)) F.control_depend(lr, self.assignadd(self.global_step, 1))
else: else:
lr = F.scalar_to_array(self.learning_rate) lr = self.learning_rate
if self.reciprocal_scale != 1.0: if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
......
...@@ -46,7 +46,7 @@ class Net(nn.Cell): ...@@ -46,7 +46,7 @@ class Net(nn.Cell):
return x return x
def test_lars(): def test_lars_multi_step_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
...@@ -61,3 +61,20 @@ def test_lars(): ...@@ -61,3 +61,20 @@ def test_lars():
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
def test_lars_float_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
lr = 0.1
SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册