“ec115140273184e62354c0055d4d0ee3054b61d0”上不存在“git@gitcode.net:paddlepaddle/PaddleDetection.git”
elementwise_add_image_compute.cc 6.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "lite/kernels/opencl/elementwise_add_image_compute.h"
Y
Yan Chunwei 已提交
16
#include <memory>
17
#include "lite/backends/opencl/cl_include.h"
Y
Yan Chunwei 已提交
18 19 20 21 22 23 24 25
#include "lite/core/op_registry.h"
#include "lite/utils/replace_stl/stream.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
void ElementwiseAddImageCompute::PrepareForRun() {
  ele_param_ = param_.get_mutable<param_t>();
  auto* x = ele_param_->X;
  auto* y = ele_param_->Y;
  auto axis = ele_param_->axis;

  if (y->dims().size() == 4) {
    kernel_func_name_ = "elementwise_add";  // y: ImageDefault
  } else if (y->dims().size() == 1) {
    if (axis == x->dims().size() - 1) {
      kernel_func_name_ = "width_add";  // y: ImageDefault
    } else if (axis == x->dims().size() - 3) {
      kernel_func_name_ = "channel_add";  // y: ImageFolder
    } else {
      LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis
                 << ", x->dims().size():" << x->dims().size()
                 << ", y->dims.size():" << y->dims().size();
    }
  } else {
    LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis
               << ", x->dims().size():" << x->dims().size()
               << ", y->dims.size():" << y->dims().size();
  }
49
  VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

  auto& context = ctx_->As<OpenCLContext>();
  context.cl_context()->AddKernel(
      kernel_func_name_, "image/elementwise_add_kernel.cl", build_options_);
}

void ElementwiseAddImageCompute::Run() {
  auto& context = ctx_->As<OpenCLContext>();
  CHECK(context.cl_context() != nullptr);

  auto* x = ele_param_->X;
  auto* y = ele_param_->Y;
  auto* out = ele_param_->Out;
  auto axis = ele_param_->axis;

65
#ifndef LITE_SHUTDOWN_LOG
66 67 68 69 70 71 72
  VLOG(4) << "x->target():" << TargetToStr(x->target());
  VLOG(4) << "y->target():" << TargetToStr(y->target());
  VLOG(4) << "out->target():" << TargetToStr(out->target());
  VLOG(4) << "x->dims():" << x->dims();
  VLOG(4) << "y->dims():" << y->dims();
  VLOG(4) << "out->dims():" << out->dims();
  VLOG(4) << "axis:" << axis;
73
#endif
74 75 76 77 78 79 80 81 82

  paddle::lite::CLImageConverterDefault default_convertor;
  auto x_img_shape = default_convertor.InitImageDimInfoWith(x->dims());  // w, h
  auto x_img_width = x_img_shape[0];
  auto x_img_height = x_img_shape[1];
  auto out_img_shape =
      default_convertor.InitImageDimInfoWith(out->dims());  // w, h
  auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims());

83 84 85 86
  auto* x_img = x->data<half_t, cl::Image2D>();
  auto* y_img = y->data<half_t, cl::Image2D>();
  auto* out_img = out->mutable_data<half_t, cl::Image2D>(out_img_shape[0],
                                                         out_img_shape[1]);
87

88
#ifndef LITE_SHUTDOWN_LOG
89 90 91 92
  VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height;
  VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1];
  VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " "
          << out_img_shape[1];
93
#endif
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

  STL::stringstream kernel_key;
  kernel_key << kernel_func_name_ << build_options_;
  auto kernel = context.cl_context()->GetKernel(kernel_key.str());

  int arg_idx = 0;
  auto y_dims = y->dims();
  if (y_dims.size() == 4) {
    cl_int status = kernel.setArg(arg_idx, *x_img);
    CL_CHECK_FATAL(status);
    status = kernel.setArg(++arg_idx, *y_img);
    CL_CHECK_FATAL(status);
    status = kernel.setArg(++arg_idx, *out_img);
    CL_CHECK_FATAL(status);
  } else if (y_dims.size() == 1) {
    if (axis == x->dims().size() - 1 || axis == x->dims().size() - 3) {
      int tensor_w = x->dims()[x->dims().size() - 1];
111
#ifndef LITE_SHUTDOWN_LOG
112
      VLOG(4) << "tensor_w:" << tensor_w;
113
#endif
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
      cl_int status = kernel.setArg(arg_idx, *x_img);
      CL_CHECK_FATAL(status);
      status = kernel.setArg(++arg_idx, *y_img);
      CL_CHECK_FATAL(status);
      status = kernel.setArg(++arg_idx, *out_img);
      CL_CHECK_FATAL(status);
      status = kernel.setArg(++arg_idx, static_cast<const int>(tensor_w));
      CL_CHECK_FATAL(status);
    } else {
      LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis
                 << ", x->dims().size():" << x->dims().size()
                 << ", y->dims.size():" << y->dims().size();
    }
  } else {
    LOG(FATAL) << "ElementwiseAddImage doesn't support axis:" << axis
               << ", x->dims().size():" << x->dims().size()
               << ", y->dims.size():" << y->dims().size();
  }

  auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width),
                                      static_cast<cl::size_type>(x_img_height)};
135
#ifndef LITE_SHUTDOWN_LOG
136
  VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
137
#endif
138 139 140 141 142 143 144 145 146 147
  auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
      kernel,
      cl::NullRange,
      global_work_size,
      cl::NullRange,
      nullptr,
      event_.get());
  CL_CHECK_FATAL(status);
  context.cl_wait_list()->emplace(out_img, event_);
}
Y
Yan Chunwei 已提交
148 149 150 151 152 153 154

}  // namespace opencl
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

namespace ocl = paddle::lite::kernels::opencl;
155

156
// TODO(ysh329): May need fix.
157 158 159 160 161 162 163
// "Y" may from constant value like conv bias (kARM, need do cl_image_converter
// on CPU);
//     may from anther branch like "X" (kOpenCL, nothing to do).
// Consider 2 situations have different actions when pass running(pick kernel),
//     set target of "Y" as kOpenCL temporarily.
REGISTER_LITE_KERNEL(elementwise_add,
                     kOpenCL,
164
                     kFP16,
165 166 167 168 169
                     kImageDefault,
                     ocl::ElementwiseAddImageCompute,
                     def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
170
                                      PRECISION(kFP16),
171 172 173
                                      DATALAYOUT(kImageDefault))})
    .BindInput("Y",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
174
                                      PRECISION(kFP16),
175 176 177
                                      DATALAYOUT(kImageDefault))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kOpenCL),
178
                                       PRECISION(kFP16),
179
                                       DATALAYOUT(kImageDefault))})
Y
Yan Chunwei 已提交
180
    .Finalize();