提交 c05203ef 编写于 作者: J jiangjinsheng

add validate for geswitch and merge

上级 a3110549
......@@ -128,6 +128,7 @@ class GeSwitch(PrimitiveWithInfer):
return (data, data)
def infer_dtype(self, data_type, pred_type):
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
return (data_type, data_type)
......@@ -153,7 +154,20 @@ class Merge(PrimitiveWithInfer):
raise NotImplementedError
def infer_shape(self, inputs):
validator.check_integer('inputs len', len(inputs), 0, Rel.GT, self.name)
input_0 = inputs[0]
for i in range(1, len(inputs)):
if inputs[i] != input_0:
raise ValueError(f"For \'{self.name}\', the shape of {i}th input should be same as "
f"first input {input_0}, but got {inputs[i]}.")
return (inputs[0], [1])
def infer_dtype(self, inputs):
args = {}
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)
......@@ -2084,7 +2084,7 @@ class GetNext(PrimitiveWithInfer):
Note:
GetNext op needs to be associated with network and also depends on the init_dataset interface,
it can't be used directly as a single op.
For details, please refer to `nn.cell_wrapper.DataWrapper` source code.
For details, please refer to `nn.DataWrapper` source code.
Args:
types (list[:class:`mindspore.dtype`]): The type of the outputs.
......
......@@ -33,7 +33,7 @@ def cond_data_test(x_init, y_init):
super(Net, self).__init__()
self.square = P.Square()
self.add = P.TensorAdd()
self.value = Tensor(np.full((1), 3, dtype=np.float32))
self.value = Tensor(3, dtype=ms.float32)
self.switch = P.GeSwitch()
self.merge = P.Merge()
self.less = P.Less()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册