提交 b812b18c 编写于 作者: W wangnan39@huawei.com

support update parameter for vm

上级 7fbaf2f6
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Parameter for cell.""" """Parameter for cell."""
from copy import copy, deepcopy from copy import copy, deepcopy
import numpy as np
from .initializer import initializer from .initializer import initializer
from .tensor import Tensor from .tensor import Tensor
from .._checkparam import _check_str_by_regular from .._checkparam import _check_str_by_regular
...@@ -176,14 +175,15 @@ class Parameter: ...@@ -176,14 +175,15 @@ class Parameter:
return res return res
def set_parameter_data(self, data): def set_parameter_data(self, data):
if isinstance(data, (Tensor, list, int, float, """Set `default_input` of current `Parameter`."""
np.float16, np.float32, np.int32, np.int16, np.ndarray)) and not isinstance(data, bool): if isinstance(data, bool):
if isinstance(data, Tensor): raise ValueError('Parameter data can not be `bool`')
# make a copy of Tensor to init the parameter if isinstance(data, Tensor):
data = Tensor(data.asnumpy().copy()) # make a copy of Tensor to init the parameter
self.default_input = data data = Tensor(data.asnumpy().copy())
else: else:
raise ValueError("Parameter data must be tensor or number.") data = Tensor(data)
self.default_input = data
class ParameterTuple(tuple): class ParameterTuple(tuple):
......
...@@ -101,17 +101,6 @@ def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, ep ...@@ -101,17 +101,6 @@ def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, ep
return success return success
@adam_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor")
def _run_opt_with_two_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1,
moment2):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success
class Adam(Optimizer): class Adam(Optimizer):
r""" r"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm. Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
...@@ -183,7 +172,6 @@ class Adam(Optimizer): ...@@ -183,7 +172,6 @@ class Adam(Optimizer):
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov) self.opt = P.Adam(use_locking, use_nesterov)
......
...@@ -23,7 +23,7 @@ from mindspore._checkparam import Rel ...@@ -23,7 +23,7 @@ from mindspore._checkparam import Rel
from .optimizer import Optimizer, apply_decay, grad_scale from .optimizer import Optimizer, apply_decay, grad_scale
ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
@ftrl_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") @ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
"""Apply ftrl optimizer to the weight parameter.""" """Apply ftrl optimizer to the weight parameter."""
success = True success = True
......
...@@ -43,23 +43,6 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f ...@@ -43,23 +43,6 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f
return gradient return gradient
@lars_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Bool", "Bool")
def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
"""Apply lars optimizer to the weight parameter."""
if lars_flag:
op_reduce = P.ReduceSum()
w_square_sum = op_reduce(F.square(weight))
grad_square_sum = op_reduce(F.square(gradient))
if decay_flag:
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate)
else:
num_zero = 0.0
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, num_zero, learning_rate)
return grad_t
return gradient
class LARS(Optimizer): class LARS(Optimizer):
""" """
Implements the LARS algorithm with LARSUpdate Operator. Implements the LARS algorithm with LARSUpdate Operator.
......
...@@ -15,19 +15,13 @@ ...@@ -15,19 +15,13 @@
"""momentum""" """momentum"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from .optimizer import Optimizer from .optimizer import Optimizer
momentum_opt = C.MultitypeFuncGraph("momentum_opt") momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@momentum_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor.""" """Apply momentum optimizer to the weight parameter using Tensor."""
...@@ -36,14 +30,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): ...@@ -36,14 +30,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
return success return success
@momentum_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
class Momentum(Optimizer): class Momentum(Optimizer):
""" """
Implements the Momentum algorithm. Implements the Momentum algorithm.
...@@ -86,7 +72,7 @@ class Momentum(Optimizer): ...@@ -86,7 +72,7 @@ class Momentum(Optimizer):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(momentum, name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
......
...@@ -22,6 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P ...@@ -22,6 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
...@@ -64,6 +65,7 @@ class Optimizer(Cell): ...@@ -64,6 +65,7 @@ class Optimizer(Cell):
self.assignadd = None self.assignadd = None
self.global_step = None self.global_step = None
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
learning_rate = Tensor(learning_rate, mstype.float32)
else: else:
self.dynamic_lr = True self.dynamic_lr = True
self.gather = P.GatherV2() self.gather = P.GatherV2()
......
...@@ -21,34 +21,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") ...@@ -21,34 +21,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success
@rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") @rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success return success
@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
return success
@centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", @centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor") "Tensor", "Tensor")
def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad):
"""Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon))
......
...@@ -15,20 +15,14 @@ ...@@ -15,20 +15,14 @@
"""sgd""" """sgd"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer from .optimizer import Optimizer
sgd_opt = C.MultitypeFuncGraph("sgd_opt") sgd_opt = C.MultitypeFuncGraph("sgd_opt")
@sgd_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter."""
success = True
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
return success
@sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat): def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter using Tensor.""" """Apply sgd optimizer to the weight parameter using Tensor."""
...@@ -37,14 +31,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, s ...@@ -37,14 +31,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, s
return success return success
@sgd_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, accum, stat):
"""Apply sgd optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
return success
class SGD(Optimizer): class SGD(Optimizer):
""" """
Implements stochastic gradient descent (optionally with momentum). Implements stochastic gradient descent (optionally with momentum).
...@@ -105,7 +91,7 @@ class SGD(Optimizer): ...@@ -105,7 +91,7 @@ class SGD(Optimizer):
self.opt = P.SGD(dampening, weight_decay, nesterov) self.opt = P.SGD(dampening, weight_decay, nesterov)
self.momentum = Parameter(momentum, name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.accum = self.parameters.clone(prefix="accum", init='zeros') self.accum = self.parameters.clone(prefix="accum", init='zeros')
self.stat = self.parameters.clone(prefix="stat", init='ones') self.stat = self.parameters.clone(prefix="stat", init='ones')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
......
...@@ -13,17 +13,10 @@ ...@@ -13,17 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Cell_wrapper.""" """Cell_wrapper."""
import copy
import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from ...common import Tensor
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.initializer import initializer
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C from ...ops import composite as C
from ...ops import functional as F from ...ops import functional as F
...@@ -348,25 +341,8 @@ class ParameterUpdate(Cell): ...@@ -348,25 +341,8 @@ class ParameterUpdate(Cell):
super(ParameterUpdate, self).__init__(auto_prefix=False) super(ParameterUpdate, self).__init__(auto_prefix=False)
if not isinstance(param, Parameter): if not isinstance(param, Parameter):
raise TypeError("`param` must be `Parameter`, but got {}".format(param)) raise TypeError("`param` must be `Parameter`, but got {}".format(param))
self._param = param
default_input = param.default_input
if isinstance(default_input, Tensor):
shape = default_input.shape()
zero_dtype = default_input.dtype()
elif isinstance(default_input, float):
shape = [1]
zero_dtype = mstype.float32
elif isinstance(default_input, int):
shape = [1]
zero_dtype = mstype.int32
else:
raise TypeError("`default_input` in `param` must be Tensor, float or int, but got {}".format(default_input))
self._param = Parameter(initializer(copy.deepcopy(default_input), shape), param.name)
self._param.is_init = True
self._zero = Tensor(np.zeros(shape), zero_dtype)
def construct(self, x): def construct(self, x):
zero = self._param + self._zero self._param = x
F.control_depend(zero, F.assign(self._param, x)) return x
return zero
...@@ -36,7 +36,6 @@ tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype ...@@ -36,7 +36,6 @@ tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype
tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64, tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64} "Float16": np.float16, "Float32": np.float32, "Float64": np.float64}
def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
Processes the special condition. Processes the special condition.
...@@ -182,8 +181,14 @@ def load_checkpoint(ckpoint_file_name, net=None): ...@@ -182,8 +181,14 @@ def load_checkpoint(ckpoint_file_name, net=None):
param_data = np.fromstring(data, np_type) param_data = np.fromstring(data, np_type)
dims = element.tensor.dims dims = element.tensor.dims
if dims in [[0], [1]]: if dims == [0]:
parameter_dict[element.tag] = Parameter(param_data[0], name=element.tag) if 'Float' in data_type:
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
else: else:
param_dim = [] param_dim = []
for dim in dims: for dim in dims:
......
...@@ -94,10 +94,6 @@ def test_parameter_update_float32(): ...@@ -94,10 +94,6 @@ def test_parameter_update_float32():
def test_parameter_update_error(): def test_parameter_update_error():
""" test_parameter_update """ """ test_parameter_update """
input_np = np.array([1]) input_np = np.array([1])
input_parameter = Parameter(np.array([1]), 'input_parameter')
with pytest.raises(TypeError): with pytest.raises(TypeError):
ParameterUpdate(input_np) ParameterUpdate(input_np)
with pytest.raises(TypeError):
ParameterUpdate(input_parameter)
...@@ -52,86 +52,12 @@ def test_parameter_tuple_illegal(): ...@@ -52,86 +52,12 @@ def test_parameter_tuple_illegal():
def test_parameter_init_illegal(): def test_parameter_init_illegal():
import numpy as np
dat = np.array([[1, 2, 3], [2, 3, 4]])
tensor = Tensor(dat)
data_none = None
data_bool = True data_bool = True
data_str = "nicai" data_str = "nicai"
data_int = 3
data_list = [1, "2", True]
data_tuple = (1, 2, 3)
np_arr_int16 = np.ones([1,1], dtype=np.int16)
np_arr_int32 = np.ones([1,1], dtype=np.int32)
np_arr_float16 = np.ones([1,1], dtype=np.float16)
np_arr_float32 = np.ones([1,1], dtype=np.float32)
# with pytest.raises(ValueError):
# Parameter(np_arr_int16[0][0], name=data_str)
Parameter(np_arr_int32[0], name=data_str)
Parameter(np_arr_float16[0], name=data_str)
Parameter(np_arr_float32[0], name=data_str)
Parameter(np_arr_float32, name=data_str)
Parameter(tensor, name=data_str)
Parameter(data_int, name=data_str)
Parameter(dat, name=data_str)
with pytest.raises(ValueError):
Parameter(data_none, name=data_str)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Parameter(data_bool, name=data_str) Parameter(data_bool, name=data_str)
with pytest.raises(ValueError):
Parameter(data_str, name=data_str)
Parameter(data_list, name=data_str)
with pytest.raises(ValueError):
Parameter(data_tuple, name=data_str)
Parameter(tensor, name=data_str)
Parameter(tensor, name=data_none)
with pytest.raises(ValueError):
Parameter(tensor, name=dat)
with pytest.raises(ValueError):
Parameter(tensor, name=tensor)
with pytest.raises(ValueError):
Parameter(tensor, name=data_bool)
with pytest.raises(ValueError):
Parameter(tensor, name=data_int)
with pytest.raises(ValueError):
Parameter(tensor, name=data_list)
with pytest.raises(ValueError):
Parameter(tensor, name=data_tuple)
Parameter(tensor, name=data_str, requires_grad=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_tuple)
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_tuple)
def test_check_str_by_regular(): def test_check_str_by_regular():
......
...@@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ ...@@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
run_opt = C.MultitypeFuncGraph("run_opt") run_opt = C.MultitypeFuncGraph("run_opt")
@run_opt.register("Function", "Int", "Number", "Number", @run_opt.register("Function", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor",
"Tensor") "Tensor")
def tensor_run_opt(opt, iters, learning_rate, momentum, def tensor_run_opt(opt, iters, learning_rate, momentum,
......
...@@ -51,7 +51,7 @@ class InlineMulADD(nn.Cell): ...@@ -51,7 +51,7 @@ class InlineMulADD(nn.Cell):
def __init__(self): def __init__(self):
super(InlineMulADD, self).__init__() super(InlineMulADD, self).__init__()
self.mul_add = MulAdd() self.mul_add = MulAdd()
self.param = Parameter(2, 'param') self.param = 2
def construct(self, x, y): def construct(self, x, y):
return self.mul_add(x, y) + x + self.param * y return self.mul_add(x, y) + x + self.param * y
......
...@@ -377,8 +377,8 @@ def vm_impl_momentum(self): ...@@ -377,8 +377,8 @@ def vm_impl_momentum(self):
accumulation = accumulation.asnumpy() accumulation = accumulation.asnumpy()
variable = variable.asnumpy() variable = variable.asnumpy()
shape = accumulation.shape shape = accumulation.shape
learning_rate = np.full(shape, learning_rate) learning_rate = np.full(shape, learning_rate.asnumpy())
momentum = np.full(shape, momentum) momentum = np.full(shape, momentum.asnumpy())
accumulation = accumulation * momentum + gradient accumulation = accumulation * momentum + gradient
if use_nesterov is True: if use_nesterov is True:
variable -= gradient * learning_rate + accumulation * momentum * learning_rate variable -= gradient * learning_rate + accumulation * momentum * learning_rate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册