提交 bd57d123 编写于 作者: C c00425699

add_bool_type_check_in_comm_op

上级 aaa8d9ed
......@@ -162,6 +162,8 @@ class AllGather(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AllGather does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
......@@ -219,6 +221,8 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
......@@ -276,6 +280,8 @@ class Broadcast(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!")
return x_dtype
......@@ -318,6 +324,8 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册