未验证 提交 cf6f28f9 编写于 作者: X xiaoting 提交者: GitHub

[Cherry-pick Release 2.0] Add `nn.interpolate ` (#23434) (#23843)

* Add `nn.interpolate ` (#23434)
上级 5743cb83
......@@ -26,7 +26,8 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"bilinear" == interp_method || "nearest" == interp_method ||
"bicubic" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
const DataLayout data_layout = framework::StringToDataLayout(
......@@ -264,7 +265,8 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
"method, can be \"bilinear\" for "
"bilinear interpolation, \"trilinear\" for trilinear "
"interpolation and \"nearest\" for nearest "
"neighbor interpolation.")
"neighbor interpolation, and \"bicubic\" for bicubic"
"interpolation.")
.SetDefault("bilinear");
AddAttr<bool>(
"align_corners",
......@@ -299,6 +301,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Bicubic interpolation is an extension of cubic interpolation for interpolating
data points on a two-dimensional regular grid. The interpolated surface is
smoother than corresponding surfaces obtained by bilinear interpolation or
nearest-neighbor interpolation.
Align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
......@@ -377,6 +384,19 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Bicubic interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
......@@ -386,6 +406,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation
For details of bicubic interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bicubic_interpolation
)DOC");
}
};
......@@ -469,6 +492,11 @@ REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
......@@ -484,3 +512,7 @@ REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
......@@ -506,6 +506,206 @@ __global__ void KeTrilinearInterpBw(
}
}
template <typename T>
__device__ __forceinline__ static T Kecubic_interp(const T x0, const T x1,
const T x2, const T x3,
T t) {
T coeffs[4];
T a = -0.75;
T x_1 = t;
T x_2 = 1.0 - t;
coeffs[0] = cubic_convolution2<T>(x_1 + 1.0, a);
coeffs[1] = cubic_convolution1<T>(x_1, a);
coeffs[2] = cubic_convolution1<T>(x_2, a);
coeffs[3] = cubic_convolution2<T>(x_2 + 1.0, a);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
__global__ void KeBicubicInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T coefficients[4];
const T* in_pos_0;
const T* in_pos_1;
const T* in_pos_2;
const T* in_pos_3;
int access_x_0;
if (data_layout == DataLayout::kNCHW) {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>(in_img_h - 1)), 0);
access_x_0 = max(min(input_x - 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>(in_img_w - 1)), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>(in_img_w - 1)), 0);
in_pos_0 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_0];
in_pos_1 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_1];
in_pos_2 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_2];
in_pos_3 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_3];
coefficients[k] = Kecubic_interp<T>(in_pos_0[0], in_pos_1[0],
in_pos_2[0], in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
Kecubic_interp<T>(coefficients[0], coefficients[1], coefficients[2],
coefficients[3], y_t);
} else {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>((in_img_h - 1))), 0);
int access_x_0 =
max(min(input_x - 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>((in_img_w - 1))), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>((in_img_w - 1))), 0);
const T* in_pos_0 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_0 * num_channels + channel_id];
const T* in_pos_1 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_1 * num_channels + channel_id];
const T* in_pos_2 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_2 * num_channels + channel_id];
const T* in_pos_3 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_3 * num_channels + channel_id];
coefficients[k] = Kecubic_interp(in_pos_0[0], in_pos_1[0], in_pos_2[0],
in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
static_cast<T>(Kecubic_interp(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t));
}
}
}
template <typename T>
__global__ void KeBicubicInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients(x_coeffs, x_t);
get_cubic_upsample_coefficients(y_coeffs, y_t);
const T* out_pos = &out[out_id_h * output_w + out_id_w];
T* in_pos;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
int access_y = max(min(static_cast<int>(input_y - 1 + j),
static_cast<int>(in_img_h - 1)),
0);
int access_x = max(min(static_cast<int>(input_x - 1 + i),
static_cast<int>(in_img_w - 1)),
0);
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x];
} else {
in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x * num_channels + channel_id];
}
platform::CudaAtomicAdd(&in_pos[0],
(out_pos[0] * y_coeffs[j] * x_coeffs[i]));
}
}
}
}
template <typename T>
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
......@@ -602,6 +802,11 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpFw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
......@@ -806,6 +1011,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpBw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
......@@ -968,3 +1178,9 @@ REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
......@@ -10,10 +10,12 @@
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
......@@ -342,6 +344,106 @@ static void TrilinearInterpolation(
}
}
template <typename T>
HOSTDEVICE inline T cubic_convolution1(T x, T A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename T>
HOSTDEVICE inline T cubic_convolution2(T x, T A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
T A = -0.75;
T x1 = t;
coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
coeffs[1] = cubic_convolution1<T>(x1, A);
// opposite coefficients
T x2 = 1.0 - t;
coeffs[2] = cubic_convolution1<T>(x2, A);
coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
}
template <typename T>
static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
T coeffs[4];
get_cubic_upsample_coefficients<T>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
static void BicubicInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
const T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = floorf(x_n);
const T x_t = x_n - input_x;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
T coefficients[4];
// interp 4 times in x direction
for (int ii = 0; ii < 4; ii++) {
int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
static_cast<int>(0));
int access_x_0 =
std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
int access_x_1 =
std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
int access_x_2 =
std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
int access_x_3 =
std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
coefficients[ii] =
cubic_interp<T>(input_t(i, j, access_y, access_x_0),
input_t(i, j, access_y, access_x_1),
input_t(i, j, access_y, access_x_2),
input_t(i, j, access_y, access_x_3), x_t);
} else {
coefficients[ii] =
cubic_interp<T>(input_t(i, access_y, access_x_0, j),
input_t(i, access_y, access_x_1, j),
input_t(i, access_y, access_x_2, j),
input_t(i, access_y, access_x_3, j), x_t);
}
}
// interp y direction
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
} else {
output_t(i, k, l, j) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
}
}
}
}
}
}
template <typename T>
static void NearestNeighborInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
......@@ -509,6 +611,61 @@ static void TrilinearInterpolationGrad(
}
}
template <typename T>
static void BicubicInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h,
const int in_w, const int n, const int c,
const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = floorf(x_n);
T x_t = x_n - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
get_cubic_upsample_coefficients<T>(y_coeffs, y_t);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bicubic interpolation grad
for (int ii = 0; ii < 4; ii++) {
for (int jj = 0; jj < 4; jj++) {
int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
static_cast<int>(0));
int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, access_y, access_x) +=
grad * y_coeffs[jj] * x_coeffs[ii];
} else {
T grad = output_grad_t(i, k, l, j);
input_grad_t(i, access_y, access_x, j) +=
grad * y_coeffs[jj] * x_coeffs[ii];
}
}
}
}
}
}
}
}
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
......@@ -587,6 +744,9 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
out_w, align_corners, data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
out_h, out_w, align_corners, data_layout);
}
}
......@@ -759,6 +919,10 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
n, c, out_h, out_w, align_corners,
data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w, in_h,
in_w, n, c, out_h, out_w, align_corners,
data_layout);
}
}
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle
from paddle.fluid import Program, program_guard
from paddle.nn.functional import *
def cubic_1(x, a):
return ((a + 2) * x - (a + 3)) * x * x + 1
def cubic_2(x, a):
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a
def cubic_interp1d(x0, x1, x2, x3, t):
param = [0, 0, 0, 0]
a = -0.75
x_1 = t
x_2 = 1.0 - t
param[0] = cubic_2(x_1 + 1.0, a)
param[1] = cubic_1(x_1, a)
param[2] = cubic_1(x_2, a)
param[3] = cubic_2(x_2 + 1.0, a)
return x0 * param[0] + x1 * param[1] + x2 * param[2] + x3 * param[3]
def value_bound(input, w, h, x, y):
access_x = int(max(min(x, w - 1), 0))
access_y = int(max(min(y, h - 1), 0))
return input[:, :, access_y, access_x]
def bicubic_interp_np(input,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
data_layout='kNCHW'):
"""trilinear interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
batch_size, channel, in_h, in_w = input.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_h, out_w))
for k in range(out_h):
if (align_corners):
h = ratio_h * k
else:
h = ratio_h * (k + 0.5) - 0.5
input_y = np.floor(h)
y_t = h - input_y
for l in range(out_w):
if (align_corners):
w = ratio_w * l
else:
w = ratio_w * (l + 0.5) - 0.5
input_x = np.floor(w)
x_t = w - input_x
for i in range(batch_size):
for j in range(channel):
coefficients = [0, 0, 0, 0]
for ii in range(4):
access_x_0 = int(max(min(input_x - 1, in_w - 1), 0))
access_x_1 = int(max(min(input_x + 0, in_w - 1), 0))
access_x_2 = int(max(min(input_x + 1, in_w - 1), 0))
access_x_3 = int(max(min(input_x + 2, in_w - 1), 0))
access_y = int(max(min(input_y - 1 + ii, in_h - 1), 0))
coefficients[ii] = cubic_interp1d(
input[i, j, access_y, access_x_0],
input[i, j, access_y, access_x_1],
input[i, j, access_y, access_x_2],
input[i, j, access_y, access_x_3], x_t)
out[i, j, k, l] = cubic_interp1d(
coefficients[0], coefficients[1], coefficients[2],
coefficients[3], y_t)
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(input.dtype)
class TestBicubicInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "bicubic_interp"
input_np = np.random.random(self.input_shape).astype("float64")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale > 0:
out_h = int(in_h * self.scale)
out_w = int(in_w * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bicubic_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'data_layout': self.data_layout
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 3, 5, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
class TestBicubicInterpCase1(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
class TestBicubicInterpCase2(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [3, 3, 9, 6]
self.out_h = 10
self.out_w = 8
self.scale = 0.
self.align_corners = True
class TestBicubicInterpCase3(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.align_corners = False
class TestBicubicInterpCase4(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True
class TestBicubicInterpCase5(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [3, 3, 9, 6]
self.out_h = 11
self.out_w = 11
self.scale = 0.
self.out_size = np.array([6, 4]).astype("int32")
self.align_corners = False
class TestBicubicInterpCase6(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0
self.out_size = np.array([64, 32]).astype("int32")
self.align_corners = False
class TestBicubicInterpSame(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
self.align_corners = True
class TestBicubicInterpDataLayout(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 5, 5, 3]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.data_layout = "NHWC"
class TestBicubicInterpOpAPI(unittest.TestCase):
def test_case(self):
x_data = np.random.random((2, 3, 6, 6)).astype("float32")
dim_data = np.array([12]).astype("int32")
shape_data = np.array([12, 12]).astype("int32")
actual_size_data = np.array([12, 12]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.data(
name="shape_tensor", shape=[2], dtype="int32")
actual_size = fluid.data(
name="actual_size", shape=[2], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
out1 = interpolate(
x, out_shape=[12, 12], resample='BICUBIC', align_corners=False)
out2 = interpolate(
x, out_shape=[12, dim], resample='BICUBIC', align_corners=False)
out3 = interpolate(
x,
out_shape=shape_tensor,
resample='BICUBIC',
align_corners=False)
out4 = interpolate(
x,
out_shape=[4, 4],
actual_shape=actual_size,
resample='BICUBIC',
align_corners=False)
out5 = interpolate(
x, scale=scale_tensor, resample='BICUBIC', align_corners=False)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = bicubic_interp_np(
x_data, out_h=12, out_w=12, align_corners=False)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(x_data)
interp = interpolate(
x, out_shape=[12, 12], resample='BICUBIC', align_corners=False)
dy_result = interp.numpy()
expect = bicubic_interp_np(
x_data, out_h=12, out_w=12, align_corners=False)
self.assertTrue(np.allclose(dy_result, expect))
class TestBicubicOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of interpoalte must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
self.assertRaises(TypeError, interpolate, x1)
def test_mode_type():
# mode must be "BILINEAR" "TRILINEAR" "NEAREST" "BICUBIC"
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
out_shape=[12, 12],
resample='UNKONWN',
align_corners=False)
def test_input_shape():
x = fluid.data(name="x", shape=[2], dtype="float32")
out = interpolate(
x,
out_shape=[12, 12],
resample='BICUBIC',
align_corners=False)
def test_align_corcers():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
interpolate(
x, out_shape=[12, 12], resample='BICUBIC', align_corners=3)
def test_out_shape():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x, out_shape=[12], resample='BICUBIC', align_corners=False)
def test_attr_data_format():
# for 5-D input, data_format only can be NCDHW or NDHWC
input = fluid.data(
name="input", shape=[2, 3, 6, 9, 4], dtype="float32")
out = interpolate(
input,
out_shape=[4, 8, 4, 5],
resample='TRILINEAR',
data_format='NHWC')
def test_actual_shape():
# the actual_shape must be Variable.
x = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
out = interpolate(
x,
out_shape=[12, 12],
resample='BICUBIC',
align_corners=False)
def test_scale_value():
# the scale must be greater than zero.
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
out_shape=None,
resample='BICUBIC',
align_corners=False,
scale=-2.0)
def test_attr_5D_input():
# for 5-D input, data_format only can be NCDHW or NDHWC
input = fluid.data(
name="input", shape=[2, 3, 6, 9, 4], dtype="float32")
out = interpolate(
input,
out_shape=[4, 8, 4, 5],
resample='TRILINEAR',
data_format='NDHWC')
def test_scale_type():
# the scale must be greater than zero.
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
scale = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
out = interpolate(
x,
out_shape=None,
resample='BICUBIC',
align_corners=False,
scale=scale)
def test_align_mode():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
out_shape=None,
resample='NEAREST',
align_corners=False,
align_mode=2,
scale=1.0)
def test_outshape_and_scale():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
out_shape=None,
resample='BICUBIC',
align_corners=False,
scale=None)
self.assertRaises(ValueError, test_mode_type)
self.assertRaises(ValueError, test_input_shape)
self.assertRaises(TypeError, test_align_corcers)
self.assertRaises(ValueError, test_attr_data_format)
self.assertRaises(TypeError, test_actual_shape)
self.assertRaises(ValueError, test_scale_value)
self.assertRaises(ValueError, test_out_shape)
self.assertRaises(ValueError, test_attr_5D_input)
self.assertRaises(TypeError, test_scale_type)
self.assertRaises(ValueError, test_align_mode)
self.assertRaises(ValueError, test_outshape_and_scale)
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,7 @@ import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.nn.functional import *
def trilinear_interp_np(input,
......@@ -586,6 +587,15 @@ class TestTrilinearInterpAPI(unittest.TestCase):
out4 = fluid.layers.resize_trilinear(
x, out_shape=[4, 4, 8], actual_shape=actual_size)
out5 = fluid.layers.resize_trilinear(x, scale=scale_tensor)
out6 = interpolate(
x, scale=scale_tensor, resample='TRILINEAR', data_format="NCDHW")
out7 = interpolate(
x, out_shape=[4, 4, 8], resample='TRILINEAR', data_format="NCDHW")
out8 = interpolate(
x,
out_shape=shape_tensor,
resample='TRILINEAR',
data_format="NCDHW")
x_data = np.random.random((2, 3, 6, 9, 4)).astype("float32")
dim_data = np.array([18]).astype("int32")
......
......@@ -187,4 +187,4 @@ from .extension import row_conv #DEFINE_ALIAS
# from .common import unfold #DEFINE_ALIAS
# from .common import bilinear_tensor_product #DEFINE_ALIAS
# from .common import assign #DEFINE_ALIAS
# from .common import interpolate #DEFINE_ALIAS
from .common import interpolate #DEFINE_ALIAS
......@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable, fill_constant
# TODO: define the common functions to build a neural network
# __all__ = ['dropout',
# 'embedding',
......@@ -25,3 +29,395 @@
# 'bilinear_tensor_product',
# 'assign',
# 'interpolate']
__all__ = ['interpolate']
def interpolate(input,
out_shape=None,
scale=None,
name=None,
resample='BILINEAR',
actual_shape=None,
align_corners=True,
align_mode=1,
data_format='NCHW'):
"""
This op resizes a batch of images.
The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w)
or (num_batches, in_h, in_w, channels), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, height and width).
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the
future and only use :attr:`out_shape` instead.
Supporting resample methods:
'BILINEAR' : Bilinear interpolation
'TRILINEAR' : Trilinear interpolation
'NEAREST' : Nearest neighbor interpolation
'BICUBIC' : Bicubic interpolation
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
Bicubic interpolation is an extension of cubic interpolation for interpolating
data points on a two-dimensional regular grid. The interpolated surface is
smoother than corresponding surfaces obtained by bilinear interpolation or
nearest-neighbor interpolation.
Example:
.. code-block:: text
For scale:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Nearest neighbor interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = floor (H_{in} * scale_{factor})
W_out = floor (W_{in} * scale_{factor})
else:
align_corners = True
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = round(H_{in} * scale_{factor})
W_out = round(W_{in} * scale_{factor})
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Bicubic interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation.
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation.
For details of bicubic interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bicubic_interpolation
Parameters:
input (Variable): 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_h, out_w) when input is a 4-D Tensor and is
(out_d, out_h, out_w) when input is a 5-D Tensor. Default: None. If
a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale(float|Variable|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
resample(str): The resample method. It supports 'BILINEAR', 'TRILINEAR' ,
'BICUBIC' and 'NEAREST' currently. Default: 'BILINEAR'
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
highest priority. It is recommended to use
:attr:`out_shape` if you want to specify output
shape dynamically, because :attr:`actual_shape`
will be deprecated. When using actual_shape to
specify output shape, one of :attr:`out_shape`
and :attr:`scale` should also be set, otherwise
errors would be occurred in graph constructing stage.
Default: None
align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the
input and output tensors are aligned, preserving the values at the
corner pixels.
Default: True
align_mode(int) : An optional for bilinear interpolation. can be \'0\'
for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for
src_idx = scale*dst_index.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`, `"NCDHW"`,
`"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
Returns:
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
Raises:
TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR',
'TRILINEAR', 'BICUBIC', or 'NEAREST' currently.
ValueError: 'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.
ValueError: 'TRILINEAR' only support 5-D tensor.
ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2 for input 4-D tensor.
ValueError: out_shape length should be 3 for input 5-D tensor.
ValueError: scale should be greater than zero.
TypeError: align_corners should be a bool value
ValueError: align_mode can only be '0' or '1'
ValueError: data_format can only be 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'.
Examples:
.. code-block:: python
#declarative mode
import paddle
import numpy as np
input = fluid.data(name="input", shape=[None,3,6,10])
#1
output = paddle.nn.functional.interpolate(input=input,out_shape=[12,12])
#2
#x = np.array([2]).astype("int32")
#dim1 = fluid.data(name="dim1", shape=[1], dtype="int32")
#fluid.layers.assign(input=x, output=dim1)
#output = paddle.nn.functional.interpolate(input=input,out_shape=[12,dim1])
#3
#x = np.array([3,12]).astype("int32")
#shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32")
#fluid.layers.assign(input=x, output=shape_tensor)
#output = paddle.nn.functional.interpolate(input=input,out_shape=shape_tensor)
#4
#x = np.array([0.5]).astype("float32")
#scale_tensor = fluid.data(name="scale", shape=[1], dtype="float32")
#fluid.layers.assign(x,scale_tensor)
#output = paddle.nn.functional.interpolate(input=input,scale=scale_tensor)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.random.rand(2,3,6,10).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data},
fetch_list=[output],
return_numpy=True)
print(output_data[0].shape)
#1
# (2, 3, 12, 12)
#2
# (2, 3, 12, 2)
#3
# (2, 3, 3, 12)
#4
# (2, 3, 3, 5)
#imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = paddle.nn.functional.interpolate(input=input, out_shape=[12,12])
print(output.shape)
# [2L, 3L, 12L, 12L]
"""
resample_methods = {
'BILINEAR': 'bilinear',
'TRILINEAR': 'trilinear',
'NEAREST': 'nearest',
'BICUBIC': 'bicubic',
}
if resample not in resample_methods:
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR', "
" 'BICUBIC' or 'NEAREST' currently.")
resample_type = resample_methods[resample]
if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(input.shape) != 4:
raise ValueError(
"'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.")
if resample == 'TRILINEAR' and len(input.shape) != 5:
raise ValueError("'TRILINEAR'only support 5-D tensor.")
if not isinstance(align_corners, bool):
raise TypeError("Attr align_corners should be a bool value")
if align_mode != 0 and align_mode != 1:
raise ValueError("align_mode can only be 0 or 1")
if out_shape is None and scale is None:
raise ValueError("One of out_shape and scale must not be None.")
helper = LayerHelper('{}_interp'.format(resample_type), **locals())
dtype = helper.input_dtype()
if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 4-D input.")
elif len(input.shape) == 5 and data_format not in ['NCDHW', 'NDHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCDHW` or `NDHWC` supported for 5-D input.")
def _is_list_or_turple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if data_format == 'NCHW' or data_format == 'NCDHW':
data_layout = 'NCHW'
if data_format == 'NHWC' or data_format == 'NDHWC':
data_layout = 'NHWC'
inputs = {"X": input}
attrs = {
"out_d": -1,
"out_h": -1,
"out_w": -1,
"interp_method": resample_type,
"align_corners": align_corners,
"align_mode": align_mode,
"data_layout": data_layout
}
if out_shape is not None:
if isinstance(out_shape, Variable):
out_shape.stop_gradient = True
inputs['OutSize'] = out_shape
else:
if not (_is_list_or_turple_(out_shape)):
raise TypeError(
"out_shape should be a list or tuple or Variable.")
# Validate the shape
contain_var = False
for dim_idx, dim_size in enumerate(out_shape):
if isinstance(dim_size, Variable):
contain_var = True
continue
assert dim_size > 0, (
"Each dimension size given in out_shape must be greater than 0."
)
if contain_var:
new_size_tensor = []
size_list = []
for dim in out_shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_size_tensor.append(dim)
size_list.append(-1)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference(
'int32')
fill_constant(
[1], 'int32', dim, force_cpu=True, out=temp_out)
new_size_tensor.append(temp_out)
size_list.append(dim)
inputs['SizeTensor'] = new_size_tensor
if len(input.shape) == 4:
if len(out_shape) != 2:
raise ValueError("out_shape length should be 2 for "
"input 4-D tensor.")
if contain_var:
attrs['out_h'] = size_list[0]
attrs['out_w'] = size_list[1]
else:
out_shape = list(map(int, out_shape))
attrs['out_h'] = out_shape[0]
attrs['out_w'] = out_shape[1]
if len(input.shape) == 5:
if len(out_shape) != 3:
raise ValueError("out_shape length should be 3 for "
"input 5-D tensor.")
if contain_var:
attrs['out_d'] = size_list[0]
attrs['out_h'] = size_list[1]
attrs['out_w'] = size_list[2]
else:
out_shape = list(map(int, out_shape))
attrs['out_d'] = out_shape[0]
attrs['out_h'] = out_shape[1]
attrs['out_w'] = out_shape[2]
else:
if isinstance(scale, Variable):
scale.stop_gradient = True
inputs["Scale"] = scale
elif isinstance(scale, float) or isinstance(scale, int):
if scale <= 0:
raise ValueError("Attr(scale) should be greater than zero.")
attrs['scale'] = float(scale)
else:
raise TypeError(
"Attr(scale)'s type should be float, int or Variable.")
if isinstance(actual_shape, Variable):
warnings.warn(
"actual_shape will be deprecated, it is recommended to use "
"out_shape instead of actual_shape to specify output shape dynamically."
)
actual_shape.stop_gradient = True
inputs["OutSize"] = actual_shape
elif actual_shape is not None:
raise TypeError("actual_shape should either be Variable or None.")
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='{}_interp'.format(resample_type),
inputs=inputs,
outputs={"Out": out},
attrs=attrs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册