提交 1c1dbba5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5457 Fix logic in distribution utils common_dtype to support tensor inputs

Merge pull request !5457 from XunDeng/pp_issue_branch
......@@ -379,8 +379,12 @@ def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'):
if isinstance(arg_a, np.ndarray):
a_dtype = mstype.pytype_to_dtype(arg_a.dtype)
if isinstance(arg_a, np.ndarray):
else:
a_dtype = arg_a.dtype
if isinstance(arg_b, np.ndarray):
b_dtype = mstype.pytype_to_dtype(arg_b.dtype)
else:
b_dtype = arg_b.dtype
if a_dtype != b_dtype:
raise TypeError(f"{name_a} and {name_b} should have the same dtype.")
int_type = mstype.int_type + mstype.uint_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册