提交 cc0add56 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1930 fix validator for ScatterNdUpdate

Merge pull request !1930 from jiangjinsheng/issue_doc
...@@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer): ...@@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
Creates an empty tensor, and set values by scattering the update tensor depending on indices. Creates an empty tensor, and set values by scattering the update tensor depending on indices.
Inputs: Inputs:
- **indices** (Tensor) - The index of scattering in the new tensor. - **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type.
- **update** (Tensor) - The source Tensor to be scattered. - **update** (Tensor) - The source Tensor to be scattered.
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices. - **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
...@@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer): ...@@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape): def __infer__(self, indices, update, shape):
shp = shape['value'] shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name) validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp): for i, x in enumerate(shp):
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
...@@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer): ...@@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):
Inputs: Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor. - **indices** (Tensor) - The index of input tensor. With int32 data type.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input, - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:]. and update.shape = indices.shape + input_x.shape[1:].
...@@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer): ...@@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.
Examples: Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> update = Tensor(np_update, mindspore.float32)
>>> op = P.ScatterUpdate() >>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update) >>> output = op(input_x, indices, update)
""" """
...@@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer): ...@@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterUpdate""" """Init ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
...@@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer): ...@@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
...@@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer): ...@@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):
Inputs: Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor. - **indices** (Tensor) - The index of input tensor, with int32 data type.
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input. - **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
Outputs: Outputs:
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.
Examples: Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate() >>> op = P.ScatterNdUpdate()
...@@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): ...@@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterNdUpdate""" """Init ScatterNdUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
...@@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): ...@@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册