提交 f182edfd 编写于 作者: Z Ziyan

fix lars base class type

上级 7c06d292
...@@ -21,8 +21,7 @@ from mindspore.common.parameter import Parameter ...@@ -21,8 +21,7 @@ from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.nn.cell import Cell from .optimizer import grad_scale, Optimizer
from .optimizer import grad_scale
lars_opt = C.MultitypeFuncGraph("lars_opt") lars_opt = C.MultitypeFuncGraph("lars_opt")
...@@ -61,7 +60,7 @@ def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, deca ...@@ -61,7 +60,7 @@ def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, deca
return gradient return gradient
class LARS(Cell): class LARS(Optimizer):
""" """
Implements the LARS algorithm with LARSUpdate Operator. Implements the LARS algorithm with LARSUpdate Operator.
...@@ -98,7 +97,7 @@ class LARS(Cell): ...@@ -98,7 +97,7 @@ class LARS(Cell):
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False,
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name,
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0):
super(LARS, self).__init__(auto_prefix=False) super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")])
self.opt = optimizer self.opt = optimizer
self.parameters = optimizer.parameters self.parameters = optimizer.parameters
self.learning_rate = optimizer.learning_rate self.learning_rate = optimizer.learning_rate
......
...@@ -57,7 +57,7 @@ class Optimizer(Cell): ...@@ -57,7 +57,7 @@ class Optimizer(Cell):
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0, def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Optimizer, self).__init__() super(Optimizer, self).__init__(auto_prefix=False)
if isinstance(learning_rate, float): if isinstance(learning_rate, float):
self.dynamic_lr = False self.dynamic_lr = False
self.gather = None self.gather = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册