conv_add_kernel.cpp 5.0 KB
Newer Older
L
liuruilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef FUSION_CONVADD_OP

#include "operators/kernel/conv_add_kernel.h"

namespace paddle_mobile {
namespace operators {

template <>
bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) {
L
liuruilong 已提交
24
  PADDLE_MOBILE_ENFORCE(
L
liuruilong 已提交
25
      param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
L
liuruilong 已提交
26
          param->Paddings()[0] == param->Paddings()[1],
L
liuruilong 已提交
27
      "need equal");
L
liuruilong 已提交
28 29
  param->Bias()->InitCLImage(cl_helper_.CLContext(),
                             this->cl_helper_.CLCommandQueue());
L
liuruilong 已提交
30

L
liuruilong 已提交
31 32 33 34
  int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
               static_cast<int>(param->Paddings()[1]);
  param->SetOffset(offset);

L
liuruilong 已提交
35 36 37 38
  if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) {
    param->Filter()->InitNImage(cl_helper_.CLContext(),
                                cl_helper_.CLCommandQueue());

Y
yangfei 已提交
39
    this->cl_helper_.AddKernel("conv_1x1", "conv_add_kernel.cl");
Y
yangfei 已提交
40 41 42 43
  } else if (param->Filter()->dims()[1] == 1 &&
             param->Input()->dims()[1] == param->Output()->dims()[1] &&
             param->Filter()->dims()[2] == 3) {
    param->Filter()->InitDWImage(cl_helper_.CLContext(),
L
liuruilong 已提交
44
                                 cl_helper_.CLCommandQueue());
Y
yangfei 已提交
45
    this->cl_helper_.AddKernel("depth_conv_3x3", "conv_add_kernel.cl");
L
liuruilong 已提交
46 47 48

  } else if (param->Filter()->dims()[2] == 3 &&
             param->Filter()->dims()[3] == 3) {
Z
zhaojiaying01 已提交
49
    param->Filter()->InitCLImage(cl_helper_.CLContext(),
L
liuruilong 已提交
50 51
                                 cl_helper_.CLCommandQueue());

Y
yangfei 已提交
52
    this->cl_helper_.AddKernel("conv_3x3", "conv_add_kernel.cl");
L
liuruilong 已提交
53

L
liuruilong 已提交
54 55 56 57
  } else {
    PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
  }

L
liuruilong 已提交
58 59 60 61 62
  return true;
}

template <>
void ConvAddKernel<GPU_CL, float>::Compute(
L
liuruilong 已提交
63 64 65 66 67 68 69 70
    const FusionConvAddParam<GPU_CL> &param) {
  auto kernel = this->cl_helper_.KernelAt(0);
  auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
  int c_block = default_work_size[0];
  int w = default_work_size[1];
  int nh = default_work_size[2];
  auto input = param.Input()->GetCLImage();
  auto filter = param.Filter()->GetCLImage();
Y
yangfei 已提交
71 72 73
  DLOG << "---yangfei30---";
  DLOG << *param.Filter();
  DLOG << param.Paddings();
L
liuruilong 已提交
74
  auto biase = param.Bias()->GetCLImage();
L
liuruilong 已提交
75
  auto output = param.Output()->GetCLImage();
L
liuruilong 已提交
76 77
  int stride = param.Strides()[0];
  int offset = param.Offset();
L
liuruilong 已提交
78 79 80
  int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
                    param.Input()->Converter())
                    ->GetCBlock();
L
liuruilong 已提交
81
  int dilation = param.Dilations()[0];
L
liuruilong 已提交
82 83 84 85 86

  int input_width = param.Input()->dims()[3];
  int input_height = param.Input()->dims()[2];
  int output_width = param.Output()->dims()[3];
  int output_height = param.Output()->dims()[2];
L
liuruilong 已提交
87

L
liuruilong 已提交
88 89 90
  cl_int status;

  status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
L
liuruilong 已提交
91 92
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
93
  status = clSetKernelArg(kernel, 1, sizeof(int), &w);
L
liuruilong 已提交
94 95
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
96
  status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
L
liuruilong 已提交
97 98
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
99
  status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
L
liuruilong 已提交
100 101
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
102
  status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
L
liuruilong 已提交
103 104
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
105
  status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase);
L
liuruilong 已提交
106 107
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
108
  status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &output);
L
liuruilong 已提交
109 110
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
111
  status = clSetKernelArg(kernel, 7, sizeof(int), &stride);
L
liuruilong 已提交
112 113
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
114
  status = clSetKernelArg(kernel, 8, sizeof(int), &offset);
L
liuruilong 已提交
115 116
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
117
  status = clSetKernelArg(kernel, 9, sizeof(int), &input_c);
L
liuruilong 已提交
118 119
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
120
  status = clSetKernelArg(kernel, 10, sizeof(int), &dilation);
L
liuruilong 已提交
121 122
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
123
  status = clSetKernelArg(kernel, 11, sizeof(int), &input_width);
L
liuruilong 已提交
124 125
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
126
  status = clSetKernelArg(kernel, 12, sizeof(int), &input_height);
L
liuruilong 已提交
127 128
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
129
  status = clSetKernelArg(kernel, 13, sizeof(int), &output_width);
L
liuruilong 已提交
130
  CL_CHECK_ERRORS(status);
L
liuruilong 已提交
131

L
liuruilong 已提交
132
  status = clSetKernelArg(kernel, 14, sizeof(int), &output_height);
L
liuruilong 已提交
133 134
  CL_CHECK_ERRORS(status);

L
liuruilong 已提交
135 136
  //  cl_event out_event = param.Output()->GetClEvent();
  //  cl_event wait_event = param.Input()->GetClEvent();
L
liuruilong 已提交
137

L
liuruilong 已提交
138 139 140
  status = clEnqueueNDRangeKernel(
      this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
      default_work_size.data(), NULL, 0, NULL, NULL);
L
liuruilong 已提交
141
  CL_CHECK_ERRORS(status);
L
liuruilong 已提交
142
}
L
liuruilong 已提交
143 144 145 146 147 148 149

template class ConvAddKernel<GPU_CL, float>;

}  // namespace operators
}  // namespace paddle_mobile

#endif