diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index 9743f9e3fd266a1058dbc525d6888cc67db2c163..3f8d50857a51ade22670c4969aaca0512ce96e03 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -128,6 +128,7 @@ class GeSwitch(PrimitiveWithInfer): return (data, data) def infer_dtype(self, data_type, pred_type): + validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name) validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) return (data_type, data_type) @@ -153,7 +154,20 @@ class Merge(PrimitiveWithInfer): raise NotImplementedError def infer_shape(self, inputs): + validator.check_integer('inputs len', len(inputs), 0, Rel.GT, self.name) + input_0 = inputs[0] + + for i in range(1, len(inputs)): + if inputs[i] != input_0: + raise ValueError(f"For \'{self.name}\', the shape of {i}th input should be same as " + f"first input {input_0}, but got {inputs[i]}.") + return (inputs[0], [1]) def infer_dtype(self, inputs): + args = {} + for i, item in enumerate(inputs): + args['inputs[%d]' % i] = item + + validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) return (inputs[0], mstype.int32) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a80500c0e6414fc649eb3794e05914aea53b2200..0787720ecd35bf0b79fa74c3d249d3b02deb3d10 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2084,7 +2084,7 @@ class GetNext(PrimitiveWithInfer): Note: GetNext op needs to be associated with network and also depends on the init_dataset interface, it can't be used directly as a single op. - For details, please refer to `nn.cell_wrapper.DataWrapper` source code. + For details, please refer to `nn.DataWrapper` source code. Args: types (list[:class:`mindspore.dtype`]): The type of the outputs. diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index b182396e4fedb20b90dc5cb72a7e4d8aa96e0fe3..6204bdbabb12e71c88c36d98126daf56a6650d55 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -33,7 +33,7 @@ def cond_data_test(x_init, y_init): super(Net, self).__init__() self.square = P.Square() self.add = P.TensorAdd() - self.value = Tensor(np.full((1), 3, dtype=np.float32)) + self.value = Tensor(3, dtype=ms.float32) self.switch = P.GeSwitch() self.merge = P.Merge() self.less = P.Less()