/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "rotated_anchor_generator_op.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void GenRAnchors(T* out,
                            const T* aspect_ratios,
                            const int ar_num,
                            const T* anchor_sizes,
                            const int as_num,
                            const T* angles,
                            const int aa_num,
                            const T* stride,
                            const int sd_num,
                            const int height,
                            const int width,
                            const T offset) {
  int num_anchors = as_num * ar_num * aa_num;
  int box_num = height * width * num_anchors;
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
       i += blockDim.x * gridDim.x) {
    int h_idx = i / (num_anchors * width);
    int w_idx = (i / num_anchors) % width;
    T stride_width = stride[0];
    T stride_height = stride[1];
    T x_ctr = (w_idx * stride_width) + offset * stride_width - 1;
    T y_ctr = (h_idx * stride_height) + offset * stride_height - 1;
    T area, area_ratios;
    T base_w, base_h;
    T scale_w, scale_h;
    T anchor_width, anchor_height;
    int anch_idx = i % num_anchors;
    int ar_idx = anch_idx / (as_num * aa_num);
    int as_idx = anch_idx / aa_num % as_num;
    int aa_idx = anch_idx % aa_num;
    T aspect_ratio = aspect_ratios[ar_idx];
    T anchor_size = anchor_sizes[as_idx];
    T angle = angles[aa_idx];
    area = stride_width * stride_height;
    area_ratios = area / aspect_ratio;
    base_w = round(sqrt(area_ratios));
    base_h = round(base_w * aspect_ratio);
    scale_w = anchor_size / stride_width;
    scale_h = anchor_size / stride_height;
    anchor_width = scale_w * base_w;
    anchor_height = scale_h * base_h;
    out[i * 5] = x_ctr;
    out[i * 5 + 1] = y_ctr;
    out[i * 5 + 2] = anchor_width;
    out[i * 5 + 3] = anchor_height;
    out[i * 5 + 4] = angle;
  }
}

template <typename T>
__global__ void SetVariance(T* out,
                            const T* var,
                            const int vnum,
                            const int num) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
       i += blockDim.x * gridDim.x) {
    out[i] = var[i % vnum];
  }
}

template <typename T>
class RotatedAnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> {
public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<paddle::framework::Tensor>("Input");
    auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
    auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");

    auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
    auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
    auto angles = ctx.Attr<std::vector<float>>("angles");
    auto stride = ctx.Attr<std::vector<float>>("stride");
    auto variances = ctx.Attr<std::vector<float>>("variances");

    T offset = static_cast<T>(ctx.Attr<float>("offset"));

    auto width = input->dims()[3];
    auto height = input->dims()[2];

    int num_anchors =
        aspect_ratios.size() * anchor_sizes.size() * angles.size();

    int box_num = width * height * num_anchors;

    int block = 512;
    int grid = (box_num + block - 1) / block;

    auto stream =
        ctx.template device_context<platform::CUDADeviceContext>().stream();

    anchors->mutable_data<T>(ctx.GetPlace());
    vars->mutable_data<T>(ctx.GetPlace());

    framework::Tensor ar;
    framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar);

    framework::Tensor as;
    framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as);

    framework::Tensor aa;
    framework::TensorFromVector(angles, ctx.device_context(), &aa);

    framework::Tensor sd;
    framework::TensorFromVector(stride, ctx.device_context(), &sd);

    GenRAnchors<T><<<grid, block, 0, stream>>>(anchors->data<T>(),
                                               ar.data<T>(),
                                               aspect_ratios.size(),
                                               as.data<T>(),
                                               anchor_sizes.size(),
                                               aa.data<T>(),
                                               angles.size(),
                                               sd.data<T>(),
                                               stride.size(),
                                               height,
                                               width,
                                               offset);

    framework::Tensor v;
    framework::TensorFromVector(variances, ctx.device_context(), &v);
    grid = (box_num * 5 + block - 1) / block;
    SetVariance<T><<<grid, block, 0, stream>>>(
        vars->data<T>(), v.data<T>(), variances.size(), box_num * 5);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(rotated_anchor_generator,
                        ops::RotatedAnchorGeneratorOpCUDAKernel<float>,
                        ops::RotatedAnchorGeneratorOpCUDAKernel<double>);
