提交 2b511990 编写于 作者: V VectorSL

gpu floatstatus add type check

上级 f23bfe0d
......@@ -1712,6 +1712,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1]
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name)
return x_dtype
class NPUAllocFloatStatus(PrimitiveWithInfer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册