elementwise_mul_image_compute.cc 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
17
#include "lite/backends/opencl/cl_half.h"
18 19 20 21 22 23 24 25
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/logging.h"
#include "lite/utils/replace_stl/stream.h"
26 27 28 29
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {

class ElementwiseMulImageCompute
    : public KernelLite<TARGET(kOpenCL),
                        PRECISION(kFP16),
                        DATALAYOUT(kImageDefault)> {
 public:
  using param_t = operators::ElementwiseParam;

  std::string doc() const override {
    return "ElementwiseMul using cl::Image2D(ImageDefault/RGBA), kFP32";
  }

  void PrepareForRun() override {
    ele_param_ = param_.get_mutable<param_t>();
    auto* y = ele_param_->Y;
    auto* x = ele_param_->X;
51
    auto bias_dims = y->dims();
52
    auto x_dims = x->dims();
53 54

    if (bias_dims == x_dims) {
55
      kernel_func_name_ = "elementwise_mul";
56 57 58 59 60 61 62 63 64 65
    } else {
      const int bias_dim_size = bias_dims.size();
      if (bias_dim_size == 1) {
        kernel_func_name_ = "channel_mul_d1";
      } else if (bias_dim_size == 2) {
        kernel_func_name_ = "channel_mul_d2";
      } else if (bias_dim_size == 3) {
        kernel_func_name_ = "channel_mul_d3";
      } else if (bias_dim_size == 4) {
        kernel_func_name_ = "channel_mul_d4";
66
      } else {
67 68
        LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
                   << " y_dims:" << bias_dims;
69 70
      }
    }
71

72
    VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
73 74 75
    VLOG(4) << "x_dims:" << x_dims;
    VLOG(4) << "bias_dims:" << bias_dims;
    VLOG(4) << "bias_dims.size():" << bias_dims.size();
76 77

    auto& context = ctx_->As<OpenCLContext>();
78 79 80 81
    context.cl_context()->AddKernel(kernel_func_name_,
                                    "image/elementwise_mul_kernel.cl",
                                    build_options_,
                                    time_stamp_);
82 83 84 85 86 87 88 89 90 91
  }

  void Run() override {
    auto& context = ctx_->As<OpenCLContext>();
    CHECK(context.cl_context() != nullptr);

    auto* x = ele_param_->X;
    auto* y = ele_param_->Y;
    auto* out = ele_param_->Out;

92
#ifdef LITE_WITH_LOG
93 94 95 96 97 98
    VLOG(4) << "x->target():" << TargetToStr(x->target());
    VLOG(4) << "y->target():" << TargetToStr(y->target());
    VLOG(4) << "out->target():" << TargetToStr(out->target());
    VLOG(4) << "x->dims():" << x->dims();
    VLOG(4) << "y->dims():" << y->dims();
    VLOG(4) << "out->dims():" << out->dims();
99
#endif
100 101 102 103 104 105 106 107 108 109

    paddle::lite::CLImageConverterDefault default_convertor;
    auto x_img_shape =
        default_convertor.InitImageDimInfoWith(x->dims());  // w, h
    auto x_img_width = x_img_shape[0];
    auto x_img_height = x_img_shape[1];
    auto out_img_shape =
        default_convertor.InitImageDimInfoWith(out->dims());  // w, h
    auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims());

110 111 112 113
    auto* x_img = x->data<half_t, cl::Image2D>();
    auto* y_img = y->data<half_t, cl::Image2D>();
    auto* out_img = out->mutable_data<half_t, cl::Image2D>(out_img_shape[0],
                                                           out_img_shape[1]);
114

115
#ifdef LITE_WITH_LOG
116 117 118 119
    VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height;
    VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1];
    VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " "
            << out_img_shape[1];
120
#endif
121 122

    STL::stringstream kernel_key;
123
    kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
124 125
    auto kernel = context.cl_context()->GetKernel(kernel_key.str());

126
    auto bias_dims = y->dims();
127
    auto x_dims = x->dims();
128 129 130 131

    if (bias_dims == x_dims) {
      // kernel_func_name_ = "elementwise_mul";
      cl_int status = kernel.setArg(0, *x_img);
132
      CL_CHECK_FATAL(status);
133
      status = kernel.setArg(1, *y_img);
134
      CL_CHECK_FATAL(status);
135
      status = kernel.setArg(2, *out_img);
136
      CL_CHECK_FATAL(status);
137 138 139 140 141 142
    } else {
      const int bias_dim_size = bias_dims.size();
      if (bias_dim_size == 1) {
        // kernel_func_name_ = "channel_mul_d1";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
143
        CL_CHECK_FATAL(status);
144
        status = kernel.setArg(1, *y_img);
145
        CL_CHECK_FATAL(status);
146
        status = kernel.setArg(2, *out_img);
147
        CL_CHECK_FATAL(status);
148
        status = kernel.setArg(3, tensor_w);
149
        CL_CHECK_FATAL(status);
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
      } else if (bias_dim_size == 2) {
        // kernel_func_name_ = "channel_mul_d2";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(1, *y_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(2, *out_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(3, tensor_w);
        CL_CHECK_FATAL(status);
      } else if (bias_dim_size == 3) {
        // kernel_func_name_ = "channel_mul_d3";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(1, *y_img);
        CL_CHECK_FATAL(status);
        status = kernel.setArg(2, *out_img);
169
        CL_CHECK_FATAL(status);
170
        status = kernel.setArg(3, tensor_w);
171
        CL_CHECK_FATAL(status);
172 173 174 175
      } else if (bias_dim_size == 4) {
        // kernel_func_name_ = "channel_mul_d4";
        const int tensor_w = x_dims[x_dims.size() - 1];
        cl_int status = kernel.setArg(0, *x_img);
176
        CL_CHECK_FATAL(status);
177
        status = kernel.setArg(1, *y_img);
178
        CL_CHECK_FATAL(status);
179
        status = kernel.setArg(2, *out_img);
180
        CL_CHECK_FATAL(status);
181 182 183 184 185
        status = kernel.setArg(3, tensor_w);
        CL_CHECK_FATAL(status);
      } else {
        LOG(FATAL) << "Unsupported ElementwiseMul with x_dims:" << x_dims
                   << " y_dims:" << bias_dims;
186 187 188 189 190 191
      }
    }

    auto global_work_size =
        cl::NDRange{static_cast<cl::size_type>(x_img_width),
                    static_cast<cl::size_type>(x_img_height)};
X
xiebaiyuan 已提交
192

193 194 195 196 197 198 199
    auto status = EnqueueNDRangeKernel(context,
                                       kernel,
                                       cl::NullRange,
                                       global_work_size,
                                       cl::NullRange,
                                       nullptr,
                                       event_);
200
    CL_CHECK_FATAL(status);
201
#ifdef LITE_WITH_LOG
202
    VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
203
#endif
204 205 206 207 208 209
  }

 protected:
  param_t* ele_param_{nullptr};
  std::string kernel_func_name_{"elementwise_mul"};
  std::string build_options_{"-DCL_DTYPE_half"};
210
  std::string time_stamp_{GetTimeStamp()};
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
};

}  // namespace opencl
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

namespace ocl = paddle::lite::kernels::opencl;
REGISTER_LITE_KERNEL(elementwise_mul,
                     kOpenCL,
                     kFP16,
                     kImageDefault,
                     ocl::ElementwiseMulImageCompute,
                     def)
    .BindInput("X",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageDefault))})
    .BindInput("Y",
               {LiteType::GetTensorTy(TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageDefault))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kOpenCL),
                                       PRECISION(kFP16),
                                       DATALAYOUT(kImageDefault))})
    .Finalize();