未验证 提交 f86fead6 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add trilinear_interp OP (#18711)

* add trilinear interp. test=develop

* fix unittest. test=develop

* add python api and test_layers. test=develop

* refine API.spec. test=develop

* fix format. test=develop

* add python API test. test=develop

* format code. test=develop

* refine code strcuture. test=develop

* fix format

* fix doc. test=develop

* fix converage. test=develop

* fix format. test=develop
上级 c2063217
......@@ -173,10 +173,11 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon'
paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', '49368d724023a66b41b0071be41c0ba5'))
paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '9a7a3b88a4fae41d58d3ca9b10ba0591'))
paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '7e8e4bf1f0f8612961ed113e8af8f0c5'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', 'a29488d94d9a4bc4434d8a3529b4c6fe'))
paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '8cfc4f69dbbedb687b6c20732aa8f09e'))
paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', 'bd97ebfe4bdf5110a5fcb8ecb626a447'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '548c7c2ead5771d15abbaad505f901e9'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', 'b7d810d1e251c5957c1efa6aa699d2d0'))
paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '832b2412652d84a6631b1012c6e2d18b'))
paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '4836e98a634f6fbea26d0cdaa303f867'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '32ffc0e8818d7319ed1bf63a791e985d'))
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf'))
......
......@@ -20,6 +20,85 @@ namespace operators {
using framework::Tensor;
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
ctx->ShareLoD("X", "Out");
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE("trilinear" == interp_method,
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5");
int out_d, out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_d = static_cast<int>(dim_x[2] * scale);
out_h = static_cast<int>(dim_x[3] * scale);
out_w = static_cast<int>(dim_x[4] * scale);
// protect when input shape is -1
out_d = out_d > 0 ? out_d : -1;
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_d = ctx->Attrs().Get<int>("out_d");
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_d, 0, "out_d should be greater than 0.");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3");
ctx->ShareLoD("X", "Out");
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
class InterpolateOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -31,41 +110,17 @@ class InterpolateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of InterpolationOp should not be null.");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\".");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_h = static_cast<int>(dim_x[2] * scale);
out_w = static_cast<int>(dim_x[3] * scale);
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
ctx->ShareLoD("X", "Out");
return;
PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5,
"Input(X) dimension must be 4 or 5");
if (dim_x.size() == 4) {
// shape check for 2D interpolate for input tensor shape NCHW
Interpolate2DInferShapeCheck(ctx);
} else { // dim_x.size() == 5
// shape check for 3D interpolate for input tensor shape NCDHW
Interpolate3DInferShapeCheck(ctx);
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
protected:
......@@ -81,22 +136,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"The input tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, w].");
"This is a 4-D tensor with shape of [N, C, H, W] or a "
"5-D tensor with shape of [N, C, D, H, W].");
AddInput("OutSize",
"This is a 1-D tensor with two numbers to specify output size. "
"The first number is height and the second number is width.")
"It should be [output_height, output_width] when input is a 4-D "
"tensor and should be [output_depth, output_height, output_width] "
"when input is a 5-D tensor.")
.AsDispensable();
AddOutput("Out",
"The output tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, W].");
"This is a tensor in same rank with Input(X).");
AddAttr<int>("out_h", "output height of interpolate op.");
AddAttr<int>("out_w", "output width of interpolate op.");
AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
AddAttr<std::string>("interp_method",
"(string, default \"bilinear\"), interpolation "
"method, can be \"bilinear\" for "
"bilinear interpolation and \"nearest\" for nearest "
"bilinear interpolation, \"trilinear\" for trilinear "
"interpolation and \"nearest\" for nearest "
"neighbor interpolation.")
.SetDefault("bilinear");
AddAttr<bool>(
......@@ -127,6 +187,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
to perform linear interpolation first in one direction, and then
again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optinal parameters,the calculation method
of interpolation can be selected by them.
......@@ -183,6 +248,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
......@@ -190,6 +276,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation
)DOC");
}
};
......@@ -251,6 +340,10 @@ REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
......@@ -261,3 +354,8 @@ REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
......@@ -191,80 +191,483 @@ __global__ void KeBilinearInterpBw(
}
template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* input_data = input->data<T>();
__global__ void KeTrilinearInterpFw(
const T* in, const size_t in_img_d, const size_t in_img_h,
const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const int align_mode) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int n = input->dims()[0];
int c = input->dims()[1];
int in_h = input->dims()[2];
int in_w = input->dims()[3];
int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
int out_img_idx = tid % out_img_w;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
}
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
const T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
const T* in_pos2 = &in[in_pos2_idx];
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
// trilinear interpolation
out[out_id_h * output_w + out_id_w] =
d2lambda *
(h2lambda * (w2lambda * in_pos1[0] + w1lambda * in_pos1[w_id]) +
h1lambda * (w2lambda * in_pos1[h_id * in_img_w] +
w1lambda * in_pos1[h_id * in_img_w + w_id])) +
d1lambda *
(h2lambda * (w2lambda * in_pos2[0] + w1lambda * in_pos2[w_id]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w] +
w1lambda * in_pos2[h_id * in_img_w + w_id]));
}
}
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
template <typename T>
__global__ void KeTrilinearInterpBw(
T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, const T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const T ratio_d, const T ratio_h, const T ratio_w, const bool align_corners,
const int align_mode) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
auto* output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
T* in_pos2 = &in[in_pos2_idx];
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
const T* out_pos = &out[out_id_h * output_w + out_id_w];
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
// trilinear interpolation grad
platform::CudaAtomicAdd(&in_pos1[0],
d2lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[w_id],
d2lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w],
d2lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w + w_id],
d2lambda * h1lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[w_id],
d1lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w],
d1lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w + w_id],
d1lambda * h1lambda * w1lambda * out_pos[0]);
}
}
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
template <typename T>
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
}
}
template <typename T>
static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_d = input.dims()[2];
const int in_h = input.dims()[3];
const int in_w = input.dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
auto output_data =
output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_dhw = in_d * in_h * in_w;
int out_dhw = out_d * out_h * out_w;
int in_cdhw = c * in_dhw;
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode);
}
}
template <typename T>
static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto* output_grad_data = output_grad.data<T>();
auto* input_grad_data =
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
}
}
template <typename T>
static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad,
const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_d = input->dims()[2];
const int in_h = input->dims()[3];
const int in_w = input->dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
auto* output_grad_data = output_grad.data<T>();
auto* input_grad_data =
input_grad->mutable_data<T>({n, c, in_d, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_dhw = in_d * in_h * in_w;
int out_dhw = out_d * out_h * out_w;
int in_cdhw = c * in_dhw;
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode);
}
}
template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDAFwd<T>(ctx, *input, output);
}
}
};
......@@ -273,76 +676,16 @@ template <typename T>
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output_grad_data = output_grad->data<T>();
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
int n = input_grad->dims()[0];
int c = input_grad->dims()[1];
int in_h = input_grad->dims()[2];
int in_w = input_grad->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode);
auto output_grad_dims = output_grad->dims();
if (output_grad_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDABwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDABwd<T>(ctx, input_grad, *output_grad);
}
}
};
......@@ -363,3 +706,9 @@ REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(nearest_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
......@@ -131,6 +131,128 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output,
}
}
template <typename T>
static void TrilinearInterpolation(
const Tensor& input, Tensor* output, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const bool align_mode) {
auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vt_f, vt_b;
std::vector<float> vd_f, vd_b;
vt_f.reserve(out_d);
vt_b.reserve(out_d);
vd_f.reserve(out_d);
vd_b.reserve(out_d);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int j = 0; j < out_d; j++) {
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
{
vt_f[j] = t_f;
vt_b[j] = t_b;
vd_f[j] = d_f;
vd_b[j] = d_b;
}
}
std::vector<int> vy_n, vy_s;
std::vector<float> vd_n, vd_s;
vy_n.reserve(out_h);
vy_s.reserve(out_h);
vd_n.reserve(out_h);
vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int k = 0; k < out_h; k++) {
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
{
vy_n[k] = y_n;
vy_s[k] = y_s;
vd_n[k] = d_n;
vd_s[k] = d_s;
}
}
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int l = 0; l < out_w; l++) {
int x_w = (align_mode == 0 && !align_corners)
? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
{
vx_w[l] = x_w;
vx_e[l] = x_e;
vd_w[l] = d_w;
vd_e[l] = d_e;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(5)
#endif
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
for (int j = 0; j < out_d; j++) { // loop for D, H, W
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
// trilinear interpolation
T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] *
vd_s[k] * vd_e[l] +
input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] *
vd_s[k] * vd_w[l] +
input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] *
vd_n[k] * vd_e[l] +
input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] *
vd_n[k] * vd_w[l] +
input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] *
vd_s[k] * vd_e[l] +
input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] *
vd_s[k] * vd_w[l] +
input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] *
vd_n[k] * vd_e[l] +
input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] *
vd_n[k] * vd_w[l];
output_t(b, i, j, k, l) = out_t;
}
}
}
}
}
}
template <typename T>
static void NearestNeighborInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
......@@ -200,134 +322,340 @@ static void BilinearInterpolationGrad(const Tensor& output_grad,
}
}
}
template <typename T>
class InterpolateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
static void TrilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const int align_mode) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int j = 0; j < out_d; j++) { // loop for D
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
for (int k = 0; k < out_h; k++) { // loop for H
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) { // loop for W
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
// trilinear interpolation grad
const T grad = output_grad_t(b, i, j, k, l);
input_grad_t(b, i, t_f, y_n, x_w) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, i, t_f, y_n, x_e) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, i, t_f, y_s, x_w) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, i, t_f, y_s, x_e) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, i, t_b, y_n, x_w) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, i, t_b, y_n, x_e) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, i, t_b, y_s, x_w) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, i, t_b, y_s, x_e) +=
static_cast<T>(grad * d_f * d_n * d_w);
}
}
}
}
}
}
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, output, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("bilinear" == interp_method) {
BilinearInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
out_h, out_w, align_corners, align_mode);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
out_w, align_corners);
}
}
if ("bilinear" == interp_method) {
BilinearInterpolation<T>(*input, output, ratio_h, ratio_w, in_h, in_w, n,
c, out_h, out_w, align_corners, align_mode);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(*input, output, ratio_h, ratio_w, n, c,
out_h, out_w, align_corners);
}
template <typename T>
static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_d = input.dims()[2];
const int in_h = input.dims()[3];
const int in_w = input.dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
output->mutable_data<T>({n, c, out_d, out_h, out_w}, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
};
if ("trilinear" == interp_method) {
TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
in_h, in_w, n, c, out_d, out_h, out_w,
align_corners, align_mode);
}
}
template <typename T>
class InterpolateGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w, align_corners,
align_mode);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
n, c, out_h, out_w, align_corners);
}
}
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
template <typename T>
static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_d = input->dims()[2];
const int in_h = input->dims()[3];
const int in_w = input->dims()[4];
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_d = static_cast<int>(in_d * scale);
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
input_grad->mutable_data<T>({n, c, in_d, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
if ("trilinear" == interp_method) {
TrilinearInterpolationGrad<T>(output_grad, input_grad, ratio_d, ratio_h,
ratio_w, in_d, in_h, in_w, n, c, out_d, out_h,
out_w, align_corners, align_mode);
}
}
template <typename T>
class InterpolateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCPUFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCPUFwd<T>(ctx, *input, output);
}
}
};
template <typename T>
class InterpolateGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(*output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w,
align_corners, align_mode);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(*output_grad, input_grad, ratio_h,
ratio_w, n, c, out_h, out_w,
align_corners);
auto output_grad_dims = output_grad->dims();
if (output_grad_dims.size() == 4) { // 2D interpolation grad
Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation grad
Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad);
}
}
};
......
......@@ -119,6 +119,7 @@ __all__ = [
'image_resize',
'image_resize_short',
'resize_bilinear',
'resize_trilinear',
'resize_nearest',
'gather',
'scatter',
......@@ -7672,13 +7673,16 @@ def image_resize(input,
"""
**Resize a Batch of Images**
The input must be a tensor of the shape (num_batches, channels, in_h, in_w),
and the resizing only applies on the last two dimensions(hight and width).
The input must be a tensor of the shape (num_batches, channels, in_h, in_w)
or (num_batches, channels, in_d, in_h, in_w), and the resizing only applies
on the last two/three dimensions(depth, hight and width).
Supporting resample methods:
'BILINEAR' : Bilinear interpolation
'TRILINEAR' : Trilinear interpolation
'NEAREST' : Nearest neighbor interpolation
Nearest neighbor interpolation is to perform nearest neighbor interpolation
......@@ -7691,6 +7695,11 @@ def image_resize(input,
to perform linear interpolation first in one direction, and then
again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optinal parameters,the calculation method
of interpolation can be selected by them.
......@@ -7748,30 +7757,58 @@ def image_resize(input,
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation.
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation.
Args:
input (Variable): The input tensor of image resize layer,
This is a 4-D tensor of the shape
(num_batches, channels, in_h, in_w).
(num_batches, channels, in_h, in_w) or a
5-D tensor of the shape
(num_batches, channls, in_d, in_h, in_w).
out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_h, out_w).
Default: None
layer, the shape is (out_h, out_w) when
input is a 4-D tensor and is
(out_d, out_h, out_w) when input is a
5-D tensor. Default: None
scale(float|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST'
currently.
Default: 'BILINEAR'
resample(str): The resample method. It supports 'BILINEAR', 'TRILINEAR'
and 'NEAREST' currently. Default: 'BILINEAR'
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
......@@ -7795,15 +7832,19 @@ def image_resize(input,
Returns:
Variable: The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w).
(num_batches, channls, out_h, out_w) or a 5-D tensor of the shape
(num_batches, channels, out_d, out_h, out_w).
Raises:
TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR'
or 'NEAREST' currently.
ValueError: The 'resample' of image_resize can only be 'BILINEAR',
'TRILINEAR' or 'NEAREST' currently.
ValueError: 'BILINEAR' and 'NEAREST' only support 4-D tensor.
ValueError: 'TRILINEAR' only support 5-D tensor.
ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2.
ValueError: out_shape length should be 2 for input 4-D tensor.
ValueError: out_shape length should be 3 for input 5-D tensor.
ValueError: scale should be greater than zero.
TypeError: align_corners shoule be a bool value
ValueError: align_mode can only be '0' or '1'
......@@ -7817,14 +7858,20 @@ def image_resize(input,
"""
resample_methods = {
'BILINEAR': 'bilinear',
'TRILINEAR': 'trilinear',
'NEAREST': 'nearest',
}
if resample not in resample_methods:
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently."
)
"The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR' "
"or 'NEAREST' currently.")
resample_type = resample_methods[resample]
if resample in ['BILINEAR', 'NEAREST'] and len(input.shape) != 4:
raise ValueError("'BILINEAR' and 'NEAREST' only support 4-D tensor.")
if resample == 'TRILINEAR' and len(input.shape) != 5:
raise ValueError("'TRILINEAR'only support 5-D tensor.")
if not isinstance(align_corners, bool):
raise TypeError("Attr align_corners should be a bool value")
if align_mode != 0 and align_mode != 1:
......@@ -7840,6 +7887,7 @@ def image_resize(input,
inputs = {"X": input}
attrs = {
"out_d": 0,
"out_h": 0,
"out_w": 0,
"interp_method": resample_type,
......@@ -7857,12 +7905,21 @@ def image_resize(input,
if not (_is_list_or_turple_(out_shape)):
raise TypeError(
"out_shape should be a list or tuple or Variable.")
if len(out_shape) != 2:
raise ValueError("out_shape length should be 2.")
out_shape = list(map(int, out_shape))
attrs['out_h'] = out_shape[0]
attrs['out_w'] = out_shape[1]
if len(input.shape) == 4:
if len(out_shape) != 2:
raise ValueError("out_shape length should be 2 for "
"input 4-D tensor.")
out_shape = list(map(int, out_shape))
attrs['out_h'] = out_shape[0]
attrs['out_w'] = out_shape[1]
if len(input.shape) == 5:
if len(out_shape) != 3:
raise ValueError("out_shape length should be 3 for "
"input 5-D tensor.")
out_shape = list(map(int, out_shape))
attrs['out_d'] = out_shape[0]
attrs['out_h'] = out_shape[1]
attrs['out_w'] = out_shape[2]
else:
if scale <= 0:
......@@ -7945,7 +8002,7 @@ def resize_bilinear(input,
Args:
input(${x_type}): ${x_comment}.
input(${x_type}): input should be a 4-D tensor.
out_shape(list|tuple|Variable|None): Output shape of resize bilinear
layer, the shape is (out_h, out_w).
......@@ -7974,7 +8031,7 @@ def resize_bilinear(input,
align_mode(bool): ${align_mode_comment}
Returns:
${out_comment}.
A 4-D tensor in shape of (num_batches, channels, out_h, out_w)
Examples:
.. code-block:: python
......@@ -7988,6 +8045,112 @@ def resize_bilinear(input,
align_corners, align_mode)
@templatedoc(op_type="trilinear_interp")
def resize_trilinear(input,
out_shape=None,
scale=None,
name=None,
actual_shape=None,
align_corners=True,
align_mode=1):
"""
Resize input by performing trilinear interpolation based on given
output shape which specified by actual_shape, out_shape and scale
in priority order.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation
Align_corners and align_mode are optinal parameters,the calculation
method of interpolation can be selected by them.
Example:
.. code-block:: text
For scale:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Args:
input(${x_type}): input should be a 4-D tensor.
out_shape(list|tuple|Variable|None): Output shape of resize bilinear
layer, the shape is (out_d, out_h, out_w).
Default: None
scale(float|None): The multiplier for the input depth, height or width.
At least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
constructing stage.
Default: None
align_corners(bool): ${align_corners_comment}
align_mode(bool): ${align_mode_comment}
Returns:
A 5-D tensor in shape (num_batches, channels, out_d, out_h, out_w)
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input", shape=[3,6,9,11], dtype="float32")
out = fluid.layers.resize_trilinear(input, out_shape=[12, 12, 12])
"""
return image_resize(input, out_shape, scale, name, 'TRILINEAR',
actual_shape, align_corners, align_mode)
@templatedoc(op_type="nearest_interp")
def resize_nearest(input,
out_shape=None,
......@@ -8041,7 +8204,7 @@ def resize_nearest(input,
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
Args:
input(${x_type}): ${x_comment}.
input(${x_type}): input should be a 4-D tensor.
out_shape(list|tuple|Variable|None): Output shape of resize nearest
layer, the shape is (out_h, out_w).
......@@ -8069,7 +8232,7 @@ def resize_nearest(input,
align_corners(bool): ${align_corners_comment}
Returns:
${out_comment}.
A 4-D tensor in shape of (num_batches, channels, out_h, out_w)
Examples:
.. code-block:: python
......
......@@ -205,6 +205,17 @@ class TestBilinearInterpCase6(TestBilinearInterpOp):
self.align_mode = 1
class TestBilinearInterpSame(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 128, 64]
self.out_h = 128
self.out_w = 64
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpActualShape(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
......
......@@ -1295,16 +1295,74 @@ class TestBook(LayerTest):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_bilinear(x, out_shape=[12, 12])
return (output)
output = layers.resize_bilinear(x, scale=3)
def make_resize_bilinear_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_bilinear(x, scale=1.5)
return (output)
def make_resize_nearest(self):
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12])
except ValueError:
pass
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(
name='x2', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12, 12])
except ValueError:
pass
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12])
return (output)
output = layers.resize_nearest(x, scale=3)
def make_resize_nearest_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, scale=1.8)
return (output)
def make_resize_trilinear(self):
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x2', shape=[3, 9, 6], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12, 12])
except ValueError:
pass
try:
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(
name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12])
except ValueError:
pass
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, out_shape=[12, 12, 12])
return (output)
def make_resize_trilinear_by_scale(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32")
output = layers.resize_trilinear(x, scale=2.1)
return (output)
def make_polygon_box_transform(self):
......
......@@ -176,6 +176,16 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp):
self.align_corners = True
class TestNearestNeighborInterpSame(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 128, 64]
self.out_h = 128
self.out_w = 64
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def trilinear_interp_np(input,
out_d,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0):
"""trilinear interpolation implement in shape [N, C, D, H, W]"""
if out_size is not None:
out_d = out_size[0]
out_h = out_size[1]
out_w = out_size[2]
if actual_shape is not None:
out_d = actual_shape[0]
out_h = actual_shape[1]
out_w = actual_shape[2]
batch_size, channel, in_d, in_h, in_w = input.shape
ratio_d = ratio_h = ratio_w = 0.0
if out_d > 1:
if (align_corners):
ratio_d = (in_d - 1.0) / (out_d - 1.0)
else:
ratio_d = 1.0 * in_d / out_d
if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_d, out_h, out_w))
for i in range(out_d):
if (align_mode == 0 and not align_corners):
d = int(ratio_d * (i + 0.5) - 0.5)
else:
d = int(ratio_d * i)
d = max(0, d)
did = 1 if d < in_d - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_d = max(ratio_d * (i + 0.5) - 0.5, 0)
d1lambda = idx_src_d - d
else:
d1lambda = ratio_d * i - d
d2lambda = 1.0 - d1lambda
for j in range(out_h):
if (align_mode == 0 and not align_corners):
h = int(ratio_h * (j + 0.5) - 0.5)
else:
h = int(ratio_h * j)
h = max(0, h)
hid = 1 if h < in_h - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_h = max(ratio_h * (j + 0.5) - 0.5, 0)
h1lambda = idx_src_h - h
else:
h1lambda = ratio_h * j - h
h2lambda = 1.0 - h1lambda
for k in range(out_w):
if (align_mode == 0 and not align_corners):
w = int(ratio_w * (k + 0.5) - 0.5)
else:
w = int(ratio_w * k)
w = max(0, w)
wid = 1 if w < in_w - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_w = max(ratio_w * (k + 0.5) - 0.5, 0)
w1lambda = idx_src_w - w
else:
w1lambda = ratio_w * k - w
w2lambda = 1.0 - w1lambda
out[:, :, i, j, k] = \
d2lambda * \
(h2lambda * (w2lambda * input[:, :, d, h, w] + \
w1lambda * input[:, :, d, h, w+wid]) + \
h1lambda * (w2lambda * input[:, :, d, h+hid, w] + \
w1lambda * input[:, :, d, h+hid, w+wid])) + \
d1lambda * \
(h2lambda * (w2lambda * input[:, :, d+did, h, w] + \
w1lambda * input[:, :, d+did, h, w+wid]) + \
h1lambda * (w2lambda * input[:, :, d+did, h+hid, w] + \
w1lambda * input[:, :, d+did, h+hid, w+wid]))
return out.astype(input.dtype)
class TestTrilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "trilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32")
if self.scale > 0:
out_d = int(self.input_shape[2] * self.scale)
out_h = int(self.input_shape[3] * self.scale)
out_w = int(self.input_shape[4] * self.scale)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_d': self.out_d,
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 4, 4, 4]
self.out_d = 2
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase1(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 1, 7, 8, 9]
self.out_d = 1
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase2(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 9, 6, 8]
self.out_d = 12
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase3(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 16, 8, 4]
self.out_d = 32
self.out_h = 16
self.out_w = 8
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase4(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [4, 1, 7, 8, 9]
self.out_d = 1
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2, 2]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase5(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 3, 9, 6, 8]
self.out_d = 12
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11, 11]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase6(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 8
self.out_h = 32
self.out_w = 16
self.scale = 0.
self.out_size = np.array([17, 9, 5]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpSame(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 16
self.out_h = 8
self.out_w = 4
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpSameHW(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 8
self.out_h = 8
self.out_w = 4
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpActualShape(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 16, 8, 4]
self.out_d = 64
self.out_h = 32
self.out_w = 16
self.scale = 0.
self.out_size = np.array([33, 19, 7]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "trilinear_interp"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
if self.scale > 0:
out_d = int(self.input_shape[2] * self.scale)
out_h = int(self.input_shape[3] * self.scale)
out_w = int(self.input_shape[4] * self.scale)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_d': self.out_d,
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 3, 9, 6, 8]
self.out_d = 13
self.out_h = 10
self.out_w = 9
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase1Uint8(TestTrilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 16, 8, 4]
self.out_d = 13
self.out_h = 7
self.out_w = 2
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase2Uint8(TestTrilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [4, 1, 7, 8, 9]
self.out_d = 3
self.out_h = 5
self.out_w = 13
self.scale = 0.
self.out_size = np.array([6, 15, 21]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpOtherMethod1(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 1
class TestTrilinearInterpWithMethod2(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 0
class TestTrilinearInterpWithMethod3(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = True
self.align_mode = 0
class TestTrilinearInterpScale1(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 2.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpScale2(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 1.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpScale3(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 1.5
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpZero(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 11]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 0.2
self.align_corners = False
self.align_mode = 0
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册