From d296456c1af37cb4b1a87e2b684488f811318f58 Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Wed, 12 Jan 2022 17:16:51 +0800 Subject: [PATCH] support 5d for nearest interp (#38868) * support 5d for nearest * update nearest3d unittest, test=develop * fix approve ci, test=develop * fix approve ci, test=develop --- paddle/fluid/operators/interpolate_v2_op.cc | 12 +- paddle/fluid/operators/interpolate_v2_op.cu | 126 ++++++++++++ paddle/fluid/operators/interpolate_v2_op.h | 77 +++++++ .../unittests/test_nearest_interp_v2_op.py | 192 ++++++++++++++++-- python/paddle/nn/functional/common.py | 11 +- python/paddle/nn/layer/common.py | 3 +- 6 files changed, 394 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index de276cfa31..7783303785 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -249,12 +249,12 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { auto dim_x = ctx->GetInputDim("X"); auto interp_method = ctx->Attrs().Get("interp_method"); - PADDLE_ENFORCE_EQ( - "trilinear", interp_method, - platform::errors::InvalidArgument( - "Interpolation method can only be \"trilinear\" when Input(X) " - "dimension is 5, but got method = %s .", - interp_method)); + PADDLE_ENFORCE("nearest" == interp_method || "trilinear" == interp_method, + platform::errors::InvalidArgument( + "Interpolation method can only be \"trilinear\" or " + "\"nearest\" when Input(X) " + "dimension is 5, but got method = %s .", + interp_method)); const DataLayout data_layout = framework::StringToDataLayout( ctx->Attrs().Get("data_layout")); diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index bc1ab704aa..3db0fdf5e6 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -67,6 +67,61 @@ __global__ void KeNearestNeighborInterpFw( } } +template +__global__ void KeNearestNeighbor3DInterpFw( + const T* in, const size_t in_img_d, const size_t in_img_h, + const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, + const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, + const size_t output_h, const size_t output_w, const size_t num_channels, + const float ratio_d, const float ratio_h, const float ratio_w, + const bool align_corners, const DataLayout data_layout) { + int nthreads = output_h * output_w; // ncdhw + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idt, out_img_idy, out_img_idx; + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w; + out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h; + out_img_idx = tid % out_img_w; + } else { + out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels); + out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) / + (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + int in_img_idt = (align_corners) + ? static_cast(ratio_d * out_img_idt + 0.5) + : static_cast(ratio_d * out_img_idt); + + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); + + if (data_layout == DataLayout::kNCHW) { + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w + + in_img_idx]; + } else { + out[tid] = in[out_id_h * input_w + + in_img_idt * in_img_h * in_img_w * num_channels + + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + } + } +} + template __global__ void KeNearestNeighborInterpBw( T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, @@ -114,6 +169,63 @@ __global__ void KeNearestNeighborInterpBw( } } +template +__global__ void KeNearestNeighbor3DInterpBw( + T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, const T* out, + const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, + const size_t output_h, const size_t output_w, const size_t num_channels, + const float ratio_d, const float ratio_h, const float ratio_w, + const bool align_corners, const DataLayout data_layout) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idt, out_img_idy, out_img_idx; + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w; + out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h; + out_img_idx = tid % out_img_w; + } else { + out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels); + out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) / + (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + int in_img_idt = (align_corners) + ? static_cast(ratio_d * out_img_idt + 0.5) + : static_cast(ratio_d * out_img_idt); + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); + + T* in_pos; + if (data_layout == DataLayout::kNCHW) { + in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w + + in_img_idx]; + } else { + in_pos = &in[out_id_h * input_w + + in_img_idt * in_img_h * in_img_w * num_channels + + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + } + const T out_pos = out[out_id_h * output_w + out_id_w]; + platform::CudaAtomicAdd(in_pos, out_pos); + } +} + template __global__ void KeLinearInterpFw(const T* in, const size_t in_img_w, const size_t input_w, T* out, @@ -1376,6 +1488,13 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + KeNearestNeighbor3DInterpFw< + T><<>>( + input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, + out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, + data_layout); } } @@ -1801,6 +1920,13 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx, input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + KeNearestNeighbor3DInterpBw< + T><<>>( + input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, + out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, + data_layout); } } diff --git a/paddle/fluid/operators/interpolate_v2_op.h b/paddle/fluid/operators/interpolate_v2_op.h index 8daf440f60..0af799eca0 100644 --- a/paddle/fluid/operators/interpolate_v2_op.h +++ b/paddle/fluid/operators/interpolate_v2_op.h @@ -121,6 +121,39 @@ static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, } } +template +static void NearestNeighbor3DInterpolate( + const Tensor& input, Tensor* output, const float ratio_d, + const float ratio_h, const float ratio_w, const int n, const int c, + const int out_d, const int out_h, const int out_w, const bool align_corners, + const DataLayout& data_layout) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int d = 0; d < out_d; d++) { // loop for images + int in_d = (align_corners) ? static_cast(ratio_d * d + 0.5) + : static_cast(ratio_d * d); + for (int k = 0; k < out_h; k++) { + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); + + for (int l = 0; l < out_w; l++) { + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + if (data_layout == DataLayout::kNCHW) { + output_t(i, j, d, k, l) = input_t(i, j, in_d, in_k, in_l); + } else { // NDHWC + output_t(i, d, k, l, j) = input_t(i, in_d, in_k, in_l, j); + } + } + } + } + } + } +} + template static void LinearInterpolation(const Tensor& input, Tensor* output, const float ratio_w, const int in_w, @@ -584,6 +617,42 @@ static void NearestNeighborInterpolateGrad( } } +template +static void NearestNeighbor3DInterpolateGrad( + const Tensor& output_grad, Tensor* input_grad, const float ratio_d, + const float ratio_h, const float ratio_w, const int n, const int c, + const int out_d, const int out_h, const int out_w, const bool align_corners, + const DataLayout data_layout) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + + for (int d = 0; d < out_d; d++) { + int in_d = (align_corners) ? static_cast(ratio_d * d + 0.5) + : static_cast(ratio_d * d); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); + + for (int l = 0; l < out_w; l++) { + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + if (data_layout == DataLayout::kNCHW) { + input_grad_t(i, j, in_d, in_k, in_l) += + output_grad_t(i, j, d, k, l); + } else { + input_grad_t(i, in_d, in_k, in_l, j) += + output_grad_t(i, d, k, l, j); + } + } + } + } + } + } +} + template static void BilinearInterpolationGrad( const Tensor& output_grad, Tensor* input_grad, const float ratio_h, @@ -1137,6 +1206,10 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, TrilinearInterpolation(input, output, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n, c, out_d, out_h, out_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + NearestNeighbor3DInterpolate(input, output, ratio_d, ratio_h, ratio_w, n, + c, out_d, out_h, out_w, align_corners, + data_layout); } } @@ -1489,6 +1562,10 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx, TrilinearInterpolationGrad( output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n, c, out_d, out_h, out_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + NearestNeighbor3DInterpolateGrad(output_grad, input_grad, ratio_d, + ratio_h, ratio_w, n, c, out_d, out_h, + out_w, align_corners, data_layout); } } diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py index 04962a93c1..e2ac98f7c9 100755 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py @@ -23,6 +23,8 @@ import paddle.nn as nn import paddle from paddle.nn.functional import interpolate +paddle.enable_static() + def nearest_neighbor_interp_np(X, out_h, @@ -78,7 +80,80 @@ def nearest_neighbor_interp_np(X, if data_layout == "NHWC": out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + # out = np.expand_dims(out, 2) + return out.astype(X.dtype) + + +def nearest_neighbor_interp3d_np(X, + out_d, + out_h, + out_w, + scale_d=0, + scale_h=0, + scale_w=0, + out_size=None, + actual_shape=None, + align_corners=True, + data_layout='NCHW'): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + X = np.transpose(X, (0, 4, 1, 2, 3)) # NDHWC => NCDHW + if out_size is not None: + out_d = out_size[0] + out_h = out_size[1] + out_w = out_size[2] + if actual_shape is not None: + out_d = actual_shape[0] + out_h = actual_shape[1] + out_w = actual_shape[2] + n, c, in_d, in_h, in_w = X.shape + ratio_d = ratio_h = ratio_w = 0.0 + if (out_d > 1): + if (align_corners): + ratio_d = (in_d - 1.0) / (out_d - 1.0) + else: + if scale_d > 0: + ratio_d = 1.0 / scale_d + else: + ratio_d = 1.0 * in_d / out_d + if (out_h > 1): + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + if scale_h > 0: + ratio_h = 1.0 / scale_h + else: + ratio_h = 1.0 * in_h / out_h + if (out_w > 1): + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + if scale_w > 0: + ratio_w = 1.0 / scale_w + else: + ratio_w = 1.0 * in_w / out_w + out = np.zeros((n, c, out_d, out_h, out_w)) + + if align_corners: + for d in range(out_d): + in_d = int(ratio_d * d + 0.5) + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, d, i, j] = X[:, :, in_d, in_i, in_j] + else: + for d in range(out_d): + in_d = int(ratio_d * d) + for i in range(out_h): + in_i = int(ratio_h * i) + for j in range(out_w): + in_j = int(ratio_w * j) + out[:, :, d, i, j] = X[:, :, in_d, in_i, in_j] + + if data_layout == "NDHWC": + out = np.transpose(out, (0, 2, 3, 4, 1)) # NCDHW => NDHWC return out.astype(X.dtype) @@ -91,44 +166,81 @@ class TestNearestInterpOp(OpTest): self.op_type = "nearest_interp_v2" input_np = np.random.random(self.input_shape).astype("float64") - if self.data_layout == "NCHW": + if self.data_layout == "NCHW" and len(self.input_shape) == 4: + in_d = 1 in_h = self.input_shape[2] in_w = self.input_shape[3] else: + in_d = 1 in_h = self.input_shape[1] in_w = self.input_shape[2] + + if self.data_layout == "NCDHW" and len(self.input_shape) == 5: + in_d = self.input_shape[2] + in_h = self.input_shape[3] + in_w = self.input_shape[4] + else: + in_d = self.input_shape[1] + in_h = self.input_shape[2] + in_w = self.input_shape[3] + scale_d = 0 scale_h = 0 scale_w = 0 if self.scale: if isinstance(self.scale, float) or isinstance(self.scale, int): if self.scale > 0: - scale_h = scale_w = float(self.scale) + scale_d = scale_h = scale_w = float(self.scale) if isinstance(self.scale, list) and len(self.scale) == 1: - scale_w = scale_h = self.scale[0] + scale_d = scale_w = scale_h = self.scale[0] elif isinstance(self.scale, list) and len(self.scale) > 1: - scale_w = self.scale[1] - scale_h = self.scale[0] + if len(self.scale) == 5: + scale_w = self.scale[2] + scale_h = self.scale[1] + scale_d = self.scale[0] + else: + scale_w = self.scale[1] + scale_h = self.scale[0] + out_h = int(in_h * scale_h) out_w = int(in_w * scale_w) + out_d = int(in_d * scale_d) else: + if len(self.input_shape) == 5: + out_d = self.out_d out_h = self.out_h out_w = self.out_w - output_np = nearest_neighbor_interp_np( - input_np, out_h, out_w, scale_h, scale_w, self.out_size, - self.actual_shape, self.align_corners, self.data_layout) + if len(self.input_shape) == 4: + output_np = nearest_neighbor_interp_np( + input_np, out_h, out_w, scale_h, scale_w, self.out_size, + self.actual_shape, self.align_corners, self.data_layout) + elif len(self.input_shape) == 5: + output_np = nearest_neighbor_interp3d_np( + input_np, out_d, out_h, out_w, scale_d, scale_h, scale_w, + self.out_size, self.actual_shape, self.align_corners, + self.data_layout) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape - self.attrs = { - 'out_h': self.out_h, - 'out_w': self.out_w, - 'interp_method': self.interp_method, - 'align_corners': self.align_corners, - 'data_layout': self.data_layout - } + if len(self.input_shape) == 5: + self.attrs = { + 'out_d': self.out_d, + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'data_layout': self.data_layout + } + else: + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'data_layout': self.data_layout + } if self.scale: if isinstance(self.scale, float) or isinstance(self.scale, int): if self.scale > 0: @@ -157,7 +269,8 @@ class TestNearestInterpOp(OpTest): class TestNearestNeighborInterpCase1(TestNearestInterpOp): def init_test_case(self): self.interp_method = 'nearest' - self.input_shape = [4, 1, 7, 8] + self.input_shape = [4, 1, 1, 7, 8] + self.out_d = 1 self.out_h = 1 self.out_w = 1 self.scale = 0. @@ -366,6 +479,18 @@ class TestNearestNeighborInterpScale3(TestNearestInterpOp): self.align_corners = True +class TestNearestNeighbor3DInterp(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 4, 7, 5] + self.out_d = 8 + self.out_h = 64 + self.out_w = 32 + self.scale = [4.0, 2.0, 3.0] + self.out_size = np.array([8, 66, 40]).astype("int32") + self.align_corners = True + + class TestNearestInterpOp_attr_tensor(OpTest): def setUp(self): self.out_size = None @@ -549,8 +674,32 @@ class TestNearestInterpOpAPI_dy(unittest.TestCase): self.assertTrue(np.allclose(out.numpy(), expect_res)) +class TestNearestInterp3DOpAPI_dy(unittest.TestCase): + def test_case(self): + import paddle + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + with fluid.dygraph.guard(place): + input_data = np.random.random((2, 2, 6, 6, 6)).astype("int64") + scale_np = np.array([2, 2, 2]).astype("int64") + input_x = paddle.to_tensor(input_data) + scale = paddle.to_tensor(scale_np) + expect_res = nearest_neighbor_interp3d_np( + input_data, out_d=12, out_h=12, out_w=12, align_corners=False) + out = interpolate( + x=input_x, + scale_factor=scale, + mode="nearest", + align_corners=False, + data_format="NCDHW") + self.assertTrue(np.allclose(out.numpy(), expect_res)) + + class TestNearestInterpException(unittest.TestCase): def test_exception(self): + import paddle input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32") def attr_data_format(): @@ -564,9 +713,20 @@ class TestNearestInterpException(unittest.TestCase): def attr_scale_value(): out = fluid.layers.resize_nearest(input, scale=-0.3) + def input_shape_error(): + x = paddle.randn([1, 3]) + out = paddle.nn.functional.interpolate(x, scale_factor='scale') + + def mode_error(): + x = paddle.randn([1, 3]) + out = paddle.nn.functional.interpolate( + x, scale_factor='scale', mode="BILINEAR") + self.assertRaises(ValueError, attr_data_format) self.assertRaises(TypeError, attr_scale_type) self.assertRaises(ValueError, attr_scale_value) + self.assertRaises(ValueError, input_shape_error) + self.assertRaises(ValueError, mode_error) if __name__ == "__main__": diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 3dba9505e9..5a010ad2f2 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -221,7 +221,8 @@ def interpolate(x, ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear', 'trilinear', 'bicubic', 'area' or 'nearest' currently. ValueError: 'linear' only support 3-D tensor. - ValueError: 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. + ValueError: 'bilinear' and 'bicubic' only support 4-D tensor. + ValueError: 'nearest' only support 4-D or 5-D tensor. ValueError: 'trilinear' only support 5-D tensor. ValueError: One of size and scale_factor must not be None. ValueError: size length should be 1 for input 3-D tensor. @@ -276,9 +277,11 @@ def interpolate(x, if resample in ['LINEAR'] and len(x.shape) != 3: raise ValueError("'linear' only support 3-D tensor.") - if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(x.shape) != 4: - raise ValueError( - "'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.") + if resample in ['NEAREST'] and len(x.shape) != 4 and len(x.shape) != 5: + raise ValueError("'NEAREST' only support 4-D or 5-D tensor.") + + if resample in ['BILINEAR', 'BICUBIC'] and len(x.shape) != 4: + raise ValueError("'bilinear' and 'bicubic' only support 4-D tensor.") if resample == 'TRILINEAR' and len(x.shape) != 5: raise ValueError("'trilinear'only support 5-D tensor.") diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 22f7f79837..89ff156bde 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -359,8 +359,9 @@ class Upsample(Layer): ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear', 'trilinear', 'bicubic', or 'nearest' currently. ValueError: 'linear' only support 3-D tensor. - ValueError: 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. + ValueError: 'bilinear' and 'bicubic' only support 4-D tensor. ValueError: 'trilinear' only support 5-D tensor. + ValueError: 'nearest' only support 4-D or 5-D tensor. ValueError: One of size and scale_factor must not be None. ValueError: size length should be 1 for input 3-D tensor. ValueError: size length should be 2 for input 4-D tensor. -- GitLab