未验证 提交 d296456c 编写于 作者: X xiaoting 提交者: GitHub

support 5d for nearest interp (#38868)

* support 5d for nearest

* update nearest3d unittest, test=develop

* fix approve ci, test=develop

* fix approve ci, test=develop
上级 cc24427e
...@@ -249,12 +249,12 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { ...@@ -249,12 +249,12 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method"); auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE("nearest" == interp_method || "trilinear" == interp_method,
"trilinear", interp_method, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Interpolation method can only be \"trilinear\" or "
"Interpolation method can only be \"trilinear\" when Input(X) " "\"nearest\" when Input(X) "
"dimension is 5, but got method = %s .", "dimension is 5, but got method = %s .",
interp_method)); interp_method));
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
......
...@@ -67,6 +67,61 @@ __global__ void KeNearestNeighborInterpFw( ...@@ -67,6 +67,61 @@ __global__ void KeNearestNeighborInterpFw(
} }
} }
template <typename T>
__global__ void KeNearestNeighbor3DInterpFw(
const T* in, const size_t in_img_d, const size_t in_img_h,
const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w; // ncdhw
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
out[tid] = in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}
}
template <typename T> template <typename T>
__global__ void KeNearestNeighborInterpBw( __global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
...@@ -114,6 +169,63 @@ __global__ void KeNearestNeighborInterpBw( ...@@ -114,6 +169,63 @@ __global__ void KeNearestNeighborInterpBw(
} }
} }
template <typename T>
__global__ void KeNearestNeighbor3DInterpBw(
T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, const T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = (align_corners)
? static_cast<int>(ratio_d * out_img_idt + 0.5)
: static_cast<int>(ratio_d * out_img_idt);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w +
in_img_idx];
} else {
in_pos = &in[out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
const T out_pos = out[out_id_h * output_w + out_id_w];
platform::CudaAtomicAdd(in_pos, out_pos);
}
}
template <typename T> template <typename T>
__global__ void KeLinearInterpFw(const T* in, const size_t in_img_w, __global__ void KeLinearInterpFw(const T* in, const size_t in_img_w,
const size_t input_w, T* out, const size_t input_w, T* out,
...@@ -1376,6 +1488,13 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, ...@@ -1376,6 +1488,13 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout); align_mode, data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
data_layout);
} }
} }
...@@ -1801,6 +1920,13 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx, ...@@ -1801,6 +1920,13 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout); align_mode, data_layout);
} else if ("nearest" == interp_method) {
KeNearestNeighbor3DInterpBw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
data_layout);
} }
} }
......
...@@ -121,6 +121,39 @@ static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, ...@@ -121,6 +121,39 @@ static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
} }
} }
template <typename T>
static void NearestNeighbor3DInterpolate(
const Tensor& input, Tensor* output, const float ratio_d,
const float ratio_h, const float ratio_w, const int n, const int c,
const int out_d, const int out_h, const int out_w, const bool align_corners,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output);
for (int d = 0; d < out_d; d++) { // loop for images
int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
for (int k = 0; k < out_h; k++) {
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, d, k, l) = input_t(i, j, in_d, in_k, in_l);
} else { // NDHWC
output_t(i, d, k, l, j) = input_t(i, in_d, in_k, in_l, j);
}
}
}
}
}
}
}
template <typename T> template <typename T>
static void LinearInterpolation(const Tensor& input, Tensor* output, static void LinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_w, const int in_w, const float ratio_w, const int in_w,
...@@ -584,6 +617,42 @@ static void NearestNeighborInterpolateGrad( ...@@ -584,6 +617,42 @@ static void NearestNeighborInterpolateGrad(
} }
} }
template <typename T>
static void NearestNeighbor3DInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
const float ratio_h, const float ratio_w, const int n, const int c,
const int out_d, const int out_h, const int out_w, const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
for (int d = 0; d < out_d; d++) {
int in_d = (align_corners) ? static_cast<int>(ratio_d * d + 0.5)
: static_cast<int>(ratio_d * d);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_d, in_k, in_l) +=
output_grad_t(i, j, d, k, l);
} else {
input_grad_t(i, in_d, in_k, in_l, j) +=
output_grad_t(i, d, k, l, j);
}
}
}
}
}
}
}
template <typename T> template <typename T>
static void BilinearInterpolationGrad( static void BilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h, const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
...@@ -1137,6 +1206,10 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, ...@@ -1137,6 +1206,10 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d, TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
in_h, in_w, n, c, out_d, out_h, out_w, in_h, in_w, n, c, out_d, out_h, out_w,
align_corners, align_mode, data_layout); align_corners, align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighbor3DInterpolate<T>(input, output, ratio_d, ratio_h, ratio_w, n,
c, out_d, out_h, out_w, align_corners,
data_layout);
} }
} }
...@@ -1489,6 +1562,10 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx, ...@@ -1489,6 +1562,10 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
TrilinearInterpolationGrad<T>( TrilinearInterpolationGrad<T>(
output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n, output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n,
c, out_d, out_h, out_w, align_corners, align_mode, data_layout); c, out_d, out_h, out_w, align_corners, align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighbor3DInterpolateGrad<T>(output_grad, input_grad, ratio_d,
ratio_h, ratio_w, n, c, out_d, out_h,
out_w, align_corners, data_layout);
} }
} }
......
...@@ -23,6 +23,8 @@ import paddle.nn as nn ...@@ -23,6 +23,8 @@ import paddle.nn as nn
import paddle import paddle
from paddle.nn.functional import interpolate from paddle.nn.functional import interpolate
paddle.enable_static()
def nearest_neighbor_interp_np(X, def nearest_neighbor_interp_np(X,
out_h, out_h,
...@@ -78,7 +80,80 @@ def nearest_neighbor_interp_np(X, ...@@ -78,7 +80,80 @@ def nearest_neighbor_interp_np(X,
if data_layout == "NHWC": if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
# out = np.expand_dims(out, 2)
return out.astype(X.dtype)
def nearest_neighbor_interp3d_np(X,
out_d,
out_h,
out_w,
scale_d=0,
scale_h=0,
scale_w=0,
out_size=None,
actual_shape=None,
align_corners=True,
data_layout='NCHW'):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
X = np.transpose(X, (0, 4, 1, 2, 3)) # NDHWC => NCDHW
if out_size is not None:
out_d = out_size[0]
out_h = out_size[1]
out_w = out_size[2]
if actual_shape is not None:
out_d = actual_shape[0]
out_h = actual_shape[1]
out_w = actual_shape[2]
n, c, in_d, in_h, in_w = X.shape
ratio_d = ratio_h = ratio_w = 0.0
if (out_d > 1):
if (align_corners):
ratio_d = (in_d - 1.0) / (out_d - 1.0)
else:
if scale_d > 0:
ratio_d = 1.0 / scale_d
else:
ratio_d = 1.0 * in_d / out_d
if (out_h > 1):
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
if scale_h > 0:
ratio_h = 1.0 / scale_h
else:
ratio_h = 1.0 * in_h / out_h
if (out_w > 1):
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
if scale_w > 0:
ratio_w = 1.0 / scale_w
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((n, c, out_d, out_h, out_w))
if align_corners:
for d in range(out_d):
in_d = int(ratio_d * d + 0.5)
for i in range(out_h):
in_i = int(ratio_h * i + 0.5)
for j in range(out_w):
in_j = int(ratio_w * j + 0.5)
out[:, :, d, i, j] = X[:, :, in_d, in_i, in_j]
else:
for d in range(out_d):
in_d = int(ratio_d * d)
for i in range(out_h):
in_i = int(ratio_h * i)
for j in range(out_w):
in_j = int(ratio_w * j)
out[:, :, d, i, j] = X[:, :, in_d, in_i, in_j]
if data_layout == "NDHWC":
out = np.transpose(out, (0, 2, 3, 4, 1)) # NCDHW => NDHWC
return out.astype(X.dtype) return out.astype(X.dtype)
...@@ -91,44 +166,81 @@ class TestNearestInterpOp(OpTest): ...@@ -91,44 +166,81 @@ class TestNearestInterpOp(OpTest):
self.op_type = "nearest_interp_v2" self.op_type = "nearest_interp_v2"
input_np = np.random.random(self.input_shape).astype("float64") input_np = np.random.random(self.input_shape).astype("float64")
if self.data_layout == "NCHW": if self.data_layout == "NCHW" and len(self.input_shape) == 4:
in_d = 1
in_h = self.input_shape[2] in_h = self.input_shape[2]
in_w = self.input_shape[3] in_w = self.input_shape[3]
else: else:
in_d = 1
in_h = self.input_shape[1] in_h = self.input_shape[1]
in_w = self.input_shape[2] in_w = self.input_shape[2]
if self.data_layout == "NCDHW" and len(self.input_shape) == 5:
in_d = self.input_shape[2]
in_h = self.input_shape[3]
in_w = self.input_shape[4]
else:
in_d = self.input_shape[1]
in_h = self.input_shape[2]
in_w = self.input_shape[3]
scale_d = 0
scale_h = 0 scale_h = 0
scale_w = 0 scale_w = 0
if self.scale: if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int): if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0: if self.scale > 0:
scale_h = scale_w = float(self.scale) scale_d = scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1: if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0] scale_d = scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1: elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1] if len(self.scale) == 5:
scale_h = self.scale[0] scale_w = self.scale[2]
scale_h = self.scale[1]
scale_d = self.scale[0]
else:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(in_h * scale_h) out_h = int(in_h * scale_h)
out_w = int(in_w * scale_w) out_w = int(in_w * scale_w)
out_d = int(in_d * scale_d)
else: else:
if len(self.input_shape) == 5:
out_d = self.out_d
out_h = self.out_h out_h = self.out_h
out_w = self.out_w out_w = self.out_w
output_np = nearest_neighbor_interp_np( if len(self.input_shape) == 4:
input_np, out_h, out_w, scale_h, scale_w, self.out_size, output_np = nearest_neighbor_interp_np(
self.actual_shape, self.align_corners, self.data_layout) input_np, out_h, out_w, scale_h, scale_w, self.out_size,
self.actual_shape, self.align_corners, self.data_layout)
elif len(self.input_shape) == 5:
output_np = nearest_neighbor_interp3d_np(
input_np, out_d, out_h, out_w, scale_d, scale_h, scale_w,
self.out_size, self.actual_shape, self.align_corners,
self.data_layout)
self.inputs = {'X': input_np} self.inputs = {'X': input_np}
if self.out_size is not None: if self.out_size is not None:
self.inputs['OutSize'] = self.out_size self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None: if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape self.inputs['OutSize'] = self.actual_shape
self.attrs = { if len(self.input_shape) == 5:
'out_h': self.out_h, self.attrs = {
'out_w': self.out_w, 'out_d': self.out_d,
'interp_method': self.interp_method, 'out_h': self.out_h,
'align_corners': self.align_corners, 'out_w': self.out_w,
'data_layout': self.data_layout 'interp_method': self.interp_method,
} 'align_corners': self.align_corners,
'data_layout': self.data_layout
}
else:
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'data_layout': self.data_layout
}
if self.scale: if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int): if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0: if self.scale > 0:
...@@ -157,7 +269,8 @@ class TestNearestInterpOp(OpTest): ...@@ -157,7 +269,8 @@ class TestNearestInterpOp(OpTest):
class TestNearestNeighborInterpCase1(TestNearestInterpOp): class TestNearestNeighborInterpCase1(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8] self.input_shape = [4, 1, 1, 7, 8]
self.out_d = 1
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
self.scale = 0. self.scale = 0.
...@@ -366,6 +479,18 @@ class TestNearestNeighborInterpScale3(TestNearestInterpOp): ...@@ -366,6 +479,18 @@ class TestNearestNeighborInterpScale3(TestNearestInterpOp):
self.align_corners = True self.align_corners = True
class TestNearestNeighbor3DInterp(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 4, 7, 5]
self.out_d = 8
self.out_h = 64
self.out_w = 32
self.scale = [4.0, 2.0, 3.0]
self.out_size = np.array([8, 66, 40]).astype("int32")
self.align_corners = True
class TestNearestInterpOp_attr_tensor(OpTest): class TestNearestInterpOp_attr_tensor(OpTest):
def setUp(self): def setUp(self):
self.out_size = None self.out_size = None
...@@ -549,8 +674,32 @@ class TestNearestInterpOpAPI_dy(unittest.TestCase): ...@@ -549,8 +674,32 @@ class TestNearestInterpOpAPI_dy(unittest.TestCase):
self.assertTrue(np.allclose(out.numpy(), expect_res)) self.assertTrue(np.allclose(out.numpy(), expect_res))
class TestNearestInterp3DOpAPI_dy(unittest.TestCase):
def test_case(self):
import paddle
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with fluid.dygraph.guard(place):
input_data = np.random.random((2, 2, 6, 6, 6)).astype("int64")
scale_np = np.array([2, 2, 2]).astype("int64")
input_x = paddle.to_tensor(input_data)
scale = paddle.to_tensor(scale_np)
expect_res = nearest_neighbor_interp3d_np(
input_data, out_d=12, out_h=12, out_w=12, align_corners=False)
out = interpolate(
x=input_x,
scale_factor=scale,
mode="nearest",
align_corners=False,
data_format="NCDHW")
self.assertTrue(np.allclose(out.numpy(), expect_res))
class TestNearestInterpException(unittest.TestCase): class TestNearestInterpException(unittest.TestCase):
def test_exception(self): def test_exception(self):
import paddle
input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32") input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32")
def attr_data_format(): def attr_data_format():
...@@ -564,9 +713,20 @@ class TestNearestInterpException(unittest.TestCase): ...@@ -564,9 +713,20 @@ class TestNearestInterpException(unittest.TestCase):
def attr_scale_value(): def attr_scale_value():
out = fluid.layers.resize_nearest(input, scale=-0.3) out = fluid.layers.resize_nearest(input, scale=-0.3)
def input_shape_error():
x = paddle.randn([1, 3])
out = paddle.nn.functional.interpolate(x, scale_factor='scale')
def mode_error():
x = paddle.randn([1, 3])
out = paddle.nn.functional.interpolate(
x, scale_factor='scale', mode="BILINEAR")
self.assertRaises(ValueError, attr_data_format) self.assertRaises(ValueError, attr_data_format)
self.assertRaises(TypeError, attr_scale_type) self.assertRaises(TypeError, attr_scale_type)
self.assertRaises(ValueError, attr_scale_value) self.assertRaises(ValueError, attr_scale_value)
self.assertRaises(ValueError, input_shape_error)
self.assertRaises(ValueError, mode_error)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -221,7 +221,8 @@ def interpolate(x, ...@@ -221,7 +221,8 @@ def interpolate(x,
ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear', ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear',
'trilinear', 'bicubic', 'area' or 'nearest' currently. 'trilinear', 'bicubic', 'area' or 'nearest' currently.
ValueError: 'linear' only support 3-D tensor. ValueError: 'linear' only support 3-D tensor.
ValueError: 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. ValueError: 'bilinear' and 'bicubic' only support 4-D tensor.
ValueError: 'nearest' only support 4-D or 5-D tensor.
ValueError: 'trilinear' only support 5-D tensor. ValueError: 'trilinear' only support 5-D tensor.
ValueError: One of size and scale_factor must not be None. ValueError: One of size and scale_factor must not be None.
ValueError: size length should be 1 for input 3-D tensor. ValueError: size length should be 1 for input 3-D tensor.
...@@ -276,9 +277,11 @@ def interpolate(x, ...@@ -276,9 +277,11 @@ def interpolate(x,
if resample in ['LINEAR'] and len(x.shape) != 3: if resample in ['LINEAR'] and len(x.shape) != 3:
raise ValueError("'linear' only support 3-D tensor.") raise ValueError("'linear' only support 3-D tensor.")
if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(x.shape) != 4: if resample in ['NEAREST'] and len(x.shape) != 4 and len(x.shape) != 5:
raise ValueError( raise ValueError("'NEAREST' only support 4-D or 5-D tensor.")
"'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.")
if resample in ['BILINEAR', 'BICUBIC'] and len(x.shape) != 4:
raise ValueError("'bilinear' and 'bicubic' only support 4-D tensor.")
if resample == 'TRILINEAR' and len(x.shape) != 5: if resample == 'TRILINEAR' and len(x.shape) != 5:
raise ValueError("'trilinear'only support 5-D tensor.") raise ValueError("'trilinear'only support 5-D tensor.")
......
...@@ -359,8 +359,9 @@ class Upsample(Layer): ...@@ -359,8 +359,9 @@ class Upsample(Layer):
ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear', ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear',
'trilinear', 'bicubic', or 'nearest' currently. 'trilinear', 'bicubic', or 'nearest' currently.
ValueError: 'linear' only support 3-D tensor. ValueError: 'linear' only support 3-D tensor.
ValueError: 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. ValueError: 'bilinear' and 'bicubic' only support 4-D tensor.
ValueError: 'trilinear' only support 5-D tensor. ValueError: 'trilinear' only support 5-D tensor.
ValueError: 'nearest' only support 4-D or 5-D tensor.
ValueError: One of size and scale_factor must not be None. ValueError: One of size and scale_factor must not be None.
ValueError: size length should be 1 for input 3-D tensor. ValueError: size length should be 1 for input 3-D tensor.
ValueError: size length should be 2 for input 4-D tensor. ValueError: size length should be 2 for input 4-D tensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册