未验证 提交 738480bb 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

Record autotest wrong code (#5923)

* add log1p autotest

* add log autotest

* add autotest fake_code_gen

* refine code

* add code color

* add clear_note_fake_program

* code format

* fix clear list bug

* fix clear list bug

* code format

* delete useless file

* refine tensor method test

* fix comments
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 c2931ec5
......@@ -18,57 +18,19 @@ import unittest
from collections import OrderedDict
import numpy as np
from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
def _test_log1p(test_case, shape, device):
input_arr = np.exp(np.random.randn(*shape)) - 1
np_out = np.log1p(input_arr)
x = flow.Tensor(
input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
of_out = flow.log1p(x)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
of_out = of_out.sum()
of_out.backward()
np_out_grad = 1.0 / (1 + input_arr)
test_case.assertTrue(
np.allclose(x.grad.numpy(), np_out_grad, 0.0001, 0.0001, equal_nan=True)
)
def _test_log1p_tensor_function(test_case, shape, device):
input_arr = np.exp(np.random.randn(*shape)) - 1
np_out = np.log1p(input_arr)
x = flow.Tensor(
input_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
of_out = x.log1p()
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
of_out = of_out.sum()
of_out.backward()
np_out_grad = 1.0 / (1 + input_arr)
test_case.assertTrue(
np.allclose(x.grad.numpy(), np_out_grad, 0.0001, 0.0001, equal_nan=True)
)
from automated_test_util import *
@flow.unittest.skip_unless_1n1d()
class TestLog1p(flow.unittest.TestCase):
def test_log1p(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_log1p, _test_log1p_tensor_function]
arg_dict["shape"] = [(2,), (2, 3), (2, 3, 4, 5)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
class TestLog1pModule(flow.unittest.TestCase):
@autotest()
def test_log1p_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
return torch.log1p(x)
if __name__ == "__main__":
......
......@@ -127,51 +127,13 @@ class TestCos(flow.unittest.TestCase):
arg[0](test_case, *arg[1:])
def _test_log(test_case, shape, device):
np_arr = np.abs(np.random.randn(*shape))
input = flow.Tensor(np_arr, dtype=flow.float32, device=flow.device(device))
of_out = flow.log(input)
np_out = np.log(np_arr)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
def _test_log_nan_value(test_case, shape, device):
arr = np.array([-0.7168, -0.5471, -0.8933, -1.4428, -0.119])
input = flow.Tensor(arr, dtype=flow.float32, device=flow.device(device))
np_out = np.full((5,), np.nan)
of_out = flow.log(input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
def _test_log_backward(test_case, shape, device):
x = flow.Tensor(
np.random.randn(*shape),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
y = flow.log(x)
z = y.sum()
z.backward()
np_grad = 1 / x.numpy()
test_case.assertTrue(
np.allclose(x.grad.numpy(), np_grad, 1e-05, 1e-05, equal_nan=True)
)
@flow.unittest.skip_unless_1n1d()
class TestLog(flow.unittest.TestCase):
def test_log(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_log, _test_log_nan_value, _test_log_backward]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
class TestLogModule(flow.unittest.TestCase):
@autotest()
def test_log_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
return torch.log(x)
def _test_std(test_case, shape, device):
......
......@@ -558,35 +558,43 @@ class TestTensor(flow.unittest.TestCase):
np_out = np.mean(input.numpy(), axis=0)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
@flow.unittest.skip_unless_1n1d()
def test_neg(test_case):
input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
of_out = -input
np_out = -input.numpy()
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
@autotest()
def test_log_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
return x.log()
@flow.unittest.skip_unless_1n1d()
def test_negative(test_case):
input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
of_out = input.negative()
np_out = -input.numpy()
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
@autotest()
def test_log1p_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
return x.log1p()
@flow.unittest.skip_unless_1n1d()
def test_greater(test_case):
input1 = flow.Tensor(np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32)
input2 = flow.Tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)
of_out = input1.gt(input2)
np_out = np.greater(input1.numpy(), input2.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
@autotest()
def test_neg_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
return -x
@flow.unittest.skip_unless_1n1d()
def test_less(test_case):
input1 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
input2 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = input1.lt(input2)
np_out = np.less(input1.numpy(), input2.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
@autotest()
def test_negative_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
return x.negative()
@autotest(auto_backward=False)
def test_greater_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=3, dim1=2, dim2=3).to(device)
y = random_pytorch_tensor(ndim=3, dim1=2, dim2=3).to(device)
return x.gt(y)
@autotest(auto_backward=False)
def test_less_tensor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=3, dim1=2, dim2=3).to(device)
y = random_pytorch_tensor(ndim=3, dim1=2, dim2=3).to(device)
return x.lt(y)
@flow.unittest.skip_unless_1n1d()
def test_tensor_slice(test_case):
......@@ -650,28 +658,19 @@ class TestTensor(flow.unittest.TestCase):
np.allclose(tensor.numpy(), np.array(scalar), 0.0001, 0.0001)
)
@flow.unittest.skip_unless_1n1d()
def test_floor(test_case):
input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32)
of_out = input.floor()
np_out = np.floor(input.numpy())
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
@autotest()
def test_tensor_floor_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
y = x.floor()
return y
@flow.unittest.skip_unless_1n1d()
def test_tensor_round(test_case):
shape = (2, 3)
np_input = np.random.randn(*shape)
of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
of_out = flow.round(of_input)
np_out = np.round(np_input)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
of_out = of_out.sum()
of_out.backward()
test_case.assertTrue(
np.allclose(of_input.grad.numpy(), np.zeros(shape), 0.0001, 0.0001)
)
@autotest()
def test_tensor_round_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
y = x.round()
return y
def _test_tensor_reshape(test_case):
x = np.array(
......@@ -682,7 +681,6 @@ class TestTensor(flow.unittest.TestCase):
np_shape = (2, 2, 2, 2)
test_case.assertTrue(np.array_equal(of_shape, np_shape))
@flow.unittest.skip_unless_1n1d()
@autotest()
def test_reshape_tensor_with_random_data(test_case):
device = random_device()
......@@ -690,7 +688,6 @@ class TestTensor(flow.unittest.TestCase):
y = x.reshape(-1,)
return y
@flow.unittest.skip_unless_1n1d()
@autotest()
def test_tensor_squeeze_with_random_data(test_case):
device = random_device()
......@@ -698,7 +695,6 @@ class TestTensor(flow.unittest.TestCase):
y = x.squeeze(random().to(int))
return y
@flow.unittest.skip_unless_1n1d()
@autotest()
def test_flow_unsqueeze_with_random_data(test_case):
device = random_device()
......@@ -706,7 +702,6 @@ class TestTensor(flow.unittest.TestCase):
y = x.unsqueeze(random(1, 3).to(int))
return y
@flow.unittest.skip_unless_1n1d()
@autotest()
def test_permute_flow_with_random_data(test_case):
device = random_device()
......@@ -719,7 +714,6 @@ class TestTensor(flow.unittest.TestCase):
)
return y
@flow.unittest.skip_unless_1n1d()
@autotest()
def test_transpose_tensor_with_random_data(test_case):
device = random_device()
......@@ -749,44 +743,6 @@ class TestTensor(flow.unittest.TestCase):
np_out = np.equal(arr1, arr2)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def _test_tensor_atan(test_case, shape, device):
np_input = np.random.randn(*shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
of_out = of_input.atan()
np_out = np.arctan(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
of_out = of_out.sum()
of_out.backward()
np_out_grad = 1 / (1 + np_input ** 2)
test_case.assertTrue(
np.allclose(
of_input.grad.numpy(), np_out_grad, 1e-05, 1e-05, equal_nan=True
)
)
def _test_tensor_arctan(test_case, shape, device):
np_input = np.random.randn(*shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
of_out = of_input.arctan()
np_out = np.arctan(np_input)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
)
of_out = of_out.sum()
of_out.backward()
np_out_grad = 1 / (1 + np_input ** 2)
test_case.assertTrue(
np.allclose(
of_input.grad.numpy(), np_out_grad, 1e-05, 1e-05, equal_nan=True
)
)
@flow.unittest.skip_unless_1n1d()
def test_tensor_detach(test_case):
shape = (2, 3, 4, 5)
......@@ -798,20 +754,6 @@ class TestTensor(flow.unittest.TestCase):
test_case.assertEqual(z.is_leaf, True)
test_case.assertEqual(z.grad_fn, None)
@flow.unittest.skip_unless_1n1d()
def test_tensor_clamp_(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = input.clamp(0.1, 0.5)
np_out = np.clip(input.numpy(), 0.1, 0.5)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
@flow.unittest.skip_unless_1n1d()
def test_tensor_clip_(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = input.clip(0.1, 0.5)
np_out = np.clip(input.numpy(), 0.1, 0.5)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
def _test_cast_tensor_function(test_case):
shape = (2, 3, 4, 5)
np_arr = np.random.randn(*shape).astype(np.float32)
......
......@@ -34,6 +34,11 @@ def torch_tensor_to_flow(x):
return flow.tensor(x.cpu().numpy())
note_pytorch_method_names = []
note_pytorch_args = []
note_pytorch_kwargs = []
class PyTorchDoesNotSupportError(Exception):
def __init__(self, exc):
self.exc = exc
......@@ -100,6 +105,22 @@ def get_args(callable, *args, **kwargs):
continue
pytorch_kwargs[key] = get_pytorch_value(value)
oneflow_kwargs[key] = get_oneflow_value(value)
if not isinstance(callable, (torch_original.nn.Module)):
new_pytorch_args = []
new_pytorch_kwargs = {}
for x in pytorch_args:
if type(x) is torch_original.Tensor:
continue
new_pytorch_args.append(x)
for key, value in pytorch_kwargs.items():
if type(value) is torch_original.Tensor:
continue
new_pytorch_kwargs[key] = value
note_pytorch_method_names.append(callable.__name__)
note_pytorch_args.append(new_pytorch_args)
note_pytorch_kwargs.append(new_pytorch_kwargs)
return (pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs)
......@@ -179,6 +200,61 @@ def GetDualObject(name, pytorch, oneflow):
return Cls(name, pytorch, oneflow)
def note_print_args(x, end=True):
if end:
if isinstance(x, str):
print(f"\033[32m'{x}, '\033[0m", end="")
else:
print(f"\033[32m{x}, \033[0m", end="")
else:
if isinstance(x, str):
print(f"\033[32m'{x}'\033[0m", end="")
else:
print(f"\033[32m{x}\033[0m", end="")
def note_print_kwargs(x, y, end=True):
if end:
if isinstance(y, str):
print(f"\033[32m{x}='{y}, '\033[0m", end="")
else:
print(f"\033[32m{x}={y}, \033[0m", end="")
else:
if isinstance(y, str):
print(f"\033[32m{x}='{y}'\033[0m", end="")
else:
print(f"\033[32m{x}={y}\033[0m", end="")
def print_note_fake_program():
code_len = len(note_pytorch_method_names)
for i in range(code_len):
note_pytorch_args_len = len(note_pytorch_args[i])
note_pytorch_kwargs_len = len(note_pytorch_kwargs[i])
print(f"\033[32m{note_pytorch_method_names[i]}\033[0m", end="")
print(f"\033[32m(\033[0m", end="")
if note_pytorch_args[i]:
index = 0
for x in note_pytorch_args[i]:
index += 1
note_print_args(x, index < note_pytorch_args_len)
if note_pytorch_kwargs[i]:
index = 0
for x in note_pytorch_kwargs[i].keys():
index += 1
note_print_kwargs(
x, note_pytorch_kwargs[i][x], index < note_pytorch_kwargs_len
)
print(f"\033[32m)\033[0m")
def clear_note_fake_program():
note_pytorch_method_names.clear()
note_pytorch_args.clear()
note_pytorch_kwargs.clear()
class DualObject:
def __init__(self, name, pytorch, oneflow):
self.name = name
......@@ -251,13 +327,16 @@ def check_tensor_equality(torch_tensor, flow_tensor, rtol=0.0001, atol=1e-05):
"Grads are not equal. PyTorch grad: \n{torch_grad}\n, OneFlow grad: \n{flow_grad}"
)
return False
return np.allclose(
equality_res = np.allclose(
torch_tensor.detach().cpu().numpy(),
flow_tensor.numpy(),
rtol=rtol,
atol=atol,
equal_nan=True,
)
if equality_res == False:
print_note_fake_program()
return equality_res
@equality_checker(type(None), type(None))
......@@ -275,6 +354,7 @@ def autotest(n=20, auto_backward=True, rtol=0.0001, atol=1e-05):
loop_limit = n * 20
loop = 0
while n > 0:
clear_note_fake_program()
if loop > loop_limit:
raise ValueError("autotest stuck in an endless loop!")
dual_modules_to_test.clear()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册