fetch_kernel.cpp 3.5 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "operators/kernel/fetch_kernel.h"
Z
zhaojiaying01 已提交
16
#include "framework/cl/cl_tensor.h"
Z
zhaojiaying01 已提交
17 18 19 20 21 22

namespace paddle_mobile {
namespace operators {

template <>
bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
Y
yangfei 已提交
23
  this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
Z
zhaojiaying01 已提交
24 25 26 27
  return true;
}

template <>
28
void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
Z
zhaojiaying01 已提交
29 30 31
  auto kernel = this->cl_helper_.KernelAt(0);
  auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.InputX());

H
hjchen2 已提交
32
  const int col = param.Col();
Z
zhaojiaying01 已提交
33
  auto input = param.InputX()->GetCLImage();
H
hjchen2 已提交
34
  auto *out = &param.Out()->at(col);
Y
yangfei 已提交
35
  out->Resize(param.InputX()->dims());
36
  out->mutable_data<float>();
37 38 39 40 41

  DLOG << "fetch kernel out dims = " << out->dims();
  DLOG << "fetch kernel out memory size = " << out->memory_size();

  auto dim = param.InputX()->dims();
Z
zhaojiaying01 已提交
42 43 44 45 46 47
  size_t new_dims[] = {1, 1, 1, 1};

  for (int j = 0; j < dim.size(); ++j) {
    new_dims[4 - dim.size() + j] = dim[j];
  }

48
  size_t in_ch, in_height, in_width;
Z
zhaojiaying01 已提交
49

50
  in_ch = new_dims[1];
Z
zhaojiaying01 已提交
51
  in_height = new_dims[2];
Y
yangfei 已提交
52
  in_width = new_dims[3];
53 54 55
  int size_ch = in_height * in_width;
  int size_block = size_ch * 4;
  int size_batch = size_ch * in_ch;
Z
zhaojiaying01 已提交
56

57 58
  framework::CLTensor out_cl_tensor(this->cl_helper_.CLContext(),
                                    this->cl_helper_.CLCommandQueue());
Z
zhaojiaying01 已提交
59 60 61
  out_cl_tensor.Resize(out->dims());
  cl_mem outBuffer = out_cl_tensor.mutable_data<float>();

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  cl_int status;
  status = clSetKernelArg(kernel, 0, sizeof(int), &in_height);
  CL_CHECK_ERRORS(status);
  status = clSetKernelArg(kernel, 1, sizeof(int), &in_width);
  CL_CHECK_ERRORS(status);
  status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &input);
  CL_CHECK_ERRORS(status);
  status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &outBuffer);
  CL_CHECK_ERRORS(status);
  status = clSetKernelArg(kernel, 4, sizeof(int), &size_ch);
  CL_CHECK_ERRORS(status);
  status = clSetKernelArg(kernel, 5, sizeof(int), &size_block);
  CL_CHECK_ERRORS(status);
  status = clSetKernelArg(kernel, 6, sizeof(int), &size_batch);
  CL_CHECK_ERRORS(status);
  status = clSetKernelArg(kernel, 7, sizeof(int), &in_ch);
  CL_CHECK_ERRORS(status);
Z
zhaojiaying01 已提交
79

L
liuruilong 已提交
80
  //  cl_event wait_event = param.InpdutX()->GetClEvent();
81 82 83 84
  status =
      clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
                             default_work_size.data(), NULL, 0, NULL, NULL);
  CL_CHECK_ERRORS(status);
L
liuruilong 已提交
85

L
liuruilong 已提交
86
  clFinish(this->cl_helper_.CLCommandQueue());
Z
zhaojiaying01 已提交
87

88 89
  DLOG << "fetch kernel out dims = " << out->dims();
  DLOG << "fetch kernel out memory size = " << out->memory_size();
L
liuruilong 已提交
90

91 92 93 94 95
  DLOG << "fetch kernel out_cl_tensor dims = " << out_cl_tensor.dims();
  DLOG << "fetch kernel out_cl_tensor memery size = "
       << out_cl_tensor.memory_size();
  memcpy(out->data<float>(), out_cl_tensor.Data<float>(),
         sizeof(float) * out->numel());
96
}
Z
zhaojiaying01 已提交
97 98 99 100 101

template class FetchKernel<GPU_CL, float>;

}  // namespace operators
}  // namespace paddle_mobile