未验证 提交 ba6d677a 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

Add autotest (#5899)

* restruct logsoftmax and abs test

* add hardtanh test

* refine batchnorm autotest

* add meshgrid autotest

* add pow autotest

* add stack autotest

* delete prelu useless code

* change Stack Module Test Api

* fix comments

* fix softmax bug

* fix sign bug

* fix sign

* auto format by CI
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 c534c5e9
......@@ -31,8 +31,6 @@ REGISTER_USER_OP("prelu")
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
const user_op::TensorDesc& alpha_tensor =
ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0);
ctx->NewBuilder()
.Split(user_op::OpArg("x", 0), 1)
.Split(user_op::OpArg("alpha", 0), 0)
......
......@@ -499,7 +499,7 @@ class Hardsigmoid(Module):
class Softmax(Module):
def __init__(self, dim: Optional[int] = None):
super().__init__()
self.axis = -1 if dim is None else dim
self.axis = 1 if dim is None else dim
def forward(self, x):
(need_transpose, permute) = _softmax_need_transpose(x, self.axis)
......
......@@ -46,7 +46,7 @@ class MeshGrid(Module):
return outputs
def meshgrid_op(*inputs):
def meshgrid_op(*tensors):
"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/_modules/torch/functional.html#meshgrid
......@@ -83,7 +83,7 @@ def meshgrid_op(*inputs):
[4., 5., 6.],
[4., 5., 6.]], dtype=oneflow.float32)
"""
return MeshGrid()(inputs)
return MeshGrid()(tensors)
if __name__ == "__main__":
......
......@@ -15,77 +15,18 @@ limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
from automated_test_util import *
from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
def _test_abs_forward(test_case, device):
input = flow.Tensor(np.random.randn(2, 3).astype(np.float32))
of_out = flow.abs(input)
np_out = np.abs(input.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
test_case.assertTrue(np.allclose(input.abs().numpy(), np_out, 1e-05, 1e-05))
def _test_abs_tensor_function_forward(test_case, device):
x = np.random.randn(2, 3).astype(np.float32)
input = flow.Tensor(x, dtype=flow.float32)
np_out = np.abs(x)
of_out = input.abs()
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
def _test_abs_backward(test_case, device):
np_input = np.random.randn(2, 3).astype(np.float32)
input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
of_out = flow.abs(input).sum()
of_out.backward()
np_grad = np.where(np_input > 0, 1, -1)
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))
def _test_abs_tensor_function_backward(test_case, device):
np_input = np.random.randn(2, 3).astype(np.float32)
input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
of_out = input.abs().sum()
of_out.backward()
np_grad = np.where(np_input > 0, 1, -1)
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-05, 1e-05))
import oneflow.unittest
@flow.unittest.skip_unless_1n1d()
class TestAbs(flow.unittest.TestCase):
def test_abs(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_abs_forward,
_test_abs_tensor_function_forward,
_test_abs_backward,
_test_abs_tensor_function_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
def test_flow_abs_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_flow_against_pytorch(test_case, "abs", device=device)
def test_flow_tensor_abs_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_tensor_against_pytorch(test_case, "abs", device=device)
class TestAbsModule(flow.unittest.TestCase):
@autotest()
def test_abs_with_0shape_data(test_case):
device = random_device()
x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device)
x = random_pytorch_tensor().to(device)
y = torch.abs(x)
return y
......
......@@ -145,33 +145,8 @@ class TestGelu(flow.unittest.TestCase):
return y
def numpy_softmax(x, axis):
x = x - x.max(axis=axis, keepdims=True)
y = np.exp(x)
return y / y.sum(axis=axis, keepdims=True)
def numpy_logsoftmax(x, dim):
e_x = np.exp(x - np.max(x, axis=dim, keepdims=True))
return np.log(e_x / e_x.sum(axis=dim, keepdims=True))
def numpy_softplus(x, beta, threshold):
return np.where(
x * beta > threshold, x, 1.0 / beta * np.log(1.0 + np.exp(beta * x))
)
def numpy_mish_grad(x):
f = 1 + np.exp(x)
y_grad = (f * f - 1) / (f * f + 1) + x * (4 * f * (f - 1)) / (
(f * f + 1) * (f * f + 1)
)
return y_grad
@flow.unittest.skip_unless_1n1d()
class TestSigmoid(flow.unittest.TestCase):
class TestSigmoidModule(flow.unittest.TestCase):
@autotest()
def test_sigmoid_module_with_random_data(test_case):
m = torch.nn.Sigmoid()
......@@ -197,96 +172,17 @@ class TestSigmoid(flow.unittest.TestCase):
return y
def _test_softmax(test_case, device):
axis = 0
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(2, 3, 4, 5)
x = flow.Tensor(arr, device=flow.device(device))
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))
def _test_softmax_dim_1(test_case, device):
axis = 1
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(9, 7, 8, 16)
x = flow.Tensor(arr, device=flow.device(device))
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))
def _test_softmax_dim_2(test_case, device):
axis = 2
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(2, 5, 6, 3)
x = flow.Tensor(arr, device=flow.device(device))
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))
def _test_softmax_dim_3(test_case, device):
axis = 3
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(1, 3, 4, 7)
x = flow.Tensor(arr, device=flow.device(device))
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))
axis2 = -1
m2 = flow.nn.Softmax(dim=axis)
y2 = m(x)
output2 = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y2.numpy(), output2, 1e-05, 1e-05))
def _test_softmax_backward_normal(test_case, device):
x_grad = np.zeros((2, 3, 4, 5))
axis = 0
m = flow.nn.Softmax(dim=axis)
x = flow.Tensor(
np.random.randn(2, 3, 4, 5),
requires_grad=True,
device=flow.device(device),
dtype=flow.float64,
)
y = m(x).sum()
y.backward()
test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-05, 1e-05))
def _test_softmax_backward_1_dim(test_case, device):
a = flow.tensor(
[1, 2], dtype=flow.float64, device=flow.device(device), requires_grad=True
)
b = flow.tensor(
[3, 4], dtype=flow.float64, device=flow.device(device), requires_grad=True
)
c = a * b
m = flow.nn.Softmax(dim=None)
d = m(c)
d[0].backward()
a_grad = np.array([0.01994417, -0.0265922267])
test_case.assertTrue(np.allclose(a.grad.numpy(), a_grad, 1e-05, 1e-05))
@flow.unittest.skip_unless_1n1d()
class TestSoftmax(flow.unittest.TestCase):
def test_softmax(test_case):
arg_dict = OrderedDict()
arg_dict["fun"] = [
_test_softmax,
_test_softmax_dim_1,
_test_softmax_dim_2,
_test_softmax_dim_3,
_test_softmax_backward_normal,
_test_softmax_backward_1_dim,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest()
def test_softmax_module_with_random_data(test_case):
m = torch.nn.Softmax(dim=random(low=1, high=4).to(int) | nothing())
m.train(random())
device = random_device()
m.to(device)
x = random_pytorch_tensor(ndim=4).to(device)
y = m(x)
return y
@flow.unittest.skip_unless_1n1d()
......@@ -302,148 +198,17 @@ class TestHardsigmoidModule(flow.unittest.TestCase):
return y
def _test_logsoftmax(test_case, device):
dim = 1
m = flow.nn.LogSoftmax(dim)
input_arr = np.random.randn(4, 7)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
output = numpy_logsoftmax(input_arr, dim)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))
def _test_logsoftmax_dim_2(test_case, device):
dim = 2
m = flow.nn.LogSoftmax(dim)
input_arr = np.random.randn(3, 4, 5)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
output = numpy_logsoftmax(input_arr, dim)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))
def _test_logsoftmax_dim_3(test_case, device):
dim = 3
m = flow.nn.LogSoftmax(dim)
input_arr = np.random.randn(8, 9, 7, 3)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
output = numpy_logsoftmax(input_arr, dim)
test_case.assertTrue(np.allclose(y.numpy(), output, 1e-05, 1e-05))
def _test_logsoftmax_backward(test_case, device):
axis = 0
m = flow.nn.LogSoftmax(axis)
input_arr = np.array(
[
[
[
[2.0, 1.0, 9.0, 3.0, 4.0],
[1.0, 6.0, 7.0, 1.0, 4.0],
[4.0, 7.0, 5.0, 8.0, 1.0],
[9.0, 5.0, 7.0, 8.0, 5.0],
],
[
[1.0, 1.0, 5.0, 3.0, 5.0],
[3.0, 6.0, 3.0, 7.0, 8.0],
[8.0, 8.0, 1.0, 2.0, 6.0],
[3.0, 5.0, 6.0, 1.0, 1.0],
],
[
[8.0, 3.0, 6.0, 3.0, 7.0],
[8.0, 5.0, 1.0, 2.0, 7.0],
[3.0, 9.0, 4.0, 6.0, 5.0],
[5.0, 1.0, 2.0, 3.0, 6.0],
],
],
[
[
[3.0, 5.0, 3.0, 1.0, 7.0],
[5.0, 2.0, 6.0, 3.0, 5.0],
[5.0, 1.0, 8.0, 6.0, 9.0],
[9.0, 8.0, 4.0, 5.0, 1.0],
],
[
[7.0, 5.0, 7.0, 1.0, 6.0],
[3.0, 3.0, 6.0, 6.0, 7.0],
[9.0, 4.0, 1.0, 5.0, 7.0],
[7.0, 6.0, 9.0, 8.0, 6.0],
],
[
[6.0, 7.0, 5.0, 3.0, 9.0],
[4.0, 1.0, 2.0, 3.0, 2.0],
[4.0, 3.0, 8.0, 7.0, 8.0],
[1.0, 3.0, 8.0, 6.0, 2.0],
],
],
]
)
x = flow.Tensor(
input_arr, requires_grad=True, device=flow.device(device), dtype=flow.float64
)
x_grad = np.array(
[
[
[
[0.46211716, 0.96402758, -0.99505475, -0.76159416, 0.90514825],
[0.96402758, -0.96402758, -0.46211716, 0.76159416, 0.46211716],
[0.46211716, -0.99505475, 0.90514825, -0.76159416, 0.9993293],
[0.0, 0.90514825, -0.90514825, -0.90514825, -0.96402758],
],
[
[0.99505475, 0.96402758, 0.76159416, -0.76159416, 0.46211716],
[0.0, -0.90514825, 0.90514825, -0.46211716, -0.46211716],
[0.46211716, -0.96402758, 0.0, 0.90514825, 0.46211716],
[0.96402758, 0.46211716, 0.90514825, 0.9981779, 0.9866143],
],
[
[-0.76159416, 0.96402758, -0.46211716, 0.0, 0.76159416],
[-0.96402758, -0.96402758, 0.46211716, 0.46211716, -0.9866143],
[0.46211716, -0.99505475, 0.96402758, 0.46211716, 0.90514825],
[-0.96402758, 0.76159416, 0.99505475, 0.90514825, -0.96402758],
],
],
[
[
[-0.46211716, -0.96402758, 0.99505475, 0.76159416, -0.90514825],
[-0.96402758, 0.96402758, 0.46211716, -0.76159416, -0.46211716],
[-0.46211716, 0.99505475, -0.90514825, 0.76159416, -0.9993293],
[0.0, -0.90514825, 0.90514825, 0.90514825, 0.96402758],
],
[
[-0.99505475, -0.96402758, -0.76159416, 0.76159416, -0.46211716],
[0.0, 0.90514825, -0.90514825, 0.46211716, 0.46211716],
[-0.46211716, 0.96402758, 0.0, -0.90514825, -0.46211716],
[-0.96402758, -0.46211716, -0.90514825, -0.9981779, -0.9866143],
],
[
[0.76159416, -0.96402758, 0.46211716, 0.0, -0.76159416],
[0.96402758, 0.96402758, -0.46211716, -0.46211716, 0.9866143],
[-0.46211716, 0.99505475, -0.96402758, -0.46211716, -0.90514825],
[0.96402758, -0.76159416, -0.99505475, -0.90514825, 0.96402758],
],
],
]
)
y = m(x).sum()
y.backward()
test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-05, 1e-05))
@flow.unittest.skip_unless_1n1d()
class TestLogSoftmax(flow.unittest.TestCase):
def test_log_softmax(test_case):
arg_dict = OrderedDict()
arg_dict["fun"] = [
_test_logsoftmax,
_test_logsoftmax_dim_2,
_test_logsoftmax_dim_3,
_test_logsoftmax_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
class TestLogSoftmaxModule(flow.unittest.TestCase):
@autotest()
def test_logsoftmax_module_with_random_data(test_case):
m = torch.nn.LogSoftmax(dim=random(low=1, high=4).to(int) | nothing())
m.train(random())
device = random_device()
m.to(device)
x = random_pytorch_tensor(ndim=4).to(device)
y = m(x)
return y
@flow.unittest.skip_unless_1n1d()
......@@ -459,6 +224,12 @@ class TestLogSigmoidModule(flow.unittest.TestCase):
return y
def numpy_softplus(x, beta, threshold):
return np.where(
x * beta > threshold, x, 1.0 / beta * np.log(1.0 + np.exp(beta * x))
)
def _test_softplus(test_case, device):
m = flow.nn.Softplus()
arr = np.random.randn(2, 3, 4, 5)
......@@ -539,38 +310,20 @@ class TestHardswishModule(flow.unittest.TestCase):
return y
def _np_hardtanh_grad(x):
return np.where(x <= -2.0, 0.0, np.where(x >= 2.3, 0.0, 1.0))
def _test_hardtanh_impl(test_case, shape, device):
m = flow.nn.Hardtanh()
arr = np.random.randn(*shape)
np_out = np.maximum(-1, np.minimum(1, arr))
x = flow.Tensor(arr, device=flow.device(device))
of_out = m(x)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
m = flow.nn.Hardtanh(min_val=-2.0, max_val=2.3)
arr = np.random.randn(*shape)
np_out = np.maximum(-2.0, np.minimum(2.3, arr))
x = flow.Tensor(arr, device=flow.device(device), requires_grad=True)
of_out = m(x)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
of_out = of_out.sum()
of_out.backward()
test_case.assertTrue(
np.allclose(x.grad.numpy(), _np_hardtanh_grad(np_out), 1e-05, 1e-05)
)
@flow.unittest.skip_unless_1n1d()
class TestHardtanhModule(flow.unittest.TestCase):
def test_hardtanh(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_hardtanh_impl(test_case, *arg)
@autotest()
def test_hardtanh_module_with_random_data(test_case):
m = torch.nn.Hardtanh(
min_val=random().to(float) | nothing(),
max_val=random().to(float) | nothing(),
)
m.train(random())
device = random_device()
m.to(device)
x = random_pytorch_tensor(ndim=4).to(device)
y = m(x)
return y
@flow.unittest.skip_unless_1n1d()
......
......@@ -25,469 +25,8 @@ import oneflow as flow
import oneflow.unittest
def _test_batchnorm1d_2d_input(test_case, device):
input_arr = np.array(
[
[0.1438, 1.1229, -0.048, -1.6834, -0.8262],
[0.5836, 0.135, -0.886, -1.7878, 1.0592],
[0.7252, -1.1488, -0.0274, 1.4051, 0.1018],
[-0.3595, -0.1801, 0.1146, -1.5712, -1.9291],
],
dtype=np.float32,
)
output_arr = np.array(
[
[-0.3056, 1.4066, 0.4151, -0.5783, -0.3864],
[0.7326, 0.1884, -1.71, -0.6563, 1.317],
[1.0668, -1.3949, 0.4674, 1.7292, 0.4521],
[-1.4938, -0.2002, 0.8275, -0.4945, -1.3827],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm1d(num_features=5, eps=1e-05, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output_arr, rtol=0.0001, atol=0.0001))
def _test_batchnorm1d_3d_input(test_case, device):
input_arr = np.array(
[
[
[-0.1091, 2.0041, 0.885, -0.0412],
[-1.2055, 0.7442, 2.33, 1.2411],
[-1.2466, 0.3667, 1.2267, 0.3043],
],
[
[-0.2484, -1.1407, 0.3352, 0.6687],
[-0.2975, -0.0227, -0.2302, -0.3762],
[-0.7759, -0.6789, 1.1444, 1.8077],
],
],
dtype=np.float32,
)
output_arr = np.array(
[
[
[-0.464, 1.9673, 0.6798, -0.3859],
[-1.4207, 0.4529, 1.9767, 0.9303],
[-1.4831, 0.096, 0.9379, 0.035],
],
[
[-0.6243, -1.651, 0.0471, 0.4309],
[-0.5481, -0.284, -0.4834, -0.6237],
[-1.0224, -0.9274, 0.8573, 1.5066],
],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm1d(num_features=3, eps=1e-05, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output_arr, rtol=0.0001, atol=0.0001))
def _test_batchnorm2d(test_case, device):
input_arr = np.array(
[
[
[
[-0.8791, 0.2553, 0.7403, -0.2859],
[0.8006, -1.7701, -0.9617, 0.1705],
[0.2842, 1.7825, 0.3365, -0.8525],
],
[
[0.7332, -0.0737, 0.7245, -0.6551],
[1.4461, -0.1827, 0.9737, -2.1571],
[0.4657, 0.7244, 0.3378, 0.1775],
],
],
[
[
[1.8896, 1.8686, 0.1896, 0.9817],
[-0.0671, 1.5569, 1.1449, 0.0086],
[-0.9468, -0.0124, 1.3227, -0.6567],
],
[
[-0.8472, 1.3012, -1.1065, 0.9348],
[1.0346, 1.5703, 0.2419, -0.7048],
[0.6957, -0.4523, -0.8819, 1.0164],
],
],
],
dtype=np.float32,
)
output = np.array(
[
[
[
[-1.1868, -0.0328, 0.4606, -0.5833],
[0.522, -2.0933, -1.2709, -0.119],
[-0.0034, 1.5209, 0.0498, -1.1598],
],
[
[0.5601, -0.3231, 0.5505, -0.9595],
[1.3404, -0.4424, 0.8233, -2.6035],
[0.2673, 0.5504, 0.1273, -0.0482],
],
],
[
[
[1.6299, 1.6085, -0.0996, 0.7062],
[-0.3608, 1.2914, 0.8723, -0.2837],
[-1.2557, -0.3051, 1.0531, -0.9606],
],
[
[-1.1698, 1.1818, -1.4536, 0.7807],
[0.89, 1.4763, 0.0223, -1.0139],
[0.519, -0.7375, -1.2078, 0.87],
],
],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-05, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device), dtype=flow.float32)
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output, 0.0001, 0.0001))
def _test_batchnorm2d_track_running_stats(test_case, device):
input_arr = np.array(
[
[
[
[-0.8791, 0.2553, 0.7403, -0.2859],
[0.8006, -1.7701, -0.9617, 0.1705],
[0.2842, 1.7825, 0.3365, -0.8525],
],
[
[0.7332, -0.0737, 0.7245, -0.6551],
[1.4461, -0.1827, 0.9737, -2.1571],
[0.4657, 0.7244, 0.3378, 0.1775],
],
],
[
[
[1.8896, 1.8686, 0.1896, 0.9817],
[-0.0671, 1.5569, 1.1449, 0.0086],
[-0.9468, -0.0124, 1.3227, -0.6567],
],
[
[-0.8472, 1.3012, -1.1065, 0.9348],
[1.0346, 1.5703, 0.2419, -0.7048],
[0.6957, -0.4523, -0.8819, 1.0164],
],
],
],
dtype=np.float32,
)
output = np.array(
[
[
[
[-1.1868, -0.0328, 0.4606, -0.5833],
[0.522, -2.0933, -1.2709, -0.119],
[-0.0034, 1.5209, 0.0498, -1.1598],
],
[
[0.5601, -0.3231, 0.5505, -0.9595],
[1.3404, -0.4424, 0.8233, -2.6035],
[0.2673, 0.5504, 0.1273, -0.0482],
],
],
[
[
[1.6299, 1.6085, -0.0996, 0.7062],
[-0.3608, 1.2914, 0.8723, -0.2837],
[-1.2557, -0.3051, 1.0531, -0.9606],
],
[
[-1.1698, 1.1818, -1.4536, 0.7807],
[0.89, 1.4763, 0.0223, -1.0139],
[0.519, -0.7375, -1.2078, 0.87],
],
],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(
num_features=2, eps=1e-05, momentum=0.1, track_running_stats=False
).to(device=flow.device(device))
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output, 0.0001, 0.0001))
def _test_batchnorm2d_4d_input(test_case, device):
input_arr = np.array(
[
[
[
[-0.8791, 0.2553, 0.7403, -0.2859],
[0.8006, -1.7701, -0.9617, 0.1705],
[0.2842, 1.7825, 0.3365, -0.8525],
],
[
[0.7332, -0.0737, 0.7245, -0.6551],
[1.4461, -0.1827, 0.9737, -2.1571],
[0.4657, 0.7244, 0.3378, 0.1775],
],
],
[
[
[1.8896, 1.8686, 0.1896, 0.9817],
[-0.0671, 1.5569, 1.1449, 0.0086],
[-0.9468, -0.0124, 1.3227, -0.6567],
],
[
[-0.8472, 1.3012, -1.1065, 0.9348],
[1.0346, 1.5703, 0.2419, -0.7048],
[0.6957, -0.4523, -0.8819, 1.0164],
],
],
],
dtype=np.float32,
)
output = np.array(
[
[
[
[-1.1868, -0.0328, 0.4606, -0.5833],
[0.522, -2.0933, -1.2709, -0.119],
[-0.0034, 1.5209, 0.0498, -1.1598],
],
[
[0.5601, -0.3231, 0.5505, -0.9595],
[1.3404, -0.4424, 0.8233, -2.6035],
[0.2673, 0.5504, 0.1273, -0.0482],
],
],
[
[
[1.6299, 1.6085, -0.0996, 0.7062],
[-0.3608, 1.2914, 0.8723, -0.2837],
[-1.2557, -0.3051, 1.0531, -0.9606],
],
[
[-1.1698, 1.1818, -1.4536, 0.7807],
[0.89, 1.4763, 0.0223, -1.0139],
[0.519, -0.7375, -1.2078, 0.87],
],
],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-05, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output, 0.0001, 0.0001))
def test_batchnorm2d_infer(test_case, device):
input_arr = np.array(
[
[
[
[-0.8791, 0.2553, 0.7403, -0.2859],
[0.8006, -1.7701, -0.9617, 0.1705],
[0.2842, 1.7825, 0.3365, -0.8525],
],
[
[0.7332, -0.0737, 0.7245, -0.6551],
[1.4461, -0.1827, 0.9737, -2.1571],
[0.4657, 0.7244, 0.3378, 0.1775],
],
],
[
[
[1.8896, 1.8686, 0.1896, 0.9817],
[-0.0671, 1.5569, 1.1449, 0.0086],
[-0.9468, -0.0124, 1.3227, -0.6567],
],
[
[-0.8472, 1.3012, -1.1065, 0.9348],
[1.0346, 1.5703, 0.2419, -0.7048],
[0.6957, -0.4523, -0.8819, 1.0164],
],
],
],
dtype=np.float32,
)
output_arr = np.array(
[
[
[
[-0.8790956, 0.2552987, 0.7402963, -0.28589857],
[0.800596, -1.7700912, -0.9616952, 0.17049915],
[0.28419858, 1.7824911, 0.3364983, -0.85249573],
],
[
[0.7331963, -0.07369963, 0.72449636, -0.6550967],
[1.4460927, -0.18269908, 0.9736951, -2.1570892],
[0.46569768, 0.72439635, 0.3377983, 0.1774991],
],
],
[
[
[1.8895906, 1.8685907, 0.18959905, 0.9816951],
[-0.06709967, 1.5568923, 1.1448942, 0.00859996],
[-0.9467952, -0.01239994, 1.3226933, -0.65669674],
],
[
[-0.84719574, 1.3011935, -1.1064945, 0.9347953],
[1.0345949, 1.5702921, 0.24189879, -0.7047965],
[0.69569653, -0.45229775, -0.8818956, 1.0163949],
],
],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-05, momentum=0.1).to(
device=flow.device(device)
)
m.eval()
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output_arr, 0.0001, 0.0001))
def test_batchnorm2d_infer_4d_input(test_case, device):
input_arr = np.array(
[
[
[
[-0.8791, 0.2553, 0.7403, -0.2859],
[0.8006, -1.7701, -0.9617, 0.1705],
[0.2842, 1.7825, 0.3365, -0.8525],
],
[
[0.7332, -0.0737, 0.7245, -0.6551],
[1.4461, -0.1827, 0.9737, -2.1571],
[0.4657, 0.7244, 0.3378, 0.1775],
],
],
[
[
[1.8896, 1.8686, 0.1896, 0.9817],
[-0.0671, 1.5569, 1.1449, 0.0086],
[-0.9468, -0.0124, 1.3227, -0.6567],
],
[
[-0.8472, 1.3012, -1.1065, 0.9348],
[1.0346, 1.5703, 0.2419, -0.7048],
[0.6957, -0.4523, -0.8819, 1.0164],
],
],
],
dtype=np.float32,
)
output_arr = np.array(
[
[
[
[-0.8790956, 0.2552987, 0.7402963, -0.28589857],
[0.800596, -1.7700912, -0.9616952, 0.17049915],
[0.28419858, 1.7824911, 0.3364983, -0.85249573],
],
[
[0.7331963, -0.07369963, 0.72449636, -0.6550967],
[1.4460927, -0.18269908, 0.9736951, -2.1570892],
[0.46569768, 0.72439635, 0.3377983, 0.1774991],
],
],
[
[
[1.8895906, 1.8685907, 0.18959905, 0.9816951],
[-0.06709967, 1.5568923, 1.1448942, 0.00859996],
[-0.9467952, -0.01239994, 1.3226933, -0.65669674],
],
[
[-0.84719574, 1.3011935, -1.1064945, 0.9347953],
[1.0345949, 1.5702921, 0.24189879, -0.7047965],
[0.69569653, -0.45229775, -0.8818956, 1.0163949],
],
],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-05, momentum=0.1).to(
device=flow.device(device)
)
m.eval()
x = flow.Tensor(input_arr, device=flow.device(device))
y = m(x)
test_case.assertTrue(np.allclose(y.numpy(), output_arr, 0.0001, 0.0001))
def _test_batchnorm2d_backward(test_case, device):
input_arr = np.array(
[
[
[
[-0.8791, 0.2553, 0.7403, -0.2859],
[0.8006, -1.7701, -0.9617, 0.1705],
[0.2842, 1.7825, 0.3365, -0.8525],
],
[
[0.7332, -0.0737, 0.7245, -0.6551],
[1.4461, -0.1827, 0.9737, -2.1571],
[0.4657, 0.7244, 0.3378, 0.1775],
],
],
[
[
[1.8896, 1.8686, 0.1896, 0.9817],
[-0.0671, 1.5569, 1.1449, 0.0086],
[-0.9468, -0.0124, 1.3227, -0.6567],
],
[
[-0.8472, 1.3012, -1.1065, 0.9348],
[1.0346, 1.5703, 0.2419, -0.7048],
[0.6957, -0.4523, -0.8819, 1.0164],
],
],
],
dtype=np.float32,
)
m = flow.nn.BatchNorm2d(num_features=2, eps=1e-05, momentum=0.1).to(
device=flow.device(device)
)
x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True)
y = m(x)
z = y.sum()
z.backward()
test_case.assertTrue(
np.allclose(x.grad.numpy(), np.zeros(shape=input_arr.shape), 1e-05, 1e-05)
)
@flow.unittest.skip_unless_1n1d()
class TestBatchNorm(flow.unittest.TestCase):
def test_batchnorm(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_batchnorm2d,
_test_batchnorm1d_2d_input,
_test_batchnorm1d_3d_input,
_test_batchnorm2d_4d_input,
_test_batchnorm2d_track_running_stats,
test_batchnorm2d_infer,
test_batchnorm2d_infer_4d_input,
_test_batchnorm2d_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
class TestBatchNormModule(flow.unittest.TestCase):
@autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3)
def test_batchnorm2d_module_with_random_data(test_case):
device = random_device()
......@@ -513,35 +52,7 @@ class TestBatchNorm(flow.unittest.TestCase):
y = m(x)
return y
@unittest.skip("batchnorm module has a bug")
def test_with_random_data(test_case):
for device in ["cpu", "cuda"]:
for training in [True, False]:
test_module_against_pytorch(
test_case,
"nn.BatchNorm2d",
extra_annotations={
"num_features": int,
"eps": float,
"momentum": float,
"affine": bool,
"track_running_stats": bool,
"dtype": str,
"device": flow.device,
},
extra_generators={
"input": random_tensor(ndim=4, dim1=8),
"num_features": constant(8),
"eps": random(1e-06, 1),
"momentum": random(0, 1),
"track_running_stats": constant(True),
},
device=device,
training=training,
n=10,
)
@autotest(n=1, auto_backward=False)
@autotest(n=20, auto_backward=False, rtol=1e-3, atol=1e-3)
def test_batchnorm3d_module_with_random_data(test_case):
channel = random().to(int)
m = torch.nn.BatchNorm2d(num_features=channel, track_running_stats=False)
......
......@@ -17,36 +17,13 @@ limitations under the License.
import unittest
from collections import OrderedDict
import numpy as np
from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
def _test_ceil_impl(test_case, device, shape):
x = flow.Tensor(
np.random.randn(*shape), device=flow.device(device), requires_grad=True
)
of_out = flow.ceil(x)
np_out = np.ceil(x.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
of_out = of_out.sum()
of_out.backward()
test_case.assertTrue(np.allclose(x.grad.numpy(), np.zeros(shape), 0.0001, 0.0001))
@flow.unittest.skip_unless_1n1d()
class TestCeilModule(flow.unittest.TestCase):
def test_ceil(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_ceil_impl]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(1,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest()
def test_ceil_flow_with_random_data(test_case):
device = random_device()
......
......@@ -22,6 +22,7 @@ from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
def _test_meshgrid_forawd(test_case, device):
......@@ -37,7 +38,7 @@ def _test_meshgrid_forawd(test_case, device):
test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 0.0001, 0.0001))
def _test_meshgrid_forawd_scalr(test_case, device):
def _test_meshgrid_forawd_scalar(test_case, device):
input1 = flow.Tensor(np.array(1.0), dtype=flow.float32, device=flow.device(device))
input2 = flow.Tensor(np.array(2.0), dtype=flow.float32, device=flow.device(device))
(np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing="ij")
......@@ -66,18 +67,26 @@ def _test_meshgrid_forawd_3tensor(test_case, device):
@flow.unittest.skip_unless_1n1d()
class TestMeshGrid(flow.unittest.TestCase):
class TestMeshGridModule(flow.unittest.TestCase):
def test_meshgrid(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_meshgrid_forawd,
_test_meshgrid_forawd_scalr,
_test_meshgrid_forawd_scalar,
_test_meshgrid_forawd_3tensor,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest(auto_backward=False)
def test_meshgrid_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=1, dim0=3, requires_grad=False).to(device)
y = random_pytorch_tensor(ndim=1, dim0=3, requires_grad=False).to(device)
res = torch.meshgrid(x, y)
return res[0], res[1]
if __name__ == "__main__":
unittest.main()
......@@ -22,92 +22,32 @@ from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
def _test_pow_scalar_impl(test_case, shape, scalar, device):
np_input = 10 * np.random.rand(*shape)
of_input = flow.Tensor(np_input, dtype=flow.float32, device=flow.device(device))
of_out = flow.pow(of_input, scalar)
np_out = np.power(np_input, scalar)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
def _test_pow_elementwise_impl(test_case, shape, scalar, device):
np_input_x = 10 * np.random.rand(*shape)
np_input_y = np.random.randint(1, 3, shape) + np.random.randn(*shape)
of_input_x = flow.Tensor(np_input_x, dtype=flow.float32, device=flow.device(device))
of_input_y = flow.Tensor(np_input_y, dtype=flow.float32, device=flow.device(device))
of_out = flow.pow(of_input_x, of_input_y)
np_out = np.power(np_input_x, np_input_y)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
def _test_pow_backward_impl(test_case, device):
shape = (2, 3)
np_input_x = 10 * np.random.rand(*shape)
np_input_y = np.random.randint(1, 3, shape) + np.random.randn(*shape)
np_input_y_scalar = (np.random.randint(1, 3, (1,)) + np.random.randn(1))[0]
np_x_grad = np_input_y * np.power(np_input_x, np_input_y - 1)
np_y_grad = np.power(np_input_x, np_input_y) * np.log(np_input_x)
np_x_grad_scalar = np_input_y_scalar * np.power(np_input_x, np_input_y_scalar - 1)
def test_x_y_grad():
of_input_x = flow.Tensor(
np_input_x,
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
of_input_y = flow.Tensor(
np_input_y,
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
of_out = flow.pow(of_input_x, of_input_y)
of_out_sum = of_out.sum()
of_out_sum.backward()
test_case.assertTrue(
np.allclose(of_input_x.grad.numpy(), np_x_grad, 0.0001, 0.0001)
)
test_case.assertTrue(
np.allclose(of_input_y.grad.numpy(), np_y_grad, 0.0001, 0.0001)
)
def test_x_grad_scalar():
of_input_x = flow.Tensor(
np_input_x,
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
of_out = flow.pow(of_input_x, np_input_y_scalar)
of_out_sum = of_out.sum()
of_out_sum.backward()
test_case.assertTrue(
np.allclose(of_input_x.grad.numpy(), np_x_grad_scalar, 0.0001, 0.0001)
)
test_x_y_grad()
test_x_grad_scalar()
from automated_test_util import *
@flow.unittest.skip_unless_1n1d()
class TestPow(flow.unittest.TestCase):
def test_pow_forward(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(2, 3), (2, 3, 4, 5), (2, 3, 0, 5)]
arg_dict["scalar"] = [2.1, 0.8]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_pow_scalar_impl(test_case, *arg)
_test_pow_elementwise_impl(test_case, *arg)
def test_pow_backward(test_case):
arg_dict = OrderedDict()
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_pow_backward_impl(test_case, *arg)
class TestPowModule(flow.unittest.TestCase):
@autotest()
def test_pow_scalar_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor().to(device)
y = random().to(float)
return torch.pow(x, y)
@autotest()
def test_pow_elementwise_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=2, dim1=2).to(device)
y = random_pytorch_tensor(ndim=2, dim1=2).to(device)
return torch.pow(x, y)
@unittest.skip("not support for broadcast currently")
@autotest()
def test_pow_broadcast_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=2, dim1=2).to(device)
y = random_pytorch_tensor(ndim=2, dim1=1).to(device)
return torch.pow(x, y)
if __name__ == "__main__":
......
......@@ -14,94 +14,22 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import random
import unittest
from collections import OrderedDict
import numpy as np
from test_util import GenArgList
import oneflow as flow
import oneflow.unittest
def _test_stack(test_case, device, shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
x_tensor = flow.Tensor(x, dtype=flow.float32, device=flow.device(device))
y_tensor = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))
out_np = np.stack([x, y], axis=1)
out_of = flow.stack([x_tensor, y_tensor], dim=1).numpy()
test_case.assertTrue(np.allclose(out_np, out_of, 1e-05, 1e-05))
def _test_stack_tuple_input(test_case, device, shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
x_tensor = flow.Tensor(x, dtype=flow.float32, device=flow.device(device))
y_tensor = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))
out_np = np.stack([x, y], axis=0)
out_of = flow.stack((x_tensor, y_tensor), dim=0).numpy()
test_case.assertTrue(np.allclose(out_np, out_of, 1e-05, 1e-05))
def _test_stack_backward(test_case, device, shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
x_tensor = flow.Tensor(x, device=flow.device(device), requires_grad=True)
y_tensor = flow.Tensor(y, device=flow.device(device), requires_grad=True)
out_of = flow.stack([x_tensor, y_tensor]).sum()
out_of.backward()
test_case.assertTrue(
np.allclose(x_tensor.grad.numpy(), np.ones(x_tensor.shape), 1e-05, 1e-05)
)
test_case.assertTrue(
np.allclose(y_tensor.grad.numpy(), np.ones(y_tensor.shape), 1e-05, 1e-05)
)
def _test_stack_different_dim(test_case, device, shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
x_tensor = flow.Tensor(x, device=flow.device(device))
y_tensor = flow.Tensor(y, device=flow.device(device))
for axis in range(-len(x.shape) - 1, len(x.shape) + 1):
out_of = flow.stack([x_tensor, y_tensor], dim=axis)
out_np = np.stack([x, y], axis=axis)
test_case.assertTrue(np.allclose(out_np, out_of.numpy(), 1e-05, 1e-05))
def _test_stack_multi_input(test_case, device, shape):
max_input_num = 10
for i in range(2, max_input_num):
x = []
x_tensor = []
for _ in range(0, i):
tmp = np.random.rand(*shape)
x.append(tmp)
x_tensor.append(flow.Tensor(tmp, device=flow.device(device)))
out_of = flow.stack(x_tensor, dim=-1)
out_np = np.stack(x, axis=-1)
test_case.assertTrue(np.allclose(out_np, out_of.numpy(), 1e-05, 1e-05))
from automated_test_util import *
@flow.unittest.skip_unless_1n1d()
class TestStack(flow.unittest.TestCase):
def test_stack(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_stack,
_test_stack_tuple_input,
_test_stack_backward,
_test_stack_different_dim,
_test_stack_multi_input,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [
tuple((random.randrange(1, 10) for _ in range(i))) for i in range(3, 6)
]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
class TestStackModule(flow.unittest.TestCase):
@autotest()
def test_stack_with_random_data(test_case):
device = random_device()
x = random_pytorch_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device)
y = random_pytorch_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device)
out = torch.stack((x, y), dim=random(low=1, high=4).to(int))
return out
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册