// Copyright 2018 Xiaomi, Inc.  All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mace/kernels/winograd_transform.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"

namespace mace {
namespace kernels {

template <typename T>
void WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
    const Tensor *input_tensor, Tensor *output_tensor, StatsFuture *future) {

  auto runtime = OpenCLRuntime::Global();

  if (kernel_.get() == nullptr) {
    std::string obfuscated_kernel_name =
        MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2");
    std::set<std::string> built_options;
    built_options.emplace("-Dwinograd_transform_2x2=" + obfuscated_kernel_name);
    built_options.emplace("-DDATA_TYPE=" +
                          DtToUpstreamCLDt(DataTypeToEnum<T>::value));
    built_options.emplace("-DCMD_DATA_TYPE=" +
                          DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
    if (runtime->IsOutOfRangeCheckEnabled()) {
      built_options.emplace("-DOUT_OF_RANGE_CHECK");
      kernel_error_ = std::move(std::unique_ptr<Buffer>(
            new Buffer(GetDeviceAllocator(DeviceType::GPU), 1)));
      kernel_error_->Map(nullptr);
      *(kernel_error_->mutable_data<char>()) = 0;
      kernel_error_->UnMap();
    }
    if (runtime->IsNonUniformWorkgroupsSupported()) {
      built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
    }
    kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name,
                                   built_options);

    kwg_size_ =
        static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
  }
  std::vector<index_t> output_shape(4);
  std::vector<index_t> filter_shape = {1, input_tensor->dim(3), 3, 3};
  std::vector<int> paddings(2);
  if (paddings_.empty()) {
    kernels::CalcNHWCPaddingAndOutputSize(
        input_tensor->shape().data(), filter_shape.data(), dilations_.data(),
        strides_.data(), padding_type_, output_shape.data(), paddings.data());
  } else {
    paddings = paddings_;
    CalcOutputSize(input_tensor->shape().data(), filter_shape.data(),
                   paddings_.data(), dilations_.data(), strides_.data(),
                   RoundType::FLOOR, output_shape.data());
  }
  const index_t round_h = (output_shape[1] + 1) / 2;
  const index_t round_w = (output_shape[2] + 1) / 2;
  const index_t out_width = input_tensor->dim(0) * round_h * round_w;
  const uint32_t gws[2] = {
      static_cast<uint32_t>(out_width),
      static_cast<uint32_t>(RoundUpDiv4(input_tensor->dim(3)))};

  if (!IsVecEqual(input_shape_, input_tensor->shape())) {
    output_shape = {16, input_tensor->dim(3), out_width, 1};
    std::vector<size_t> image_shape;
    CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, &image_shape);
    output_tensor->ResizeImage(output_shape, image_shape);

    uint32_t idx = 0;
    if (runtime->IsOutOfRangeCheckEnabled()) {
      kernel_.setArg(idx++,
          *(static_cast<cl::Buffer *>(kernel_error_->buffer())));
    }
    if (!runtime->IsNonUniformWorkgroupsSupported()) {
      kernel_.setArg(idx++, gws[0]);
      kernel_.setArg(idx++, gws[1]);
    }
    kernel_.setArg(idx++, *(input_tensor->opencl_image()));
    kernel_.setArg(idx++, *(output_tensor->opencl_image()));
    kernel_.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(1)));
    kernel_.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(2)));
    kernel_.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(3)));
    kernel_.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
    kernel_.setArg(idx++, static_cast<uint32_t>(round_w));
    kernel_.setArg(idx++, static_cast<uint32_t>(paddings[0] / 2));
    kernel_.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2));

    input_shape_ = input_tensor->shape();
  }

  const std::vector<uint32_t> lws = {kwg_size_ / 8, 8, 0};
  std::string tuning_key =
      Concat("winograd_transform_kernel", output_tensor->dim(0),
             output_tensor->dim(1), output_tensor->dim(2),
             output_tensor->dim(3));
  TuningOrRun2DKernel(kernel_, tuning_key, gws, lws, future);

  if (runtime->IsOutOfRangeCheckEnabled()) {
    kernel_error_->Map(nullptr);
    char *kerror_code = kernel_error_->mutable_data<char>();
    MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
    kernel_error_->UnMap();
  }
}

template <typename T>
void WinogradInverseTransformFunctor<DeviceType::GPU, T>::operator()(
    const Tensor *input_tensor,
    const Tensor *bias,
    Tensor *output_tensor,
    StatsFuture *future) {

  auto runtime = OpenCLRuntime::Global();

  if (kernel_.get() == nullptr) {
    std::string obfuscated_kernel_name =
        MACE_OBFUSCATE_SYMBOL("winograd_inverse_transform_2x2");
    std::set<std::string> built_options;
    built_options.emplace("-Dwinograd_inverse_transform_2x2=" +
                          obfuscated_kernel_name);
    built_options.emplace("-DDATA_TYPE=" +
                          DtToUpstreamCLDt(DataTypeToEnum<T>::value));
    built_options.emplace("-DCMD_DATA_TYPE=" +
                          DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
    if (runtime->IsOutOfRangeCheckEnabled()) {
      built_options.emplace("-DOUT_OF_RANGE_CHECK");
      kernel_error_ = std::move(std::unique_ptr<Buffer>(
            new Buffer(GetDeviceAllocator(DeviceType::GPU), 1)));
      kernel_error_->Map(nullptr);
      *(kernel_error_->mutable_data<char>()) = 0;
      kernel_error_->UnMap();
    }
    if (runtime->IsNonUniformWorkgroupsSupported()) {
      built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
    }
    built_options.emplace(bias != nullptr ? "-DBIAS" : "");
    switch (activation_) {
      case NOOP:
        break;
      case RELU:
        built_options.emplace("-DUSE_RELU");
        break;
      case RELUX:
        built_options.emplace("-DUSE_RELUX");
        break;
      case PRELU:
        built_options.emplace("-DUSE_PRELU");
        break;
      case TANH:
        built_options.emplace("-DUSE_TANH");
        break;
      case SIGMOID:
        built_options.emplace("-DUSE_SIGMOID");
        break;
      default:
        LOG(FATAL) << "Unknown activation type: " << activation_;
    }

    kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name,
                                   built_options);

    kwg_size_ =
        static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
  }

  const uint32_t gws[2] = {
      static_cast<uint32_t>(input_tensor->dim(2)),
      static_cast<uint32_t>(RoundUpDiv4(input_tensor->dim(1)))};
  if (!IsVecEqual(input_shape_, input_tensor->shape())) {
    std::vector<index_t> output_shape = {batch_, height_, width_,
                                         input_tensor->dim(1)};
    std::vector<size_t> image_shape;
    CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
    output_tensor->ResizeImage(output_shape, image_shape);

    const uint32_t round_h = (height_ + 1) / 2;
    const uint32_t round_w = (width_ + 1) / 2;
    uint32_t idx = 0;
    if (runtime->IsOutOfRangeCheckEnabled()) {
      kernel_.setArg(idx++,
          *(static_cast<cl::Buffer *>(kernel_error_->buffer())));
    }
    if (!runtime->IsNonUniformWorkgroupsSupported()) {
      kernel_.setArg(idx++, gws[0]);
      kernel_.setArg(idx++, gws[1]);
    }
    kernel_.setArg(
        idx++,
        *(static_cast<const cl::Image2D *>(input_tensor->opencl_image())));
    if (bias != nullptr) {
      kernel_.setArg(idx++,
                     *(static_cast<const cl::Image2D *>(bias->opencl_image())));
    }
    kernel_.setArg(
        idx++, *(static_cast<cl::Image2D *>(output_tensor->opencl_image())));
    kernel_.setArg(idx++, static_cast<uint32_t>(output_shape[1]));
    kernel_.setArg(idx++, static_cast<uint32_t>(output_shape[2]));
    kernel_.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
    kernel_.setArg(idx++, static_cast<uint32_t>(round_w));
    kernel_.setArg(idx++, relux_max_limit_);

    input_shape_ = input_tensor->shape();
  }

  const std::vector<uint32_t> lws = {kwg_size_ / 8, 8, 0};
  std::string tuning_key =
      Concat("winograd_inverse_transform_kernel", output_tensor->dim(0),
             output_tensor->dim(1), output_tensor->dim(2),
             output_tensor->dim(3), input_tensor->dim(2));
  TuningOrRun2DKernel(kernel_, tuning_key, gws, lws, future);

  if (runtime->IsOutOfRangeCheckEnabled()) {
    kernel_error_->Map(nullptr);
    char *kerror_code = kernel_error_->mutable_data<char>();
    MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
    kernel_error_->UnMap();
  }
}

template struct WinogradTransformFunctor<DeviceType::GPU, float>;
template struct WinogradTransformFunctor<DeviceType::GPU, half>;

template struct WinogradInverseTransformFunctor<DeviceType::GPU, float>;
template struct WinogradInverseTransformFunctor<DeviceType::GPU, half>;

}  // namespace kernels
}  // namespace mace
