未验证 提交 e583a55d 编写于 作者: X xiebaiyuan 提交者: GitHub

[LITE][OPENCL] Add conv2d_1x1 opencl kernel (#2591)

* add opencl conv1x1 image impl and unit test pass with relu & bias,
add layout_compute --> buffer2image float32 --> with unit test pass
suite checked test for more situation , test=develop

* add opencl conv1x1 image impl and unit test pass with relu & bias,
add layout_compute --> buffer2image float32 --> with unit test pass
suite checked test for more situation , test=develop

* fix white space cpp lint , test=develop
上级 aab5f53f
......@@ -103,6 +103,7 @@ class CLImageConverterNormal : public CLImageConverterBase {
};
class CLImageConverterNWBlock : public CLImageConverterBase {
public:
DDim InitImageDimInfoWith(const DDim &tensor_dim) override;
void NCHWToImage(float *tensor,
float *image,
......
......@@ -61,6 +61,57 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in,
write_imagef(output_image, output_pos, output);
}
// buffer -> image2d_nw
__kernel void buffer_to_image2d_nw(__global CL_DTYPE* in,
__write_only image2d_t output_image,
__private const int out_H,
__private const int out_W,
__private const int out_N,
__private const int Stride0,
__private const int Stride1,
__private const int Stride2) {
const int out_n = get_global_id(0);
const int out_w = get_global_id(1);
const int out_ch = get_global_id(2);
const int out_c = out_ch / out_H;
const int out_h = out_ch % out_H;
const int in_c = out_c; // index of c in h direction
const int in_n0 = out_n * 4 + 0;
const int in_n1 = out_n * 4 + 1;
const int in_n2 = out_n * 4 + 2;
const int in_n3 = out_n * 4 + 3;
const int in_h = out_h;
const int in_w = out_w;
int input_pos0 = in_n0 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w;
int input_pos1 = in_n1 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w;
int input_pos2 = in_n2 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w;
int input_pos3 = in_n3 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w;
int2 output_pos;
output_pos.x = out_n * out_W + out_w;
output_pos.y = out_ch;
CL_DTYPE4 output = (CL_DTYPE4)0.0f;
output.x = convert_float(in[input_pos0]);
if (out_N - 4 * out_n >= 2) {
output.y = convert_float(in[input_pos1]);
}
if (out_N - 4 * out_n >= 3) {
output.z = convert_float(in[input_pos2]);
}
if (out_N - 4 * out_n >= 4) {
output.w = convert_float(in[input_pos3]);
}
write_imagef(output_image, output_pos, output);
}
// image2d -> buffer
__kernel void image2d_to_buffer(__read_only image2d_t input,
__private const int in_width,
......
......@@ -61,3 +61,19 @@ inline CL_DTYPE activation(CL_DTYPE in
#endif
return output;
}
inline CL_DTYPE4 activation_type4(CL_DTYPE4 in
#ifdef PRELU
,
CL_DTYPE4 prelu_alpha
#endif
) {
CL_DTYPE4 output;
#ifdef PRELU
output = select(prelu_alpha * in, in, in >= (CL_DTYPE4)0.0);
#endif
#ifdef RELU
output = fmax(in, (CL_DTYPE4)0);
#endif
return output;
}
#include <cl_common.h>
__kernel void conv_1x1(
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale, __read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int input_c_origin,
__private const int dilation,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__private const int old_w) {
CL_DTYPE zero = 0.0f;
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int out_w0 = out_w;
int out_w1 = out_w + global_size_dim1;
int out_w2 = out_w + global_size_dim1 * 2;
int out_w3 = out_w + global_size_dim1 * 3;
int outpos_main = mul24(out_c, old_w);
int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh);
int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh);
int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh);
int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 stride_xy = (int2)(stride, stride);
int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh);
int2 in_pos_in_one_block0 =
ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset);
int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh);
int2 in_pos_in_one_block1 =
ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset);
int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh);
int2 in_pos_in_one_block2 =
ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset);
int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh);
int2 in_pos_in_one_block3 =
ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset);
#ifdef BIASE_CH
CL_DTYPE4 output0 = read_imagef(bias, sampler, (int2)(out_c, 0));
CL_DTYPE4 output1 = output0;
CL_DTYPE4 output2 = output0;
CL_DTYPE4 output3 = output0;
#elif defined(BIASE_ELE)
CL_DTYPE4 output0 = read_imagef(bias, sampler, output_pos0);
CL_DTYPE4 output1 = output0;
CL_DTYPE4 output2 = output0;
CL_DTYPE4 output3 = output0;
#else
CL_DTYPE4 output0 = 0.0f;
CL_DTYPE4 output1 = 0.0f;
CL_DTYPE4 output2 = 0.0f;
CL_DTYPE4 output3 = 0.0f;
#endif
int max_w_bound = input_c * input_width;
int burndary_index = input_c * 4 - input_c_origin;
bool burndary_index_w =
burndary_index == 1 || burndary_index == 2 || burndary_index == 3;
bool burndary_index_z = burndary_index == 2 || burndary_index == 3;
bool burndary_index_y = burndary_index == 3;
for (int i = 0; i < input_c; ++i) {
// ------------0---------------
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x,
in_pos_in_one_block0.y);
CL_DTYPE4 input0 = read_imagef(input_image, sampler, pos_in);
CL_DTYPE4 weight0 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 0));
CL_DTYPE4 weight1 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 1));
CL_DTYPE4 weight2 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 2));
CL_DTYPE4 weight3 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 3));
int bound_gap = max_w_bound - pos_in.x - 1;
bool outof_bound = bound_gap < input_width && bound_gap >= 0;
input0.w = select(input0.w, zero, outof_bound && burndary_index_w);
input0.z = select(input0.z, zero, outof_bound && burndary_index_z);
input0.y = select(input0.y, zero, outof_bound && burndary_index_y);
output0 = mad(input0.x, weight0, output0);
output0 = mad(input0.y, weight1, output0);
output0 = mad(input0.z, weight2, output0);
output0 = mad(input0.w, weight3, output0);
// -------------1--------------
pos_in = (int2)(i * input_width + in_pos_in_one_block1.x,
in_pos_in_one_block1.y);
CL_DTYPE4 input1 = read_imagef(input_image, sampler, pos_in);
bound_gap = max_w_bound - pos_in.x - 1;
outof_bound = bound_gap < input_width && bound_gap >= 0;
input1.w = select(input1.w, zero, outof_bound && burndary_index_w);
input1.z = select(input1.z, zero, outof_bound && burndary_index_z);
input1.y = select(input1.y, zero, outof_bound && burndary_index_y);
output1 = mad(input1.x, weight0, output1);
output1 = mad(input1.y, weight1, output1);
output1 = mad(input1.z, weight2, output1);
output1 = mad(input1.w, weight3, output1);
// -------------2--------------
pos_in = (int2)(i * input_width + in_pos_in_one_block2.x,
in_pos_in_one_block2.y);
CL_DTYPE4 input2 = read_imagef(input_image, sampler, pos_in);
bound_gap = max_w_bound - pos_in.x - 1;
outof_bound = bound_gap < input_width && bound_gap >= 0;
input2.w = select(input2.w, zero, outof_bound && burndary_index_w);
input2.z = select(input2.z, zero, outof_bound && burndary_index_z);
input2.y = select(input2.y, zero, outof_bound && burndary_index_y);
output2 = mad(input2.x, weight0, output2);
output2 = mad(input2.y, weight1, output2);
output2 = mad(input2.z, weight2, output2);
output2 = mad(input2.w, weight3, output2);
// -------------3--------------
pos_in = (int2)(i * input_width + in_pos_in_one_block3.x,
in_pos_in_one_block3.y);
CL_DTYPE4 input3 = read_imagef(input_image, sampler, pos_in);
bound_gap = max_w_bound - pos_in.x - 1;
outof_bound = bound_gap < input_width && bound_gap >= 0;
input3.w =
select(input3.w,
zero,
outof_bound && (burndary_index == 1 || burndary_index == 2 ||
burndary_index == 3));
input3.z =
select(input3.z,
zero,
outof_bound && (burndary_index == 2 || burndary_index == 3));
input3.y = select(input3.y, zero, outof_bound && burndary_index == 3);
output3 = mad(input3.x, weight0, output3);
output3 = mad(input3.y, weight1, output3);
output3 = mad(input3.z, weight2, output3);
output3 = mad(input3.w, weight3, output3);
}
#ifdef BATCH_NORM
output0 = output0 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) +
read_imagef(new_biase, sampler, (int2)(out_c, 0));
output1 = output1 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) +
read_imagef(new_biase, sampler, (int2)(out_c, 0));
output2 = output2 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) +
read_imagef(new_biase, sampler, (int2)(out_c, 0));
output3 = output3 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) +
read_imagef(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output0 = activation_type4(output0);
output1 = activation_type4(output1);
output2 = activation_type4(output2);
output3 = activation_type4(output3);
#endif
if (out_w0 < old_w) {
write_imagef(output_image, output_pos0, output0);
}
if (out_w1 < old_w) {
write_imagef(output_image, output_pos1, output1);
}
if (out_w2 < old_w) {
write_imagef(output_image, output_pos2, output2);
}
if (out_w3 < old_w) {
write_imagef(output_image, output_pos3, output3);
}
}
\ No newline at end of file
......@@ -14,6 +14,7 @@ add_kernel(pool_opencl OPENCL basic SRCS pool_compute.cc DEPS ${cl_kernel_deps})
add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps})
add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps})
add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps})
add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps})
add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps})
add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps})
......@@ -47,10 +48,14 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc
DEPS conv2d_1x1_opencl cl_image_converter op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc
DEPS conv_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc
DEPS layout_opencl op_registry program context
DEPS layout_opencl op_registry program context cl_image_converter
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
#define USE_BUFFER_FOR_CONV1x1_BIAS
class Conv2d1x1Image2DCompute
: public KernelLite<TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override {
LOG(INFO) << "PrepareForRun ...";
const auto& param = *param_.get_mutable<param_t>();
if (param.fuse_relu) {
build_options_ += " -DRELU";
}
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
if (has_bias) {
build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/conv_1x1_kernel.cl", build_options_);
LOG(INFO) << "PrepareForRun Ready";
}
void Run() override {
LOG(INFO) << "opencl conv 1x1 run begin ...";
LOG(INFO) << "param_.get_mutable<param_t> ...";
const auto& param = *param_.get_mutable<param_t>();
LOG(INFO) << "get param dims ...";
auto input_dims = param.x->dims();
CHECK_GE(input_dims.size(), 4);
LOG(INFO) << "input_dims: " << input_dims;
int input_width = input_dims[3];
int input_height = input_dims[2];
auto filter_dims = param.filter->dims();
LOG(INFO) << "filter_dims: " << filter_dims;
auto output_dims = param.output->dims();
LOG(INFO) << "output_dims: " << output_dims;
int output_width = output_dims[3];
int output_height = output_dims[2];
// mute output image
auto out_image_shape = InitImageDimInfoWith(output_dims);
LOG(INFO) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
// gen default_work_size
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
LOG(INFO) << "default work size: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
auto paddings = *param.paddings;
LOG(INFO) << "paddings: " << paddings[0] << "," << paddings[1];
auto strides = param.strides;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = param.filter->data<float, cl::Image2D>();
// handle bias use buffer for channel wise , use image for element wise
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
LOG(INFO) << "has bias: " << has_bias;
LOG(INFO) << "is_element_wise_bias : " << is_element_wise_bias;
LOG(INFO) << "get kernel ...";
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
#ifndef USE_BUFFER_FOR_CONV1x1_BIAS
is_element_wise_bias
? (bias_image = param.bias->data<float, cl::Image2D>())
: (bias_buf = param.bias->data<float, cl::Buffer>());
#else
bias_image = param.bias->data<float, cl::Image2D>();
#endif
}
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
LOG(INFO) << "kernel_key: " << kernel_key.str();
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
LOG(INFO) << "kernel ready ... " << kernel_key.str();
cl_int status;
auto numel = output_dims.production();
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
int maped_w = maptofactor(w, 4);
LOG(INFO) << "maped_w: " << maped_w;
status = kernel.setArg(++arg_idx, maped_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
LOG(INFO) << "nh: " << nh;
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
LOG(INFO) << "input_image: ";
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
LOG(INFO) << "filter_image: ";
CL_CHECK_FATAL(status);
if (has_bias) {
#ifndef USE_BUFFER_FOR_CONV1x1_BIAS
if (is_element_wise_bias != 0) {
LOG(INFO) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
} else {
LOG(INFO) << "set bias_buf: ";
status = kernel.setArg(++arg_idx, *bias_buf);
}
#else
status = kernel.setArg(++arg_idx, *bias_image);
#endif
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
LOG(INFO) << "out_image: ";
CL_CHECK_FATAL(status);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
status = kernel.setArg(++arg_idx, strides[0]);
LOG(INFO) << "strides: " << strides[0] << "," << strides[1];
CL_CHECK_FATAL(status);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
LOG(INFO) << "offset: " << offset;
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
LOG(INFO) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
int input_c_block = input_image_shape["width"] / input_dims[3];
LOG(INFO) << "input_c_block: " << input_c_block;
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
int input_c = input_dims[1];
LOG(INFO) << "input_c: " << input_c;
status = kernel.setArg(++arg_idx, input_c);
CL_CHECK_FATAL(status);
auto dilations = *param.dilations;
LOG(INFO) << "dilations.size : " << dilations.size();
LOG(INFO) << "dilations: " << dilations[0] << ", " << dilations[1];
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
// clac gloabl_work_size
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(maped_w),
static_cast<size_t>(default_work_size.data()[2])};
LOG(INFO) << "global_work_size :" << global_work_size;
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
private:
std::string kernel_func_name_{"conv_1x1"};
std::string build_options_{"-DCL_DTYPE=float "};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(conv2d_1x1,
kOpenCL,
kFloat,
kNHWC,
paddle::lite::kernels::opencl::Conv2d1x1Image2DCompute,
image2d)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
template <typename Dtype1, typename Dtype2>
static void conv_basic(const Dtype1* din,
Dtype2* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu) {
Dtype2 beta = 0;
auto src_data = din;
auto dst_data_ref = dout;
auto weights_data = weights;
auto with_bias = flag_bias;
auto bias_data = bias;
int in_num = num;
int out_channels = chout;
int out_h = hout;
int out_w = wout;
int in_channel = chin;
int in_h = hin;
int in_w = win;
int out_c_group = out_channels / group;
int in_c_group = in_channel / group;
for (int n = 0; n < in_num; ++n) {
for (int g = 0; g < group; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * group * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w + oc * out_h * out_w +
oh * out_w + ow;
Dtype2 bias_d =
with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0;
dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int iidx = n * in_channel * in_h * in_w +
g * in_c_group * in_h * in_w + ic * in_h * in_w +
ih * in_w + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx];
}
}
}
if (flag_relu) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
}
}
}
}
}
}
}
TEST(conv2d_1x1, compute) {
// conv infos
const int ksize = 1;
const int stride = 1;
const int pad = 0;
const int group = 1;
const int dilation = 0;
// int loop_cnt = 0;
const bool bias_flag = true;
const bool relu_flag = true;
const int batch_size = 8;
const int oc = 64;
const int ih = 28;
const int iw = 28;
const int ic = 63;
const int oh = ih;
const int ow = iw;
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"conv2d_1x1", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
LOG(INFO) << "created conv2d_1x1 kernel";
LOG(INFO) << "prepare kernel ------";
lite::Tensor input, filter, bias, output;
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
if (bias_flag) {
param.bias = &bias;
}
param.fuse_relu = relu_flag;
std::vector<int> paddings = {pad, pad, pad, pad};
std::vector<int> dilations = {dilation, dilation};
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.strides = std::vector<int>{stride, stride};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
std::unique_ptr<KernelContext> conv_1x1_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(conv_1x1_context->As<OpenCLContext>()));
kernel->SetContext(std::move(conv_1x1_context));
const DDim& input_dim =
lite::DDim{std::vector<int64_t>({batch_size, ic, ih, iw})};
const DDim& filter_dim =
lite::DDim{std::vector<int64_t>({oc, ic, ksize, ksize})};
const DDim& out_dim =
lite::DDim{std::vector<int64_t>({batch_size, oc, ih, iw})};
// element wise bias
const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})};
param.x->Resize(input_dim);
param.filter->Resize(filter_dim);
param.output->Resize(out_dim);
if (bias_flag) {
param.bias->Resize(bias_dim);
}
kernel->SetParam(param);
size_t input_image_width = iw * ((ic + 3) / 4);
size_t input_image_height = ih * batch_size;
size_t out_image_width = ow * ((oc + 3) / 4);
size_t out_image_height = oh * batch_size;
size_t bias_image_width = ow * ((oc + 3) / 4);
size_t bias_image_height = oh * batch_size;
size_t filter_image_width = ksize * ((oc + 3) / 4);
size_t filter_image_height = ic * ksize;
auto* input_data = input.mutable_data<float, cl::Image2D>(input_image_width,
input_image_height);
auto* filter_data = filter.mutable_data<float, cl::Image2D>(
filter_image_width, filter_image_height);
bias.mutable_data<float, cl::Image2D>(bias_image_width, bias_image_height);
auto* bias_data = bias.mutable_data<float, cl::Image2D>(bias_image_width,
bias_image_height);
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
LOG(INFO) << "map input ...";
auto* mapped_input =
static_cast<float*>(TargetWrapperCL::MapImage(input_data,
input_image_width,
input_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch));
LOG(INFO) << "map filter ...";
auto* mapped_filter =
static_cast<float*>(TargetWrapperCL::MapImage(filter_data,
filter_image_width,
filter_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch));
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
std::vector<float> input_v(batch_size * ic * ih * iw);
std::vector<float> filter_v(oc * ic * ksize * ksize);
std::vector<float> output_v(batch_size * oc * ih * iw);
std::vector<float> bias_v(oc);
float* input_v_data = &input_v[0];
float* filter_v_data = &filter_v[0];
float* output_v_data = &output_v[0];
float* bias_v_data = &bias_v[0];
LOG(INFO) << "gen input and filter ...";
for (auto& i : input_v) {
i = gen(engine);
}
for (auto& f : filter_v) {
f = gen(engine);
}
LOG(INFO) << "after gen input and filter ...";
LOG(INFO) << "input_v.size(): " << input_v.size();
LOG(INFO) << "filter_v.size(): " << filter_v.size();
LOG(INFO) << "output_v.size(): " << output_v.size();
LOG(INFO) << "bias_v.size(): " << bias_v.size();
LOG(INFO) << "input_dim.production(): " << input_dim.production();
LOG(INFO) << "filter_dim.production(): " << filter_dim.production();
LOG(INFO) << "out_dim.production(): " << out_dim.production();
LOG(INFO) << "bias_dim.production(): " << bias_dim.production();
LOG(INFO) << "4 * input_image_height * input_image_width: "
<< 4 * input_image_height * input_image_width;
LOG(INFO) << "4 * filter_image_width * filter_image_height: "
<< 4 * filter_image_width * filter_image_height;
CHECK(input_dim.production() == input_v.size());
CHECK_LE(input_dim.production(), 4 * input_image_height * input_image_width);
CHECK(filter_dim.production() == filter_v.size());
CHECK_LE(filter_dim.production(),
4 * filter_image_width * filter_image_height);
paddle::lite::CLImageConverterDefault default_convertor;
LOG(INFO) << "set mapped input ...";
default_convertor.NCHWToImage(input_v_data, mapped_input, input_dim);
LOG(INFO) << "set mapped filter ...";
paddle::lite::CLImageConverterNWBlock nw_convertor;
nw_convertor.NCHWToImage(filter_v_data, mapped_filter, filter_dim);
LOG(INFO) << "resize output ...";
output.Resize(out_dim);
// cpu conv basic calc
lite::Tensor out_ref;
out_ref.Resize(out_dim);
float* mapped_bias = nullptr;
if (bias_flag) {
mapped_bias =
static_cast<float*>(TargetWrapperCL::MapImage(bias_data,
bias_image_width,
bias_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch));
for (int i = 0; i < bias_dim.production(); ++i) {
bias_v[i] = static_cast<int>(gen(engine));
}
CLImageConverterFolder folder_convertor;
folder_convertor.NCHWToImage(bias_v_data, mapped_bias, bias_dim);
}
LOG(INFO) << "prepare kernel ready";
LOG(INFO) << "kernel launch ...";
kernel->Launch();
LOG(INFO) << "mutable output ...";
auto* output_data = output.mutable_data<float, cl::Image2D>(out_image_width,
out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<float, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
auto* mapped_output =
static_cast<float*>(TargetWrapperCL::MapImage(output_data,
out_image_width,
out_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch));
LOG(INFO) << "mutable_data out_ref_data: ";
// run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
LOG(INFO) << " conv_basic beigin ..... ";
conv_basic<float, float>(input_v_data,
out_ref_data,
batch_size,
oc,
oh,
ow,
ic,
ih,
iw,
filter_v_data,
bias_v_data, // mapped_bias,
group,
ksize,
ksize,
stride,
stride,
dilation,
dilation,
pad,
pad,
bias_flag,
relu_flag);
LOG(INFO) << " conv_basic end ..... ";
LOG(INFO) << " out_dim: " << out_dim;
const DDim& out_image_dims = lite::DDim{
std::vector<int64_t>({static_cast<int64_t>(out_image_width),
static_cast<int64_t>(out_image_height)})};
default_convertor.ImageToNCHW(
mapped_output, output_v_data, out_image_dims, out_dim);
for (int i = 0; i < out_dim.production(); i++) {
EXPECT_NEAR(output_v_data[i], out_ref_data[i], 1e-3);
if (abs(output_v_data[i] - out_ref_data[i]) > 1e-3) {
LOG(FATAL) << "error idx:" << i;
}
}
TargetWrapperCL::Unmap(output_data, mapped_output);
TargetWrapperCL::Unmap(filter_data, mapped_filter);
TargetWrapperCL::Unmap(input_data, mapped_input);
if (bias_flag) {
if (mapped_bias) {
TargetWrapperCL::Unmap(bias_data, mapped_bias);
}
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv2d_1x1, kOpenCL, kFloat, kNHWC, image2d);
......@@ -40,6 +40,39 @@ static std::map<std::string, size_t> InitImageDimInfoWith(
size_t height = H * N;
return std::map<std::string, size_t>({{"width", width}, {"height", height}});
}
inline static int maptofactor(int i, int factor) {
return (i + factor - 1) / factor;
}
static std::vector<size_t> DefaultWorkSize(const DDim& image_dim,
const DDim& image_shape) {
// n c h w
// auto image_dim = image.dims();
if (image_dim.size() == 4) {
auto n = image_dim[0];
auto h = image_dim[2];
auto w = image_dim[3];
auto image_width = image_shape[0];
size_t work_size_0 = image_width / w;
size_t work_size_1 = w;
size_t work_size_2 = n * h;
return {work_size_0, work_size_1, work_size_2};
} else if (image_dim.size() == 2) {
auto h = image_dim[0];
auto w = image_dim[1];
return {1,
static_cast<unsigned int>(image_shape[0]),
static_cast<unsigned int>(image_shape[1])};
} else if (image_dim.size() == 1) {
return {1, static_cast<unsigned int>(image_shape[0]), 1};
} else if (image_dim.size() == 3) {
size_t c = image_dim[0];
size_t h = image_dim[1];
size_t w = image_dim[2];
return {(c + 3) / 4, w, h};
}
LOG(FATAL) << " not support this dim, need imp ";
}
} // namespace opencl
} // namespace kernels
......
......@@ -126,6 +126,102 @@ class LayoutComputeBufferChwToImage2DHwc
std::shared_ptr<cl::Event> event_{new cl::Event};
};
// buffer chw 2 image2d nw
class LayoutComputeBufferChwToImage2DNw
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW)> {
public:
using param_t = operators::LayoutParam;
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "buffer/layout_kernel.cl", build_options_);
}
void Run() override {
auto& param = Param<param_t>();
auto* x_data = param.x->data<float, cl::Buffer>();
auto x_dims = param.x->dims();
CHECK(x_dims.size() == 4) << " Tensor dim is not 4.";
size_t image_width = x_dims[3] * ((x_dims[0] + 3) / 4);
size_t image_height = x_dims[1] * x_dims[2];
auto* y_data =
param.y->mutable_data<float, cl::Image2D>(image_width, image_height);
auto y_dims = param.y->dims();
// out info
std::vector<size_t> new_dims = {1, 1, 1, 1};
for (int tidx = 0; tidx < x_dims.size(); ++tidx) {
new_dims[4 - x_dims.size() + tidx] = x_dims[tidx];
}
const int out_N = new_dims[0];
const int out_C = new_dims[1];
const int out_H = new_dims[2];
const int out_W = new_dims[3];
const int Stride2 = out_C * out_H * out_W;
const int Stride1 = out_H * out_W;
const int Stride0 = out_W;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0;
cl_int status = kernel.setArg(arg_idx, *x_data);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *y_data);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_H));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_W));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(out_N));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride0));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride1));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride2));
CL_CHECK_FATAL(status);
VLOG(4) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " "
<< (out_C * out_H);
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>((out_N + 3) / 4), // N blocks
static_cast<cl::size_type>(out_W), // w
static_cast<cl::size_type>(out_C * out_H)}; // ch
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
// TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list`
// context.cl_wait_list()->emplace(y_data, event_);
context.cl_context()->GetCommandQueue().finish();
// auto image_shape = InitImageDimInfoWith(x_dims);
}
std::string doc() const override {
return "Trans Layout from cl::Buffer(NCHW) to cl::Image2D(CLNW)";
}
private:
std::string kernel_func_name_{"buffer_to_image2d_nw"};
std::string build_options_{"-DCL_DTYPE_float "};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class LayoutComputeImage2DHwcToBufferChw
: public KernelLite<TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)> {
public:
......@@ -293,3 +389,21 @@ REGISTER_LITE_KERNEL(
PRECISION(kAny),
DATALAYOUT(kNCHW))})
.Finalize();
// [hwc] -> [chw]
REGISTER_LITE_KERNEL(
layout_once,
kOpenCL,
kFloat,
kImageNW,
paddle::lite::kernels::opencl::LayoutComputeBufferChwToImage2DNw,
buffer_chw_to_image2d_nw_opencl_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW))})
.Finalize();
......@@ -144,7 +144,141 @@ TEST(layout, compute) {
// nothing to do.
#endif
}
TEST(layout, compute_buffer2image2dnw) {
#ifdef LOOP_TEST
for (int n = 1; n <= 100; n += 21) {
for (auto c : {1, 3}) {
for (int h = 1; h <= 100; h += 13) {
for (int w = 1; w <= 100; w += 17) {
#else
const int n = 1;
const int c = 1;
const int h = 1;
const int w = 100;
#endif // LOOP_TEST
LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " "
<< h << " " << w << " ========";
// set layout kernels
auto buf_to_img_nw_kernels =
KernelRegistry::Global().Create("layout_once",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW));
ASSERT_FALSE(buf_to_img_nw_kernels.empty());
auto buf_to_img_nw_kernel = std::move(buf_to_img_nw_kernels.front());
LOG(INFO) << "get 1st kernel: " << buf_to_img_nw_kernel->doc();
// set tensors about op param
operators::LayoutParam bufferToImageParam;
lite::Tensor x, y, cpu_y;
bufferToImageParam.x = &x;
bufferToImageParam.y = &y;
const DDim x_dim = DDim(std::vector<DDim::value_type>{n, c, h, w});
x.Resize(x_dim);
y.Resize(x_dim); // useless for image2D
cpu_y.Resize(x_dim);
// initialize tensors
LOG(INFO) << "initialize tensors";
// mute in buffer
auto* x_data = x.mutable_data<float, cl::Buffer>(TARGET(kOpenCL));
// mute out image nw
size_t image_width = w * ((n + 3) / 4);
size_t image_height = c * h;
auto* y_data =
y.mutable_data<float, cl::Image2D>(image_width, image_height);
auto* cpu_y_data =
cpu_y.mutable_data<float, cl::Image2D>(image_width, image_height);
auto* mapped_x = static_cast<float*>(TargetWrapperCL::Map(
x_data, 0, sizeof(float) * x_dim.production()));
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
auto* mapped_y = static_cast<float*>(
TargetWrapperCL::MapImage(y_data,
image_width,
image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch));
auto* mapped_cpu_y = static_cast<float*>(
TargetWrapperCL::MapImage(cpu_y_data,
image_width,
image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch));
// random datas
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
for (int i = 0; i < x_dim.production(); ++i) {
mapped_x[i] = gen(engine);
}
// gen cpu y_data
CLImageConverterNWBlock nw_converter;
nw_converter.NCHWToImage(mapped_x, mapped_cpu_y, x_dim);
// set context and kernel args
LOG(INFO) << "set context and kernel args";
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
// set kernel params
buf_to_img_nw_kernel->SetParam(bufferToImageParam);
std::unique_ptr<KernelContext> buf_to_img_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(buf_to_img_context->As<OpenCLContext>()));
// set context
buf_to_img_nw_kernel->SetContext(std::move(buf_to_img_context));
// run kernels
LOG(INFO) << "run kernel: buf_to_img_kernel";
buf_to_img_nw_kernel->Launch();
// result
#ifdef PRINT_RESULT
LOG(INFO) << "---- print result ----";
for (int eidx = 0; i < x_dim.production(); ++eidx) {
std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx]
<< std::endl;
}
#endif // PRINT_RESULT
// check result: compare input and output
for (int eidx = 0; eidx < x_dim.production(); eidx++) {
EXPECT_NEAR(mapped_cpu_y[eidx], mapped_y[eidx], 1e-3);
if (abs(mapped_cpu_y[eidx] - mapped_y[eidx]) > 1e-3) {
LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx
<< " / " << x_dim.production() << ", mapped_x[" << eidx
<< "]:" << mapped_cpu_y[eidx] << ", mapped_y[" << eidx
<< "]:" << mapped_y[eidx];
break;
}
}
// free
LOG(INFO) << "free: unmap x, y";
TargetWrapperCL::Unmap(x_data, mapped_x);
TargetWrapperCL::Unmap(y_data, mapped_y);
#ifdef LOOP_TEST
} // w
} // h
} // c
} // n
#else
// nothing to do.
#endif
}
} // namespace lite
} // namespace paddle
......@@ -152,3 +286,8 @@ USE_LITE_KERNEL(
layout, kOpenCL, kAny, kNHWC, buffer_chw_to_image2d_hwc_opencl_fp32);
USE_LITE_KERNEL(
layout, kOpenCL, kAny, kNCHW, image2d_hwc_to_buffer_chw_opencl_fp32);
USE_LITE_KERNEL(layout_once,
kOpenCL,
kFloat,
kImageNW,
buffer_chw_to_image2d_nw_opencl_fp32);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册