提交 c05203ef 编写于 作者: J jiangjinsheng

add validate for geswitch and merge

上级 a3110549
...@@ -128,6 +128,7 @@ class GeSwitch(PrimitiveWithInfer): ...@@ -128,6 +128,7 @@ class GeSwitch(PrimitiveWithInfer):
return (data, data) return (data, data)
def infer_dtype(self, data_type, pred_type): def infer_dtype(self, data_type, pred_type):
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
return (data_type, data_type) return (data_type, data_type)
...@@ -153,7 +154,20 @@ class Merge(PrimitiveWithInfer): ...@@ -153,7 +154,20 @@ class Merge(PrimitiveWithInfer):
raise NotImplementedError raise NotImplementedError
def infer_shape(self, inputs): def infer_shape(self, inputs):
validator.check_integer('inputs len', len(inputs), 0, Rel.GT, self.name)
input_0 = inputs[0]
for i in range(1, len(inputs)):
if inputs[i] != input_0:
raise ValueError(f"For \'{self.name}\', the shape of {i}th input should be same as "
f"first input {input_0}, but got {inputs[i]}.")
return (inputs[0], [1]) return (inputs[0], [1])
def infer_dtype(self, inputs): def infer_dtype(self, inputs):
args = {}
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32) return (inputs[0], mstype.int32)
...@@ -2084,7 +2084,7 @@ class GetNext(PrimitiveWithInfer): ...@@ -2084,7 +2084,7 @@ class GetNext(PrimitiveWithInfer):
Note: Note:
GetNext op needs to be associated with network and also depends on the init_dataset interface, GetNext op needs to be associated with network and also depends on the init_dataset interface,
it can't be used directly as a single op. it can't be used directly as a single op.
For details, please refer to `nn.cell_wrapper.DataWrapper` source code. For details, please refer to `nn.DataWrapper` source code.
Args: Args:
types (list[:class:`mindspore.dtype`]): The type of the outputs. types (list[:class:`mindspore.dtype`]): The type of the outputs.
......
...@@ -33,7 +33,7 @@ def cond_data_test(x_init, y_init): ...@@ -33,7 +33,7 @@ def cond_data_test(x_init, y_init):
super(Net, self).__init__() super(Net, self).__init__()
self.square = P.Square() self.square = P.Square()
self.add = P.TensorAdd() self.add = P.TensorAdd()
self.value = Tensor(np.full((1), 3, dtype=np.float32)) self.value = Tensor(3, dtype=ms.float32)
self.switch = P.GeSwitch() self.switch = P.GeSwitch()
self.merge = P.Merge() self.merge = P.Merge()
self.less = P.Less() self.less = P.Less()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册