diff --git a/paddle/fluid/operators/bilinear_interp_op.cc b/paddle/fluid/operators/bilinear_interp_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..896ef7bed258b25e39bb6cb2f6c0d35da0eef22e --- /dev/null +++ b/paddle/fluid/operators/bilinear_interp_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/bilinear_interp_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class BilinearInterpOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of BilinearInterOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of BilinearInterOp should not be null."); + + auto dim_x = ctx->GetInputDim("Input"); // NCHW format + int out_h = ctx->Attrs().Get("out_h"); + int out_w = ctx->Attrs().Get("out_w"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); + + std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); + ctx->SetOutputDim("Output", framework::make_ddim(dim_out)); + } +}; + +class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input tensor of bilinear interpolation, 4-D with NCHW shape"); + AddOutput("Out", "The output tensor with the same shape as X"); + AddAttr("out_h", "output height of bilinear interpolation op."); + AddAttr("out_w", "output weight of bilinear interpolation op."); + AddComment(R"DOC( + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and W-direction + in this op) on a rectilinear 2D grid. + + The key idea is to perform linear interpolation first in one direction, + and then again in the other direction. + + For details, please refer to Wikipedia: + https://en.wikipedia.org/wiki/Bilinear_interpolation + )DOC"); + } +}; + +class BilinearInterpOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(bilinear_interp, ops::BilinearInterpOp, ops::BilinearInterpOpMaker, + bilinear_interp_grad, ops::BilinearInterpOpGrad); +REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel); +REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::BilinearInterpKernel); diff --git a/paddle/fluid/operators/bilinear_interp_op.h b/paddle/fluid/operators/bilinear_interp_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9571d8699c76fef145cd360bdf9a7b4e1f34bea7 --- /dev/null +++ b/paddle/fluid/operators/bilinear_interp_op.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class BilinearInterpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_t = ctx.Input("X"); // float tensor + auto* output_t = ctx.Output("Out"); // float tensor + auto* input = input_t->data(); + auto* output = output_t->mutable_data(ctx.GetPlace()); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + int batch_size = input_t->dims()[0]; + int channels = input_t->dims()[1]; + int in_h = input_t->dims()[2]; + int in_w = input_t->dims()[3]; + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = channels * in_hw; + int out_chw = channels * out_hw; + + T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + memcpy(output, input, product(input_t->dims()) * sizeof(T)); + } else { + for (int k = 0; k < batch_size; ++k) { // loop for batches + for (int i = 0; i < out_h; ++i) { // loop for images + int h = ratio_h * i; + int hid = (h < in_h - 1) ? 1 : 0; + T h1lambda = ratio_h * i - h; + T h2lambda = 1 - h1lambda; + + for (int j = 0; j < out_w; ++j) { + int w = ratio_w * j; + int wid = (w < in_w - 1) ? 1 : 0; + T w1lambda = ratio_w * j - w; + T w2lambda = 1 - w1lambda; + // calculate four position for bilinear interpolation + const T* in_pos = &input[k * in_chw + h * in_w + w]; + T* out_pos = &output[k * out_chw + i * out_w + j]; + + for (int c = 0; c < channels; ++c) { // loop for channels + // bilinear interpolation + out_pos[0] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) + + h1lambda * (w2lambda * in_pos[hid * in_w] + + w1lambda * in_pos[hid * in_w + wid]); + in_pos += in_hw; + out_pos += out_hw; + } + } + } + } + } + } +}; + +template +class BilinearInterpGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_input_t = ctx.Output(framework::GradVarName("X")); + auto* d_output_t = ctx.Input(framework::GradVarName("Out")); + auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); + auto* d_output = d_output_t->data(); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + int batch_size = d_input_t->dims()[0]; + int channels = d_input_t->dims()[1]; + int in_h = d_input_t->dims()[2]; + int in_w = d_input_t->dims()[3]; + + int in_hw = in_h * in_w; + int out_hw = out_h * out_w; + int in_chw = channels * in_hw; + int out_chw = channels * out_hw; + + T ratio_h = (out_h > 1) ? static_cast(in_h - 1) / (out_h - 1) : 0.f; + T ratio_w = (out_w > 1) ? static_cast(in_w - 1) / (out_w - 1) : 0.f; + + if (in_h == out_h && in_w == out_w) { + memcpy(d_input, d_output, product(d_input_t->dims()) * sizeof(T)); + } else { + for (int k = 0; k < batch_size; ++k) { // loop for batches + for (int i = 0; i < out_h; ++i) { // loop for images + int h = ratio_h * i; + int hid = (h < in_h - 1) ? 1 : 0; + T h1lambda = ratio_h * i - h; + T h2lambda = 1 - h1lambda; + + for (int j = 0; j < out_w; ++j) { + int w = ratio_w * j; + int wid = (w < in_w - 1) ? 1 : 0; + T w1lambda = ratio_w * j - w; + T w2lambda = 1 - w1lambda; + T* in_pos = &d_input[k * in_chw + h * in_w + w]; + const T* out_pos = &d_output[k * out_chw + i * out_w + j]; + + for (int c = 0; c < channels; ++c) { // loop for channels + in_pos[0] = h2lambda * w2lambda * out_pos[0]; + in_pos[wid] = h2lambda * w1lambda * out_pos[0]; + in_pos[hid * in_w] = h1lambda * w2lambda * out_pos[0]; + in_pos[hid * in_w + wid] = h1lambda * w1lambda * out_pos[0]; + in_pos += in_hw; + out_pos += out_hw; + } + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle