diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu index 31a4d97dff5cc3192462c79cab0deb81a6a272af..913aaa3b8d3026f6d7e3ab14b6bace113b386646 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu @@ -19,17 +19,17 @@ #include "device/gpu/cuda_common.h" template -__global__ void RmsPropKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, +__global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, T*moment, T* gradients, const size_t size) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { - mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; - moment[i] = momentum[0] * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon[0]) * gradients[i]; + mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i]; + moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i]; variable[i] -= moment[i]; } } template -void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, +void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { RmsPropKernel<<>>(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size); @@ -58,7 +58,7 @@ void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, co } template -void RmsProp(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, +void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon, float* variable, float* mean_square, float* moment, float* gradients, const size_t size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh index 62d7e19ba2233cefabb4f5bea0e4c2ea9f2de8b7..b5802dbb67fb18fea2b414e49dafe944390f8762 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh @@ -19,7 +19,7 @@ #include "device/gpu/cuda_common.h" template -void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, T* mean_square, +void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); template diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc index 707aa7764709198bb457a892e290e2bf62f18426..032e8eeec4b95ae6f14c389a3df60f966c72d770 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc @@ -25,9 +25,6 @@ MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), RMSPropGpuKernel, float) diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h index 7eaedfba520df4d4b9c4d8320e274cd2521664ef..9e148b690d02fbe8dcc02e33626cedc750a8b11d 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class RMSPropGpuKernel : public GpuKernel { public: - RMSPropGpuKernel() : size_(1), use_center_(false) {} + RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {} ~RMSPropGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -40,13 +40,10 @@ class RMSPropGpuKernel : public GpuKernel { T *variable = GetDeviceAddress(inputs, 0); T *mean_square = GetDeviceAddress(inputs, 1); T *moment = GetDeviceAddress(inputs, 2); - T *gradients = GetDeviceAddress(inputs, 3); - T *learning_rate = GetDeviceAddress(inputs, 4); - T *decay = GetDeviceAddress(inputs, 5); - T *momentum = GetDeviceAddress(inputs, 6); - T *epsilon = GetDeviceAddress(inputs, 7); + T *learning_rate = GetDeviceAddress(inputs, 3); + T *gradients = GetDeviceAddress(inputs, 4); - RmsProp(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size_, + RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_, reinterpret_cast(stream)); } else { T *variable = GetDeviceAddress(inputs, 0); @@ -70,6 +67,11 @@ class RMSPropGpuKernel : public GpuKernel { use_center_ = true; } + if (node_name == "ApplyRMSProp") { + decay_ = GetAttr(kernel_node, "rho"); + momentum_ = GetAttr(kernel_node, "momentum"); + epsilon_ = GetAttr(kernel_node, "epsilon"); + } auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); for (auto &dim : input_shape) { size_ *= dim; @@ -81,24 +83,33 @@ class RMSPropGpuKernel : public GpuKernel { protected: void InitSizeLists() override { size_t input_size = size_ * sizeof(T); - input_size_list_.push_back(input_size); - if (use_center_) { + if (!use_center_) { + input_size_list_.push_back(input_size); input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(input_size); + output_size_list_.push_back(input_size); + } else { + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(input_size); } - - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - output_size_list_.push_back(0); } private: size_t size_; bool use_center_; + float decay_; + float momentum_; + float epsilon_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 77fda2162e8e67c1e474ec8a573109c657b790cf..a0b2e5bdb2aa9e4cb0076ae5f8742b6952509e5f 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -175,7 +175,7 @@ class FakeQuantWithMinMaxAscend(Cell): else: quant_fun = P.FakeQuantPerLayer ema_fun = P.FakeQuantMinMaxPerLayerUpdate - + self.fake_quant = quant_fun(num_bits=self.num_bits, ema=self.ema, ema_decay=self.ema_decay, @@ -272,7 +272,7 @@ class FakeQuantWithMinMaxGPU(Cell): 0, self.out_channels)]).astype(np.float32) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) - + if per_channel: quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) else: diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index ee5970cc7644854b82300c89f3ab1032434190eb..fa0257be75e66dae0688597ce1a277f2c439c7d2 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -175,8 +175,7 @@ class Adam(Optimizer): If True, updates the gradients using NAG. If False, updates the gradients without using NAG. Default: False. weight_decay (float): Weight decay (L2 penalty). Default: 0.0. - loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: - 1.0. + loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. Inputs: - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. @@ -210,7 +209,7 @@ class Adam(Optimizer): validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) - validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) + validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) self.beta1 = Tensor(beta1, mstype.float32) self.beta2 = Tensor(beta2, mstype.float32) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 694c4f1d74b011a7e4df5b7cea8d45871d024104..aa52c7c3e66e49678f01a06311d26dee85b1e02c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -122,7 +122,8 @@ class SameTypeShape(PrimitiveWithInfer): Checks whether data type and shape of two tensors are the same. Raises: - TypeError or ValueError: If not the same. + TypeError - If data type not the same. + ValueError - If shape of two tensors not the same. Inputs: - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. @@ -1031,7 +1032,7 @@ class InvertPermutation(PrimitiveWithInfer): - **input_x** (Union(tuple[int], Tensor[int])) - The input tuple is constructed by multiple integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices. The values must include 0. There can be no duplicate values or negative values. - If the input is Tensor, it must be 1-d and the dtype is int. + If the input is Tensor, it must be 1-d and the dtype is int. Only constant value is allowed. Outputs: @@ -1061,7 +1062,9 @@ class InvertPermutation(PrimitiveWithInfer): z = [x_value[i] for i in range(len(x_value))] z.sort() - validator.check(f'value length', len(x_value), f'unique value length', len(set(x_value)), Rel.EQ, self.name) + for i in range(1, len(z)): + if z[i-1] == z[i]: + raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.") validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name) validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index ec61132eefb0fb94b456e277221fd776f63b649f..176e91b48b9b746ed593e2ee4bf2c4e650dc3826 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -258,6 +258,8 @@ class _Reduce(PrimitiveWithInfer): args = {'input_x': input_x['dtype']} validator.check_tensor_type_same(args, valid_dtype, self.name) + if axis_v is None: + raise ValueError(f"For {self.name}, axis must be const.") input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) return {'shape': input_shp, 'dtype': input_x['dtype'], @@ -445,8 +447,9 @@ class ReduceProd(_Reduce): Default : False, don't keep these reduced dimensions. Inputs: - - **input_x** (Tensor[Number]) - The input tensor. - - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. + - **input_x** (Tensor[Number]) - The input tensor. + - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. + Only constant value is allowed. Outputs: Tensor, has the same dtype as the 'input_x'. @@ -474,8 +477,9 @@ class CumProd(PrimitiveWithInfer): reverse (bool): If True, reverse the result along axis. Default: False Inputs: - - **input_x** (Tensor[Number]) - The input tensor. - - **axis** (int) - The dimensions to compute the cumulative product. + - **input_x** (Tensor[Number]) - The input tensor. + - **axis** (int) - The dimensions to compute the cumulative product. + Only constant value is allowed. Outputs: Tensor, has the same shape and dtype as the 'input_x'. @@ -507,6 +511,10 @@ class CumProd(PrimitiveWithInfer): validator.check_subclass("axis", axis_type, mstype.int_, cls_name) return x_type + def infer_value(self, x, axis): + if axis is None: + raise ValueError(f"For {self.name}, axis must be const.") + class MatMul(PrimitiveWithInfer): """ @@ -669,6 +677,10 @@ class CumSum(PrimitiveWithInfer): 'dtype': x['dtype'], 'value': None} + def infer_value(self, x, axis): + if axis is None: + raise ValueError(f"For {self.name}, axis must be const.") + class AddN(PrimitiveWithInfer): """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 39876cabcac676b1ac6fa7482b07c5fc668d0d6a..c5472297ccff382ff79c355ffe63c9ff40e87a6e 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1707,9 +1707,9 @@ class ApplyRMSProp(PrimitiveWithInfer): - **moment** (Tensor) - Delta of `var`, must have the same type as `var`. - **learning_rate** (Union[Number, Tensor]) - Learning rate. - **grad** (Tensor) - Gradients, must have the same type as `var`. - - **decay** (float) - Decay rate. - - **momentum** (float) - Momentum. - - **epsilon** (float) - Ridge term. + - **decay** (float) - Decay rate. Only constant value is allowed. + - **momentum** (float) - Momentum. Only constant value is allowed. + - **epsilon** (float) - Ridge term. Only constant value is allowed. Outputs: Tensor, parameters to be update. @@ -1759,6 +1759,13 @@ class ApplyRMSProp(PrimitiveWithInfer): return var_dtype, var_dtype, var_dtype return var_dtype + def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): + if decay is None or momentum is None or epsilon is None: + raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.") + if not self.is_ge and self.is_d: + return None, None, None + return None + class ApplyCenteredRMSProp(PrimitiveWithInfer): """ diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 714cad681fddaaadd4303b1835fdbd7f97ff48c6..6205f6b275addcd0f3436c332ee9f9529aa28862 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -379,7 +379,7 @@ class ConfusionMatrix(PrimitiveWithInfer): Inputs: - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer. - **predictions** (Tensor) - the labels from prediction, tensor of 1-D. - the shape same as `labels` and the dtype must be non-negative Integer. + the shape same as `labels` and the dtype must be non-negative Integer. - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`. Outputs: diff --git a/tests/st/ops/test_rmsprop.py b/tests/st/ops/gpu/test_rmsprop.py similarity index 80% rename from tests/st/ops/test_rmsprop.py rename to tests/st/ops/gpu/test_rmsprop.py index d0b65d627f1a91e7f1f66eabab500a75a31435fc..aa27fcf2638da55536faa439159c21958868c43d 100644 --- a/tests/st/ops/test_rmsprop.py +++ b/tests/st/ops/gpu/test_rmsprop.py @@ -24,19 +24,25 @@ from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target="GPU") +class NetCenteredRMSProp(nn.Cell): + def __init__(self): + super(NetCenteredRMSProp, self).__init__() + self.rms_opt = P.ApplyCenteredRMSProp() + + def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon): + return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon) + + class NetRMSProp(nn.Cell): - def __init__(self, use_centered): + def __init__(self, decay, momentum, epsilon): super(NetRMSProp, self).__init__() - self.use_centered = use_centered - if use_centered: - self.rms_opt = P.ApplyCenteredRMSProp() - else: - self.rms_opt = P.ApplyRMSProp() + self.decay = decay + self.momentum = momentum + self.epsilon = epsilon + self.rms_opt = P.ApplyRMSProp() - def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon): - if self.use_centered: - return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon) - return self.rms_opt(var, rms, mom, lr, g, decay, momentum, epsilon) + def construct(self, var, g, mg, rms, mom, lr): + return self.rms_opt(var, rms, mom, lr, g, self.decay, self.momentum, self.epsilon) def rmsprop_numpy(variable, gradients, mean_square, moment, @@ -76,13 +82,16 @@ def test_rmsprop(): if centered: rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, learning_rate, decay, momentum, epsilon) + net = NetCenteredRMSProp() + _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, + moment_ms, learning_rate, decay, momentum, epsilon) + else: rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, learning_rate, decay, momentum, epsilon) - - net = NetRMSProp(centered) - _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, - moment_ms, learning_rate, decay, momentum, epsilon) + net = NetRMSProp(decay, momentum, epsilon) + _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, + moment_ms, learning_rate) error = np.ones(shape=variable_np.shape) * 10e-6 diff = variable_ms.asnumpy() - variable_np @@ -126,13 +135,15 @@ def test_rmspropcenter(): if centered: rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, learning_rate, decay, momentum, epsilon) + net = NetCenteredRMSProp() + _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, + learning_rate, decay, momentum, epsilon) else: rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, learning_rate, decay, momentum, epsilon) - - net = NetRMSProp(centered) - _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, - learning_rate, decay, momentum, epsilon) + net = NetRMSProp(decay, momentum, epsilon) + _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, + learning_rate) error = np.ones(shape=variable_np.shape) * 10e-6 diff = variable_ms.asnumpy() - variable_np