提交 813123f5 编写于 作者: J Jiaying Zhao 提交者: GitHub

[Lite][OpenCL]Add conv2d_3x3_opt_kernel. test=develop (#3170)

上级 c1bfa65c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void conv2d_3x3(__private const int item_ch,
__private const int item_w,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
__private const int dilation,
__private const int in_ch,
__private const int in_w,
__private const int in_h,
__private const int out_w,
__private const int out_h) {
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
// item_id
const int item_ch_id = get_global_id(0);
const int item_w_id = get_global_id(1);
const int item_h_id = get_global_id(2);
// out_width_id_per_blk and out_batch_id
int out_batch_id = item_h_id / in_h;
int out_w_base_id = item_ch_id * out_w;
int out_w_id0 = item_w_id;
int out_w_id1 = out_w_id0 + item_w;
int out_w_id2 = out_w_id1 + item_w;
int out_w_id3 = out_w_id2 + item_w;
int out_w_id4 = out_w_id3 + item_w;
// in_width_id_per_blk and in_height_id_per_batch
int in_h_id = (item_h_id % out_h) * stride - pad;
int in_w_id0 = item_w_id * stride - pad;
int in_w_id1 = in_w_id0 + item_w * stride;
int in_w_id2 = in_w_id1 + item_w * stride;
int in_w_id3 = in_w_id2 + item_w * stride;
int in_w_id4 = in_w_id3 + item_w * stride;
#ifdef BIASE_CH
CL_DTYPE4 output[5];
output[0] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(item_ch_id, 0));
output[1] = output[0];
output[2] = output[0];
output[3] = output[0];
output[4] = output[0];
#elif defined(BIASE_ELE)
CL_DTYPE4 output[5];
output[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_w_base_id + out_w_id0, item_h_id));
if (out_w_id1 < out_w) {
output[1] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
(int2)(out_w_base_id + out_w_id1, item_h_id));
}
if (out_w_id2 < out_w) {
output[2] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
(int2)(out_w_base_id + out_w_id2, item_h_id));
}
if (out_w_id3 < out_w) {
output[3] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
(int2)(out_w_base_id + out_w_id3, item_h_id));
}
if (out_w_id4 < out_w) {
output[4] = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler,
(int2)(out_w_base_id + out_w_id4, item_h_id));
}
#else
CL_DTYPE4 output[5] = {0.0f};
#endif
CL_DTYPE4 filter[4] = {0.0f};
CL_DTYPE4 filter_trans[4] = {0.0f};
CL_DTYPE4 input[5] = {0.0f};
int filter_h_val0 = item_ch_id * 4 * 3;
int filter_h_val1 = filter_h_val0 + 3;
int filter_h_val2 = filter_h_val1 + 3;
int filter_h_val3 = filter_h_val2 + 3;
for (int ch = 0; ch < (in_ch + 3) / 4; ch++) {
int ch_surplus = (ch + 1) * 4 - in_ch > 0 ? (ch + 1) * 4 - in_ch : 0;
const int in_w_base_id = mul24(ch, in_w);
int filter_w_val = ch * 3;
for (int h = 0; h < 3; h++) {
int in_h_val = select(out_batch_id * in_h + in_h_id + h, -1,
(out_batch_id * in_h + in_h_id + h < 0 ||
out_batch_id * in_h + in_h_id + h >= in_h));
for (int w = 0; w < 3; w++) {
int in_w_val0 = select(in_w_base_id + in_w_id0 + w, -1,
(in_w_id0 + w < 0 || in_w_id0 + w >= in_w));
int in_w_val1 = select(in_w_base_id + in_w_id1 + w, -1,
(in_w_id1 + w < 0 || in_w_id1 + w >= in_w));
int in_w_val2 = select(in_w_base_id + in_w_id2 + w, -1,
(in_w_id2 + w < 0 || in_w_id2 + w >= in_w));
int in_w_val3 = select(in_w_base_id + in_w_id3 + w, -1,
(in_w_id3 + w < 0 || in_w_id3 + w >= in_w));
int in_w_val4 = select(in_w_base_id + in_w_id4 + w, -1,
(in_w_id4 + w < 0 || in_w_id4 + w >= in_w));
filter[0] = READ_IMG_TYPE(CL_DTYPE_CHAR,
filter_image, sampler,
(int2)(filter_w_val + w, filter_h_val0 + h)); // in_ch:0-3,out_ch:0
filter[1] = READ_IMG_TYPE(CL_DTYPE_CHAR,
filter_image, sampler,
(int2)(filter_w_val + w, filter_h_val1 + h)); // in_ch:0-3,out_ch:1
filter[2] = READ_IMG_TYPE(CL_DTYPE_CHAR,
filter_image, sampler,
(int2)(filter_w_val + w, filter_h_val2 + h)); // in_ch:0-3,out_ch:2
filter[3] = READ_IMG_TYPE(CL_DTYPE_CHAR,
filter_image, sampler,
(int2)(filter_w_val + w, filter_h_val3 + h)); // in_ch:0-3,out_ch:3
filter_trans[0] = (CL_DTYPE4)(filter[0].x, filter[1].x, filter[2].x,
filter[3].x); // in_ch:0,out_ch:0-3
filter_trans[1] = (CL_DTYPE4)(filter[0].y, filter[1].y, filter[2].y,
filter[3].y); // in_ch:1,out_ch:0-3
filter_trans[2] = (CL_DTYPE4)(filter[0].z, filter[1].z, filter[2].z,
filter[3].z); // in_ch:2,out_ch:0-3
filter_trans[3] = (CL_DTYPE4)(filter[0].w, filter[1].w, filter[2].w,
filter[3].w); // in_ch:3,out_ch:0-3
input[0] =
READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val0, in_h_val));
input[1] =
READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val1, in_h_val));
input[2] =
READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val2, in_h_val));
input[3] =
READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val3, in_h_val));
input[4] =
READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, (int2)(in_w_val4, in_h_val));
output[0] = mad(input[0].x, filter_trans[0], output[0]);
output[1] = mad(input[1].x, filter_trans[0], output[1]);
output[2] = mad(input[2].x, filter_trans[0], output[2]);
output[3] = mad(input[3].x, filter_trans[0], output[3]);
output[4] = mad(input[4].x, filter_trans[0], output[4]);
if (ch_surplus < 3) {
output[0] = mad(input[0].y, filter_trans[1], output[0]);
output[1] = mad(input[1].y, filter_trans[1], output[1]);
output[2] = mad(input[2].y, filter_trans[1], output[2]);
output[3] = mad(input[3].y, filter_trans[1], output[3]);
output[4] = mad(input[4].y, filter_trans[1], output[4]);
}
if (ch_surplus < 2) {
output[0] = mad(input[0].z, filter_trans[2], output[0]);
output[1] = mad(input[1].z, filter_trans[2], output[1]);
output[2] = mad(input[2].z, filter_trans[2], output[2]);
output[3] = mad(input[3].z, filter_trans[2], output[3]);
output[4] = mad(input[4].z, filter_trans[2], output[4]);
}
if (ch_surplus < 1) {
output[0] = mad(input[0].w, filter_trans[3], output[0]);
output[1] = mad(input[1].w, filter_trans[3], output[1]);
output[2] = mad(input[2].w, filter_trans[3], output[2]);
output[3] = mad(input[3].w, filter_trans[3], output[3]);
output[4] = mad(input[4].w, filter_trans[3], output[4]);
}
}
}
}
#ifdef RELU
output[0] = activation_type4(output[0]);
output[1] = activation_type4(output[1]);
output[2] = activation_type4(output[2]);
output[3] = activation_type4(output[3]);
output[4] = activation_type4(output[4]);
#endif
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id0, item_h_id),
output[0]);
if (out_w_id1 < out_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id1, item_h_id),
output[1]);
}
if (out_w_id2 < out_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id2, item_h_id),
output[2]);
}
if (out_w_id3 < out_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id3, item_h_id),
output[3]);
}
if (out_w_id4 < out_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(out_w_base_id + out_w_id4, item_h_id),
output[4]);
}
}
......@@ -143,7 +143,7 @@ void ConvImageCompute::PrepareForRun() {
} else if (kernel_h == 3 && kernel_h == 3) {
// conv2d_3x3
kernel_func_names_.push_back("conv2d_3x3");
kernel_func_paths_.push_back("image/conv2d_3x3_kernel.cl");
kernel_func_paths_.push_back("image/conv2d_3x3_opt_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
......@@ -153,7 +153,7 @@ void ConvImageCompute::PrepareForRun() {
filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d3x3;
impl_ = &ConvImageCompute::Conv2d3x3opt;
} else if (kernel_h == 5 && kernel_w == 5) {
// conv2d_5x5
kernel_func_names_.push_back("conv2d_5x5");
......@@ -554,6 +554,150 @@ void ConvImageCompute::Conv2d3x3() {
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Conv2d3x3opt() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int input_channel = input_dims[1];
int output_width = output_dims[3];
int output_height = output_dims[2];
int output_channel = output_dims[1];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
int w_blk_size = 5;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
// default_work_size[1] = w_blk;
int h_blk_size = 1;
int h_blk = (nh + h_blk_size - 1) / h_blk_size;
// default_work_size[2] = h_blk;
VLOG(4) << "============ conv2d params ============";
// VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
// << input_image_shape["height"];
// VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
// VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w_blk);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, h_blk);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, paddings[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(w_blk),
static_cast<size_t>(h_blk)};
// VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Conv2d5x5() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
......
......@@ -43,6 +43,7 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
private:
void Conv2d1x1();
void Conv2d3x3();
void Conv2d3x3opt();
void Conv2d5x5();
void Conv2d7x7();
void DepthwiseConv2d3x3s1();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册