diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 1ecb20056e2b73b0850fe2d1972c1e252060a973..cf18d1cf0f4454d474e8c1c0600b333b0591249d 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -132,23 +132,19 @@ class Range(Cell): class LinSpace(Cell): r""" - Generates values in an interval. And return the corresponding interpolation accroding to assist. + Generates values in an interval. Args: - - **start** (Union[int, float]) - The start of interval, With shape of 0-D. - - **stop** (Union[int, float]) - The end of interval, With shape of 0-D. - - **num** (int) - ticks number in the interval, the ticks include start and stop value. - With shape of 0-D. + start (Union[int, float]): The start of interval. With shape of 0-D. + stop (Union[int, float]): The end of interval. With shape of 0-D. + num (int): ticks number in the interval, the ticks include start and stop value. With shape of 0-D. Outputs: Tensor, With type same as `start`. The shape is 1-D with length of `num`. Examples: - >>> linspace = nn.LinSpace() - >>> start = Tensor(1, mindspore.float32) - >>> stop = Tensor(10, mindspore.float32) - >>> num = Tensor(5, mindspore.int32) - >>> output = linspace(start, stop, num) + >>> linspace = nn.LinSpace(1, 10, 5) + >>> output = linspace() [1, 3.25, 5.5, 7.75, 10] """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 7117e494e491f52b4d904a593ef24a86515ab825..3c7615ce6e983467a2149fdc990d4d1d1527a0f1 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2711,7 +2711,7 @@ class ROIAlign(PrimitiveWithInfer): Examples: >>> input_tensor = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32) >>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32) - >>> roi_align = P.ROIAlign(1, 1, 0.5, 2) + >>> roi_align = P.ROIAlign(2, 2, 0.5, 2) >>> output_tensor = roi_align(input_tensor, rois) >>> assert output_tensor == Tensor(np.array([[[[2.15]]]]), mindspore.float32) """ @@ -4980,4 +4980,5 @@ class LRN(PrimitiveWithInfer): return x_dtype def infer_shape(self, x_shape): + validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) return x_shape