提交 7bba07bf 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] Support video-sr feature using OpenCL FP16 Image (#3049)

* [LITE][OPENCL] Support video-sr feature using OpenCL FP16 Image. test=develop

* optimize image2d_to_buffer_with_post255. test=develop

* add def debug in cl kernel. test=develop

* remove conv image code in conv buffer. test=develop
上级 05b5ef29
......@@ -38,11 +38,13 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
std::vector<std::string> passes{};
auto use_layout_preprocess_pass =
config.model_dir().find("OPENCL_PRE_PRECESS");
if (use_layout_preprocess_pass != std::string::npos) {
VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass;
if (places[0].target == TARGET(kOpenCL) &&
use_layout_preprocess_pass != std::string::npos) {
passes = {"type_layout_cast_preprocess_pass"};
VLOG(1) << "add pass:" << passes[0];
}
raw_predictor_.Build(config, places, passes);
mode_ = config.power_mode();
threads_ = config.threads();
......
......@@ -38,6 +38,7 @@ void Tensor::Resize(const shape_t &shape) {
tensor(raw_tensor_)->Resize(shape);
}
// Tensor::data
template <>
const float *Tensor::data() const {
return ctensor(raw_tensor_)->data<float>();
......@@ -47,15 +48,19 @@ const int8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int8_t>();
}
template <>
const uint8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<uint8_t>();
}
template <>
const int64_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int64_t>();
}
template <>
const int32_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int32_t>();
}
// Tensor::mutable_data
template <>
int *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int>(type);
......@@ -69,6 +74,10 @@ int8_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int8_t>(type);
}
template <>
uint8_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<uint8_t>(type);
}
template <>
int64_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int64_t>(type);
}
......@@ -116,18 +125,22 @@ void Tensor::CopyToCpu(T *data) const {
template void Tensor::CopyFromCpu<int, TargetType::kHost>(const int *);
template void Tensor::CopyFromCpu<float, TargetType::kHost>(const float *);
template void Tensor::CopyFromCpu<int8_t, TargetType::kHost>(const int8_t *);
template void Tensor::CopyFromCpu<uint8_t, TargetType::kHost>(const uint8_t *);
template void Tensor::CopyFromCpu<int, TargetType::kARM>(const int *);
template void Tensor::CopyFromCpu<float, TargetType::kARM>(const float *);
template void Tensor::CopyFromCpu<int8_t, TargetType::kARM>(const int8_t *);
template void Tensor::CopyFromCpu<uint8_t, TargetType::kARM>(const uint8_t *);
template void Tensor::CopyFromCpu<int, TargetType::kCUDA>(const int *);
template void Tensor::CopyFromCpu<int64_t, TargetType::kCUDA>(const int64_t *);
template void Tensor::CopyFromCpu<float, TargetType::kCUDA>(const float *);
template void Tensor::CopyFromCpu<int8_t, TargetType::kCUDA>(const int8_t *);
template void Tensor::CopyToCpu(int8_t *) const;
template void Tensor::CopyToCpu(float *) const;
template void Tensor::CopyToCpu(int *) const;
template void Tensor::CopyToCpu(int8_t *) const;
template void Tensor::CopyToCpu(uint8_t *) const;
shape_t Tensor::shape() const {
return ctensor(raw_tensor_)->dims().Vectorize();
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include <cl_common.h>
// #define DEBUG
////////////////////////////////////////////////////////
// buffer -> image2d
////////////////////////////////////////////////////////
__kernel void buffer_to_image2d(__global CL_DTYPE *in,
__write_only image2d_t output_image,
__private const int out_H,
......@@ -80,8 +82,9 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in,
WRITE_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, output_image, output_pos, output);
}
////////////////////////////////////////////////////////
// image2d -> buffer
////////////////////////////////////////////////////////
__kernel void image2d_to_buffer(__read_only image2d_t input,
__private const int in_width,
__private const int in_height,
......@@ -125,8 +128,10 @@ __kernel void image2d_to_buffer(__read_only image2d_t input,
}
#if 0
#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile
////////////////////////////////////////////////////////
// buffer -> image2d_nw
////////////////////////////////////////////////////////
__kernel void buffer_to_image2d_nw(__global CL_DTYPE* in,
__write_only image2d_t output_image,
__private const int out_H,
......@@ -178,7 +183,7 @@ __kernel void buffer_to_image2d_nw(__global CL_DTYPE* in,
#endif
#if 0
#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile
// image2d -> buffer
__kernel void image2d_to_buffer_2d(__private const int in_height,
__private const int in_width,
......@@ -200,7 +205,9 @@ __kernel void image2d_to_buffer_2d(__private const int in_height,
}
#endif
////////////////////////////////////////////////////////
// buffer -> image2d (divide by 255 to normalize)
////////////////////////////////////////////////////////
__kernel void buffer_to_image2d_with_pre255(__global uchar *in,
__write_only image2d_t output_image,
__private const int out_H,
......@@ -248,7 +255,10 @@ __kernel void buffer_to_image2d_with_pre255(__global uchar *in,
WRITE_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, output_image, output_pos, output);
}
////////////////////////////////////////////////////////
// image2d -> buffer (multiply by 255 to de-normalize)
////////////////////////////////////////////////////////
__kernel void image2d_to_buffer_with_post255(__read_only image2d_t input,
__private const int in_width,
__private const int in_height,
......@@ -267,17 +277,22 @@ __kernel void image2d_to_buffer_with_post255(__read_only image2d_t input,
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const int pos_x = mad24(in_c, in_width, in_w);
CL_COMPUTE_DTYPE4 in = READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh));
CL_COMPUTE_DTYPE4 in = READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh)) * 255;
#ifdef DEBUG
printf("in_c:%d, in_w:%d, in_nh:%d ===> in(%d,%d): %.2f %.2f %.2f %.2f\n",
in_c, in_w, in_nh, pos_x, in_nh, in.x, in.y, in.z, in.w);
#endif
const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w;
out[index] = convert_uchar_sat(in.x * 255);
out[index] = convert_uchar_sat(in.x);
if(C - 4 * in_c>=2){
out[index + size_ch] = convert_uchar_sat(in.y * 255);
out[index + size_ch] = convert_uchar_sat(in.y);
}
if(C - 4 * in_c>=3){
out[index + size_ch * 2] = convert_uchar_sat(in.z * 255);
out[index + size_ch * 2] = convert_uchar_sat(in.z);
}
if(C - 4 * in_c>=4){
out[index + size_ch * 3] = convert_uchar_sat(in.w * 255);
out[index + size_ch * 3] = convert_uchar_sat(in.w);
}
}
......@@ -217,7 +217,9 @@ void OpenCLTypeLayoutTransformPass::Apply(
for (auto& node : nodes) {
VLOG(4) << "!node->IsStmt():" << !node->IsStmt();
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
if (node->AsStmt().op_type() == "layout") {
VLOG(1) << "node->AsStmt().op_type():" << node->AsStmt().op_type();
if (node->AsStmt().op_type() == "layout" ||
node->AsStmt().op_type() == "io_copy") {
auto new_op = node->AsStmt().mutable_op_info();
int process_type = 1;
new_op->SetAttr("process_type", process_type);
......
......@@ -297,1135 +297,6 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel,
void ConvCompute::Run() { (this->*impl_)(); }
/* image kernel*/
void ConvImageCompute::PrepareForRun() {
const auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
float* filter_cpu = param.filter->mutable_data<float>();
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
int bs = x_dims[0];
int c_in = x_dims[1];
int h_out = output_dims[2];
int w_out = output_dims[3];
int kernel_h = filter_dims[2]; // oihw
int kernel_w = filter_dims[3];
auto paddings = *param.paddings;
auto dilations = *param.dilations;
int stride_h = param.strides[0];
int stride_w = param.strides[1];
int pad_h = paddings[0];
int pad_w = paddings[2];
int groups = param.groups;
bool relu_fused = param.fuse_relu;
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool zero_pad = (pad_h == 0) && (pad_w == 0);
bool pad_equal =
((paddings[0] == paddings[1]) && (paddings[1] == paddings[2]) &&
(paddings[2] == paddings[3]));
bool stride_equal = stride_h == stride_w;
bool dilation_equal = dilations[0] == dilations[1];
CHECK(pad_equal && stride_equal && dilation_equal);
VLOG(3) << "Is relu fused? / " << (relu_fused ? "Yes" : "No");
VLOG(3) << "groups:" << groups << " stride_h:" << stride_h
<< " stride_w:" << stride_w << " pad_h:" << pad_h
<< " pad_w:" << pad_w << " kernel_h:" << kernel_h
<< " kernel_h:" << kernel_h;
VLOG(3) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2]
<< " " << x_dims[3];
VLOG(3) << "output_dims:" << output_dims[0] << " " << output_dims[1] << " "
<< output_dims[2] << " " << output_dims[3];
VLOG(3) << "filter_dims:" << filter_dims[0] << " " << filter_dims[1] << " "
<< filter_dims[2] << " " << filter_dims[3];
if (kernel_h == 1 && kernel_w == 1) {
// conv2d_1x1
if (param.x->dims()[1] % 4 == 0) {
kernel_func_names_.push_back("conv2d_1x1_simple");
} else {
kernel_func_names_.push_back("conv2d_1x1");
}
kernel_func_paths_.push_back("image/conv2d_1x1_kernel.cl");
CLImageConverterNWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d1x1;
#if 1 // TODO(ysh329): enable general dwconv
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1]) {
#else // TODO(ysh329): remove dwconv3x3s1 and dwconv3x3 temporarily, need fix
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] &&
kernel_h == 3 && kernel_w == 3 && groups > 1) {
// depth_conv2d_3x3s1, depth_conv2d_3x3
if (stride_h == 1 && dilations[0] == 1) {
kernel_func_names_.push_back("depth_conv2d_3x3s1");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3s1;
} else {
kernel_func_names_.push_back("depth_conv2d_3x3");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3;
}
kernel_func_paths_.push_back("image/depthwise_conv2d_kernel.cl");
CLImageConverterNWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] &&
kernel_h != 3) {
#endif
// depth_conv2d
kernel_func_names_.push_back("depth_conv2d");
kernel_func_paths_.push_back("image/depthwise_conv2d_basic_kernel.cl");
CLImageConverterNWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::DepthwiseConv2d;
} else if (kernel_h == 3 && kernel_h == 3) {
// conv2d_3x3
kernel_func_names_.push_back("conv2d_3x3");
kernel_func_paths_.push_back("image/conv2d_3x3_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d3x3;
} else if (kernel_h == 5 && kernel_w == 5) {
// conv2d_5x5
kernel_func_names_.push_back("conv2d_5x5");
kernel_func_paths_.push_back("image/conv2d_5x5_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d5x5;
} else if (kernel_h == 7 && kernel_w == 7) {
// conv2d_7x7
kernel_func_names_.push_back("conv2d_7x7");
kernel_func_paths_.push_back("image/conv2d_7x7_kernel.cl");
CLImageConverterFolder converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
this->filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d7x7;
} else {
LOG(FATAL) << "conv image compute not support this condition yet! ";
}
VLOG(1) << "kernel_func_names_[0]:" << kernel_func_names_[0]
<< " kernel_func_paths_[0]:" << kernel_func_paths_[0];
std::string build_options_single(" -DCL_DTYPE_float");
// relu options
if (relu_fused) {
build_options_single += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_single += " -DRELU6";
} else {
// do nothing
}
// bias options
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
if (has_bias) {
build_options_single +=
is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
// convert cpu buffer bias --> gpu image
CLImageConverterFolder bias_converter;
const DDim& bias_image_dims =
bias_converter.InitImageDimInfoWith(param.bias->dims());
std::vector<float> bias_image_v(bias_image_dims[0] * bias_image_dims[1] *
4);
float* bias_cpu_data = param.bias->mutable_data<float>();
bias_converter.NCHWToImage(
bias_cpu_data, bias_image_v.data(), param.bias->dims());
this->bias_gpu_image_.mutable_data<float, cl::Image2D>(
bias_image_dims[0], bias_image_dims[1], bias_image_v.data());
// convert cpu buffer bias --> gpu image --- end ----
}
build_options_.push_back(build_options_single);
for (size_t i = 0; i < kernel_func_names_.size(); i++) {
context.cl_context()->AddKernel(
kernel_func_names_[i], kernel_func_paths_[i], build_options_[i]);
}
}
void ConvImageCompute::Conv2d1x1() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ conv2d_1x1 params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
std::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int maped_w = maptofactor(w, 4);
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "maped_w: " << maped_w;
VLOG(4) << "hasbias: " << has_bias;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, maped_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(maped_w),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Conv2d3x3() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int input_channel = input_dims[1];
int output_width = output_dims[3];
int output_height = output_dims[2];
int output_channel = output_dims[1];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
int filter_channel = filter_dims[1];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
// re-calc group
int new_groups{param.groups};
if (filter_dims[0] == output_dims[1] && filter_dims[1] == input_dims[1]) {
new_groups = 1;
} else if (!(filter_dims[0] == input_dims[1] && filter_dims[1] == 1)) {
new_groups = input_channel / filter_channel;
}
/* TODO(ysh329): mobile has no case below
else {
LOG(FATAL) << "Not support conv3x3 case with"
<< " input_dims:" << input_dims << " output_dims:" <<
output_dims
<< " filter_dims:" << filter_dims;
}
*/
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "param.groups(groups):" << param.groups;
VLOG(4) << "new_groups:" << new_groups;
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_channel);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, new_groups);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Conv2d5x5() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Conv2d7x7() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "input_dims: " << input_dims;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::DepthwiseConv2d3x3s1() {
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<float, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<float, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_.data<float, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<float, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
int w_blk_size = 2;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
auto global_work_size = cl::NDRange(c_block, w_blk, nh);
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w_blk));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[1]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_img, event_);
}
void ConvImageCompute::DepthwiseConv2d3x3() {
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
int offset = filter_dims[2] / 2 - paddings[0];
int input_c_block = (x_dims[1] + 3) / 4;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<float, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<float, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_.data<float, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<float, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
auto global_work_size = cl::NDRange(c_block, w, nh);
VLOG(4) << "setArg";
VLOG(4) << "c_block = " << c_block;
VLOG(4) << "w = " << w;
VLOG(4) << "nh = " << nh;
VLOG(4) << "strides = " << strides[0];
VLOG(4) << "offset = " << offset;
VLOG(4) << "dilations = " << dilations[0];
VLOG(4) << "input_c_block = " << input_c_block;
VLOG(4) << "x_dims[3] = " << x_dims[3];
VLOG(4) << "x_dims[2] = " << x_dims[2];
VLOG(4) << "output_dims[3] = " << output_dims[3];
VLOG(4) << "output_dims[2] = " << output_dims[2];
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(offset));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(input_c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_img, event_);
}
void ConvImageCompute::DepthwiseConv2d() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ depthwise conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Run() { (this->*impl_)(); }
} // namespace opencl
} // namespace kernels
} // namespace lite
......
......@@ -42,16 +42,11 @@ class IoCopyHostToOpenCLCompute
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kARM));
auto mem_size = param.x->memory_size();
VLOG(4) << "copy size " << mem_size;
VLOG(4) << "param.x->dims().size():" << param.x->dims().size();
VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " "
<< param.x->dims()[1] << " " << param.x->dims()[2] << " "
<< param.x->dims()[3];
VLOG(4) << "param.y->dims().size():" << param.y->dims().size();
VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " "
<< param.y->dims()[1] << " " << param.y->dims()[2] << " "
<< param.y->dims()[3];
VLOG(2) << "param.x->memory_size():" << mem_size;
VLOG(2) << "param.x->dims().size():" << param.x->dims().size();
VLOG(2) << "param.x->dims():" << param.x->dims();
VLOG(2) << "param.y->dims().size():" << param.y->dims().size();
VLOG(2) << "param.y->dims():" << param.y->dims();
auto* data = param.y->mutable_data(TARGET(kOpenCL), mem_size);
CopyFromHostSync(data, param.x->raw_data(), mem_size);
}
......@@ -89,23 +84,27 @@ class IoCopykOpenCLToHostCompute
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kOpenCL));
auto mem_size = param.x->memory_size();
VLOG(4) << "copy size " << mem_size;
VLOG(4) << "param.x->dims().size():" << param.x->dims().size();
VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " "
<< param.x->dims()[1] << " " << param.x->dims()[2] << " "
<< param.x->dims()[3];
VLOG(4) << "param.y->dims().size():" << param.y->dims().size();
VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " "
<< param.y->dims()[1] << " " << param.y->dims()[2] << " "
<< param.y->dims()[3];
VLOG(2) << "copy size " << mem_size;
VLOG(2) << "param.x->dims().size():" << param.x->dims().size();
VLOG(2) << "param.x->dims():" << param.x->dims();
VLOG(2) << "param.y->dims().size():" << param.y->dims().size();
VLOG(2) << "param.y->dims():" << param.y->dims();
VLOG(2) << "param.process_type:" << param.process_type;
auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
const cl::Buffer* x_ptr;
if (param.process_type == 1) {
x_ptr = param.x->data<uint8_t, cl::Buffer>();
} else {
x_ptr = param.x->data<float, cl::Buffer>();
}
auto& context = ctx_->As<OpenCLContext>();
auto* wait_list = context.cl_wait_list();
auto* x_ptr = param.x->data<float, cl::Buffer>();
auto it = wait_list->find(x_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
VLOG(2) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
......
......@@ -42,6 +42,7 @@ class LayoutComputeBufferChwToImageDefault
if (param.process_type == 1) {
kernel_func_name_ = "buffer_to_image2d_with_pre255";
}
VLOG(2) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/layout_kernel.cl", build_options_);
......@@ -73,20 +74,21 @@ class LayoutComputeBufferChwToImageDefault
const int Stride1 = out_H * out_W;
const int Stride0 = out_W;
VLOG(4) << "y image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " "
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " "
VLOG(2) << "param.process_type:" << param.process_type;
VLOG(2) << "x_dims:" << x_dims;
VLOG(2) << "param.x->memory_size():" << param.x->memory_size();
VLOG(2) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " "
<< new_dims[1] << " " << new_dims[2] << " " << new_dims[3];
VLOG(4) << "out_C:" << out_C;
VLOG(4) << "out_H:" << out_H;
VLOG(4) << "out_W:" << out_W;
VLOG(4) << "Stride2:" << Stride2;
VLOG(4) << "Stride1:" << Stride1;
VLOG(4) << "Stride0:" << Stride0;
VLOG(2) << "y_dims:" << y_dims;
VLOG(2) << "param.y->memory_size():" << param.y->memory_size();
VLOG(2) << "y image_shape(w,h):" << image_shape["width"] << " "
<< image_shape["height"];
VLOG(2) << "out_C:" << out_C;
VLOG(2) << "out_H:" << out_H;
VLOG(2) << "out_W:" << out_W;
VLOG(2) << "Stride2:" << Stride2;
VLOG(2) << "Stride1:" << Stride1;
VLOG(2) << "Stride0:" << Stride0;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
......@@ -112,7 +114,7 @@ class LayoutComputeBufferChwToImageDefault
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride2));
CL_CHECK_FATAL(status);
VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3]
VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3]
<< " " << (new_dims[0] * new_dims[2]);
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>((new_dims[1] + 3) / 4),
......@@ -151,6 +153,7 @@ class LayoutComputeImageDefaultToBufferChw
if (param.process_type == 1) {
kernel_func_name_ = "image2d_to_buffer_with_post255";
}
VLOG(2) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/layout_kernel.cl", build_options_);
......@@ -174,14 +177,15 @@ class LayoutComputeImageDefaultToBufferChw
new_dims[4 - x_dims.size() + j] = x_dims[j];
}
VLOG(4) << "x_image_shape(w,h):" << x_image_shape["width"] << " "
VLOG(2) << "param.process_type:" << param.process_type;
VLOG(2) << "x_dims:" << x_dims;
VLOG(2) << "param.x->memory_size():" << param.x->memory_size();
VLOG(2) << "x_image_shape(w,h):" << x_image_shape["width"] << " "
<< x_image_shape["height"];
VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " "
<< x_dims[1] << " " << x_dims[2] << " " << x_dims[3];
VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " "
<< y_dims[1] << " " << y_dims[2] << " " << y_dims[3];
VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " "
VLOG(2) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " "
<< new_dims[1] << " " << new_dims[2] << " " << new_dims[3];
VLOG(2) << "y_dims:" << y_dims;
VLOG(2) << "param.y->memory_size():" << param.y->memory_size();
size_t C = new_dims[1];
size_t in_height = new_dims[2];
......@@ -213,7 +217,7 @@ class LayoutComputeImageDefaultToBufferChw
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(C));
CL_CHECK_FATAL(status);
VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3]
VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3]
<< " " << (new_dims[0] * new_dims[2]);
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>((new_dims[1] + 3) / 4),
......@@ -307,7 +311,7 @@ class LayoutComputeBufferChwToImage2DNw
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride2));
CL_CHECK_FATAL(status);
VLOG(4) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " "
VLOG(2) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " "
<< (out_C * out_H);
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>((out_N + 3) / 4), // N blocks
......
......@@ -35,6 +35,9 @@ bool IoCopyOp::AttachImpl(const cpp::OpDesc &opdesc,
auto out = opdesc.Output("Out").front();
param_.x = GetTensor(scope, x);
param_.y = GetMutableTensor(scope, out);
if (opdesc.HasAttr("process_type")) {
param_.process_type = opdesc.GetAttr<int>("process_type");
}
return true;
}
std::string IoCopyOp::DebugString() const { return "io_copy_op"; }
......
......@@ -57,6 +57,7 @@ struct FetchParam {
struct IoCopyParam {
const lite::Tensor* x{};
lite::Tensor* y{};
int process_type{0};
};
struct LayoutParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册