提交 f85eb146 编写于 作者: W wjj19950828

fixed compare bug

上级 807b4385
......@@ -60,7 +60,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10):
result.shape, expect.shape)
assert result.dtype == expect.dtype, "result.dtype: {} != expect.dtype: {}".format(
result.dtype, expect.dtype)
elif isinstance(result, (list, tuple)) and len(result) > 1:
elif isinstance(result, (list, tuple)):
for i in range(len(result)):
if isinstance(result[i], (np.generic, np.ndarray)):
compare(result[i], expect[i], delta, rtol)
......@@ -69,6 +69,8 @@ def compare(result, expect, delta=1e-10, rtol=1e-10):
# deal with scalar tensor
elif len(expect) == 1:
compare(result, expect[0], delta, rtol)
else:
raise Exception("Compare diff wrong!!!!!!")
def randtool(dtype, low, high, shape):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册