diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index bfbde78ffff1c62a17edf14f2e1c5da19bad78dc..6c6d14ed7abfabdedbd2fac11f0615c39ade361e 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -81,8 +81,22 @@ class Optimizer(Cell): else: raise TypeError("Learning rate should be float, Tensor or Iterable.") + if isinstance(weight_decay, int): + weight_decay = float(weight_decay) + + if not isinstance(weight_decay, float): + raise TypeError("weight_decay should be a float number!") + + if isinstance(loss_scale, int): + loss_scale = float(loss_scale) + + if not isinstance(loss_scale, float): + raise TypeError("loss_scale should be a float number!") + if loss_scale <= 0.0: raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) + self.loss_scale = loss_scale + if weight_decay < 0.0: raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay)) diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index a18adb81846e5c9ac41d1f7fc7d4e970c300d13b..983be4bf80b5923509001783e6ce6f7d544576dd 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -61,7 +61,8 @@ class SGD(Optimizer): dampening (float): A floating point value of dampening for momentum. Default: 0. weight_decay (float): Weight decay (L2 penalty). Default: 0. nesterov (bool): Enables the Nesterov momentum. Default: False. - loss_scale (float): A floating point value for the loss scale. Default: 1.0. + loss_scale (float): A floating point value for the loss scale, which should be larger + than 0.0. Default: 1.0. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -83,9 +84,18 @@ class SGD(Optimizer): super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale) + if not isinstance(momentum, float): + raise TypeError("momentum should be float number!") + if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + if not isinstance(dampening, float): + raise TypeError("dampening should be float number") + + if isinstance(dampening, int): + dampening = float(dampening) + if dampening < 0.0: raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) self.dampening = dampening diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a7c3f5044048909580eacecba4b0d717b3fd2055..499f5d4f57a6ab54f08716ac11b652f7b940ab00 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1008,6 +1008,7 @@ class Argmax(PrimitiveWithInfer): def infer_dtype(self, x_dtype): validator.check_subclass("input_x", x_dtype, mstype.tensor) + validator.check_typename('input_x', x_dtype, [mstype.float32, mstype.float16]) return mstype.tensor_type(self.output_type) @@ -1500,7 +1501,9 @@ class Slice(PrimitiveWithInfer): Tensor. Examples: - >>> data = Tensor(np.array([3,2,3]).astype(np.int32)) + >>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], + >>> [[3, 3, 3], [4, 4, 4]], + >>> [[5, 5, 5], [6, 6, 6]]]).astype(np.int32)) >>> type = P.Slice()(data, (1, 0, 0), (1, 1, 3)) """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f5037882f17adf798e0e63292d7c7e988b2e131d..180e4cfe330cdde58a7dc803c5ea0b05ffdcd58e 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1436,9 +1436,9 @@ class SGD(PrimitiveWithInfer): nesterov (bool): Enable Nesterov momentum. Default: False. Inputs: - - **parameters** (Tensor) - Parameters to be updated. + - **parameters** (Tensor) - Parameters to be updated. Their data type can be list or tuple. - **gradient** (Tensor) - Gradients. - - **learning_rate** (Tensor) - Learning rate. e.g. Tensor(0.1, mindspore.float32). + - **learning_rate** (Tensor) - Learning rate. Must be float value. e.g. Tensor(0.1, mindspore.float32). - **accum** (Tensor) - Accum(velocity) to be updated. - **momentum** (Tensor) - Momentum. e.g. Tensor(0.1, mindspore.float32). - **stat** (Tensor) - States to be updated with the same shape as gradient. @@ -1449,6 +1449,7 @@ class SGD(PrimitiveWithInfer): @prim_attr_register def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False): + validator.check_type("nesterov", nesterov, [bool]) self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], outputs=['output'])