elementwise_mul_image_compute.cc 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
17
#include "lite/backends/opencl/cl_half.h"
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {

class ElementwiseMulImageCompute
    : public KernelLite<TARGET(kOpenCL),
                        PRECISION(kFP16),
                        DATALAYOUT(kImageDefault)> {
 public:
  using param_t = operators::ElementwiseParam;

  std::string doc() const override {
    return "ElementwiseMul using cl::Image2D(ImageDefault/RGBA), kFP32";
  }

  void PrepareForRun() override {
    ele_param_ = param_.get_mutable<param_t>();
    auto* y = ele_param_->Y;
    auto* x = ele_param_->X;
47
    auto bias_dims = y->dims();
48
    auto x_dims = x->dims();
49 50

    if (bias_dims == x_dims) {
51
      kernel_func_name_ = "elementwise_mul";
52 53 54 55 56 57 58 59 60 61
    } else {
      const int bias_dim_size = bias_dims.size();
      if (bias_dim_size == 1) {
        kernel_func_name_ = "channel_mul_d1";
      } else if (bias_dim_size == 2) {
        kernel_func_name_ = "channel_mul_d2";
      } else if (bias_dim_size == 3) {
        kernel_func_name_ = "channel_mul_d3";
      } else if (bias_dim_size == 4) {
        kernel_func_name_ = "channel_mul_d4";
62
      } else {
63 64
        LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
                   << " y_dims:" << bias_dims;
65 66
      }
    }
67

68
    VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
69 70 71
    VLOG(4) << "x_dims:" << x_dims;
    VLOG(4) << "bias_dims:" << bias_dims;
    VLOG(4) << "bias_dims.size():" << bias_dims.size();
72 73

    auto& context = ctx_->As<OpenCLContext>();
74 75 76 77
    context.cl_context()->AddKernel(kernel_func_name_,
                                    "image/elementwise_mul_kernel.cl",
                                    build_options_,
                                    time_stamp_);
78 79 80 81 82 83 84 85 86 87
  }

  void Run() override {
    auto& context = ctx_->As<OpenCLContext>();
    CHECK(context.cl_context() != nullptr);

    auto* x = ele_param_->X;
    auto* y = ele_param_->Y;
    auto* out = ele_param_->Out;

88
#ifdef LITE_WITH_LOG
89 90 91 92 93 94
    VLOG(4) << "x->target():" << TargetToStr(x->target());
    VLOG(4) << "y->target():" << TargetToStr(y->target());
    VLOG(4) << "out->target():" << TargetToStr(out->target());
    VLOG(4) << "x->dims():" << x->dims();
    VLOG(4) << "y->dims():" << y->dims();
    VLOG(4) << "out->dims():" << out->dims();
95
#endif
96 97 98 99 100 101 102 103 104 105

    paddle::lite::CLImageConverterDefault default_convertor;
    auto x_img_shape =
        default_convertor.InitImageDimInfoWith(x->dims());  // w, h
    auto x_img_width = x_img_shape[0];
    auto x_img_height = x_img_shape[1];
    auto out_img_shape =
        default_convertor.InitImageDimInfoWith(out->dims());  // w, h
    auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims());

106 107 108 109
    auto* x_img = x->data<half_t, cl::Image2D>();
    auto* y_img = y->data<half_t, cl::Image2D>();
    auto* out_img = out->mutable_data<half_t, cl::Image2D>(out_img_shape[0],
                                                           out_img_shape[1]);
110

111
#ifdef LITE_WITH_LOG
112 113 114 115
    VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height;
    VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1];
    VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " "
            << out_img_shape[1];
116
#endif
117 118

    STL::stringstream kernel_key;
119
    kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
120 121
    auto kernel = context.cl_context()->GetKernel(kernel_key.str());

122
    auto bias_dims = y->dims();
123
    auto x_dims = x->dims();
124 125 126 127

    if (bias_dims == x_dims) {
      // kernel_func_name_ = "elementwise_mul";
      cl_int status = kernel.setArg(0, *x_img);
128
      CL_CHECK_FATAL(status);
129
      status = kernel.setArg(1, *y_img);
130
      CL_CHECK_FATAL(status);
131
      status = kernel.setArg(2, *out_img);
132
      CL_CHECK_FATAL(status);
133 134 135 136 137 138
    } else {
      const int bias_dim_size = bias_dims.size();
      if (bias_dim_size == 1) {
        // kernel_func_name_ = "channel_mul_d1";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
139
        CL_CHECK_FATAL(status);
140
        status = kernel.setArg(1, *y_img);
141
        CL_CHECK_FATAL(status);
142
        status = kernel.setArg(2, *out_img);
143
        CL_CHECK_FATAL(status);
144
        status = kernel.setArg(3, tensor_w);
145
        CL_CHECK_FATAL(status);
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
      } else if (bias_dim_size == 2) {
        // kernel_func_name_ = "channel_mul_d2";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(1, *y_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(2, *out_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(3, tensor_w);
        CL_CHECK_FATAL(status);
      } else if (bias_dim_size == 3) {
        // kernel_func_name_ = "channel_mul_d3";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(1, *y_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(2, *out_img);
165
        CL_CHECK_FATAL(status);
166
        status = kernel.setArg(3, tensor_w);
167
        CL_CHECK_FATAL(status);
168 169 170 171
      } else if (bias_dim_size == 4) {
        // kernel_func_name_ = "channel_mul_d4";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
172
        CL_CHECK_FATAL(status);
173
        status = kernel.setArg(1, *y_img);
174
        CL_CHECK_FATAL(status);
175
        status = kernel.setArg(2, *out_img);
176
        CL_CHECK_FATAL(status);
177 178 179 180 181
        status = kernel.setArg(3, tensor_w);
        CL_CHECK_FATAL(status);
      } else {
        LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
                   << " y_dims:" << bias_dims;
182 183 184 185 186 187
      }
    }

    auto global_work_size =
        cl::NDRange{static_cast<cl::size_type>(x_img_width),
                    static_cast<cl::size_type>(x_img_height)};
X
xiebaiyuan 已提交
188

189 190 191 192 193 194
    auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
        kernel,
        cl::NullRange,
        global_work_size,
        cl::NullRange,
        nullptr,
X
xiebaiyuan 已提交
195
        nullptr);
196
    CL_CHECK_FATAL(status);
197
#ifdef LITE_WITH_LOG
198
    VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
199
#endif
200 201 202 203 204 205
  }

 protected:
  param_t* ele_param_{nullptr};
  std::string kernel_func_name_{"elementwise_mul"};
  std::string build_options_{"-DCL_DTYPE_half"};
206
  std::string time_stamp_{GetTimeStamp()};
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
};

}  // namespace opencl
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

namespace ocl = paddle::lite::kernels::opencl;
REGISTER_LITE_KERNEL(elementwise_mul,
                     kOpenCL,
                     kFP16,
                     kImageDefault,
                     ocl::ElementwiseMulImageCompute,
                     def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageDefault))})
    .BindInput("Y",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageDefault))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kOpenCL),
                                       PRECISION(kFP16),
                                       DATALAYOUT(kImageDefault))})
    .Finalize();