diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index dc2b51d57242e50c6fb76867a10fbf8ec2f5e52a..a5bd3cd922070deb2c91e1485874c9b3d4e8d278 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -173,10 +173,11 @@ paddle.fluid.layers.label_smooth (ArgSpec(args=['label', 'prior_dist', 'epsilon' paddle.fluid.layers.roi_pool (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)), ('document', '49368d724023a66b41b0071be41c0ba5')) paddle.fluid.layers.roi_align (ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)), ('document', '9a7a3b88a4fae41d58d3ca9b10ba0591')) paddle.fluid.layers.dice_loss (ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)), ('document', '7e8e4bf1f0f8612961ed113e8af8f0c5')) -paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', 'a29488d94d9a4bc4434d8a3529b4c6fe')) +paddle.fluid.layers.image_resize (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None, True, 1)), ('document', '8cfc4f69dbbedb687b6c20732aa8f09e')) paddle.fluid.layers.image_resize_short (ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)), ('document', 'bd97ebfe4bdf5110a5fcb8ecb626a447')) -paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '548c7c2ead5771d15abbaad505f901e9')) -paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', 'b7d810d1e251c5957c1efa6aa699d2d0')) +paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '832b2412652d84a6631b1012c6e2d18b')) +paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '4836e98a634f6fbea26d0cdaa303f867')) +paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '32ffc0e8818d7319ed1bf63a791e985d')) paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c')) paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab')) paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf')) diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 900b0c636ddafc8c033560adf58d596eb696621f..cd3fdc79acf2c364bdc39e9bdb3192683c8fd4e9 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -20,6 +20,85 @@ namespace operators { using framework::Tensor; +static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { + auto dim_x = ctx->GetInputDim("X"); + auto interp_method = ctx->Attrs().Get("interp_method"); + + PADDLE_ENFORCE( + "bilinear" == interp_method || "nearest" == interp_method, + "Interpolation method can only be \"bilinear\" or \"nearest\" when " + "Input(X) dimension is 4"); + + int out_h, out_w; + float scale = ctx->Attrs().Get("scale"); + if (scale > 0) { + // round down + out_h = static_cast(dim_x[2] * scale); + out_w = static_cast(dim_x[3] * scale); + // protect when input shape is -1 + out_h = out_h > 0 ? out_h : -1; + out_w = out_w > 0 ? out_w : -1; + } else { + out_h = ctx->Attrs().Get("out_h"); + out_w = ctx->Attrs().Get("out_w"); + PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0."); + PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0."); + } + + if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { + auto out_size_dim = ctx->GetInputDim("OutSize"); + PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, + "OutSize's dimension size must be 1"); + PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); + ctx->ShareLoD("X", "Out"); + return; + } + + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); +} + +static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { + auto dim_x = ctx->GetInputDim("X"); + auto interp_method = ctx->Attrs().Get("interp_method"); + + PADDLE_ENFORCE("trilinear" == interp_method, + "Interpolation method can only be \"trilinear\" when Input(X) " + "dimension is 5"); + + int out_d, out_h, out_w; + float scale = ctx->Attrs().Get("scale"); + if (scale > 0) { + // round down + out_d = static_cast(dim_x[2] * scale); + out_h = static_cast(dim_x[3] * scale); + out_w = static_cast(dim_x[4] * scale); + // protect when input shape is -1 + out_d = out_d > 0 ? out_d : -1; + out_h = out_h > 0 ? out_h : -1; + out_w = out_w > 0 ? out_w : -1; + } else { + out_d = ctx->Attrs().Get("out_d"); + out_h = ctx->Attrs().Get("out_h"); + out_w = ctx->Attrs().Get("out_w"); + PADDLE_ENFORCE_GT(out_d, 0, "out_d should be greater than 0."); + PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0."); + PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0."); + } + + if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { + auto out_size_dim = ctx->GetInputDim("OutSize"); + PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, + "OutSize's dimension size must be 1"); + PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3"); + ctx->ShareLoD("X", "Out"); + return; + } + + std::vector dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w}); + ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); +} + class InterpolateOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -31,41 +110,17 @@ class InterpolateOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of InterpolationOp should not be null."); - auto interp_method = ctx->Attrs().Get("interp_method"); - PADDLE_ENFORCE( - "bilinear" == interp_method || "nearest" == interp_method, - "Interpolation method can only be \"bilinear\" or \"nearest\"."); - auto dim_x = ctx->GetInputDim("X"); // NCHW format - PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); - - int out_h, out_w; - float scale = ctx->Attrs().Get("scale"); - if (scale > 0) { - // round down - out_h = static_cast(dim_x[2] * scale); - out_w = static_cast(dim_x[3] * scale); - // protect when input shape is -1 - out_h = out_h > 0 ? out_h : -1; - out_w = out_w > 0 ? out_w : -1; - } else { - out_h = ctx->Attrs().Get("out_h"); - out_w = ctx->Attrs().Get("out_w"); - PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0."); - PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0."); - } - - if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { - auto out_size_dim = ctx->GetInputDim("OutSize"); - PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, - "OutSize's dimension size must be 1"); - PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); - ctx->ShareLoD("X", "Out"); - return; + PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5, + "Input(X) dimension must be 4 or 5"); + + if (dim_x.size() == 4) { + // shape check for 2D interpolate for input tensor shape NCHW + Interpolate2DInferShapeCheck(ctx); + } else { // dim_x.size() == 5 + // shape check for 3D interpolate for input tensor shape NCDHW + Interpolate3DInferShapeCheck(ctx); } - - std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); - ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); } protected: @@ -81,22 +136,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "The input tensor of interpolate operator, " - "This is a 4-D tensor with shape of [N, C, H, w]."); + "This is a 4-D tensor with shape of [N, C, H, W] or a " + "5-D tensor with shape of [N, C, D, H, W]."); AddInput("OutSize", "This is a 1-D tensor with two numbers to specify output size. " - "The first number is height and the second number is width.") + "It should be [output_height, output_width] when input is a 4-D " + "tensor and should be [output_depth, output_height, output_width] " + "when input is a 5-D tensor.") .AsDispensable(); AddOutput("Out", "The output tensor of interpolate operator, " - "This is a 4-D tensor with shape of [N, C, H, W]."); + "This is a tensor in same rank with Input(X)."); - AddAttr("out_h", "output height of interpolate op."); - AddAttr("out_w", "output width of interpolate op."); + AddAttr("out_d", "output depth of interpolate op.").SetDefault(0); + AddAttr("out_h", "output height of interpolate op.").SetDefault(0); + AddAttr("out_w", "output width of interpolate op.").SetDefault(0); AddAttr("scale", "scale factor of interpolate op.").SetDefault(0.); AddAttr("interp_method", "(string, default \"bilinear\"), interpolation " "method, can be \"bilinear\" for " - "bilinear interpolation and \"nearest\" for nearest " + "bilinear interpolation, \"trilinear\" for trilinear " + "interpolation and \"nearest\" for nearest " "neighbor interpolation.") .SetDefault("bilinear"); AddAttr( @@ -127,6 +187,11 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { to perform linear interpolation first in one direction, and then again in the other direction. + Trilinear interpolation is an extension of linear interpolation for + interpolating functions of three variables (e.g. D-direction, + H-direction and W-direction in this op) on a rectilinear 3D grid. + The linear interpolation is performed on three directions. + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. @@ -183,6 +248,27 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { H_out = H_{in} * scale_{factor} W_out = W_{in} * scale_{factor} + Trilinear interpolation: + + if: + align_corners = False , align_mode = 0 + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + + D_out = (D_{in}+0.5) * scale_{factor} - 0.5 + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + + else: + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + + D_out = D_{in} * scale_{factor} + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} For details of nearest neighbor interpolation, please refer to Wikipedia: @@ -190,6 +276,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation + + For details of trilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Trilinear_interpolation )DOC"); } }; @@ -251,6 +340,10 @@ REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker, ops::InterpolateGradDescMaker); REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad, ops::InterpolateGradNoNeedBufferVarsInference); +REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker, + ops::InterpolateGradDescMaker); +REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad, + ops::InterpolateGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel, ops::InterpolateKernel, ops::InterpolateKernel); @@ -261,3 +354,8 @@ REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel, ops::InterpolateKernel); REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel, ops::InterpolateGradKernel); +REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel, + ops::InterpolateKernel, + ops::InterpolateKernel); +REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel, + ops::InterpolateGradKernel); diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 1cdda4cfe90c459b74fe9436654c88206e498b50..cfe441f6c192b5a2cb33bf685cb0cb95b8abe3a7 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -191,80 +191,483 @@ __global__ void KeBilinearInterpBw( } template -class InterpolateOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "This kernel only runs on GPU device."); - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - auto* input_data = input->data(); +__global__ void KeTrilinearInterpFw( + const T* in, const size_t in_img_d, const size_t in_img_h, + const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, + const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, + const size_t output_h, const size_t output_w, const size_t num_channels, + const float ratio_d, const float ratio_h, const float ratio_w, + const bool align_corners, const int align_mode) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; - int n = input->dims()[0]; - int c = input->dims()[1]; - int in_h = input->dims()[2]; - int in_w = input->dims()[3]; + int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w; + int in_img_idt = align_flag + ? static_cast(ratio_d * (out_img_idt + 0.5) - 0.5) + : static_cast(ratio_d * out_img_idt); + in_img_idt = (in_img_idt > 0) ? in_img_idt : 0; + int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0; + T src_d = ratio_d * (out_img_idt + 0.5) - 0.5; + src_d = (src_d > 0) ? src_d : 0; + T d1lambda = + align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt; + T d2lambda = 1.f - d1lambda; + + int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h; + int in_img_idy = align_flag + ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) + : static_cast(ratio_h * out_img_idy); + in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = (src_h > 0) ? src_h : 0; + T h1lambda = + align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; - auto interp_method = ctx.Attr("interp_method"); - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); + int out_img_idx = tid % out_img_w; + int in_img_idx = align_flag + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; - float scale = ctx.Attr("scale"); - if (scale > 0) { - out_h = in_h * scale; - out_w = in_w * scale; - } + int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size + + (in_img_idt * in_img_h + in_img_idy) * in_img_w + + in_img_idx; + const T* in_pos1 = &in[in_pos1_idx]; + int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w; + const T* in_pos2 = &in[in_pos2_idx]; - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } + // trilinear interpolation + out[out_id_h * output_w + out_id_w] = + d2lambda * + (h2lambda * (w2lambda * in_pos1[0] + w1lambda * in_pos1[w_id]) + + h1lambda * (w2lambda * in_pos1[h_id * in_img_w] + + w1lambda * in_pos1[h_id * in_img_w + w_id])) + + d1lambda * + (h2lambda * (w2lambda * in_pos2[0] + w1lambda * in_pos2[w_id]) + + h1lambda * (w2lambda * in_pos2[h_id * in_img_w] + + w1lambda * in_pos2[h_id * in_img_w + w_id])); + } +} - bool align_corners = ctx.Attr("align_corners"); - int align_mode = ctx.Attr("align_mode"); +template +__global__ void KeTrilinearInterpBw( + T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, const T* out, + const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, + const size_t output_h, const size_t output_w, const size_t num_channels, + const T ratio_d, const T ratio_h, const T ratio_w, const bool align_corners, + const int align_mode) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + int channel_id = out_id_w / out_img_size; - auto* output_data = - output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + int out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w; + int in_img_idt = align_flag + ? static_cast(ratio_d * (out_img_idt + 0.5) - 0.5) + : static_cast(ratio_d * out_img_idt); + in_img_idt = (in_img_idt > 0) ? in_img_idt : 0; + int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0; + T src_d = ratio_d * (out_img_idt + 0.5) - 0.5; + src_d = (src_d > 0) ? src_d : 0; + T d1lambda = + align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt; + T d2lambda = 1.f - d1lambda; + + int out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h; + int in_img_idy = align_flag + ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) + : static_cast(ratio_h * out_img_idy); + in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = (src_h > 0) ? src_h : 0; + T h1lambda = + align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; + T h2lambda = 1.f - h1lambda; + + int out_img_idx = tid % out_img_w; + int in_img_idx = align_flag + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w : 0; + T w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + T w2lambda = 1.f - w1lambda; - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = c * in_hw; - int out_chw = c * out_hw; + int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size + + (in_img_idt * in_img_h + in_img_idy) * in_img_w + + in_img_idx; + T* in_pos1 = &in[in_pos1_idx]; + int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w; + T* in_pos2 = &in[in_pos2_idx]; - float ratio_h = 0.f; - float ratio_w = 0.f; - if (out_h > 1) { - ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - } - if (out_w > 1) { - ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; - } + const T* out_pos = &out[out_id_h * output_w + out_id_w]; - if (in_h == out_h && in_w == out_w) { - framework::TensorCopy(*input, ctx.GetPlace(), output); - return; - } + // trilinear interpolation grad + platform::CudaAtomicAdd(&in_pos1[0], + d2lambda * h2lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos1[w_id], + d2lambda * h2lambda * w1lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w], + d2lambda * h1lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w + w_id], + d2lambda * h1lambda * w1lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos2[0], + d1lambda * h2lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos2[w_id], + d1lambda * h2lambda * w1lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w], + d1lambda * h1lambda * w2lambda * out_pos[0]); + platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w + w_id], + d1lambda * h1lambda * w1lambda * out_pos[0]); + } +} - int pixelNum = n * out_chw; - int grid_dim = (pixelNum + 512 - 1) / 512; - grid_dim = grid_dim > 8 ? 8 : grid_dim; - - if ("nearest" == interp_method) { - KeNearestNeighborInterpFw< - T><<>>( - input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, - out_chw, c, ratio_h, ratio_w, align_corners); - } else if ("bilinear" == interp_method) { - KeBilinearInterpFw< - T><<>>( - input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, - out_chw, c, ratio_h, ratio_w, align_corners, align_mode); +template +static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, + const Tensor& input, Tensor* output) { + auto* input_data = input.data(); + + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int in_h = input.dims()[2]; + const int in_w = input.dims()[3]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + + auto output_data = + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(input, ctx.GetPlace(), output); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + if ("nearest" == interp_method) { + KeNearestNeighborInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w, align_corners); + } else if ("bilinear" == interp_method) { + KeBilinearInterpFw< + T><<>>( + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w, align_corners, align_mode); + } +} + +template +static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, + const Tensor& input, Tensor* output) { + auto* input_data = input.data(); + + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int in_d = input.dims()[2]; + const int in_h = input.dims()[3]; + const int in_w = input.dims()[4]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_d = ctx.Attr("out_d"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_d = static_cast(in_d * scale); + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_d = size_data[0]; + out_h = size_data[1]; + out_w = size_data[2]; + } + + auto output_data = + output->mutable_data({n, c, out_d, out_h, out_w}, ctx.GetPlace()); + + if (in_d == out_d && in_h == out_h && in_w == out_w) { + framework::TensorCopy(input, ctx.GetPlace(), output); + return; + } + + float ratio_d = 0.f; + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_d > 1) { + ratio_d = (align_corners) ? static_cast(in_d - 1) / (out_d - 1) + : static_cast(in_d) / out_d; + } + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + int in_dhw = in_d * in_h * in_w; + int out_dhw = out_d * out_h * out_w; + int in_cdhw = c * in_dhw; + int out_cdhw = c * out_dhw; + + int pixelNum = n * out_cdhw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + if ("trilinear" == interp_method) { + KeTrilinearInterpFw< + T><<>>( + input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, + out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, + align_mode); + } +} + +template +static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, + Tensor* input_grad, const Tensor output_grad) { + auto* input = ctx.Input("X"); + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + + auto* output_grad_data = output_grad.data(); + auto* input_grad_data = + input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); + auto& device_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = c * in_hw; + int out_chw = c * out_hw; + + int pixelNum = n * out_chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + if ("nearest" == interp_method) { + KeNearestNeighborInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, + n, out_chw, c, ratio_h, ratio_w, align_corners); + } else if ("bilinear" == interp_method) { + KeBilinearInterpBw< + T><<>>( + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, + n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode); + } +} + +template +static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx, + Tensor* input_grad, + const Tensor& output_grad) { + auto* input = ctx.Input("X"); + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_d = input->dims()[2]; + const int in_h = input->dims()[3]; + const int in_w = input->dims()[4]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_d = ctx.Attr("out_d"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_d = static_cast(in_d * scale); + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_d = size_data[0]; + out_h = size_data[1]; + out_w = size_data[2]; + } + + auto* output_grad_data = output_grad.data(); + auto* input_grad_data = + input_grad->mutable_data({n, c, in_d, in_h, in_w}, ctx.GetPlace()); + auto& device_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_d == out_d && in_h == out_h && in_w == out_w) { + framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad); + return; + } + + float ratio_d = 0.f; + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_d > 1) { + ratio_d = (align_corners) ? static_cast(in_d - 1) / (out_d - 1) + : static_cast(in_d) / out_d; + } + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + int in_dhw = in_d * in_h * in_w; + int out_dhw = out_d * out_h * out_w; + int in_cdhw = c * in_dhw; + int out_cdhw = c * out_dhw; + + int pixelNum = n * out_cdhw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + if ("trilinear" == interp_method) { + KeTrilinearInterpBw< + T><<>>( + input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, + out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, + align_mode); + } +} + +template +class InterpolateOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + auto input_dims = input->dims(); + if (input_dims.size() == 4) { // 2D interpolation + Interpolate2DCUDAFwd(ctx, *input, output); + } else if (input_dims.size() == 5) { // 3D interpolation + Interpolate3DCUDAFwd(ctx, *input, output); } } }; @@ -273,76 +676,16 @@ template class InterpolateGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Out")); - auto* output_grad_data = output_grad->data(); - auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, input_grad, static_cast(0.0)); - - int n = input_grad->dims()[0]; - int c = input_grad->dims()[1]; - int in_h = input_grad->dims()[2]; - int in_w = input_grad->dims()[3]; - - auto interp_method = ctx.Attr("interp_method"); - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); - float scale = ctx.Attr("scale"); - if (scale > 0) { - out_h = in_h * scale; - out_w = in_w * scale; - } - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - Tensor sizes; - framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); - auto size_data = sizes.data(); - out_h = size_data[0]; - out_w = size_data[1]; - } - - bool align_corners = ctx.Attr("align_corners"); - int align_mode = ctx.Attr("align_mode"); - int in_hw = in_h * in_w; - int out_hw = out_h * out_w; - int in_chw = c * in_hw; - int out_chw = c * out_hw; - - float ratio_h = 0.f; - float ratio_w = 0.f; - if (out_h > 1) { - ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - } - if (out_w > 1) { - ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; - } - - if (in_h == out_h && in_w == out_w) { - framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); - return; - } - - int pixelNum = n * out_chw; - int grid_dim = (pixelNum + 512 - 1) / 512; - grid_dim = grid_dim > 8 ? 8 : grid_dim; - - if ("nearest" == interp_method) { - KeNearestNeighborInterpBw< - T><<>>( - input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, - out_w, n, out_chw, c, ratio_h, ratio_w, align_corners); - } else if ("bilinear" == interp_method) { - KeBilinearInterpBw< - T><<>>( - input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, - out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode); + auto output_grad_dims = output_grad->dims(); + if (output_grad_dims.size() == 4) { // 2D interpolation + Interpolate2DCUDABwd(ctx, input_grad, *output_grad); + } else if (output_grad_dims.size() == 5) { // 3D interpolation + Interpolate3DCUDABwd(ctx, input_grad, *output_grad); } } }; @@ -363,3 +706,9 @@ REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel, REGISTER_OP_CUDA_KERNEL(nearest_interp_grad, ops::InterpolateGradOpCUDAKernel, ops::InterpolateGradOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad, + ops::InterpolateGradOpCUDAKernel, + ops::InterpolateGradOpCUDAKernel); diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index bd33abb98f2f1a6ad75b64e37ca14b411a4a168e..8fffe1ca48ef0f4fed20c7b1108bec755c1dc64f 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -131,6 +131,128 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output, } } +template +static void TrilinearInterpolation( + const Tensor& input, Tensor* output, const float ratio_d, + const float ratio_h, const float ratio_w, const int in_d, const int in_h, + const int in_w, const int n, const int c, const int out_d, const int out_h, + const int out_w, const bool align_corners, const bool align_mode) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + bool align_flag = (align_mode == 0 && !align_corners); + + std::vector vt_f, vt_b; + std::vector vd_f, vd_b; + vt_f.reserve(out_d); + vt_b.reserve(out_d); + vd_f.reserve(out_d); + vd_b.reserve(out_d); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int j = 0; j < out_d; j++) { + int t_f = align_flag ? static_cast(ratio_d * (j + 0.5) - 0.5) + : static_cast(ratio_d * j); + t_f = (t_f > 0) ? t_f : 0; + int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1); + float idx_src_t = ratio_d * (j + 0.5) - 0.5; + idx_src_t = (idx_src_t > 0) ? idx_src_t : 0; + float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f; + float d_b = 1.f - d_f; + { + vt_f[j] = t_f; + vt_b[j] = t_b; + vd_f[j] = d_f; + vd_b[j] = d_b; + } + } + + std::vector vy_n, vy_s; + std::vector vd_n, vd_s; + vy_n.reserve(out_h); + vy_s.reserve(out_h); + vd_n.reserve(out_h); + vd_s.reserve(out_h); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int k = 0; k < out_h; k++) { + int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); + y_n = (y_n > 0) ? y_n : 0; + int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); + float idx_src_y = ratio_h * (k + 0.5) - 0.5; + idx_src_y = (idx_src_y > 0) ? idx_src_y : 0; + float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n; + float d_s = 1.f - d_n; + { + vy_n[k] = y_n; + vy_s[k] = y_s; + vd_n[k] = d_n; + vd_s[k] = d_s; + } + } + + std::vector vx_w, vx_e; + std::vector vd_w, vd_e; + vx_w.reserve(out_w); + vx_e.reserve(out_w); + vd_w.reserve(out_w); + vd_e.reserve(out_w); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int l = 0; l < out_w; l++) { + int x_w = (align_mode == 0 && !align_corners) + ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); + x_w = (x_w > 0) ? x_w : 0; + int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; + float d_e = 1.f - d_w; + { + vx_w[l] = x_w; + vx_e[l] = x_e; + vd_w[l] = d_w; + vd_e[l] = d_e; + } + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(5) +#endif + for (int b = 0; b < n; b++) { // loop for batches + for (int i = 0; i < c; i++) { // loop for channels + for (int j = 0; j < out_d; j++) { // loop for D, H, W + for (int k = 0; k < out_h; k++) { + for (int l = 0; l < out_w; l++) { + // trilinear interpolation + T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] * + vd_s[k] * vd_e[l] + + input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] * + vd_s[k] * vd_w[l] + + input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] * + vd_n[k] * vd_e[l] + + input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] * + vd_n[k] * vd_w[l] + + input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] * + vd_s[k] * vd_e[l] + + input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] * + vd_s[k] * vd_w[l] + + input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] * + vd_n[k] * vd_e[l] + + input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] * + vd_n[k] * vd_w[l]; + output_t(b, i, j, k, l) = out_t; + } + } + } + } + } +} + template static void NearestNeighborInterpolateGrad( const Tensor& output_grad, Tensor* input_grad, const float ratio_h, @@ -200,134 +322,340 @@ static void BilinearInterpolationGrad(const Tensor& output_grad, } } } + template -class InterpolateKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); +static void TrilinearInterpolationGrad( + const Tensor& output_grad, Tensor* input_grad, const float ratio_d, + const float ratio_h, const float ratio_w, const int in_d, const int in_h, + const int in_w, const int n, const int c, const int out_d, const int out_h, + const int out_w, const bool align_corners, const int align_mode) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + bool align_flag = (align_mode == 0 && !align_corners); + for (int j = 0; j < out_d; j++) { // loop for D + int t_f = align_flag ? static_cast(ratio_d * (j + 0.5) - 0.5) + : static_cast(ratio_d * j); + t_f = (t_f > 0) ? t_f : 0; + int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1); + float idx_src_t = ratio_d * (j + 0.5) - 0.5; + idx_src_t = (idx_src_t > 0) ? idx_src_t : 0; + float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f; + float d_b = 1.f - d_f; + + for (int k = 0; k < out_h; k++) { // loop for H + int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); + y_n = (y_n > 0) ? y_n : 0; + int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); + float idx_src_y = ratio_h * (k + 0.5) - 0.5; + idx_src_y = (idx_src_y > 0) ? idx_src_y : 0; + float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n; + float d_s = 1.f - d_n; + + for (int l = 0; l < out_w; l++) { // loop for W + int x_w = align_flag ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); + x_w = (x_w > 0) ? x_w : 0; + int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; + float d_e = 1.f - d_w; + + for (int b = 0; b < n; b++) { // loop for batches + for (int i = 0; i < c; i++) { // loop for channels + // trilinear interpolation grad + const T grad = output_grad_t(b, i, j, k, l); + input_grad_t(b, i, t_f, y_n, x_w) += + static_cast(grad * d_b * d_s * d_e); + input_grad_t(b, i, t_f, y_n, x_e) += + static_cast(grad * d_b * d_s * d_w); + input_grad_t(b, i, t_f, y_s, x_w) += + static_cast(grad * d_b * d_n * d_e); + input_grad_t(b, i, t_f, y_s, x_e) += + static_cast(grad * d_b * d_n * d_w); + input_grad_t(b, i, t_b, y_n, x_w) += + static_cast(grad * d_f * d_s * d_e); + input_grad_t(b, i, t_b, y_n, x_e) += + static_cast(grad * d_f * d_s * d_w); + input_grad_t(b, i, t_b, y_s, x_w) += + static_cast(grad * d_f * d_n * d_e); + input_grad_t(b, i, t_b, y_s, x_e) += + static_cast(grad * d_f * d_n * d_w); + } + } + } + } + } +} - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; +template +static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, + const Tensor& input, Tensor* output) { + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int in_h = input.dims()[2]; + const int in_w = input.dims()[3]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } - std::string interp_method = ctx.Attr("interp_method"); - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } - float scale = ctx.Attr("scale"); - if (scale > 0) { - out_h = static_cast(in_h * scale); - out_w = static_cast(in_w * scale); - } + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - auto out_size_data = out_size->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } - bool align_corners = ctx.Attr("align_corners"); - int align_mode = ctx.Attr("align_mode"); - - output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, output, static_cast(0.0)); - - if (in_h == out_h && in_w == out_w) { - framework::TensorCopy(*input, ctx.GetPlace(), output); - return; - } + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(input, ctx.GetPlace(), output); + return; + } - float ratio_h = 0.f; - float ratio_w = 0.f; + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } - if (out_h > 1) { - ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - } - if (out_w > 1) { - ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; - } + if ("bilinear" == interp_method) { + BilinearInterpolation(input, output, ratio_h, ratio_w, in_h, in_w, n, c, + out_h, out_w, align_corners, align_mode); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolate(input, output, ratio_h, ratio_w, n, c, out_h, + out_w, align_corners); + } +} - if ("bilinear" == interp_method) { - BilinearInterpolation(*input, output, ratio_h, ratio_w, in_h, in_w, n, - c, out_h, out_w, align_corners, align_mode); - } else if ("nearest" == interp_method) { - NearestNeighborInterpolate(*input, output, ratio_h, ratio_w, n, c, - out_h, out_w, align_corners); - } +template +static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, + const Tensor& input, Tensor* output) { + const int n = input.dims()[0]; + const int c = input.dims()[1]; + const int in_d = input.dims()[2]; + const int in_h = input.dims()[3]; + const int in_w = input.dims()[4]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_d = ctx.Attr("out_d"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_d = static_cast(in_d * scale); + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_d = out_size_data[0]; + out_h = out_size_data[1]; + out_w = out_size_data[2]; + } + + output->mutable_data({n, c, out_d, out_h, out_w}, ctx.GetPlace()); + + if (in_d == out_d && in_h == out_h && in_w == out_w) { + framework::TensorCopy(input, ctx.GetPlace(), output); + return; + } + + float ratio_d = 0.f; + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_d > 1) { + ratio_d = (align_corners) ? static_cast(in_d - 1) / (out_d - 1) + : static_cast(in_d) / out_d; + } + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; } -}; + + if ("trilinear" == interp_method) { + TrilinearInterpolation(input, output, ratio_d, ratio_h, ratio_w, in_d, + in_h, in_w, n, c, out_d, out_h, out_w, + align_corners, align_mode); + } +} template -class InterpolateGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* input_grad = ctx.Output(framework::GradVarName("X")); - auto* output_grad = ctx.Input(framework::GradVarName("Out")); +static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx, + Tensor* input_grad, const Tensor& output_grad) { + auto* input = ctx.Input("X"); + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_h = input->dims()[2]; + const int in_w = input->dims()[3]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } - const int n = input->dims()[0]; - const int c = input->dims()[1]; - const int in_h = input->dims()[2]; - const int in_w = input->dims()[3]; + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } - std::string interp_method = ctx.Attr("interp_method"); - int out_h = ctx.Attr("out_h"); - int out_w = ctx.Attr("out_w"); + input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); + auto& device_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); - float scale = ctx.Attr("scale"); - if (scale > 0) { - out_h = static_cast(in_h * scale); - out_w = static_cast(in_w * scale); - } + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad); + return; + } - auto out_size = ctx.Input("OutSize"); - if (out_size != nullptr) { - auto out_size_data = out_size->data(); - out_h = out_size_data[0]; - out_w = out_size_data[1]; - } + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + if ("bilinear" == interp_method) { + BilinearInterpolationGrad(output_grad, input_grad, ratio_h, ratio_w, + in_h, in_w, n, c, out_h, out_w, align_corners, + align_mode); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_h, ratio_w, + n, c, out_h, out_w, align_corners); + } +} - bool align_corners = ctx.Attr("align_corners"); - int align_mode = ctx.Attr("align_mode"); +template +static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx, + Tensor* input_grad, const Tensor output_grad) { + auto* input = ctx.Input("X"); + const int n = input->dims()[0]; + const int c = input->dims()[1]; + const int in_d = input->dims()[2]; + const int in_h = input->dims()[3]; + const int in_w = input->dims()[4]; + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_d = ctx.Attr("out_d"); + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale = ctx.Attr("scale"); + if (scale > 0) { + out_d = static_cast(in_d * scale); + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } - input_grad->mutable_data({n, c, in_h, in_w}, ctx.GetPlace()); - auto& device_ctx = - ctx.template device_context(); - math::SetConstant zero; - zero(device_ctx, input_grad, static_cast(0.0)); + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = out_size->data(); + out_d = out_size_data[0]; + out_h = out_size_data[1]; + out_w = out_size_data[2]; + } - if (in_h == out_h && in_w == out_w) { - framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); - return; - } + input_grad->mutable_data({n, c, in_d, in_h, in_w}, ctx.GetPlace()); + auto& device_ctx = ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + if (in_d == out_d && in_h == out_h && in_w == out_w) { + framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad); + return; + } - float ratio_h = 0.f; - float ratio_w = 0.f; + float ratio_d = 0.f; + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_d > 1) { + ratio_d = (align_corners) ? static_cast(in_d - 1) / (out_d - 1) + : static_cast(in_d) / out_d; + } + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } - if (out_h > 1) { - ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - } - if (out_w > 1) { - ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; + if ("trilinear" == interp_method) { + TrilinearInterpolationGrad(output_grad, input_grad, ratio_d, ratio_h, + ratio_w, in_d, in_h, in_w, n, c, out_d, out_h, + out_w, align_corners, align_mode); + } +} + +template +class InterpolateKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + auto input_dims = input->dims(); + if (input_dims.size() == 4) { // 2D interpolation + Interpolate2DCPUFwd(ctx, *input, output); + } else if (input_dims.size() == 5) { // 3D interpolation + Interpolate3DCPUFwd(ctx, *input, output); } + } +}; + +template +class InterpolateGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); - if ("bilinear" == interp_method) { - BilinearInterpolationGrad(*output_grad, input_grad, ratio_h, ratio_w, - in_h, in_w, n, c, out_h, out_w, - align_corners, align_mode); - } else if ("nearest" == interp_method) { - NearestNeighborInterpolateGrad(*output_grad, input_grad, ratio_h, - ratio_w, n, c, out_h, out_w, - align_corners); + auto output_grad_dims = output_grad->dims(); + if (output_grad_dims.size() == 4) { // 2D interpolation grad + Interpolate2DCPUBwd(ctx, input_grad, *output_grad); + } else if (output_grad_dims.size() == 5) { // 3D interpolation grad + Interpolate3DCPUBwd(ctx, input_grad, *output_grad); } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 21427e6ca634c8ebe09747b431e199c57a00df7d..babf5fd64e720a66334ae3953e0d248a1e7c8bf8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -119,6 +119,7 @@ __all__ = [ 'image_resize', 'image_resize_short', 'resize_bilinear', + 'resize_trilinear', 'resize_nearest', 'gather', 'scatter', @@ -7672,13 +7673,16 @@ def image_resize(input, """ **Resize a Batch of Images** - The input must be a tensor of the shape (num_batches, channels, in_h, in_w), - and the resizing only applies on the last two dimensions(hight and width). + The input must be a tensor of the shape (num_batches, channels, in_h, in_w) + or (num_batches, channels, in_d, in_h, in_w), and the resizing only applies + on the last two/three dimensions(depth, hight and width). Supporting resample methods: 'BILINEAR' : Bilinear interpolation + 'TRILINEAR' : Trilinear interpolation + 'NEAREST' : Nearest neighbor interpolation Nearest neighbor interpolation is to perform nearest neighbor interpolation @@ -7691,6 +7695,11 @@ def image_resize(input, to perform linear interpolation first in one direction, and then again in the other direction. + Trilinear interpolation is an extension of linear interpolation for + interpolating functions of three variables (e.g. D-direction, + H-direction and W-direction in this op) on a rectilinear 3D grid. + The linear interpolation is performed on three directions. + Align_corners and align_mode are optinal parameters,the calculation method of interpolation can be selected by them. @@ -7748,30 +7757,58 @@ def image_resize(input, H_out = H_{in} * scale_{factor} W_out = W_{in} * scale_{factor} + Trilinear interpolation: + + if: + align_corners = False , align_mode = 0 + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + + D_out = (D_{in}+0.5) * scale_{factor} - 0.5 + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + + else: + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + + D_out = D_{in} * scale_{factor} + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + For details of nearest neighbor interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation. + For details of trilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Trilinear_interpolation. + Args: input (Variable): The input tensor of image resize layer, This is a 4-D tensor of the shape - (num_batches, channels, in_h, in_w). + (num_batches, channels, in_h, in_w) or a + 5-D tensor of the shape + (num_batches, channls, in_d, in_h, in_w). out_shape(list|tuple|Variable|None): Output shape of image resize - layer, the shape is (out_h, out_w). - Default: None + layer, the shape is (out_h, out_w) when + input is a 4-D tensor and is + (out_d, out_h, out_w) when input is a + 5-D tensor. Default: None scale(float|None): The multiplier for the input height or width. At least one of :attr:`out_shape` or :attr:`scale` must be set. And :attr:`out_shape` has a higher priority than :attr:`scale`. Default: None. name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. - resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST' - currently. - Default: 'BILINEAR' + resample(str): The resample method. It supports 'BILINEAR', 'TRILINEAR' + and 'NEAREST' currently. Default: 'BILINEAR' actual_shape(Variable): An optional input to specify output shape dynamically. If provided, image resize according to this given shape rather than @@ -7795,15 +7832,19 @@ def image_resize(input, Returns: Variable: The output is a 4-D tensor of the shape - (num_batches, channls, out_h, out_w). + (num_batches, channls, out_h, out_w) or a 5-D tensor of the shape + (num_batches, channels, out_d, out_h, out_w). Raises: TypeError: out_shape should be a list or tuple or Variable. TypeError: actual_shape should either be Variable or None. - ValueError: The 'resample' of image_resize can only be 'BILINEAR' - or 'NEAREST' currently. + ValueError: The 'resample' of image_resize can only be 'BILINEAR', + 'TRILINEAR' or 'NEAREST' currently. + ValueError: 'BILINEAR' and 'NEAREST' only support 4-D tensor. + ValueError: 'TRILINEAR' only support 5-D tensor. ValueError: One of out_shape and scale must not be None. - ValueError: out_shape length should be 2. + ValueError: out_shape length should be 2 for input 4-D tensor. + ValueError: out_shape length should be 3 for input 5-D tensor. ValueError: scale should be greater than zero. TypeError: align_corners shoule be a bool value ValueError: align_mode can only be '0' or '1' @@ -7817,14 +7858,20 @@ def image_resize(input, """ resample_methods = { 'BILINEAR': 'bilinear', + 'TRILINEAR': 'trilinear', 'NEAREST': 'nearest', } if resample not in resample_methods: raise ValueError( - "The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently." - ) + "The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR' " + "or 'NEAREST' currently.") resample_type = resample_methods[resample] + if resample in ['BILINEAR', 'NEAREST'] and len(input.shape) != 4: + raise ValueError("'BILINEAR' and 'NEAREST' only support 4-D tensor.") + if resample == 'TRILINEAR' and len(input.shape) != 5: + raise ValueError("'TRILINEAR'only support 5-D tensor.") + if not isinstance(align_corners, bool): raise TypeError("Attr align_corners should be a bool value") if align_mode != 0 and align_mode != 1: @@ -7840,6 +7887,7 @@ def image_resize(input, inputs = {"X": input} attrs = { + "out_d": 0, "out_h": 0, "out_w": 0, "interp_method": resample_type, @@ -7857,12 +7905,21 @@ def image_resize(input, if not (_is_list_or_turple_(out_shape)): raise TypeError( "out_shape should be a list or tuple or Variable.") - if len(out_shape) != 2: - raise ValueError("out_shape length should be 2.") - - out_shape = list(map(int, out_shape)) - attrs['out_h'] = out_shape[0] - attrs['out_w'] = out_shape[1] + if len(input.shape) == 4: + if len(out_shape) != 2: + raise ValueError("out_shape length should be 2 for " + "input 4-D tensor.") + out_shape = list(map(int, out_shape)) + attrs['out_h'] = out_shape[0] + attrs['out_w'] = out_shape[1] + if len(input.shape) == 5: + if len(out_shape) != 3: + raise ValueError("out_shape length should be 3 for " + "input 5-D tensor.") + out_shape = list(map(int, out_shape)) + attrs['out_d'] = out_shape[0] + attrs['out_h'] = out_shape[1] + attrs['out_w'] = out_shape[2] else: if scale <= 0: @@ -7945,7 +8002,7 @@ def resize_bilinear(input, Args: - input(${x_type}): ${x_comment}. + input(${x_type}): input should be a 4-D tensor. out_shape(list|tuple|Variable|None): Output shape of resize bilinear layer, the shape is (out_h, out_w). @@ -7974,7 +8031,7 @@ def resize_bilinear(input, align_mode(bool): ${align_mode_comment} Returns: - ${out_comment}. + A 4-D tensor in shape of (num_batches, channels, out_h, out_w) Examples: .. code-block:: python @@ -7988,6 +8045,112 @@ def resize_bilinear(input, align_corners, align_mode) +@templatedoc(op_type="trilinear_interp") +def resize_trilinear(input, + out_shape=None, + scale=None, + name=None, + actual_shape=None, + align_corners=True, + align_mode=1): + """ + Resize input by performing trilinear interpolation based on given + output shape which specified by actual_shape, out_shape and scale + in priority order. + + Trilinear interpolation is an extension of linear interpolation for + interpolating functions of three variables (e.g. D-direction, + H-direction and W-direction in this op) on a rectilinear 3D grid. + The linear interpolation is performed on three directions. + + For details of trilinear interpolation, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Trilinear_interpolation + + Align_corners and align_mode are optinal parameters,the calculation + method of interpolation can be selected by them. + + Example: + + .. code-block:: text + + For scale: + + if align_corners = True && out_size > 1 : + + scale_factor = (in_size-1.0)/(out_size-1.0) + + else: + + scale_factor = float(in_size/out_size) + + Bilinear interpolation: + + if: + align_corners = False , align_mode = 0 + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + + D_out = (D_{in}+0.5) * scale_{factor} - 0.5 + H_out = (H_{in}+0.5) * scale_{factor} - 0.5 + W_out = (W_{in}+0.5) * scale_{factor} - 0.5 + + + else: + + input : (N,C,D_in,H_in,W_in) + output: (N,C,D_out,H_out,W_out) where: + + D_out = D_{in} * scale_{factor} + H_out = H_{in} * scale_{factor} + W_out = W_{in} * scale_{factor} + + + + Args: + input(${x_type}): input should be a 4-D tensor. + + out_shape(list|tuple|Variable|None): Output shape of resize bilinear + layer, the shape is (out_d, out_h, out_w). + Default: None + + scale(float|None): The multiplier for the input depth, height or width. + At least one of :attr:`out_shape` or :attr:`scale` must be set. + And :attr:`out_shape` has a higher priority than :attr:`scale`. + Default: None. + + name(str|None): The output variable name. + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than + :attr:`out_shape` and :attr:`scale` specifying + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph + constructing stage. + Default: None + align_corners(bool): ${align_corners_comment} + align_mode(bool): ${align_mode_comment} + + Returns: + A 5-D tensor in shape (num_batches, channels, out_d, out_h, out_w) + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + input = fluid.layers.data(name="input", shape=[3,6,9,11], dtype="float32") + out = fluid.layers.resize_trilinear(input, out_shape=[12, 12, 12]) + """ + + return image_resize(input, out_shape, scale, name, 'TRILINEAR', + actual_shape, align_corners, align_mode) + + @templatedoc(op_type="nearest_interp") def resize_nearest(input, out_shape=None, @@ -8041,7 +8204,7 @@ def resize_nearest(input, https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation Args: - input(${x_type}): ${x_comment}. + input(${x_type}): input should be a 4-D tensor. out_shape(list|tuple|Variable|None): Output shape of resize nearest layer, the shape is (out_h, out_w). @@ -8069,7 +8232,7 @@ def resize_nearest(input, align_corners(bool): ${align_corners_comment} Returns: - ${out_comment}. + A 4-D tensor in shape of (num_batches, channels, out_h, out_w) Examples: .. code-block:: python diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index 7e577229777e256b15a02232229f4127b0f877f5..199a446a11a64fe1627ec5a80e340bd6073a0a30 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -205,6 +205,17 @@ class TestBilinearInterpCase6(TestBilinearInterpOp): self.align_mode = 1 +class TestBilinearInterpSame(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 128, 64] + self.out_h = 128 + self.out_w = 64 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + class TestBilinearInterpActualShape(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index b071ce0a757cd70f7b83d379c463c01c6d6047d0..a4e51d6cfea1c3dd20516f4f9f1d76ff6492f91c 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1295,16 +1295,74 @@ class TestBook(LayerTest): x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32") output = layers.resize_bilinear(x, out_shape=[12, 12]) return (output) - output = layers.resize_bilinear(x, scale=3) + + def make_resize_bilinear_by_scale(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32") + output = layers.resize_bilinear(x, scale=1.5) return (output) def make_resize_nearest(self): + try: + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32") + output = layers.resize_nearest(x, out_shape=[12, 12]) + except ValueError: + pass + + try: + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data( + name='x2', shape=[3, 9, 6, 7], dtype="float32") + output = layers.resize_nearest(x, out_shape=[12, 12, 12]) + except ValueError: + pass + with program_guard(fluid.default_main_program(), fluid.default_startup_program()): x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32") output = layers.resize_nearest(x, out_shape=[12, 12]) return (output) - output = layers.resize_nearest(x, scale=3) + + def make_resize_nearest_by_scale(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x1', shape=[3, 9, 6], dtype="float32") + output = layers.resize_nearest(x, scale=1.8) + return (output) + + def make_resize_trilinear(self): + try: + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x2', shape=[3, 9, 6], dtype="float32") + output = layers.resize_trilinear(x, out_shape=[12, 12, 12]) + except ValueError: + pass + + try: + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data( + name='x', shape=[3, 9, 6, 7], dtype="float32") + output = layers.resize_trilinear(x, out_shape=[12, 12]) + except ValueError: + pass + + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32") + output = layers.resize_trilinear(x, out_shape=[12, 12, 12]) + return (output) + + def make_resize_trilinear_by_scale(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 9, 6, 7], dtype="float32") + output = layers.resize_trilinear(x, scale=2.1) return (output) def make_polygon_box_transform(self): diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py index 1feb2aefda4d18255db13f657a79f0bd05d1b0a3..163293621f9f64e3290ff964e068b63603b91c42 100644 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py @@ -176,6 +176,16 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp): self.align_corners = True +class TestNearestNeighborInterpSame(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 128, 64] + self.out_h = 128 + self.out_w = 64 + self.scale = 0. + self.align_corners = True + + class TestNearestNeighborInterpActualShape(TestNearestInterpOp): def init_test_case(self): self.interp_method = 'nearest' diff --git a/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1d712e8485aa9a048ca75f94fe48cd5652adc102 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_trilinear_interp_op.py @@ -0,0 +1,428 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +def trilinear_interp_np(input, + out_d, + out_h, + out_w, + out_size=None, + actual_shape=None, + align_corners=True, + align_mode=0): + """trilinear interpolation implement in shape [N, C, D, H, W]""" + if out_size is not None: + out_d = out_size[0] + out_h = out_size[1] + out_w = out_size[2] + if actual_shape is not None: + out_d = actual_shape[0] + out_h = actual_shape[1] + out_w = actual_shape[2] + batch_size, channel, in_d, in_h, in_w = input.shape + + ratio_d = ratio_h = ratio_w = 0.0 + if out_d > 1: + if (align_corners): + ratio_d = (in_d - 1.0) / (out_d - 1.0) + else: + ratio_d = 1.0 * in_d / out_d + if out_h > 1: + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if out_w > 1: + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((batch_size, channel, out_d, out_h, out_w)) + + for i in range(out_d): + if (align_mode == 0 and not align_corners): + d = int(ratio_d * (i + 0.5) - 0.5) + else: + d = int(ratio_d * i) + + d = max(0, d) + did = 1 if d < in_d - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_d = max(ratio_d * (i + 0.5) - 0.5, 0) + d1lambda = idx_src_d - d + else: + d1lambda = ratio_d * i - d + d2lambda = 1.0 - d1lambda + + for j in range(out_h): + if (align_mode == 0 and not align_corners): + h = int(ratio_h * (j + 0.5) - 0.5) + else: + h = int(ratio_h * j) + + h = max(0, h) + hid = 1 if h < in_h - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_h = max(ratio_h * (j + 0.5) - 0.5, 0) + h1lambda = idx_src_h - h + else: + h1lambda = ratio_h * j - h + h2lambda = 1.0 - h1lambda + + for k in range(out_w): + if (align_mode == 0 and not align_corners): + w = int(ratio_w * (k + 0.5) - 0.5) + else: + w = int(ratio_w * k) + w = max(0, w) + wid = 1 if w < in_w - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_w = max(ratio_w * (k + 0.5) - 0.5, 0) + w1lambda = idx_src_w - w + else: + w1lambda = ratio_w * k - w + w2lambda = 1.0 - w1lambda + + out[:, :, i, j, k] = \ + d2lambda * \ + (h2lambda * (w2lambda * input[:, :, d, h, w] + \ + w1lambda * input[:, :, d, h, w+wid]) + \ + h1lambda * (w2lambda * input[:, :, d, h+hid, w] + \ + w1lambda * input[:, :, d, h+hid, w+wid])) + \ + d1lambda * \ + (h2lambda * (w2lambda * input[:, :, d+did, h, w] + \ + w1lambda * input[:, :, d+did, h, w+wid]) + \ + h1lambda * (w2lambda * input[:, :, d+did, h+hid, w] + \ + w1lambda * input[:, :, d+did, h+hid, w+wid])) + return out.astype(input.dtype) + + +class TestTrilinearInterpOp(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "trilinear_interp" + input_np = np.random.random(self.input_shape).astype("float32") + + if self.scale > 0: + out_d = int(self.input_shape[2] * self.scale) + out_h = int(self.input_shape[3] * self.scale) + out_w = int(self.input_shape[4] * self.scale) + else: + out_d = self.out_d + out_h = self.out_h + out_w = self.out_w + + output_np = trilinear_interp_np(input_np, out_d, out_h, out_w, + self.out_size, self.actual_shape, + self.align_corners, self.align_mode) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + + self.attrs = { + 'out_d': self.out_d, + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 3, 4, 4, 4] + self.out_d = 2 + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3, 3]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase1(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 1, 7, 8, 9] + self.out_d = 1 + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase2(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 3, 9, 6, 8] + self.out_d = 12 + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase3(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [3, 2, 16, 8, 4] + self.out_d = 32 + self.out_h = 16 + self.out_w = 8 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase4(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [4, 1, 7, 8, 9] + self.out_d = 1 + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.out_size = np.array([2, 2, 2]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase5(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [3, 3, 9, 6, 8] + self.out_d = 12 + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = np.array([11, 11, 11]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase6(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [1, 1, 16, 8, 4] + self.out_d = 8 + self.out_h = 32 + self.out_w = 16 + self.scale = 0. + self.out_size = np.array([17, 9, 5]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpSame(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [1, 1, 16, 8, 4] + self.out_d = 16 + self.out_h = 8 + self.out_w = 4 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpSameHW(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [1, 1, 16, 8, 4] + self.out_d = 8 + self.out_h = 8 + self.out_w = 4 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpActualShape(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [3, 2, 16, 8, 4] + self.out_d = 64 + self.out_h = 32 + self.out_w = 16 + self.scale = 0. + self.out_size = np.array([33, 19, 7]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpOpUint8(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "trilinear_interp" + input_np = np.random.randint( + low=0, high=256, size=self.input_shape).astype("uint8") + + if self.scale > 0: + out_d = int(self.input_shape[2] * self.scale) + out_h = int(self.input_shape[3] * self.scale) + out_w = int(self.input_shape[4] * self.scale) + else: + out_d = self.out_d + out_h = self.out_h + out_w = self.out_w + + output_np = trilinear_interp_np(input_np, out_d, out_h, out_w, + self.out_size, self.actual_shape, + self.align_corners, self.align_mode) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + + self.attrs = { + 'out_d': self.out_d, + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(place=core.CPUPlace(), atol=1) + + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [1, 3, 9, 6, 8] + self.out_d = 13 + self.out_h = 10 + self.out_w = 9 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase1Uint8(TestTrilinearInterpOpUint8): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 3, 16, 8, 4] + self.out_d = 13 + self.out_h = 7 + self.out_w = 2 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpCase2Uint8(TestTrilinearInterpOpUint8): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [4, 1, 7, 8, 9] + self.out_d = 3 + self.out_h = 5 + self.out_w = 13 + self.scale = 0. + self.out_size = np.array([6, 15, 21]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpOtherMethod1(TestTrilinearInterpOp): + def set_align_mode(self): + self.align_corners = False + self.align_mode = 1 + + +class TestTrilinearInterpWithMethod2(TestTrilinearInterpOp): + def set_align_mode(self): + self.align_corners = False + self.align_mode = 0 + + +class TestTrilinearInterpWithMethod3(TestTrilinearInterpOp): + def set_align_mode(self): + self.align_corners = True + self.align_mode = 0 + + +class TestTrilinearInterpScale1(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 3, 5, 7, 9] + self.out_d = 82 + self.out_h = 60 + self.out_w = 25 + self.scale = 2. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpScale2(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 3, 5, 7, 9] + self.out_d = 82 + self.out_h = 60 + self.out_w = 25 + self.scale = 1. + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpScale3(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 3, 5, 7, 9] + self.out_d = 82 + self.out_h = 60 + self.out_w = 25 + self.scale = 1.5 + self.align_corners = True + self.align_mode = 1 + + +class TestTrilinearInterpZero(TestTrilinearInterpOp): + def init_test_case(self): + self.interp_method = 'trilinear' + self.input_shape = [2, 3, 5, 7, 11] + self.out_d = 82 + self.out_h = 60 + self.out_w = 25 + self.scale = 0.2 + self.align_corners = False + self.align_mode = 0 + + +if __name__ == "__main__": + unittest.main()