未验证 提交 668a0d3b 编写于 作者: X xiaoting 提交者: GitHub

support int for nearest_interp, test=develop (#32270)

上级 cfdde0ec
......@@ -672,6 +672,8 @@ REGISTER_OP_CPU_KERNEL(bilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<int>,
ops::InterpolateV2Kernel<int64_t>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
......
......@@ -1738,6 +1738,7 @@ REGISTER_OP_CUDA_KERNEL(bilinear_interp_v2_grad,
REGISTER_OP_CUDA_KERNEL(nearest_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int64_t>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
......
......@@ -21,6 +21,7 @@ import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn as nn
import paddle
from paddle.nn.functional import interpolate
def nearest_neighbor_interp_np(X,
......@@ -526,6 +527,28 @@ class TestNearestAPI(unittest.TestCase):
self.assertTrue(np.allclose(results[i + 1], expect_res))
class TestNearestInterpOpAPI_dy(unittest.TestCase):
def test_case(self):
import paddle
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with fluid.dygraph.guard(place):
input_data = np.random.random((2, 3, 6, 6)).astype("int64")
scale_np = np.array([2, 2]).astype("int64")
input_x = paddle.to_tensor(input_data)
scale = paddle.to_tensor(scale_np)
expect_res = nearest_neighbor_interp_np(
input_data, out_h=12, out_w=12, align_corners=False)
out = interpolate(
x=input_x,
scale_factor=scale,
mode="nearest",
align_corners=False)
self.assertTrue(np.allclose(out.numpy(), expect_res))
class TestNearestInterpException(unittest.TestCase):
def test_exception(self):
input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册