提交 67a2c5b7 编写于 作者: J jiangjinsheng

fix InvertPermutation

上级 2711a628
...@@ -986,11 +986,10 @@ class InvertPermutation(PrimitiveWithInfer): ...@@ -986,11 +986,10 @@ class InvertPermutation(PrimitiveWithInfer):
values can not be negative. values can not be negative.
Inputs: Inputs:
- **input_x** (Union(tuple[int], Tensor[int])) - The input tuple is constructed by multiple - **input_x** (Union(tuple[int]) - The input tuple is constructed by multiple
integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices. integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices.
The values must include 0. There can be no duplicate values or negative values. The values must include 0. There can be no duplicate values or negative values.
If the input is Tensor, it must be 1-d and the dtype is int. Only constant value is allowed. Only constant value is allowed.
Outputs: Outputs:
tuple[int]. the lenth is same as input. tuple[int]. the lenth is same as input.
...@@ -1014,9 +1013,7 @@ class InvertPermutation(PrimitiveWithInfer): ...@@ -1014,9 +1013,7 @@ class InvertPermutation(PrimitiveWithInfer):
raise ValueError(f'For \'{self.name}\' the input value must be const.') raise ValueError(f'For \'{self.name}\' the input value must be const.')
validator.check_value_type("shape", x_shp, [tuple, list], self.name) validator.check_value_type("shape", x_shp, [tuple, list], self.name)
if mstype.issubclass_(x['dtype'], mstype.tensor): if mstype.issubclass_(x['dtype'], mstype.tensor):
validator.check('x dimension', len(x_shp), '', 1, Rel.EQ, self.name) raise ValueError(f'For \'{self.name}\' the input value must be non-Tensor.')
validator.check_tensor_type_same({'x dtype': x['dtype']}, mstype.int_type, self.name)
x_value = [int(i) for i in x_value.asnumpy()]
z = [x_value[i] for i in range(len(x_value))] z = [x_value[i] for i in range(len(x_value))]
z.sort() z.sort()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册