fc_buffer_compute.cc 5.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>
16
#include "lite/backends/opencl/cl_include.h"
Y
Yan Chunwei 已提交
17 18
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
19
#include "lite/kernels/opencl/image_helper.h"
Y
Yan Chunwei 已提交
20 21 22
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
23 24 25 26
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
Y
Yan Chunwei 已提交
27 28 29 30 31 32 33 34 35 36 37

namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {

class FcCompute
    : public KernelLite<TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
 public:
  using param_t = operators::FcParam;

38
  void PrepareForRun() override {}
Y
Yan Chunwei 已提交
39

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
  void ReInitWhenNeeded() override {
    fc_param_ = param_.get_mutable<param_t>();
    const auto x_dims = fc_param_->input->dims();
    if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
        first_epoch_for_reinit_) {
      last_x_dims_ = x_dims;
      first_epoch_for_reinit_ = false;

      // compute m,n,k
      const auto w_dims = fc_param_->w->dims();
      CHECK_GE(x_dims.size(), 2UL);
      CHECK_GE(w_dims.size(), 2UL);
      CHECK_EQ(fc_param_->output->dims().size(), 2UL);

      m_ = x_dims.Slice(0, fc_param_->in_num_col_dims).production();
      k_ = x_dims.Slice(fc_param_->in_num_col_dims, x_dims.size()).production();
      n_ = w_dims[1];
      CHECK_EQ(k_, static_cast<int>(w_dims[0]));

59
#ifdef LITE_WITH_LOG
60 61 62 63 64 65 66 67 68 69 70 71 72
      VLOG(4) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2]
              << " " << x_dims[3];
      VLOG(4) << "w_dims:" << w_dims[0] << " " << w_dims[1] << " " << w_dims[2]
              << " " << w_dims[3];
      VLOG(4) << "m_: " << m_ << " n_: " << n_ << " k_: " << k_;
#endif

      // choose kernel
      if (m_ == 1) {  // gemv
        kernel_func_name_ = "fc_gemv_1x4";
      } else {  // gemm
        kernel_func_name_ = "fc_gemm_4x4";
      }
73
#ifdef LITE_WITH_LOG
74 75 76 77 78 79 80 81
      VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
#endif

      if (fc_param_->activation_type == "relu") {
        build_options_ += "-DRELU";
      }

      auto& context = ctx_->As<OpenCLContext>();
82 83 84 85
      context.cl_context()->AddKernel(kernel_func_name_,
                                      "buffer/fc_kernel.cl",
                                      build_options_,
                                      time_stamp_);
86
      STL::stringstream kernel_key;
87
      kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
88 89 90 91 92 93 94 95
      kernel_ = context.cl_context()->GetKernel(kernel_key.str());

      // compute global work size
      GetGlobalWorkSize();
    }
  }

  void GetGlobalWorkSize() {
Y
Yan Chunwei 已提交
96 97 98 99 100 101 102 103 104
    if (m_ == 1) {  // gemv
      global_work_size_ = cl::NDRange{static_cast<size_t>((n_ + 3) / 4)};
    } else {  // gemm
      global_work_size_ = cl::NDRange{static_cast<size_t>((m_ + 3) / 4),
                                      static_cast<size_t>((n_ + 3) / 4)};
    }
  }

  void Run() override {
105 106 107
    auto* x_buf = fc_param_->input->data<float, cl::Buffer>();
    auto* w_buf = fc_param_->w->data<float, cl::Buffer>();
    auto* bias_buf = fc_param_->bias->data<float, cl::Buffer>();
Y
Yan Chunwei 已提交
108
    auto* out_buf =
109
        fc_param_->output->mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
Y
Yan Chunwei 已提交
110

111
    auto kernel = kernel_;
Y
Yan Chunwei 已提交
112
    cl_int status;
113
    status = kernel.setArg(0, *x_buf);
Y
Yan Chunwei 已提交
114
    CL_CHECK_FATAL(status);
115
    status = kernel.setArg(1, *w_buf);
Y
Yan Chunwei 已提交
116
    CL_CHECK_FATAL(status);
117
    status = kernel.setArg(2, *bias_buf);
Y
Yan Chunwei 已提交
118
    CL_CHECK_FATAL(status);
119
    status = kernel.setArg(3, *out_buf);
Y
Yan Chunwei 已提交
120
    CL_CHECK_FATAL(status);
121
    status = kernel.setArg(4, static_cast<const int>(m_));
Y
Yan Chunwei 已提交
122
    CL_CHECK_FATAL(status);
123
    status = kernel.setArg(5, static_cast<const int>(n_));
Y
Yan Chunwei 已提交
124
    CL_CHECK_FATAL(status);
125
    status = kernel.setArg(6, static_cast<const int>(k_));
Y
Yan Chunwei 已提交
126 127
    CL_CHECK_FATAL(status);

128 129
    auto& context = ctx_->As<OpenCLContext>();
    CHECK(context.cl_context() != nullptr);
X
xiebaiyuan 已提交
130

131 132 133 134 135 136 137
    status = EnqueueNDRangeKernel(context,
                                  kernel,
                                  cl::NullRange,
                                  global_work_size_,
                                  cl::NullRange,
                                  nullptr,
                                  event_);
Y
Yan Chunwei 已提交
138 139 140
    CL_CHECK_FATAL(status);
  }

141 142 143 144 145 146 147 148
#ifdef LITE_WITH_PROFILE
  void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
    ch->kernel_func_name = kernel_func_name_;
    ch->cl_event =
        event_;  // `event_` defined in `kernel.h`, valid after kernel::Run
  }
#endif

Y
Yan Chunwei 已提交
149 150
 private:
  int m_, n_, k_;
151
  param_t* fc_param_{nullptr};
Y
Yan Chunwei 已提交
152
  std::string kernel_func_name_{};
153
  std::string build_options_{"-DCL_DTYPE_float "};
154
  std::string time_stamp_{GetTimeStamp()};
155 156
  bool first_epoch_for_reinit_{true};
  DDim last_x_dims_;
Y
Yan Chunwei 已提交
157
  cl::NDRange global_work_size_;
158
  cl::Kernel kernel_;
Y
Yan Chunwei 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172
};

}  // namespace opencl
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(
    fc, kOpenCL, kFloat, kNCHW, paddle::lite::kernels::opencl::FcCompute, def)
    .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .BindInput("W", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))})
    .Finalize();