elementwise_mul_image_compute.cc 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
17
#include "lite/backends/opencl/cl_half.h"
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {

class ElementwiseMulImageCompute
    : public KernelLite<TARGET(kOpenCL),
                        PRECISION(kFP16),
                        DATALAYOUT(kImageDefault)> {
 public:
  using param_t = operators::ElementwiseParam;

  std::string doc() const override {
    return "ElementwiseMul using cl::Image2D(ImageDefault/RGBA), kFP32";
  }

  void PrepareForRun() override {
    ele_param_ = param_.get_mutable<param_t>();
    auto* y = ele_param_->Y;
    auto* x = ele_param_->X;
47
    auto bias_dims = y->dims();
48
    auto x_dims = x->dims();
49 50

    if (bias_dims == x_dims) {
51
      kernel_func_name_ = "elementwise_mul";
52 53 54 55 56 57 58 59 60 61
    } else {
      const int bias_dim_size = bias_dims.size();
      if (bias_dim_size == 1) {
        kernel_func_name_ = "channel_mul_d1";
      } else if (bias_dim_size == 2) {
        kernel_func_name_ = "channel_mul_d2";
      } else if (bias_dim_size == 3) {
        kernel_func_name_ = "channel_mul_d3";
      } else if (bias_dim_size == 4) {
        kernel_func_name_ = "channel_mul_d4";
62
      } else {
63 64
        LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
                   << " y_dims:" << bias_dims;
65 66
      }
    }
67

68
    VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
69 70 71
    VLOG(4) << "x_dims:" << x_dims;
    VLOG(4) << "bias_dims:" << bias_dims;
    VLOG(4) << "bias_dims.size():" << bias_dims.size();
72 73

    auto& context = ctx_->As<OpenCLContext>();
74 75 76 77 78
    kernel_ =
        context.cl_context()->CreateKernel(kernel_func_name_,
                                           "image/elementwise_mul_kernel.cl",
                                           build_options_,
                                           time_stamp_);
79 80 81 82 83 84 85 86 87 88
  }

  void Run() override {
    auto& context = ctx_->As<OpenCLContext>();
    CHECK(context.cl_context() != nullptr);

    auto* x = ele_param_->X;
    auto* y = ele_param_->Y;
    auto* out = ele_param_->Out;

89
#ifndef LITE_SHUTDOWN_LOG
90 91 92 93 94 95
    VLOG(4) << "x->target():" << TargetToStr(x->target());
    VLOG(4) << "y->target():" << TargetToStr(y->target());
    VLOG(4) << "out->target():" << TargetToStr(out->target());
    VLOG(4) << "x->dims():" << x->dims();
    VLOG(4) << "y->dims():" << y->dims();
    VLOG(4) << "out->dims():" << out->dims();
96
#endif
97 98 99 100 101 102 103 104 105 106

    paddle::lite::CLImageConverterDefault default_convertor;
    auto x_img_shape =
        default_convertor.InitImageDimInfoWith(x->dims());  // w, h
    auto x_img_width = x_img_shape[0];
    auto x_img_height = x_img_shape[1];
    auto out_img_shape =
        default_convertor.InitImageDimInfoWith(out->dims());  // w, h
    auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims());

107 108 109 110
    auto* x_img = x->data<half_t, cl::Image2D>();
    auto* y_img = y->data<half_t, cl::Image2D>();
    auto* out_img = out->mutable_data<half_t, cl::Image2D>(out_img_shape[0],
                                                           out_img_shape[1]);
111

112
#ifndef LITE_SHUTDOWN_LOG
113 114 115 116
    VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height;
    VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1];
    VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " "
            << out_img_shape[1];
117
#endif
118

119
    auto bias_dims = y->dims();
120
    auto x_dims = x->dims();
121 122
    if (bias_dims == x_dims) {
      // kernel_func_name_ = "elementwise_mul";
123
      cl_int status = kernel_->setArg(0, *x_img);
124
      CL_CHECK_FATAL(status);
125
      status = kernel_->setArg(1, *y_img);
126
      CL_CHECK_FATAL(status);
127
      status = kernel_->setArg(2, *out_img);
128
      CL_CHECK_FATAL(status);
129 130 131 132 133
    } else {
      const int bias_dim_size = bias_dims.size();
      if (bias_dim_size == 1) {
        // kernel_func_name_ = "channel_mul_d1";
        const int tensor_w = x_dims[x_dims.size() - 1];
134
        cl_int status = kernel_->setArg(0, *x_img);
135
        CL_CHECK_FATAL(status);
136
        status = kernel_->setArg(1, *y_img);
137
        CL_CHECK_FATAL(status);
138
        status = kernel_->setArg(2, *out_img);
139
        CL_CHECK_FATAL(status);
140
        status = kernel_->setArg(3, tensor_w);
141
        CL_CHECK_FATAL(status);
142 143 144
      } else if (bias_dim_size == 2) {
        // kernel_func_name_ = "channel_mul_d2";
        const int tensor_w = x_dims[x_dims.size() - 1];
145
        cl_int status = kernel_->setArg(0, *x_img);
146
        CL_CHECK_FATAL(status);
147
        status = kernel_->setArg(1, *y_img);
148
        CL_CHECK_FATAL(status);
149
        status = kernel_->setArg(2, *out_img);
150
        CL_CHECK_FATAL(status);
151
        status = kernel_->setArg(3, tensor_w);
152 153 154 155
        CL_CHECK_FATAL(status);
      } else if (bias_dim_size == 3) {
        // kernel_func_name_ = "channel_mul_d3";
        const int tensor_w = x_dims[x_dims.size() - 1];
156
        cl_int status = kernel_->setArg(0, *x_img);
157
        CL_CHECK_FATAL(status);
158
        status = kernel_->setArg(1, *y_img);
159
        CL_CHECK_FATAL(status);
160
        status = kernel_->setArg(2, *out_img);
161
        CL_CHECK_FATAL(status);
162
        status = kernel_->setArg(3, tensor_w);
163
        CL_CHECK_FATAL(status);
164 165 166
      } else if (bias_dim_size == 4) {
        // kernel_func_name_ = "channel_mul_d4";
        const int tensor_w = x_dims[x_dims.size() - 1];
167
        cl_int status = kernel_->setArg(0, *x_img);
168
        CL_CHECK_FATAL(status);
169
        status = kernel_->setArg(1, *y_img);
170
        CL_CHECK_FATAL(status);
171
        status = kernel_->setArg(2, *out_img);
172
        CL_CHECK_FATAL(status);
173
        status = kernel_->setArg(3, tensor_w);
174 175 176 177
        CL_CHECK_FATAL(status);
      } else {
        LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
                   << " y_dims:" << bias_dims;
178 179 180 181 182 183 184
      }
    }

    auto global_work_size =
        cl::NDRange{static_cast<cl::size_type>(x_img_width),
                    static_cast<cl::size_type>(x_img_height)};
    auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
185
        *(kernel_.get()),
186 187 188 189 190 191 192
        cl::NullRange,
        global_work_size,
        cl::NullRange,
        nullptr,
        event_.get());
    CL_CHECK_FATAL(status);
    context.cl_wait_list()->emplace(out_img, event_);
193
#ifndef LITE_SHUTDOWN_LOG
194
    VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
195
#endif
196 197 198 199 200 201
  }

 protected:
  param_t* ele_param_{nullptr};
  std::string kernel_func_name_{"elementwise_mul"};
  std::string build_options_{"-DCL_DTYPE_half"};
202
  std::string time_stamp_{GetTimeStamp()};
203
  std::shared_ptr<cl::Event> event_{new cl::Event};
204
  std::shared_ptr<cl::Kernel> kernel_;
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
};

}  // namespace opencl
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

namespace ocl = paddle::lite::kernels::opencl;
REGISTER_LITE_KERNEL(elementwise_mul,
                     kOpenCL,
                     kFP16,
                     kImageDefault,
                     ocl::ElementwiseMulImageCompute,
                     def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageDefault))})
    .BindInput("Y",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageDefault))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kOpenCL),
                                       PRECISION(kFP16),
                                       DATALAYOUT(kImageDefault))})
    .Finalize();