提交 eb4e1a0d 编写于 作者: Z zhaozhenlong

ScatterAdd ScatterMax indices limited to int32

上级 bbc64b88
......@@ -2222,7 +2222,7 @@ class ScatterMax(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do max operation whose data type should be int.
- **indices** (Tensor) - The index to do max operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the maximum operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
......@@ -2249,7 +2249,7 @@ class ScatterMax(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
......@@ -2266,7 +2266,7 @@ class ScatterAdd(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do add operation whose data type should be int.
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the add operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
......@@ -2292,7 +2292,7 @@ class ScatterAdd(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {'x': x_dtype, 'updates': updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册