From b7039a57975b74f10bba950edfbea8a937d87826 Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Tue, 12 May 2020 00:05:34 +0800 Subject: [PATCH] fix check dtype in comm_ops.py --- mindspore/ops/operations/comm_ops.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 16d987a45..f5c005e81 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -39,6 +39,8 @@ class ReduceOp: PROD = "prod" +target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) + class AllReduce(PrimitiveWithInfer): """ Reduces the tensor data across all devices in such a way that all devices will get the same final result. @@ -102,8 +104,7 @@ class AllReduce(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype.element_type() == mstype.bool_: - raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") + validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) return x_dtype @@ -161,8 +162,7 @@ class AllGather(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype.element_type() == mstype.bool_: - raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") + validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) return x_dtype def __call__(self, tensor): @@ -219,8 +219,7 @@ class ReduceScatter(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype.element_type() == mstype.bool_: - raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") + validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) return x_dtype def __call__(self, tensor): @@ -279,8 +278,7 @@ class Broadcast(PrimitiveWithInfer): if not isinstance(x_dtype, tuple): raise TypeError(f"{self.name}'s input should be a tuple!") for _ele in x_dtype: - if _ele.element_type() == mstype.bool_: - raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") + validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name) return x_dtype @@ -322,8 +320,7 @@ class _AlltoAll(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype.element_type() == mstype.bool_: - raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") + validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) return x_dtype def __call__(self, tensor): -- GitLab