fc_buffer_compute.cc 5.3 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>
16
#include "lite/backends/opencl/cl_include.h"
Y
Yan Chunwei 已提交
17 18
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
19
#include "lite/kernels/opencl/image_helper.h"
Y
Yan Chunwei 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {

class FcCompute
    : public KernelLite<TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
 public:
  using param_t = operators::FcParam;

34
  void PrepareForRun() override {}
Y
Yan Chunwei 已提交
35

36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  void ReInitWhenNeeded() override {
    fc_param_ = param_.get_mutable<param_t>();
    const auto x_dims = fc_param_->input->dims();
    if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
        first_epoch_for_reinit_) {
      last_x_dims_ = x_dims;
      first_epoch_for_reinit_ = false;

      // compute m,n,k
      const auto w_dims = fc_param_->w->dims();
      CHECK_GE(x_dims.size(), 2UL);
      CHECK_GE(w_dims.size(), 2UL);
      CHECK_EQ(fc_param_->output->dims().size(), 2UL);

      m_ = x_dims.Slice(0, fc_param_->in_num_col_dims).production();
      k_ = x_dims.Slice(fc_param_->in_num_col_dims, x_dims.size()).production();
      n_ = w_dims[1];
      CHECK_EQ(k_, static_cast<int>(w_dims[0]));

55
#ifdef LITE_WITH_LOG
56 57 58 59 60 61 62 63 64 65 66 67 68
      VLOG(4) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2]
              << " " << x_dims[3];
      VLOG(4) << "w_dims:" << w_dims[0] << " " << w_dims[1] << " " << w_dims[2]
              << " " << w_dims[3];
      VLOG(4) << "m_: " << m_ << " n_: " << n_ << " k_: " << k_;
#endif

      // choose kernel
      if (m_ == 1) {  // gemv
        kernel_func_name_ = "fc_gemv_1x4";
      } else {  // gemm
        kernel_func_name_ = "fc_gemm_4x4";
      }
69
#ifdef LITE_WITH_LOG
70 71 72 73 74 75 76 77
      VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
#endif

      if (fc_param_->activation_type == "relu") {
        build_options_ += "-DRELU";
      }

      auto& context = ctx_->As<OpenCLContext>();
78 79 80 81
      context.cl_context()->AddKernel(kernel_func_name_,
                                      "buffer/fc_kernel.cl",
                                      build_options_,
                                      time_stamp_);
82
      STL::stringstream kernel_key;
83
      kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
84 85 86 87 88 89 90 91
      kernel_ = context.cl_context()->GetKernel(kernel_key.str());

      // compute global work size
      GetGlobalWorkSize();
    }
  }

  void GetGlobalWorkSize() {
Y
Yan Chunwei 已提交
92 93 94 95 96 97 98 99 100
    if (m_ == 1) {  // gemv
      global_work_size_ = cl::NDRange{static_cast<size_t>((n_ + 3) / 4)};
    } else {  // gemm
      global_work_size_ = cl::NDRange{static_cast<size_t>((m_ + 3) / 4),
                                      static_cast<size_t>((n_ + 3) / 4)};
    }
  }

  void Run() override {
101 102 103
    auto* x_buf = fc_param_->input->data<float, cl::Buffer>();
    auto* w_buf = fc_param_->w->data<float, cl::Buffer>();
    auto* bias_buf = fc_param_->bias->data<float, cl::Buffer>();
Y
Yan Chunwei 已提交
104
    auto* out_buf =
105
        fc_param_->output->mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
Y
Yan Chunwei 已提交
106

107
    auto kernel = kernel_;
Y
Yan Chunwei 已提交
108
    cl_int status;
109
    status = kernel.setArg(0, *x_buf);
Y
Yan Chunwei 已提交
110
    CL_CHECK_FATAL(status);
111
    status = kernel.setArg(1, *w_buf);
Y
Yan Chunwei 已提交
112
    CL_CHECK_FATAL(status);
113
    status = kernel.setArg(2, *bias_buf);
Y
Yan Chunwei 已提交
114
    CL_CHECK_FATAL(status);
115
    status = kernel.setArg(3, *out_buf);
Y
Yan Chunwei 已提交
116
    CL_CHECK_FATAL(status);
117
    status = kernel.setArg(4, static_cast<const int>(m_));
Y
Yan Chunwei 已提交
118
    CL_CHECK_FATAL(status);
119
    status = kernel.setArg(5, static_cast<const int>(n_));
Y
Yan Chunwei 已提交
120
    CL_CHECK_FATAL(status);
121
    status = kernel.setArg(6, static_cast<const int>(k_));
Y
Yan Chunwei 已提交
122 123
    CL_CHECK_FATAL(status);

124 125
    auto& context = ctx_->As<OpenCLContext>();
    CHECK(context.cl_context() != nullptr);
X
xiebaiyuan 已提交
126

Y
Yan Chunwei 已提交
127 128 129 130 131 132
    status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
        kernel,
        cl::NullRange,
        global_work_size_,
        cl::NullRange,
        nullptr,
X
xiebaiyuan 已提交
133
        nullptr);
Y
Yan Chunwei 已提交
134 135 136 137 138
    CL_CHECK_FATAL(status);
  }

 private:
  int m_, n_, k_;
139
  param_t* fc_param_{nullptr};
Y
Yan Chunwei 已提交
140
  std::string kernel_func_name_{};
141
  std::string build_options_{"-DCL_DTYPE_float "};
142
  std::string time_stamp_{GetTimeStamp()};
143 144
  bool first_epoch_for_reinit_{true};
  DDim last_x_dims_;
Y
Yan Chunwei 已提交
145
  cl::NDRange global_work_size_;
146
  cl::Kernel kernel_;
Y
Yan Chunwei 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160
};

}  // namespace opencl
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(
    fc, kOpenCL, kFloat, kNCHW, paddle::lite::kernels::opencl::FcCompute, def)
    .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .BindInput("W", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .Finalize();