未验证 提交 cb0ce132 编写于 作者: X xiebaiyuan 提交者: GitHub

[LITE][OPENCL]develop basic image depthwiseconv,passed loop test,test… (#2788)

* [LITE][OPENCL]develop basic image depthwiseconv,passed loop test,test=develop

* [LITE][OPENCL]log to vlog(4),test=develop

* [LITE][OPENCL]fix depthwise buffer conv kernel name ,test=develop
上级 69ad4b80
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void depth_conv2d(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__private const int filter_width,
__private const int filter_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const int batch_index = out_nh / output_height;
const int out_nh_in_one_batch = out_nh % output_height;
int2 stride_xy = (int2)(stride, stride);
int2 ouput_pos_in_one_block = (int2)(out_w, out_nh_in_one_batch);
int2 in_pos_in_one_block =
ouput_pos_in_one_block * stride_xy + (int2)(offset, offset);
#ifdef BIASE_CH
CL_DTYPE4 output =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0));
#elif defined(BIASE_ELE)
CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos);
#else
CL_DTYPE4 output = 0.0f;
#endif
int2 pos_in_input_block =
(int2)(out_c * input_width, batch_index * input_height);
int2 pos_in_filter_block =
(int2)(out_c * filter_width, batch_index * filter_height);
int filter_x = pos_in_filter_block.x;
int filter_y = pos_in_filter_block.y;
int input_x_base = pos_in_input_block.x + in_pos_in_one_block.x;
int input_y_base = pos_in_input_block.y + in_pos_in_one_block.y;
int2 align = {filter_width / 2, filter_height / 2};
for (int fy = 0; fy < filter_height; ++fy) {
for (int fx = 0; fx < filter_width; ++fx) {
int x_off = fx - align.x;
int y_off = fy - align.y;
CL_DTYPE4 in = select(
READ_IMG_TYPE(CL_DTYPE_CHAR,
input,
sampler,
(int2)(input_x_base + x_off, input_y_base + y_off)),
(CL_DTYPE4)(0.0f),
(ushort4)((in_pos_in_one_block.x + x_off < 0 ||
in_pos_in_one_block.y + y_off < 0 ||
in_pos_in_one_block.x + x_off >= input_width ||
in_pos_in_one_block.y + y_off >= input_height)
<< 15));
CL_DTYPE4 f = READ_IMG_TYPE(
CL_DTYPE_CHAR, filter, sampler, (int2)(filter_x + fx, filter_y + fy));
output += in * f;
}
}
#ifdef BATCH_NORM
output = output * READ_IMG_TYPE(
CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) +
READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation_type4(output);
#endif
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output);
}
\ No newline at end of file
......@@ -49,6 +49,10 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_depthwise_conv2d_basic_opencl SRCS depthwise_conv2d_basic_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
#lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc
# DEPS conv2d_1x1_opencl cl_image_converter op_registry program context
# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>
#include <random>
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
template <typename T, int STRIDE_H = 1, int STRIDE_W = 1>
void depth_conv(const T* input_data,
const lite::DDim& input_dims,
const T* filter_data,
const lite::DDim& filter_dims,
T* output_data,
const lite::DDim& output_dims) {
int stride_h = STRIDE_H, stride_w = STRIDE_W;
int64_t batches = input_dims[0];
int64_t channels = input_dims[1];
int64_t h = input_dims[2];
int64_t w = input_dims[3];
int64_t num_output = output_dims[1];
int64_t outh = output_dims[2];
int64_t outw = output_dims[3];
int64_t filter_h = filter_dims[2];
int64_t filter_w = filter_dims[3];
const int64_t in_batch_size = channels * h * w;
const int64_t out_batch_size = num_output * outh * outw;
auto kernel_offset = std::unique_ptr<int[]>(new int[filter_h * filter_w]);
{
int p = 0;
int offset = 0;
int gap = w - filter_w;
for (int i = 0; i < filter_h; i++) {
for (int j = 0; j < filter_w; j++) {
kernel_offset[p++] = offset;
offset += 1;
}
offset += gap;
}
}
for (int b = 0; b < batches; b++) {
auto* input_batch_start = input_data + b * in_batch_size;
auto* output_batch_start = output_data + b * out_batch_size;
for (int p = 0; p < num_output; p++) {
float* output_ptr = output_batch_start + p * outh * outw;
const float* filter_ptr = filter_data + p * filter_h * filter_w;
const float* input_ptr = input_batch_start + p * h * w;
for (int i = 0; i < outh; i++) {
for (int j = 0; j < outw; j++) {
float sum = 0;
const float* input_ch_start =
input_ptr + i * stride_h * w + j * stride_w;
for (int fh = 0; fh < filter_h; ++fh) {
for (int fw = 0; fw < filter_w; ++fw) {
float val = input_ch_start[kernel_offset[fh * filter_w + fw]];
float w = filter_ptr[fh * filter_w + fw];
sum += val * w;
}
}
output_ptr[j] = sum;
}
output_ptr += outw;
}
}
}
}
int ConvOutputSize(int input_size,
int filter_size,
int dilation,
int pad_left,
int pad_right,
int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size =
(input_size + (pad_left + pad_right) - dkernel) / stride + 1;
return output_size;
}
TEST(depthwise_conv2d_basic, compute) {
// conv infos
// const int ksize = 1;
const int stride = 1;
const int pad = 0;
const int group = 1;
const int dilation = 1;
const int fc = 1;
const int batch_size = 1;
const int bias_flag = false;
const bool relu_flag = false;
// int loop_cnt = 0;
#ifdef LOOP_TEST
// for (int batch_size = 1; batch_size < 2; ++batch_size) {
for (int oc = 4; oc < 10; oc += 1) { // oc = ic
for (int fw = 3; fw < 10; fw += 2) { // fh = fw
for (int ih = fw; ih < 15; ih += 1) { // ih
for (int iw = fw; iw < 15; iw += 1) { // iw
#else
const int oc = 32;
const int ih = 112;
const int iw = 112;
const int fw = 5;
#endif
const int fb = oc;
const int ic = oc;
const int fh = fw;
const int oh = ConvOutputSize(ih, fh, dilation, pad, pad, stride);
const int ow = ConvOutputSize(iw, fw, dilation, pad, pad, stride);
VLOG(4) << "to get kernel ...";
auto kernels =
KernelRegistry::Global().Create("depthwise_conv2d_basic",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
VLOG(4) << "created depthconv2d kernel";
VLOG(4) << "prepare kernel ------";
lite::Tensor input, filter, bias, output;
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
if (bias_flag) {
param.bias = &bias;
}
param.fuse_relu = relu_flag;
std::vector<int> paddings = {pad, pad, pad, pad};
std::vector<int> dilations = {dilation, dilation};
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.strides = std::vector<int>{stride, stride};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
std::unique_ptr<KernelContext> depth_conv_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(depth_conv_context->As<OpenCLContext>()));
kernel->SetContext(std::move(depth_conv_context));
const DDim& input_dim =
lite::DDim{std::vector<int64_t>({batch_size, ic, ih, iw})};
const DDim& filter_dim =
lite::DDim{std::vector<int64_t>({fb, fc, fh, fw})};
const DDim& out_dim =
lite::DDim{std::vector<int64_t>({batch_size, oc, oh, ow})};
// element wise bias
const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})};
param.x->Resize(input_dim);
param.filter->Resize(filter_dim);
param.output->Resize(out_dim);
if (bias_flag) {
param.bias->Resize(bias_dim);
}
kernel->SetParam(param);
size_t input_image_width = iw * ((ic + 3) / 4);
size_t input_image_height = ih * batch_size;
size_t out_image_width = ow * ((oc + 3) / 4);
size_t out_image_height = oh * batch_size;
size_t bias_image_width = ow * ((oc + 3) / 4);
size_t bias_image_height = oh * batch_size;
size_t filter_image_width = fw * ((fb + 3) / 4);
size_t filter_image_height = fc * fh;
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
std::vector<float> input_v(batch_size * ic * ih * iw);
std::vector<float> filter_v(fb * fc * fh * fw);
std::vector<float> output_v(batch_size * oc * ih * iw);
std::vector<float> bias_v(oc);
VLOG(4) << "gen input and filter ...";
for (auto& i : input_v) {
i = gen(engine);
}
for (auto& f : filter_v) {
f = gen(engine);
}
VLOG(4) << "after gen input and filter ...";
VLOG(4) << "input_v.size(): " << input_v.size();
VLOG(4) << "filter_v.size(): " << filter_v.size();
VLOG(4) << "output_v.size(): " << output_v.size();
VLOG(4) << "bias_v.size(): " << bias_v.size();
VLOG(4) << "input_dim.production(): " << input_dim.production();
VLOG(4) << "filter_dim.production(): " << filter_dim.production();
VLOG(4) << "out_dim.production(): " << out_dim.production();
VLOG(4) << "bias_dim.production(): " << bias_dim.production();
VLOG(4) << "4 * input_image_height * input_image_width: "
<< 4 * input_image_height * input_image_width;
VLOG(4) << "4 * filter_image_width * filter_image_height: "
<< 4 * filter_image_width * filter_image_height;
CHECK(input_dim.production() == input_v.size());
CHECK_LE(input_dim.production(),
4 * input_image_height * input_image_width);
CHECK(filter_dim.production() == filter_v.size());
CHECK_LE(filter_dim.production(),
4 * filter_image_width * filter_image_height);
paddle::lite::CLImageConverterDefault default_convertor;
VLOG(4) << "set mapped input ...";
std::vector<float> x_image_v(input_image_width * input_image_height *
4); // 4 : RGBA
std::vector<float> filter_image_v(
filter_image_width * filter_image_height * 4); // 4 : RGBA
std::vector<float> bias_image_v(bias_image_width * bias_image_height *
4); // 4 : RGBA
std::vector<float> out_image_v(out_image_width * out_image_height *
4); // 4 : RGBA
default_convertor.NCHWToImage(
input_v.data(), x_image_v.data(), input_dim);
VLOG(4) << "set mapped filter ...";
paddle::lite::CLImageConverterNWBlock nw_convertor;
nw_convertor.NCHWToImage(
filter_v.data(), filter_image_v.data(), filter_dim);
auto* input_image2d = input.mutable_data<float, cl::Image2D>(
input_image_width, input_image_height, x_image_v.data());
auto* filter_image2d = filter.mutable_data<float, cl::Image2D>(
filter_image_width, filter_image_height, filter_image_v.data());
if (bias_flag) {
nw_convertor.NCHWToImage(
filter_v.data(), filter_image_v.data(), filter_dim);
for (int i = 0; i < bias_dim.production(); ++i) {
bias_v[i] = static_cast<int>(gen(engine));
}
CLImageConverterFolder folder_convertor;
folder_convertor.NCHWToImage(
bias_v.data(), bias_image_v.data(), bias_dim);
auto* bias_data = bias.mutable_data<float, cl::Image2D>(
bias_image_width, bias_image_height, bias_image_v.data());
}
VLOG(4) << "resize output ...";
output.Resize(out_dim);
// cpu conv basic calc
lite::Tensor out_ref;
out_ref.Resize(out_dim);
VLOG(4) << "prepare kernel ready";
VLOG(4) << "kernel launch ...";
kernel->Launch();
VLOG(4) << "mutable output ...";
auto* output_image2d = output.mutable_data<float, cl::Image2D>(
out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<float, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<float, cl::Image2D>(),
out_image_width,
out_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
DDim out_image_shape =
default_convertor.InitImageDimInfoWith(output.dims());
default_convertor.ImageToNCHW(out_image_v.data(),
output_v.data(),
out_image_shape,
output.dims());
// for (int j = 0; j < input_v.size(); j += 1) {
// VLOG(4) << "input_v input[" << j
// << "]: " << input_v.data()[j];
// std::cout<< j << " " << input_v.data()[j] << std::endl;
// }
// std::cout << std::endl;
// for (int j = 0; j < output_v.size(); j += 1) {
// VLOG(4) << "output_v output_v[" << j
// << "]:" << output_v.data()[j];
// std::cout << j << " " << output_v.data()[j] <<
// std::endl;
// }
VLOG(4) << "mutable_data out_ref_data: ";
// run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
VLOG(4) << " conv_basic beigin ..... ";
depth_conv<float, 1, 1>(input_v.data(),
input.dims(),
filter_v.data(),
filter.dims(),
out_ref_data,
out_dim);
VLOG(4) << " conv_basic end ..... ";
VLOG(4) << " input_dim: " << input_dim;
VLOG(4) << " filter_dim: " << filter_dim;
const DDim& out_image_dims = lite::DDim{
std::vector<int64_t>({static_cast<int64_t>(out_image_width),
static_cast<int64_t>(out_image_height)})};
for (int i = 0; i < out_dim.production(); i++) {
EXPECT_NEAR(output_v[i], out_ref_data[i], 1e-2);
if (abs(output_v[i] - out_ref_data[i]) > 1e-2) {
LOG(FATAL) << "error idx:" << i;
}
}
#ifdef LOOP_TEST
}
}
}
}
#else
// nothing to do.
#endif
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(
depthwise_conv2d_basic, kOpenCL, kFloat, kImageDefault, image2d);
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
......@@ -114,7 +115,7 @@ class DepthwiseConv2dCompute
}
private:
std::string kernel_func_name_{"depthwise_conv2d_3x3"};
std::string kernel_func_name_{"depthwise_conv2d"};
std::string build_options_{"-DCL_DTYPE=float"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
......@@ -341,6 +342,189 @@ class DepthwiseConv2d3x3s1ComputeFP16Image
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class DepthwiseConv2dBasicComputeFP32Image
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ConvParam;
std::string doc() const override {
return "DepthwiseConv2d basic using cl::Image2D/kImageDefault, kFloat32";
}
void PrepareForRun() override {
const auto& param = *param_.get_mutable<param_t>();
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
if (param.fuse_relu) {
build_options_ += " -DRELU";
}
if (has_bias) {
build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/depthwise_conv2d_basic_kernel.cl",
build_options_);
}
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = param.filter->data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ depthwise conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = param.bias->data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
private:
std::string kernel_func_name_{"depth_conv2d"};
std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
......@@ -382,3 +566,27 @@ REGISTER_LITE_KERNEL(
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
REGISTER_LITE_KERNEL(
depthwise_conv2d_basic,
kOpenCL,
kFloat,
kImageDefault,
paddle::lite::kernels::opencl::DepthwiseConv2dBasicComputeFP32Image,
image2d)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册