提交 f85eb146 编写于 作者: W wjj19950828

fixed compare bug

上级 807b4385
...@@ -60,7 +60,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): ...@@ -60,7 +60,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10):
result.shape, expect.shape) result.shape, expect.shape)
assert result.dtype == expect.dtype, "result.dtype: {} != expect.dtype: {}".format( assert result.dtype == expect.dtype, "result.dtype: {} != expect.dtype: {}".format(
result.dtype, expect.dtype) result.dtype, expect.dtype)
elif isinstance(result, (list, tuple)) and len(result) > 1: elif isinstance(result, (list, tuple)):
for i in range(len(result)): for i in range(len(result)):
if isinstance(result[i], (np.generic, np.ndarray)): if isinstance(result[i], (np.generic, np.ndarray)):
compare(result[i], expect[i], delta, rtol) compare(result[i], expect[i], delta, rtol)
...@@ -69,6 +69,8 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): ...@@ -69,6 +69,8 @@ def compare(result, expect, delta=1e-10, rtol=1e-10):
# deal with scalar tensor # deal with scalar tensor
elif len(expect) == 1: elif len(expect) == 1:
compare(result, expect[0], delta, rtol) compare(result, expect[0], delta, rtol)
else:
raise Exception("Compare diff wrong!!!!!!")
def randtool(dtype, low, high, shape): def randtool(dtype, low, high, shape):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册