diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 967decc4a1e164e4c22181c4d314f07fbb9511e9..972210c8f9ea05ba1b041382c43efad64aeacc1b 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -38,11 +38,13 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { std::vector passes{}; auto use_layout_preprocess_pass = config.model_dir().find("OPENCL_PRE_PRECESS"); - if (use_layout_preprocess_pass != std::string::npos) { + VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass; + if (places[0].target == TARGET(kOpenCL) && + use_layout_preprocess_pass != std::string::npos) { passes = {"type_layout_cast_preprocess_pass"}; + VLOG(1) << "add pass:" << passes[0]; } raw_predictor_.Build(config, places, passes); - mode_ = config.power_mode(); threads_ = config.threads(); diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 9f071cf7780e27defdd1fcd6be02844618165fb6..2cb2064da518bca442e882d0733c5c6966c4fac0 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -38,6 +38,7 @@ void Tensor::Resize(const shape_t &shape) { tensor(raw_tensor_)->Resize(shape); } +// Tensor::data template <> const float *Tensor::data() const { return ctensor(raw_tensor_)->data(); @@ -47,15 +48,19 @@ const int8_t *Tensor::data() const { return ctensor(raw_tensor_)->data(); } template <> +const uint8_t *Tensor::data() const { + return ctensor(raw_tensor_)->data(); +} +template <> const int64_t *Tensor::data() const { return ctensor(raw_tensor_)->data(); } - template <> const int32_t *Tensor::data() const { return ctensor(raw_tensor_)->data(); } +// Tensor::mutable_data template <> int *Tensor::mutable_data(TargetType type) const { return tensor(raw_tensor_)->mutable_data(type); @@ -69,6 +74,10 @@ int8_t *Tensor::mutable_data(TargetType type) const { return tensor(raw_tensor_)->mutable_data(type); } template <> +uint8_t *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); +} +template <> int64_t *Tensor::mutable_data(TargetType type) const { return tensor(raw_tensor_)->mutable_data(type); } @@ -116,18 +125,22 @@ void Tensor::CopyToCpu(T *data) const { template void Tensor::CopyFromCpu(const int *); template void Tensor::CopyFromCpu(const float *); template void Tensor::CopyFromCpu(const int8_t *); +template void Tensor::CopyFromCpu(const uint8_t *); template void Tensor::CopyFromCpu(const int *); template void Tensor::CopyFromCpu(const float *); template void Tensor::CopyFromCpu(const int8_t *); +template void Tensor::CopyFromCpu(const uint8_t *); + template void Tensor::CopyFromCpu(const int *); template void Tensor::CopyFromCpu(const int64_t *); template void Tensor::CopyFromCpu(const float *); template void Tensor::CopyFromCpu(const int8_t *); -template void Tensor::CopyToCpu(int8_t *) const; template void Tensor::CopyToCpu(float *) const; template void Tensor::CopyToCpu(int *) const; +template void Tensor::CopyToCpu(int8_t *) const; +template void Tensor::CopyToCpu(uint8_t *) const; shape_t Tensor::shape() const { return ctensor(raw_tensor_)->dims().Vectorize(); diff --git a/lite/backends/opencl/cl_kernel/image/layout_kernel.cl b/lite/backends/opencl/cl_kernel/image/layout_kernel.cl index 65f4e9fb03164635707596df8ab0c0bef7c95d9c..6c419fe3c134614d28b3bcee3eabac5e8f7bdf6e 100644 --- a/lite/backends/opencl/cl_kernel/image/layout_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/layout_kernel.cl @@ -15,7 +15,9 @@ limitations under the License. */ #include // #define DEBUG +//////////////////////////////////////////////////////// // buffer -> image2d +//////////////////////////////////////////////////////// __kernel void buffer_to_image2d(__global CL_DTYPE *in, __write_only image2d_t output_image, __private const int out_H, @@ -80,8 +82,9 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in, WRITE_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, output_image, output_pos, output); } - +//////////////////////////////////////////////////////// // image2d -> buffer +//////////////////////////////////////////////////////// __kernel void image2d_to_buffer(__read_only image2d_t input, __private const int in_width, __private const int in_height, @@ -125,8 +128,10 @@ __kernel void image2d_to_buffer(__read_only image2d_t input, } -#if 0 +#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile +//////////////////////////////////////////////////////// // buffer -> image2d_nw +//////////////////////////////////////////////////////// __kernel void buffer_to_image2d_nw(__global CL_DTYPE* in, __write_only image2d_t output_image, __private const int out_H, @@ -178,7 +183,7 @@ __kernel void buffer_to_image2d_nw(__global CL_DTYPE* in, #endif -#if 0 +#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile // image2d -> buffer __kernel void image2d_to_buffer_2d(__private const int in_height, __private const int in_width, @@ -200,7 +205,9 @@ __kernel void image2d_to_buffer_2d(__private const int in_height, } #endif +//////////////////////////////////////////////////////// // buffer -> image2d (divide by 255 to normalize) +//////////////////////////////////////////////////////// __kernel void buffer_to_image2d_with_pre255(__global uchar *in, __write_only image2d_t output_image, __private const int out_H, @@ -248,7 +255,10 @@ __kernel void buffer_to_image2d_with_pre255(__global uchar *in, WRITE_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, output_image, output_pos, output); } + +//////////////////////////////////////////////////////// // image2d -> buffer (multiply by 255 to de-normalize) +//////////////////////////////////////////////////////// __kernel void image2d_to_buffer_with_post255(__read_only image2d_t input, __private const int in_width, __private const int in_height, @@ -267,17 +277,22 @@ __kernel void image2d_to_buffer_with_post255(__read_only image2d_t input, CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; const int pos_x = mad24(in_c, in_width, in_w); - CL_COMPUTE_DTYPE4 in = READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh)); + CL_COMPUTE_DTYPE4 in = READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh)) * 255; + +#ifdef DEBUG + printf("in_c:%d, in_w:%d, in_nh:%d ===> in(%d,%d): %.2f %.2f %.2f %.2f\n", + in_c, in_w, in_nh, pos_x, in_nh, in.x, in.y, in.z, in.w); +#endif const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; - out[index] = convert_uchar_sat(in.x * 255); + out[index] = convert_uchar_sat(in.x); if(C - 4 * in_c>=2){ - out[index + size_ch] = convert_uchar_sat(in.y * 255); + out[index + size_ch] = convert_uchar_sat(in.y); } if(C - 4 * in_c>=3){ - out[index + size_ch * 2] = convert_uchar_sat(in.z * 255); + out[index + size_ch * 2] = convert_uchar_sat(in.z); } if(C - 4 * in_c>=4){ - out[index + size_ch * 3] = convert_uchar_sat(in.w * 255); + out[index + size_ch * 3] = convert_uchar_sat(in.w); } } diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc index 5625a34e93c8301e7df8736b5cab26d7b88d7e2a..6cf03ee3b5dd5c1d497c12797bbba631c87480f1 100644 --- a/lite/core/mir/type_layout_cast_pass.cc +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -217,7 +217,9 @@ void OpenCLTypeLayoutTransformPass::Apply( for (auto& node : nodes) { VLOG(4) << "!node->IsStmt():" << !node->IsStmt(); if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue; - if (node->AsStmt().op_type() == "layout") { + VLOG(1) << "node->AsStmt().op_type():" << node->AsStmt().op_type(); + if (node->AsStmt().op_type() == "layout" || + node->AsStmt().op_type() == "io_copy") { auto new_op = node->AsStmt().mutable_op_info(); int process_type = 1; new_op->SetAttr("process_type", process_type); diff --git a/lite/kernels/opencl/conv_buffer_compute.cc b/lite/kernels/opencl/conv_buffer_compute.cc index 502b853fa003c0b8d94dd3c9720204664391c26e..0fa607c938fc30efef06d34ec05640069dc3aeea 100644 --- a/lite/kernels/opencl/conv_buffer_compute.cc +++ b/lite/kernels/opencl/conv_buffer_compute.cc @@ -297,1135 +297,6 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel, void ConvCompute::Run() { (this->*impl_)(); } -/* image kernel*/ -void ConvImageCompute::PrepareForRun() { - const auto& param = this->Param(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - float* filter_cpu = param.filter->mutable_data(); - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - - int bs = x_dims[0]; - int c_in = x_dims[1]; - int h_out = output_dims[2]; - int w_out = output_dims[3]; - int kernel_h = filter_dims[2]; // oihw - int kernel_w = filter_dims[3]; - auto paddings = *param.paddings; - auto dilations = *param.dilations; - int stride_h = param.strides[0]; - int stride_w = param.strides[1]; - int pad_h = paddings[0]; - int pad_w = paddings[2]; - int groups = param.groups; - bool relu_fused = param.fuse_relu; - bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); - bool zero_pad = (pad_h == 0) && (pad_w == 0); - - bool pad_equal = - ((paddings[0] == paddings[1]) && (paddings[1] == paddings[2]) && - (paddings[2] == paddings[3])); - bool stride_equal = stride_h == stride_w; - bool dilation_equal = dilations[0] == dilations[1]; - - CHECK(pad_equal && stride_equal && dilation_equal); - - VLOG(3) << "Is relu fused? / " << (relu_fused ? "Yes" : "No"); - VLOG(3) << "groups:" << groups << " stride_h:" << stride_h - << " stride_w:" << stride_w << " pad_h:" << pad_h - << " pad_w:" << pad_w << " kernel_h:" << kernel_h - << " kernel_h:" << kernel_h; - VLOG(3) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2] - << " " << x_dims[3]; - VLOG(3) << "output_dims:" << output_dims[0] << " " << output_dims[1] << " " - << output_dims[2] << " " << output_dims[3]; - VLOG(3) << "filter_dims:" << filter_dims[0] << " " << filter_dims[1] << " " - << filter_dims[2] << " " << filter_dims[3]; - if (kernel_h == 1 && kernel_w == 1) { - // conv2d_1x1 - if (param.x->dims()[1] % 4 == 0) { - kernel_func_names_.push_back("conv2d_1x1_simple"); - } else { - kernel_func_names_.push_back("conv2d_1x1"); - } - kernel_func_paths_.push_back("image/conv2d_1x1_kernel.cl"); - - CLImageConverterNWBlock converter; - const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - std::vector filter_image_v(filter_image_dims[0] * - filter_image_dims[1] * 4); // 4 : RGBA - converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); - filter_gpu_image_.mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); - - impl_ = &ConvImageCompute::Conv2d1x1; -#if 1 // TODO(ysh329): enable general dwconv - } else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1]) { -#else // TODO(ysh329): remove dwconv3x3s1 and dwconv3x3 temporarily, need fix - } else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] && - kernel_h == 3 && kernel_w == 3 && groups > 1) { - // depth_conv2d_3x3s1, depth_conv2d_3x3 - if (stride_h == 1 && dilations[0] == 1) { - kernel_func_names_.push_back("depth_conv2d_3x3s1"); - impl_ = &ConvImageCompute::DepthwiseConv2d3x3s1; - } else { - kernel_func_names_.push_back("depth_conv2d_3x3"); - impl_ = &ConvImageCompute::DepthwiseConv2d3x3; - } - kernel_func_paths_.push_back("image/depthwise_conv2d_kernel.cl"); - - CLImageConverterNWBlock converter; - const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - std::vector filter_image_v(filter_image_dims[0] * - filter_image_dims[1] * 4); // 4 : RGBA - converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); - filter_gpu_image_.mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); - } else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] && - kernel_h != 3) { -#endif - // depth_conv2d - kernel_func_names_.push_back("depth_conv2d"); - kernel_func_paths_.push_back("image/depthwise_conv2d_basic_kernel.cl"); - - CLImageConverterNWBlock converter; - const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - std::vector filter_image_v(filter_image_dims[0] * - filter_image_dims[1] * 4); // 4 : RGBA - converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); - filter_gpu_image_.mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); - - impl_ = &ConvImageCompute::DepthwiseConv2d; - } else if (kernel_h == 3 && kernel_h == 3) { - // conv2d_3x3 - kernel_func_names_.push_back("conv2d_3x3"); - kernel_func_paths_.push_back("image/conv2d_3x3_kernel.cl"); - - CLImageConverterFolder converter; - const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - std::vector filter_image_v(filter_image_dims[0] * - filter_image_dims[1] * 4); // 4 : RGBA - converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); - filter_gpu_image_.mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); - - impl_ = &ConvImageCompute::Conv2d3x3; - } else if (kernel_h == 5 && kernel_w == 5) { - // conv2d_5x5 - kernel_func_names_.push_back("conv2d_5x5"); - kernel_func_paths_.push_back("image/conv2d_5x5_kernel.cl"); - - CLImageConverterFolder converter; - const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - std::vector filter_image_v(filter_image_dims[0] * - filter_image_dims[1] * 4); // 4 : RGBA - converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); - filter_gpu_image_.mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); - - impl_ = &ConvImageCompute::Conv2d5x5; - } else if (kernel_h == 7 && kernel_w == 7) { - // conv2d_7x7 - kernel_func_names_.push_back("conv2d_7x7"); - kernel_func_paths_.push_back("image/conv2d_7x7_kernel.cl"); - - CLImageConverterFolder converter; - const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - std::vector filter_image_v(filter_image_dims[0] * - filter_image_dims[1] * 4); // 4 : RGBA - converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); - this->filter_gpu_image_.mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); - - impl_ = &ConvImageCompute::Conv2d7x7; - } else { - LOG(FATAL) << "conv image compute not support this condition yet! "; - } - VLOG(1) << "kernel_func_names_[0]:" << kernel_func_names_[0] - << " kernel_func_paths_[0]:" << kernel_func_paths_[0]; - - std::string build_options_single(" -DCL_DTYPE_float"); - // relu options - if (relu_fused) { - build_options_single += " -DRELU"; - } else if (param.activation_param.active_type == - lite_api::ActivationType::kRelu6) { - build_options_single += " -DRELU6"; - } else { - // do nothing - } - // bias options - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - if (has_bias) { - build_options_single += - is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH"; - - // convert cpu buffer bias --> gpu image - CLImageConverterFolder bias_converter; - const DDim& bias_image_dims = - bias_converter.InitImageDimInfoWith(param.bias->dims()); - std::vector bias_image_v(bias_image_dims[0] * bias_image_dims[1] * - 4); - float* bias_cpu_data = param.bias->mutable_data(); - bias_converter.NCHWToImage( - bias_cpu_data, bias_image_v.data(), param.bias->dims()); - this->bias_gpu_image_.mutable_data( - bias_image_dims[0], bias_image_dims[1], bias_image_v.data()); - // convert cpu buffer bias --> gpu image --- end ---- - } - - build_options_.push_back(build_options_single); - - for (size_t i = 0; i < kernel_func_names_.size(); i++) { - context.cl_context()->AddKernel( - kernel_func_names_[i], kernel_func_paths_[i], build_options_[i]); - } -} - -void ConvImageCompute::Conv2d1x1() { - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_.data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; - - VLOG(4) << "============ conv2d_1x1 params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - VLOG(4) << "input_image: " << input_image; - VLOG(4) << "filter_dims: " << filter_dims; - VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - VLOG(4) << "default work size{c_block, w, nh}: " - << "{" << c_block << ", " << w << ", " << nh << "" - << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - // handle bias use buffer for channel wise , use image for element wise - const cl::Buffer* bias_buf = nullptr; - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_.data(); - } - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - std::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - int maped_w = maptofactor(w, 4); - - VLOG(4) << "kernel_key: " << kernel_key.str(); - VLOG(4) << "kernel ready ... " << kernel_key.str(); - VLOG(4) << "maped_w: " << maped_w; - VLOG(4) << "hasbias: " << has_bias; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, maped_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w); - CL_CHECK_FATAL(status); - - auto global_work_size = - cl::NDRange{static_cast(default_work_size.data()[0]), - static_cast(maped_w), - static_cast(default_work_size.data()[2])}; - - VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); -} - -void ConvImageCompute::Conv2d3x3() { - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_.data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int input_channel = input_dims[1]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int output_channel = output_dims[1]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - int filter_channel = filter_dims[1]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - // re-calc group - int new_groups{param.groups}; - if (filter_dims[0] == output_dims[1] && filter_dims[1] == input_dims[1]) { - new_groups = 1; - } else if (!(filter_dims[0] == input_dims[1] && filter_dims[1] == 1)) { - new_groups = input_channel / filter_channel; - } - /* TODO(ysh329): mobile has no case below - else { - LOG(FATAL) << "Not support conv3x3 case with" - << " input_dims:" << input_dims << " output_dims:" << - output_dims - << " filter_dims:" << filter_dims; - } - */ - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; - - VLOG(4) << "============ conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - VLOG(4) << "param.groups(groups):" << param.groups; - VLOG(4) << "new_groups:" << new_groups; - VLOG(4) << "default work size{c_block, w, nh}: " - << "{" << c_block << ", " << w << ", " << nh << "" - << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_.data(); - } - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - STL::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - VLOG(4) << "kernel_key: " << kernel_key.str(); - VLOG(4) << "kernel ready ... " << kernel_key.str(); - VLOG(4) << "w: " << w; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - VLOG(4) << "set bias_image: "; - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_channel); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_channel); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, new_groups); - CL_CHECK_FATAL(status); - - auto global_work_size = - cl::NDRange{static_cast(default_work_size.data()[0]), - static_cast(default_work_size.data()[1]), - static_cast(default_work_size.data()[2])}; - - VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); -} - -void ConvImageCompute::Conv2d5x5() { - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_.data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; - - VLOG(4) << "============ conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - VLOG(4) << "default work size{c_block, w, nh}: " - << "{" << c_block << ", " << w << ", " << nh << "" - << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_.data(); - } - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - STL::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - VLOG(4) << "kernel_key: " << kernel_key.str(); - VLOG(4) << "kernel ready ... " << kernel_key.str(); - VLOG(4) << "w: " << w; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - VLOG(4) << "set bias_image: "; - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - - auto global_work_size = - cl::NDRange{static_cast(default_work_size.data()[0]), - static_cast(default_work_size.data()[1]), - static_cast(default_work_size.data()[2])}; - - VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); -} - -void ConvImageCompute::Conv2d7x7() { - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_.data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; - - VLOG(4) << "============ conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - VLOG(4) << "default work size{c_block, w, nh}: " - << "{" << c_block << ", " << w << ", " << nh << "" - << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_.data(); - } - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - STL::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - VLOG(4) << "kernel_key: " << kernel_key.str(); - VLOG(4) << "kernel ready ... " << kernel_key.str(); - VLOG(4) << "w: " << w; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - VLOG(4) << "set bias_image: "; - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - - auto global_work_size = - cl::NDRange{static_cast(default_work_size.data()[0]), - static_cast(default_work_size.data()[1]), - static_cast(default_work_size.data()[2])}; - - VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); -} - -void ConvImageCompute::DepthwiseConv2d3x3s1() { - const auto& param = *param_.get_mutable(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - auto* input_img = param.x->data(); - auto* filter_img = filter_gpu_image_.data(); - - const cl::Image2D* bias_img = nullptr; - if (param.bias) { - bias_img = bias_gpu_image_.data(); - } - - auto image_shape = InitImageDimInfoWith(output_dims); - - auto* output_img = param.output->mutable_data( - image_shape["width"], image_shape["height"]); - - STL::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - - int c_block = (output_dims[1] + 3) / 4; - int w = output_dims[3]; - int nh = output_dims[0] * output_dims[2]; - - int w_blk_size = 2; - int w_blk = (w + w_blk_size - 1) / w_blk_size; - - auto global_work_size = cl::NDRange(c_block, w_blk, nh); - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, static_cast(c_block)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(w_blk)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(nh)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *output_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(strides[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(paddings[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(dilations[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[1])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); - CL_CHECK_FATAL(status); - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(output_img, event_); -} - -void ConvImageCompute::DepthwiseConv2d3x3() { - const auto& param = *param_.get_mutable(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - int offset = filter_dims[2] / 2 - paddings[0]; - int input_c_block = (x_dims[1] + 3) / 4; - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - auto* input_img = param.x->data(); - auto* filter_img = filter_gpu_image_.data(); - - const cl::Image2D* bias_img = nullptr; - if (param.bias) { - bias_img = bias_gpu_image_.data(); - } - - auto image_shape = InitImageDimInfoWith(output_dims); - - auto* output_img = param.output->mutable_data( - image_shape["width"], image_shape["height"]); - - STL::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - - int c_block = (output_dims[1] + 3) / 4; - int w = output_dims[3]; - int nh = output_dims[0] * output_dims[2]; - auto global_work_size = cl::NDRange(c_block, w, nh); - - VLOG(4) << "setArg"; - VLOG(4) << "c_block = " << c_block; - VLOG(4) << "w = " << w; - VLOG(4) << "nh = " << nh; - - VLOG(4) << "strides = " << strides[0]; - VLOG(4) << "offset = " << offset; - VLOG(4) << "dilations = " << dilations[0]; - VLOG(4) << "input_c_block = " << input_c_block; - VLOG(4) << "x_dims[3] = " << x_dims[3]; - VLOG(4) << "x_dims[2] = " << x_dims[2]; - VLOG(4) << "output_dims[3] = " << output_dims[3]; - VLOG(4) << "output_dims[2] = " << output_dims[2]; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, static_cast(c_block)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(w)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(nh)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *output_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(strides[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(offset)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(dilations[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(input_c_block)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); - CL_CHECK_FATAL(status); - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(output_img, event_); -} - -void ConvImageCompute::DepthwiseConv2d() { - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_.data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; - - VLOG(4) << "============ depthwise conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - VLOG(4) << "input_image: " << input_image; - VLOG(4) << "filter_dims: " << filter_dims; - VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - VLOG(4) << "default work size{c_block, w, nh}: " - << "{" << c_block << ", " << w << ", " << nh << "" - << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - // handle bias use buffer for channel wise , use image for element wise - const cl::Buffer* bias_buf = nullptr; - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_.data(); - } - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - STL::stringstream kernel_key; - kernel_key << kernel_func_names_[0] << build_options_[0]; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - VLOG(4) << "kernel_key: " << kernel_key.str(); - VLOG(4) << "kernel ready ... " << kernel_key.str(); - VLOG(4) << "w: " << w; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - VLOG(4) << "set bias_image: "; - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_height); - CL_CHECK_FATAL(status); - - auto global_work_size = - cl::NDRange{static_cast(default_work_size.data()[0]), - static_cast(default_work_size.data()[1]), - static_cast(default_work_size.data()[2])}; - - VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); -} - -void ConvImageCompute::Run() { (this->*impl_)(); } - } // namespace opencl } // namespace kernels } // namespace lite diff --git a/lite/kernels/opencl/io_copy_buffer_compute.cc b/lite/kernels/opencl/io_copy_buffer_compute.cc index 0148e6b143ebae1d4c2ca60e6f40f8a0228a2979..0e9a5941c0a3484ffbb72012f64c07296694078b 100644 --- a/lite/kernels/opencl/io_copy_buffer_compute.cc +++ b/lite/kernels/opencl/io_copy_buffer_compute.cc @@ -42,16 +42,11 @@ class IoCopyHostToOpenCLCompute CHECK(param.x->target() == TARGET(kHost) || param.x->target() == TARGET(kARM)); auto mem_size = param.x->memory_size(); - - VLOG(4) << "copy size " << mem_size; - VLOG(4) << "param.x->dims().size():" << param.x->dims().size(); - VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " " - << param.x->dims()[1] << " " << param.x->dims()[2] << " " - << param.x->dims()[3]; - VLOG(4) << "param.y->dims().size():" << param.y->dims().size(); - VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " " - << param.y->dims()[1] << " " << param.y->dims()[2] << " " - << param.y->dims()[3]; + VLOG(2) << "param.x->memory_size():" << mem_size; + VLOG(2) << "param.x->dims().size():" << param.x->dims().size(); + VLOG(2) << "param.x->dims():" << param.x->dims(); + VLOG(2) << "param.y->dims().size():" << param.y->dims().size(); + VLOG(2) << "param.y->dims():" << param.y->dims(); auto* data = param.y->mutable_data(TARGET(kOpenCL), mem_size); CopyFromHostSync(data, param.x->raw_data(), mem_size); } @@ -89,23 +84,27 @@ class IoCopykOpenCLToHostCompute auto& param = Param(); CHECK(param.x->target() == TARGET(kOpenCL)); auto mem_size = param.x->memory_size(); - VLOG(4) << "copy size " << mem_size; - VLOG(4) << "param.x->dims().size():" << param.x->dims().size(); - VLOG(4) << "param.x->dims():" << param.x->dims()[0] << " " - << param.x->dims()[1] << " " << param.x->dims()[2] << " " - << param.x->dims()[3]; - VLOG(4) << "param.y->dims().size():" << param.y->dims().size(); - VLOG(4) << "param.y->dims():" << param.y->dims()[0] << " " - << param.y->dims()[1] << " " << param.y->dims()[2] << " " - << param.y->dims()[3]; + + VLOG(2) << "copy size " << mem_size; + VLOG(2) << "param.x->dims().size():" << param.x->dims().size(); + VLOG(2) << "param.x->dims():" << param.x->dims(); + VLOG(2) << "param.y->dims().size():" << param.y->dims().size(); + VLOG(2) << "param.y->dims():" << param.y->dims(); + VLOG(2) << "param.process_type:" << param.process_type; + auto* data = param.y->mutable_data(TARGET(kHost), mem_size); + const cl::Buffer* x_ptr; + if (param.process_type == 1) { + x_ptr = param.x->data(); + } else { + x_ptr = param.x->data(); + } + auto& context = ctx_->As(); auto* wait_list = context.cl_wait_list(); - auto* x_ptr = param.x->data(); - auto it = wait_list->find(x_ptr); if (it != wait_list->end()) { - VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; + VLOG(2) << "--- Find the sync event for the target cl tensor. ---"; auto& event = *(it->second); event.wait(); } else { diff --git a/lite/kernels/opencl/layout_image_compute.cc b/lite/kernels/opencl/layout_image_compute.cc index fad37aa709441f95f2aadc535e7eb5db895765f4..901f6a31871af107b12a77b4a97a4b26d9701baf 100644 --- a/lite/kernels/opencl/layout_image_compute.cc +++ b/lite/kernels/opencl/layout_image_compute.cc @@ -42,6 +42,7 @@ class LayoutComputeBufferChwToImageDefault if (param.process_type == 1) { kernel_func_name_ = "buffer_to_image2d_with_pre255"; } + VLOG(2) << "kernel_func_name_:" << kernel_func_name_; auto& context = ctx_->As(); context.cl_context()->AddKernel( kernel_func_name_, "image/layout_kernel.cl", build_options_); @@ -73,20 +74,21 @@ class LayoutComputeBufferChwToImageDefault const int Stride1 = out_H * out_W; const int Stride0 = out_W; - VLOG(4) << "y image_shape(w,h):" << image_shape["width"] << " " - << image_shape["height"]; - VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " - << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; - VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " - << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; - VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " + VLOG(2) << "param.process_type:" << param.process_type; + VLOG(2) << "x_dims:" << x_dims; + VLOG(2) << "param.x->memory_size():" << param.x->memory_size(); + VLOG(2) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " << new_dims[1] << " " << new_dims[2] << " " << new_dims[3]; - VLOG(4) << "out_C:" << out_C; - VLOG(4) << "out_H:" << out_H; - VLOG(4) << "out_W:" << out_W; - VLOG(4) << "Stride2:" << Stride2; - VLOG(4) << "Stride1:" << Stride1; - VLOG(4) << "Stride0:" << Stride0; + VLOG(2) << "y_dims:" << y_dims; + VLOG(2) << "param.y->memory_size():" << param.y->memory_size(); + VLOG(2) << "y image_shape(w,h):" << image_shape["width"] << " " + << image_shape["height"]; + VLOG(2) << "out_C:" << out_C; + VLOG(2) << "out_H:" << out_H; + VLOG(2) << "out_W:" << out_W; + VLOG(2) << "Stride2:" << Stride2; + VLOG(2) << "Stride1:" << Stride1; + VLOG(2) << "Stride0:" << Stride0; auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); @@ -112,7 +114,7 @@ class LayoutComputeBufferChwToImageDefault status = kernel.setArg(++arg_idx, static_cast(Stride2)); CL_CHECK_FATAL(status); - VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] + VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] << " " << (new_dims[0] * new_dims[2]); auto global_work_size = cl::NDRange{static_cast((new_dims[1] + 3) / 4), @@ -151,6 +153,7 @@ class LayoutComputeImageDefaultToBufferChw if (param.process_type == 1) { kernel_func_name_ = "image2d_to_buffer_with_post255"; } + VLOG(2) << "kernel_func_name_:" << kernel_func_name_; auto& context = ctx_->As(); context.cl_context()->AddKernel( kernel_func_name_, "image/layout_kernel.cl", build_options_); @@ -174,14 +177,15 @@ class LayoutComputeImageDefaultToBufferChw new_dims[4 - x_dims.size() + j] = x_dims[j]; } - VLOG(4) << "x_image_shape(w,h):" << x_image_shape["width"] << " " + VLOG(2) << "param.process_type:" << param.process_type; + VLOG(2) << "x_dims:" << x_dims; + VLOG(2) << "param.x->memory_size():" << param.x->memory_size(); + VLOG(2) << "x_image_shape(w,h):" << x_image_shape["width"] << " " << x_image_shape["height"]; - VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " - << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; - VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " - << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; - VLOG(4) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " + VLOG(2) << "new_dims[" << new_dims.size() << "D]:" << new_dims[0] << " " << new_dims[1] << " " << new_dims[2] << " " << new_dims[3]; + VLOG(2) << "y_dims:" << y_dims; + VLOG(2) << "param.y->memory_size():" << param.y->memory_size(); size_t C = new_dims[1]; size_t in_height = new_dims[2]; @@ -213,7 +217,7 @@ class LayoutComputeImageDefaultToBufferChw CL_CHECK_FATAL(status); status = kernel.setArg(++arg_idx, static_cast(C)); CL_CHECK_FATAL(status); - VLOG(4) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] + VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3] << " " << (new_dims[0] * new_dims[2]); auto global_work_size = cl::NDRange{static_cast((new_dims[1] + 3) / 4), @@ -307,7 +311,7 @@ class LayoutComputeBufferChwToImage2DNw status = kernel.setArg(++arg_idx, static_cast(Stride2)); CL_CHECK_FATAL(status); - VLOG(4) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " " + VLOG(2) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " " << (out_C * out_H); auto global_work_size = cl::NDRange{static_cast((out_N + 3) / 4), // N blocks diff --git a/lite/operators/io_copy_op.cc b/lite/operators/io_copy_op.cc index f7e72a6e1e1ecb01e866fece1a09d7b9c4e7a695..7df636d7b2d877a5539a980080077be785d47505 100644 --- a/lite/operators/io_copy_op.cc +++ b/lite/operators/io_copy_op.cc @@ -35,6 +35,9 @@ bool IoCopyOp::AttachImpl(const cpp::OpDesc &opdesc, auto out = opdesc.Output("Out").front(); param_.x = GetTensor(scope, x); param_.y = GetMutableTensor(scope, out); + if (opdesc.HasAttr("process_type")) { + param_.process_type = opdesc.GetAttr("process_type"); + } return true; } std::string IoCopyOp::DebugString() const { return "io_copy_op"; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index f524d8a6aec748dd2fca4375d6f68602148c93d4..612ef253f8a5452872057ab2e10e58e530642b8b 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -57,6 +57,7 @@ struct FetchParam { struct IoCopyParam { const lite::Tensor* x{}; lite::Tensor* y{}; + int process_type{0}; }; struct LayoutParam {