提交 e42679b5 编写于 作者: M Megvii Engine Team

feat(mge): do not export F.nn.interpolate

GitOrigin-RevId: 031c6555c0a4c190594d96e47d989da5cd5df62f
上级 fd802e97
......@@ -36,7 +36,6 @@ __all__ = [
"dot",
"dropout",
"indexing_one_hot",
"interpolate",
"leaky_relu",
"linear",
"local_conv2d",
......@@ -1112,9 +1111,9 @@ def interpolate(
import megengine.functional as F
x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
out = F.interpolate(x, [4, 4], align_corners=False)
out = F.nn.interpolate(x, [4, 4], align_corners=False)
print(out.numpy())
out2 = F.interpolate(x, scale_factor=2.)
out2 = F.nn.interpolate(x, scale_factor=2.)
np.testing.assert_allclose(out.numpy(), out2.numpy())
Outputs:
......
......@@ -101,8 +101,8 @@ def test_interpolate():
def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
out = F.interpolate(inp, scale_factor=2.0, mode="LINEAR")
out2 = F.interpolate(inp, 4, mode="LINEAR")
out = F.nn.interpolate(inp, scale_factor=2.0, mode="LINEAR")
out2 = F.nn.interpolate(inp, 4, mode="LINEAR")
np.testing.assert_allclose(
out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
......@@ -114,16 +114,16 @@ def test_interpolate():
def many_batch_interpolate():
inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
out = F.interpolate(inp, [4, 4])
out2 = F.interpolate(inp, scale_factor=2.0)
out = F.nn.interpolate(inp, [4, 4])
out2 = F.nn.interpolate(inp, scale_factor=2.0)
np.testing.assert_allclose(out.numpy(), out2.numpy())
def assign_corner_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
out = F.interpolate(inp, [4, 4], align_corners=True)
out2 = F.interpolate(inp, scale_factor=2.0, align_corners=True)
out = F.nn.interpolate(inp, [4, 4], align_corners=True)
out2 = F.nn.interpolate(inp, scale_factor=2.0, align_corners=True)
np.testing.assert_allclose(out.numpy(), out2.numpy())
......@@ -131,13 +131,13 @@ def test_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
with pytest.raises(ValueError):
F.interpolate(inp, scale_factor=2.0, mode="LINEAR")
F.nn.interpolate(inp, scale_factor=2.0, mode="LINEAR")
def inappropriate_scale_linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
with pytest.raises(ValueError):
F.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR")
F.nn.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR")
linear_interpolate()
many_batch_interpolate()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册