提交 f4e8bca7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!787 Fix dtype judge sentence in infer_dtype function of hcom operations

Merge pull request !787 from zhouyuanshen/r0.2
...@@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer): ...@@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer):
Note: Note:
The operation of AllReduce does not support "prod" currently. The operation of AllReduce does not support "prod" currently.
The input of AllReduce does not support dtype "Bool".
Tensor must have same shape and format in all processes participating in the collective. Tensor must have same shape and format in all processes participating in the collective.
Args: Args:
...@@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer): ...@@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
...@@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer): ...@@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
...@@ -218,7 +217,7 @@ class ReduceScatter(PrimitiveWithInfer): ...@@ -218,7 +217,7 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
...@@ -275,11 +274,13 @@ class Broadcast(PrimitiveWithInfer): ...@@ -275,11 +274,13 @@ class Broadcast(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if not isinstance(x_dtype, tuple):
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name}'s input should be a tuple!")
for _ele in x_dtype:
if _ele.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
class _AlltoAll(PrimitiveWithInfer): class _AlltoAll(PrimitiveWithInfer):
""" """
AlltoAll is a collective operation. AlltoAll is a collective operation.
...@@ -318,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer): ...@@ -318,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
......
...@@ -55,7 +55,7 @@ class BroadCastNet(nn.Cell): ...@@ -55,7 +55,7 @@ class BroadCastNet(nn.Cell):
self.broadcast = Broadcast(0) self.broadcast = Broadcast(0)
def construct(self, x): def construct(self, x):
x = self.broadcast((x)) x, = self.broadcast((x,))
x = self.dense(x) x = self.dense(x)
return x return x
......
...@@ -52,7 +52,7 @@ class CommonNet(nn.Cell): ...@@ -52,7 +52,7 @@ class CommonNet(nn.Cell):
def __init__(self): def __init__(self):
super(CommonNet, self).__init__() super(CommonNet, self).__init__()
self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight") self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight")
self.logicalnot = P.LogicalNot().set_strategy(((4,1),)) self.logicalnot = P.LogicalNot().set_strategy(((4,2),))
self.equal = P.Equal().set_strategy(((4,2),(4,2))) self.equal = P.Equal().set_strategy(((4,2),(4,2)))
def construct(self, x, label): def construct(self, x, label):
...@@ -78,4 +78,5 @@ def common_net(): ...@@ -78,4 +78,5 @@ def common_net():
def test_bool_grad(): def test_bool_grad():
common_net() common_net()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册