提交 cb8f3c03 编写于 作者: Z Zhang Ting 提交者: hong

resize Ops support data_layout:channel_last, test=develop, test=document_preview (#19914)

上级 9901f696
......@@ -194,11 +194,11 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon'
paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', '49368d724023a66b41b0071be41c0ba5'))
paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '9a7a3b88a4fae41d58d3ca9b10ba0591'))
paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '7e8e4bf1f0f8612961ed113e8af8f0c5'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '0e8567334d72a214c2e3ce0ce19e4d37'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode', 'data_format'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1, 'NCHW')), ('document', 'd29d829607b5ff12924197a3ba296c89'))
paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', 'bd97ebfe4bdf5110a5fcb8ecb626a447'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '0a7b98e57eb74bab6e3c2a95e41298a7'))
paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '6baf2ddf375d3059e5aa74d7fde76517'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '699bf1de6af91235367e9c7a9a6e252c'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode', 'data_format'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1, 'NCHW')), ('document', '44da7890c8a362a83a1c0902a1dc1e4d'))
paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode', 'data_format'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1, 'NCDHW')), ('document', '5b4d0f823f94c260fe5e6f7eec60a797'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'data_format'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 'NCHW')), ('document', '0107a5cbae1aef3f381d3d769a6068eb'))
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c'))
paddle.fluid.layers.gather_nd (ArgSpec(args=['input', 'index', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3cc24f9cf135770aa6263dba25b457f9'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
......
......@@ -19,6 +19,7 @@ namespace paddle {
namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
......@@ -28,6 +29,8 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->HasInputs("SizeTensor")) {
// top prority size
......@@ -38,8 +41,13 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
"Attr(out_shape)'s length must be 2 for 4-D input tensor.");
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h, out_w};
} else {
dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
}
ctx->SetOutputDim("Out", dim_out);
return;
}
......@@ -55,8 +63,12 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
out_h = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale)
: static_cast<int>(dim_x[1] * scale));
out_w = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale)
: static_cast<int>(dim_x[2] * scale));
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
......@@ -75,8 +87,13 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h, out_w};
} else {
dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
}
ctx->SetOutputDim("Out", dim_out);
}
static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
......@@ -86,6 +103,8 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
PADDLE_ENFORCE("trilinear" == interp_method,
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->HasInputs("SizeTensor")) {
// top prority size
......@@ -97,8 +116,13 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
int out_d = ctx->Attrs().Get<int>("out_d");
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
} else {
dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
}
ctx->SetOutputDim("Out", dim_out);
return;
}
......@@ -115,9 +139,15 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_d = static_cast<int>(dim_x[2] * scale);
out_h = static_cast<int>(dim_x[3] * scale);
out_w = static_cast<int>(dim_x[4] * scale);
out_d = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale)
: static_cast<int>(dim_x[1] * scale));
out_h = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale)
: static_cast<int>(dim_x[2] * scale));
out_w = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[4] * scale)
: static_cast<int>(dim_x[3] * scale));
// protect when input shape is -1
out_d = out_d > 0 ? out_d : -1;
out_h = out_h > 0 ? out_h : -1;
......@@ -138,8 +168,13 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
} else {
dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
}
ctx->SetOutputDim("Out", dim_out);
}
class InterpolateOp : public framework::OperatorWithKernel {
......@@ -213,6 +248,13 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
"The output tensor of interpolate operator, "
"This is a tensor in same rank with Input(X).");
AddAttr<std::string>(
"data_layout",
"(string, default NCHW) Only used in "
"an optional string from: \"NHWC\", \"NCHW\". "
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
......
......@@ -22,6 +22,7 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
inline std::vector<int> get_new_shape(
const std::vector<const Tensor*>& list_new_shape_tensor) {
......@@ -57,12 +58,30 @@ inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
return vec_new_data;
}
inline void ExtractNCDWH(const framework::DDim& dims,
const DataLayout& data_layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
if (dims.size() == 4) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
*D = 1;
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
*D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
*W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
}
}
template <typename T>
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int n, const int c,
const int out_h, const int out_w,
const bool align_corners) {
const bool align_corners,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
......@@ -75,7 +94,11 @@ static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
} else {
output_t(i, k, l, j) = input_t(i, in_k, in_l, j);
}
}
}
}
......@@ -88,7 +111,8 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output,
const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w,
const bool align_corners,
const bool align_mode) {
const bool align_mode,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
......@@ -154,11 +178,21 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output,
for (int k = 0; k < out_h; k++) { // loop for images
for (int l = 0; l < out_w; l++) {
// bilinear interpolation
T out_t = input_t(i, j, vy_n[k], vx_w[l]) * vd_s[k] * vd_e[l] +
T out_t;
if (data_layout == DataLayout::kNCHW) {
out_t = input_t(i, j, vy_n[k], vx_w[l]) * vd_s[k] * vd_e[l] +
input_t(i, j, vy_s[k], vx_w[l]) * vd_n[k] * vd_e[l] +
input_t(i, j, vy_n[k], vx_e[l]) * vd_s[k] * vd_w[l] +
input_t(i, j, vy_s[k], vx_e[l]) * vd_n[k] * vd_w[l];
output_t(i, j, k, l) = out_t;
} else {
out_t = input_t(i, vy_n[k], vx_w[l], j) * vd_s[k] * vd_e[l] +
input_t(i, vy_s[k], vx_w[l], j) * vd_n[k] * vd_e[l] +
input_t(i, vy_n[k], vx_e[l], j) * vd_s[k] * vd_w[l] +
input_t(i, vy_s[k], vx_e[l], j) * vd_n[k] * vd_w[l];
output_t(i, k, l, j) = out_t;
}
}
}
}
......@@ -170,7 +204,8 @@ static void TrilinearInterpolation(
const Tensor& input, Tensor* output, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const bool align_mode) {
const int out_w, const bool align_corners, const bool align_mode,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
......@@ -263,6 +298,7 @@ static void TrilinearInterpolation(
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
// trilinear interpolation
if (data_layout == DataLayout::kNCHW) {
T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] *
vd_s[k] * vd_e[l] +
input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] *
......@@ -280,6 +316,25 @@ static void TrilinearInterpolation(
input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] *
vd_n[k] * vd_w[l];
output_t(b, i, j, k, l) = out_t;
} else {
T out_t = input_t(b, vt_f[j], vy_n[k], vx_w[l], i) * vd_b[j] *
vd_s[k] * vd_e[l] +
input_t(b, vt_f[j], vy_n[k], vx_e[l], i) * vd_b[j] *
vd_s[k] * vd_w[l] +
input_t(b, vt_f[j], vy_s[k], vx_w[l], i) * vd_b[j] *
vd_n[k] * vd_e[l] +
input_t(b, vt_f[j], vy_s[k], vx_e[l], i) * vd_b[j] *
vd_n[k] * vd_w[l] +
input_t(b, vt_b[j], vy_n[k], vx_w[l], i) * vd_f[j] *
vd_s[k] * vd_e[l] +
input_t(b, vt_b[j], vy_n[k], vx_e[l], i) * vd_f[j] *
vd_s[k] * vd_w[l] +
input_t(b, vt_b[j], vy_s[k], vx_w[l], i) * vd_f[j] *
vd_n[k] * vd_e[l] +
input_t(b, vt_b[j], vy_s[k], vx_e[l], i) * vd_f[j] *
vd_n[k] * vd_w[l];
output_t(b, j, k, l, i) = out_t;
}
}
}
}
......@@ -291,7 +346,7 @@ template <typename T>
static void NearestNeighborInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
const float ratio_w, const int n, const int c, const int out_h,
const int out_w, const bool align_corners) {
const int out_w, const bool align_corners, const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
......@@ -305,7 +360,11 @@ static void NearestNeighborInterpolateGrad(
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
} else {
input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j);
}
}
}
}
......@@ -313,13 +372,11 @@ static void NearestNeighborInterpolateGrad(
}
template <typename T>
static void BilinearInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h,
const int in_w, const int n, const int c,
const int out_h, const int out_w,
const bool align_corners,
const int align_mode) {
static void BilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w, const bool align_corners,
const int align_mode, const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
......@@ -346,11 +403,19 @@ static void BilinearInterpolationGrad(const Tensor& output_grad,
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
} else {
const T grad = output_grad_t(i, k, l, j);
input_grad_t(i, y_n, x_w, j) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, y_s, x_w, j) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, y_n, x_e, j) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, y_s, x_e, j) += static_cast<T>(grad * d_n * d_w);
}
}
}
}
......@@ -362,7 +427,8 @@ static void TrilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const int align_mode) {
const int out_w, const bool align_corners, const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
......@@ -399,6 +465,7 @@ static void TrilinearInterpolationGrad(
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
// trilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(b, i, j, k, l);
input_grad_t(b, i, t_f, y_n, x_w) +=
static_cast<T>(grad * d_b * d_s * d_e);
......@@ -416,6 +483,25 @@ static void TrilinearInterpolationGrad(
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, i, t_b, y_s, x_e) +=
static_cast<T>(grad * d_f * d_n * d_w);
} else {
const T grad = output_grad_t(b, j, k, l, i);
input_grad_t(b, t_f, y_n, x_w, i) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, t_f, y_n, x_e, i) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, t_f, y_s, x_w, i) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, t_f, y_s, x_e, i) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, t_b, y_n, x_w, i) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, t_b, y_n, x_e, i) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, t_b, y_s, x_w, i) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, t_b, y_s, x_e, i) +=
static_cast<T>(grad * d_f * d_n * d_w);
}
}
}
}
......@@ -426,10 +512,10 @@ static void TrilinearInterpolationGrad(
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -470,7 +556,13 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
} else {
dim_out = {n, out_h, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
......@@ -490,21 +582,21 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
if ("bilinear" == interp_method) {
BilinearInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
out_h, out_w, align_corners, align_mode);
out_h, out_w, align_corners, align_mode,
data_layout);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
out_w, align_corners);
out_w, align_corners, data_layout);
}
}
template <typename T>
static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_d = input.dims()[2];
const int in_h = input.dims()[3];
const int in_w = input.dims()[4];
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -552,7 +644,15 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_d, out_h, out_w};
} else {
dim_out = {n, out_d, out_h, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
......@@ -578,7 +678,7 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
if ("trilinear" == interp_method) {
TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
in_h, in_w, n, c, out_d, out_h, out_w,
align_corners, align_mode);
align_corners, align_mode, data_layout);
}
}
......@@ -586,10 +686,10 @@ template <typename T>
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -623,7 +723,14 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
out_w = new_size[1];
}
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
......@@ -647,10 +754,11 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w, align_corners,
align_mode);
align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
n, c, out_h, out_w, align_corners);
n, c, out_h, out_w, align_corners,
data_layout);
}
}
......@@ -658,11 +766,10 @@ template <typename T>
static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_d = input->dims()[2];
const int in_h = input->dims()[3];
const int in_w = input->dims()[4];
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
......@@ -700,7 +807,13 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
out_w = new_size[2];
}
input_grad->mutable_data<T>({n, c, in_d, in_h, in_w}, ctx.GetPlace());
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_d, in_h, in_w};
} else {
dim_grad = {n, in_d, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
......@@ -727,9 +840,9 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
}
if ("trilinear" == interp_method) {
TrilinearInterpolationGrad<T>(output_grad, input_grad, ratio_d, ratio_h,
ratio_w, in_d, in_h, in_w, n, c, out_d, out_h,
out_w, align_corners, align_mode);
TrilinearInterpolationGrad<T>(
output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n,
c, out_d, out_h, out_w, align_corners, align_mode, data_layout);
}
}
......
......@@ -8019,13 +8019,15 @@ def image_resize(input,
resample='BILINEAR',
actual_shape=None,
align_corners=True,
align_mode=1):
align_mode=1,
data_format='NCHW'):
"""
**Resize a Batch of Images**
The input must be a tensor of the shape (num_batches, channels, in_h, in_w)
or (num_batches, channels, in_d, in_h, in_w), and the resizing only applies
on the last two/three dimensions(depth, hight and width).
The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w)
or (num_batches, in_h, in_w, channels), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, hight and width).
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the
future and only use :attr:`out_shape` instead.
......@@ -8144,16 +8146,13 @@ def image_resize(input,
Args:
input (Variable): The input tensor of image resize layer,
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w) or a
5-D tensor of the shape
(num_batches, channls, in_d, in_h, in_w).
input (Variable): 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_h, out_w) when input is a 4-D tensor and is
(out_d, out_h, out_w) when input is a 5-D tensor. Default: None. If
a list, each element can be an integer or a tensor Variable of shape: [1].
If a tesnosr Variable, its dimensions size should be a 1.
layer, the shape is (out_h, out_w) when input is a 4-D Tensor and is
(out_d, out_h, out_w) when input is a 5-D Tensor. Default: None. If
a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale(float|Variable|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
......@@ -8181,12 +8180,16 @@ def image_resize(input,
Default: True
align_mode(int) : An optional for bilinear interpolation. can be \'0\'
for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for
src_idx = scale*dst_index .
src_idx = scale*dst_index.
data_format(str, optional): NCHW(num_batches, channels, height, width) or
NHWC(num_batches, height, width, channels) for 4-D Tensor,
NCDHW(num_batches, channels, depth, height, width) or
NDHWC(num_batches, depth, height, width, channels) for 5-D Tensor.
Default: 'NCHW'.
Returns:
Variable: The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w) or a 5-D tensor of the shape
(num_batches, channels, out_d, out_h, out_w).
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
Raises:
TypeError: out_shape should be a list or tuple or Variable.
......@@ -8201,6 +8204,7 @@ def image_resize(input,
ValueError: scale should be greater than zero.
TypeError: align_corners shoule be a bool value
ValueError: align_mode can only be '0' or '1'
ValueError: data_format can only be 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'.
Examples:
.. code-block:: python
......@@ -8259,9 +8263,23 @@ def image_resize(input,
helper = LayerHelper('{}_interp'.format(resample_type), **locals())
dtype = helper.input_dtype()
if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 4-D input.")
elif len(input.shape) == 5 and data_format not in ['NCDHW', 'NDHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCDHW` or `NDHWC` supported for 5-D input.")
def _is_list_or_turple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if data_format == 'NCHW' or data_format == 'NCDHW':
data_layout = 'NCHW'
if data_format == 'NHWC' or data_format == 'NDHWC':
data_layout = 'NHWC'
inputs = {"X": input}
attrs = {
"out_d": -1,
......@@ -8269,7 +8287,8 @@ def image_resize(input,
"out_w": -1,
"interp_method": resample_type,
"align_corners": align_corners,
"align_mode": align_mode
"align_mode": align_mode,
"data_layout": data_layout
}
if out_shape is not None:
......@@ -8368,7 +8387,8 @@ def resize_bilinear(input,
name=None,
actual_shape=None,
align_corners=True,
align_mode=1):
align_mode=1,
data_format='NCHW'):
"""
Resize input by performing bilinear interpolation based on given
output shape which specified by actual_shape, out_shape and scale
......@@ -8414,31 +8434,24 @@ def resize_bilinear(input,
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Args:
input(${x_type}): input should be a 4-D tensor of shape
(num_batches, channels, in_h, in_w).
input(${x_type}): 4-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of resize bilinear
layer, the shape is (out_h, out_w).Default: None. If a list, each
element can be an integer or a tensor Variable with shape: [1]. If a
tensor Variable, its dimension size should be 1.
element can be an integer or a Tensor Variable with shape: [1]. If a
Tensor Variable, its dimension size should be 1.
scale(float|Variable|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
......@@ -8455,9 +8468,12 @@ def resize_bilinear(input,
Default: None
align_corners(bool): ${align_corners_comment}
align_mode(bool): ${align_mode_comment}
data_format(str, optional): NCHW(num_batches, channels, height, width) or
NHWC(num_batches, height, width, channels). Default: 'NCHW'.
Returns:
A 4-D tensor in shape of (num_batches, channels, out_h, out_w)
A 4-D Tensor in shape of (num_batches, channels, out_h, out_w) or
(num_batches, out_h, out_w, channels).
Examples:
.. code-block:: python
......@@ -8491,7 +8507,7 @@ def resize_bilinear(input,
"""
return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape,
align_corners, align_mode)
align_corners, align_mode, data_format)
@templatedoc(op_type="trilinear_interp")
......@@ -8501,7 +8517,8 @@ def resize_trilinear(input,
name=None,
actual_shape=None,
align_corners=True,
align_mode=1):
align_mode=1,
data_format='NCDHW'):
"""
Resize input by performing trilinear interpolation based on given
output shape which specified by actual_shape, out_shape and scale
......@@ -8538,6 +8555,7 @@ def resize_trilinear(input,
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
......@@ -8547,7 +8565,6 @@ def resize_trilinear(input,
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
......@@ -8557,22 +8574,17 @@ def resize_trilinear(input,
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Args:
input(${x_type}): input should be a 5-D tensor of shape
(num_batches, channls, in_d, in_h, in_w).
input(${x_type}): 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of resize bilinear
layer, the shape is (out_d, out_h, out_w). Default: None. If a list,
each element can be an integer or a tensor Variable with shape: [1]. If
a tensor Variable, its dimension size should be 1.
each element can be an integer or a Tensor Variable with shape: [1]. If
a Tensor Variable, its dimension size should be 1.
scale(float|Variable|None): The multiplier for the input depth, height or width.
At least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
......@@ -8589,9 +8601,13 @@ def resize_trilinear(input,
Default: None
align_corners(bool): ${align_corners_comment}
align_mode(bool): ${align_mode_comment}
data_format(str, optional): NCDHW(num_batches, channels, depth, height, width) or
NDHWC(num_batches, depth, height, width, channels).
Default: 'NCDHW'.
Returns:
A 5-D tensor in shape (num_batches, channels, out_d, out_h, out_w)
A 5-D Tensor in shape of (num_batches, channels, out_d, out_h, out_w) or
(num_batches, out_d, out_h, out_w, channels).
Examples:
.. code-block:: python
......@@ -8622,11 +8638,10 @@ def resize_trilinear(input,
scale_tensor = fluid.layers.data(name="scale", shape=[1], dtype="float32", append_batch_size=False)
out4 = fluid.layers.resize_trilinear(input, scale=scale_tensor)
# out4.shape = [-1, 3, -1, -1, -1]
"""
return image_resize(input, out_shape, scale, name, 'TRILINEAR',
actual_shape, align_corners, align_mode)
actual_shape, align_corners, align_mode, data_format)
@templatedoc(op_type="nearest_interp")
......@@ -8635,12 +8650,12 @@ def resize_nearest(input,
scale=None,
name=None,
actual_shape=None,
align_corners=True):
align_corners=True,
data_format='NCHW'):
"""
Resize input by performing nearest neighbor interpolation in both the
3rd dimension(in height direction) and the 4th dimension(in width
direction) based on given output shape which is specified by actual_shape,
out_shape and scale in priority order.
height direction and the width direction based on given output shape
which is specified by actual_shape, out_shape and scale in priority order.
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the
future and only use :attr:`out_shape` instead.
......@@ -8652,14 +8667,12 @@ def resize_nearest(input,
For scale:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Nearest neighbor interpolation:
if:
......@@ -8685,19 +8698,16 @@ def resize_nearest(input,
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
Args:
input(${x_type}): input should be a 4-D tensor of shape
(num_batches, channls, in_h, in_w).
input(${x_type}): 4-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of resize nearest
layer, the shape is (out_h, out_w). Default: None. If a list, each
element can be integer or a tensor Variable with shape: [1]. If a
tensor Variable, its dimension size should be 1.
scale(float|Variable|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
......@@ -8713,9 +8723,13 @@ def resize_nearest(input,
errors would be occured in graph constructing stage.
Default: None
align_corners(bool): ${align_corners_comment}
data_format(str, optional): NCHW(num_batches, channels, height, width) or
NHWC(num_batches, height, width, channels).
Default: 'NCHW'.
Returns:
A 4-D tensor in shape of (num_batches, channels, out_h, out_w)
A 4-D Tensor in shape of (num_batches, channels, out_h, out_w) or
(num_batches, out_h, out_w, channels).
Examples:
.. code-block:: python
......@@ -8746,11 +8760,18 @@ def resize_nearest(input,
scale_tensor = fluid.layers.data(name="scale", shape=[1], dtype="float32", append_batch_size=False)
out4 = fluid.layers.resize_nearest(input, scale=scale_tensor)
# out4.shape = [-1, 3, -1, -1]
"""
return image_resize(input, out_shape, scale, name, 'NEAREST', actual_shape,
align_corners)
return image_resize(
input,
out_shape,
scale,
name,
'NEAREST',
actual_shape,
align_corners,
align_mode=1,
data_format=data_format)
def image_resize_short(input, out_short_len, resample='BILINEAR'):
......
......@@ -27,8 +27,11 @@ def bilinear_interp_np(input,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0):
align_mode=0,
data_layout='NCHW'):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
......@@ -83,6 +86,10 @@ def bilinear_interp_np(input,
w1lambda*input[:, :, h, w+wid]) + \
h1lambda*(w2lambda*input[:, :, h+hid, w] +
w1lambda*input[:, :, h+hid, w+wid])
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(input.dtype)
......@@ -90,20 +97,28 @@ class TestBilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "bilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
out_h = int(in_h * self.scale)
out_w = int(in_w * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode)
self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
......@@ -116,7 +131,8 @@ class TestBilinearInterpOp(OpTest):
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
'align_mode': self.align_mode,
'data_layout': self.data_layout
}
self.outputs = {'Out': output_np}
......@@ -229,6 +245,19 @@ class TestBilinearInterpActualShape(TestBilinearInterpOp):
self.align_mode = 1
class TestBilinearInterpDataLayout(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 4, 4, 3]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = "NHWC"
class TestBilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
......
......@@ -26,8 +26,11 @@ def nearest_neighbor_interp_np(X,
out_w,
out_size=None,
actual_shape=None,
align_corners=True):
align_corners=True,
data_layout='NCHW'):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
X = np.transpose(X, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
......@@ -63,6 +66,9 @@ def nearest_neighbor_interp_np(X,
in_j = int(ratio_w * j)
out[:, :, i, j] = X[:, :, in_i, in_j]
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(X.dtype)
......@@ -70,20 +76,28 @@ class TestNearestInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "nearest_interp"
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
out_h = int(in_h * self.scale)
out_w = int(in_w * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners)
output_np = nearest_neighbor_interp_np(
input_np, out_h, out_w, self.out_size, self.actual_shape,
self.align_corners, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
......@@ -95,6 +109,7 @@ class TestNearestInterpOp(OpTest):
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'data_layout': self.data_layout
}
self.outputs = {'Out': output_np}
......@@ -198,6 +213,18 @@ class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
self.align_corners = True
class TestNearestNeighborInterpDataLayout(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 4, 4, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 8]).astype("int32")
self.align_corners = True
self.data_layout = "NHWC"
class TestNearestInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
......@@ -399,6 +426,7 @@ class TestNearestInterp_attr_tensor_Case3(TestNearestInterpOp_attr_tensor):
class TestNearestAPI(OpTest):
def test_case(self):
x = fluid.layers.data(name="x", shape=[3, 6, 6], dtype="float32")
y = fluid.layers.data(name="y", shape=[6, 6, 3], dtype="float32")
dim = fluid.layers.data(
name="dim", shape=[1], dtype="int32", append_batch_size=False)
......@@ -418,7 +446,8 @@ class TestNearestAPI(OpTest):
dtype="float32",
append_batch_size=False)
out1 = fluid.layers.resize_nearest(x, out_shape=[12, 12])
out1 = fluid.layers.resize_nearest(
y, out_shape=[12, 12], data_format='NHWC')
out2 = fluid.layers.resize_nearest(x, out_shape=[12, dim])
out3 = fluid.layers.resize_nearest(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_nearest(
......@@ -436,6 +465,7 @@ class TestNearestAPI(OpTest):
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": np.transpose(x_data, (0, 2, 3, 1)),
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
......@@ -446,8 +476,20 @@ class TestNearestAPI(OpTest):
expect_res = nearest_neighbor_interp_np(
x_data, out_h=12, out_w=12, align_corners=True)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
self.assertTrue(
np.allclose(results[0], np.transpose(expect_res, (0, 2, 3, 1))))
for i in range(len(results) - 1):
self.assertTrue(np.allclose(results[i + 1], expect_res))
def test_exception(self):
# for 4-D input, data_format can only be NCHW or NHWC
input = fluid.layers.data(
name="input", shape=[3, 6, 6], dtype="float32")
try:
out = fluid.layers.resize_nearest(
input, out_shape=[4, 8], data_format='NDHWC')
except:
pass
if __name__ == "__main__":
......
......@@ -28,8 +28,11 @@ def trilinear_interp_np(input,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0):
align_mode=0,
data_layout='NCDHW'):
"""trilinear interpolation implement in shape [N, C, D, H, W]"""
if data_layout == "NDHWC":
input = np.transpose(input, (0, 4, 1, 2, 3)) # NDHWC => NCDHW
if out_size is not None:
out_d = out_size[0]
out_h = out_size[1]
......@@ -114,6 +117,9 @@ def trilinear_interp_np(input,
w1lambda * input[:, :, d+did, h, w+wid]) + \
h1lambda * (w2lambda * input[:, :, d+did, h+hid, w] + \
w1lambda * input[:, :, d+did, h+hid, w+wid]))
if data_layout == "NDHWC":
out = np.transpose(out, (0, 2, 3, 4, 1)) # NCDHW => NDHWC
return out.astype(input.dtype)
......@@ -121,28 +127,42 @@ class TestTrilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCDHW'
self.init_test_case()
self.op_type = "trilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCDHW":
in_d = self.input_shape[2]
in_h = self.input_shape[3]
in_w = self.input_shape[4]
else:
in_d = self.input_shape[1]
in_h = self.input_shape[2]
in_w = self.input_shape[3]
if self.scale > 0:
out_d = int(self.input_shape[2] * self.scale)
out_h = int(self.input_shape[3] * self.scale)
out_w = int(self.input_shape[4] * self.scale)
out_d = int(in_d * self.scale)
out_h = int(in_h * self.scale)
out_w = int(in_w * self.scale)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
output_np = trilinear_interp_np(
input_np, out_d, out_h, out_w, self.out_size, self.actual_shape,
self.align_corners, self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
# c++ end treat NCDHW the same way as NCHW
if self.data_layout == 'NCDHW':
data_layout = 'NCHW'
else:
data_layout = 'NHWC'
self.attrs = {
'out_d': self.out_d,
'out_h': self.out_h,
......@@ -150,7 +170,8 @@ class TestTrilinearInterpOp(OpTest):
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
'align_mode': self.align_mode,
'data_layout': data_layout
}
self.outputs = {'Out': output_np}
......@@ -284,6 +305,20 @@ class TestTrilinearInterpActualShape(TestTrilinearInterpOp):
self.align_mode = 1
class TestTrilinearInterpDatalayout(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 4, 4, 4, 3]
self.out_d = 2
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = "NDHWC"
class TestTrilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
......@@ -536,6 +571,7 @@ class TestTrilinearInterp_attr_tensor_Case3(TestTrilinearInterpOp_attr_tensor):
class TestTrilinearInterpAPI(OpTest):
def test_case(self):
x = fluid.layers.data(name="x", shape=[3, 6, 9, 4], dtype="float32")
y = fluid.layers.data(name="y", shape=[6, 9, 4, 3], dtype="float32")
dim = fluid.layers.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.layers.data(
......@@ -554,7 +590,8 @@ class TestTrilinearInterpAPI(OpTest):
dtype="float32",
append_batch_size=False)
out1 = fluid.layers.resize_trilinear(x, out_shape=[12, 18, 8])
out1 = fluid.layers.resize_trilinear(
y, out_shape=[12, 18, 8], data_format='NDHWC')
out2 = fluid.layers.resize_trilinear(x, out_shape=[12, dim, 8])
out3 = fluid.layers.resize_trilinear(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_trilinear(
......@@ -572,6 +609,7 @@ class TestTrilinearInterpAPI(OpTest):
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": np.transpose(x_data, (0, 2, 3, 4, 1)),
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
......@@ -582,8 +620,20 @@ class TestTrilinearInterpAPI(OpTest):
expect_res = trilinear_interp_np(
x_data, out_d=12, out_h=18, out_w=8, align_mode=1)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
self.assertTrue(
np.allclose(results[0], np.transpose(expect_res, (0, 2, 3, 4, 1))))
for i in range(len(results) - 1):
self.assertTrue(np.allclose(results[i + 1], expect_res))
def test_exception(self):
input = fluid.layers.data(
name="input", shape=[3, 6, 9, 4], dtype="float32")
try:
# for 5-D input, data_format only can be NCDHW or NDHWC
out = fluid.layers.resize_trilinear(
input, out_shape=[4, 8, 4], data_format='NHWC')
except:
pass
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册