提交 f6d7d977 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3281 Fix some API description of ops.

Merge pull request !3281 from liuxiao93/fix-api-bug
...@@ -448,7 +448,7 @@ class Squeeze(PrimitiveWithInfer): ...@@ -448,7 +448,7 @@ class Squeeze(PrimitiveWithInfer):
ValueError: If the corresponding dimension of the specified axis does not equal to 1. ValueError: If the corresponding dimension of the specified axis does not equal to 1.
Args: Args:
axis (int): Specifies the dimension indexes of shape to be removed, which will remove axis (Union[int, tuple(int)]): Specifies the dimension indexes of shape to be removed, which will remove
all the dimensions that are equal to 1. If specified, it must be int32 or int64. all the dimensions that are equal to 1. If specified, it must be int32 or int64.
Default: (), an empty tuple. Default: (), an empty tuple.
...@@ -1440,7 +1440,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer): ...@@ -1440,7 +1440,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
Inputs: Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
With float16, float32 or int32 data type. With float16, float32 or int32 data type.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. Data type must be int32. - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value should be >= 0.
Data type must be int32.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`, - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`,
should be greater than 0. should be greater than 0.
......
...@@ -3760,12 +3760,12 @@ class ApplyAdagradV2(PrimitiveWithInfer): ...@@ -3760,12 +3760,12 @@ class ApplyAdagradV2(PrimitiveWithInfer):
update_slots (bool): If `True`, `accum` will be updated. Default: True. update_slots (bool): If `True`, `accum` will be updated. Default: True.
Inputs: Inputs:
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type. - **var** (Parameter) - Variable to be updated. With float32 data type.
- **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`. - **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`.
With float32 or float16 data type. With float32 data type.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 or float16 data type. - **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 data type.
- **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`. - **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`.
With float32 or float16 data type. With float32 data type.
Outputs: Outputs:
Tuple of 2 Tensor, the updated parameters. Tuple of 2 Tensor, the updated parameters.
...@@ -3817,9 +3817,8 @@ class ApplyAdagradV2(PrimitiveWithInfer): ...@@ -3817,9 +3817,8 @@ class ApplyAdagradV2(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
valid_types = [mstype.float16, mstype.float32] validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name)
return var_dtype, accum_dtype return var_dtype, accum_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册