未验证 提交 2a68646c 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] comment All opencl buffer kernels. test=develop (#2874)

* comment All opencl buffer kernels. test=develop

* refactor conv, depthwise into one routing selection. test=develop
上级 5394e674
...@@ -123,10 +123,10 @@ TEST(MobileNetV1, test_arm) { ...@@ -123,10 +123,10 @@ TEST(MobileNetV1, test_arm) {
#ifdef LITE_WITH_OPENCL #ifdef LITE_WITH_OPENCL
TEST(MobileNetV1, test_opencl) { TEST(MobileNetV1, test_opencl) {
std::vector<Place> valid_places({ std::vector<Place> valid_places({
Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)}, Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)},
Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNHWC)},
Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}, Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)}, Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)},
Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)},
TARGET(kARM), // enable kARM CPU kernel when no opencl kernel TARGET(kARM), // enable kARM CPU kernel when no opencl kernel
}); });
......
...@@ -89,13 +89,13 @@ std::vector<Place> ParserValidPlaces() { ...@@ -89,13 +89,13 @@ std::vector<Place> ParserValidPlaces() {
valid_places.emplace_back(TARGET(kARM)); valid_places.emplace_back(TARGET(kARM));
} else if (target_repr == "opencl") { } else if (target_repr == "opencl") {
valid_places.emplace_back( valid_places.emplace_back(
Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)}); Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)});
valid_places.emplace_back(
Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNHWC)});
valid_places.emplace_back( valid_places.emplace_back(
Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}); Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)});
valid_places.emplace_back( valid_places.emplace_back(
Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)}); Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)});
valid_places.emplace_back(
Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)});
valid_places.emplace_back( valid_places.emplace_back(
TARGET(kARM)); // enable kARM CPU kernel when no opencl kernel TARGET(kARM)); // enable kARM CPU kernel when no opencl kernel
} else if (target_repr == "x86") { } else if (target_repr == "x86") {
......
...@@ -142,7 +142,7 @@ __kernel void depth_conv2d_3x3(__private const int global_size_dim0, ...@@ -142,7 +142,7 @@ __kernel void depth_conv2d_3x3(__private const int global_size_dim0,
#endif #endif
#ifdef RELU #ifdef RELU
output = activation(output); output = activation_type4(output);
#endif #endif
...@@ -309,8 +309,8 @@ __kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk, ...@@ -309,8 +309,8 @@ __kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk,
#endif #endif
#ifdef RELU #ifdef RELU
output[0] = activation(output[0]); output[0] = activation_type4(output[0]);
output[1] = activation(output[1]); output[1] = activation_type4(output[1]);
#endif #endif
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(ou_x, ou_nh_id), output[0]); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(ou_x, ou_nh_id), output[0]);
......
...@@ -16,7 +16,6 @@ add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${te ...@@ -16,7 +16,6 @@ add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${te
add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps})
add_kernel(sigmoid_opencl OPENCL basic SRCS sigmoid_compute.cc DEPS ${cl_kernel_deps}) add_kernel(sigmoid_opencl OPENCL basic SRCS sigmoid_compute.cc DEPS ${cl_kernel_deps})
add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps})
#add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps})
add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps}) add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps})
add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps} cl_image_converter) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps} cl_image_converter)
add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps})
...@@ -62,14 +61,10 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc ...@@ -62,14 +61,10 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context DEPS depthwise_conv2d_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_depthwise_conv2d_basic_opencl SRCS depthwise_conv2d_basic_compute_test.cc lite_cc_test(test_depthwise_conv2d_image2d_opencl SRCS depthwise_conv2d_image2d_compute_test.cc
DEPS depthwise_conv2d_opencl op_registry program context DEPS conv_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
#lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc
# DEPS conv2d_1x1_opencl op_registry program context
# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
lite_cc_test(test_reshape_opencl SRCS reshape_compute_test.cc lite_cc_test(test_reshape_opencl SRCS reshape_compute_test.cc
DEPS reshape_opencl op_registry program context DEPS reshape_opencl op_registry program context
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl)
......
...@@ -356,17 +356,17 @@ REGISTER_LITE_KERNEL( ...@@ -356,17 +356,17 @@ REGISTER_LITE_KERNEL(
DATALAYOUT(kImageDefault))}) DATALAYOUT(kImageDefault))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, Concat_buffer, def) // REGISTER_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, Concat_buffer, def)
.BindInput("X", // .BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL), // {LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat), // PRECISION(kFloat),
DATALAYOUT(kNCHW))}) // DATALAYOUT(kNCHW))})
.BindInput("AxisTensor", // .BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kOpenCL), // {LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kInt32), // PRECISION(kInt32),
DATALAYOUT(kNCHW))}) // DATALAYOUT(kNCHW))})
.BindOutput("Out", // .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL), // {LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat), // PRECISION(kFloat),
DATALAYOUT(kNCHW))}) // DATALAYOUT(kNCHW))})
.Finalize(); // .Finalize();
...@@ -73,7 +73,7 @@ void concat_mul_compute_ref(std::vector<const dtype *> ins_data, ...@@ -73,7 +73,7 @@ void concat_mul_compute_ref(std::vector<const dtype *> ins_data,
} }
} }
} }
#if 1 // concat_buffer #if 0 // concat_buffer
TEST(opencl_concat_buffer, compute) { TEST(opencl_concat_buffer, compute) {
// prepare data // prepare data
const DDim x0_dim = DDim(std::vector<DDim::value_type>{1, 2, 3, 4}); const DDim x0_dim = DDim(std::vector<DDim::value_type>{1, 2, 3, 4});
...@@ -382,7 +382,7 @@ TEST(concat_image2d_fp32, compute) { ...@@ -382,7 +382,7 @@ TEST(concat_image2d_fp32, compute) {
} // namespace paddle } // namespace paddle
// concat buffer // concat buffer
USE_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, def); // USE_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, def);
// concat image2d fp32 // concat image2d fp32
USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
#define USE_BUFFER_FOR_CONV1x1_BIAS
class Conv2d1x1Image2DCompute : public KernelLite<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override {
const auto& param = *param_.get_mutable<param_t>();
if (param.fuse_relu) {
build_options_ += " -DRELU";
}
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
if (has_bias) {
build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
}
auto& context = ctx_->As<OpenCLContext>();
if (param.x->dims()[1] % 4 == 0) {
context.cl_context()->AddKernel(kernel_func_name_simple_,
"image/conv2d_1x1_kernel.cl",
build_options_);
} else {
context.cl_context()->AddKernel(
kernel_func_name_, "image/conv2d_1x1_kernel.cl", build_options_);
}
}
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = param.filter->data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ conv2d_1x1 params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
#ifndef USE_BUFFER_FOR_CONV1x1_BIAS
is_element_wise_bias
? (bias_image = param.bias->data<float, cl::Image2D>())
: (bias_buf = param.bias->data<float, cl::Buffer>());
#else
bias_image = param.bias->data<float, cl::Image2D>();
#endif
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
if (input_dims[1] % 4 == 0) {
kernel_key << kernel_func_name_simple_ << build_options_;
} else {
kernel_key << kernel_func_name_ << build_options_;
}
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int maped_w = maptofactor(w, 4);
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "maped_w: " << maped_w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, maped_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
#ifndef USE_BUFFER_FOR_CONV1x1_BIAS
if (is_element_wise_bias != 0) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
} else {
VLOG(4) << "set bias_buf: ";
status = kernel.setArg(++arg_idx, *bias_buf);
}
#else
status = kernel.setArg(++arg_idx, *bias_image);
#endif
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(maped_w),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
private:
std::string kernel_func_name_{"conv2d_1x1"};
std::string kernel_func_name_simple_{"conv2d_1x1_simple"};
std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(conv2d_1x1,
kOpenCL,
kFloat,
kImageDefault,
paddle::lite::kernels::opencl::Conv2d1x1Image2DCompute,
image2d)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/utils/logging.h"
namespace paddle {
namespace lite {
template <typename Dtype1, typename Dtype2>
static void conv_basic(const Dtype1* din,
Dtype2* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu) {
Dtype2 beta = 0;
auto src_data = din;
auto dst_data_ref = dout;
auto weights_data = weights;
auto with_bias = flag_bias;
auto bias_data = bias;
int in_num = num;
int out_channels = chout;
int out_h = hout;
int out_w = wout;
int in_channel = chin;
int in_h = hin;
int in_w = win;
int out_c_group = out_channels / group;
int in_c_group = in_channel / group;
for (int n = 0; n < in_num; ++n) {
for (int g = 0; g < group; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * group * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w + oc * out_h * out_w +
oh * out_w + ow;
Dtype2 bias_d =
with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0;
dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int iidx = n * in_channel * in_h * in_w +
g * in_c_group * in_h * in_w + ic * in_h * in_w +
ih * in_w + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx];
}
}
}
if (flag_relu) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
}
}
}
}
}
}
}
TEST(conv2d_1x1, compute) {
// conv infos
const int ksize = 1;
const int stride = 1;
const int pad = 0;
const int group = 1;
const int dilation = 0;
// int loop_cnt = 0;
#ifdef LOOP_TEST
for (int batch_size = 1; batch_size < 4; ++batch_size) {
for (int oc = 4; oc < 10; oc += 1) { // oc
for (int ih = 4; ih < 9; ih += 1) { // ih
/*int iw = ih;*/ for (int iw = 4; iw < 10; iw += 1) { // iw
for (int ic = 4; ic < 10; ic += 1) { // ic
for (bool bias_flag : {true, false}) {
for (bool relu_flag : {true, false}) {
#else
const int batch_size = 1;
const int oc = 4;
const int ih = 8;
const int iw = 8;
const int ic = 4;
const bool bias_flag = false;
const bool relu_flag = false;
#endif
const int oh = ih;
const int ow = iw;
VLOG(4) << "to get kernel ...";
auto kernels =
KernelRegistry::Global().Create("conv2d_1x1",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
VLOG(4) << "created conv2d_1x1 kernel";
VLOG(4) << "prepare kernel ------";
lite::Tensor input, filter, bias, output;
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
if (bias_flag) {
param.bias = &bias;
}
param.fuse_relu = relu_flag;
std::vector<int> paddings = {pad, pad, pad, pad};
std::vector<int> dilations = {dilation, dilation};
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.strides = std::vector<int>{stride, stride};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
std::unique_ptr<KernelContext> conv_1x1_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(conv_1x1_context->As<OpenCLContext>()));
kernel->SetContext(std::move(conv_1x1_context));
const DDim& input_dim =
lite::DDim{std::vector<int64_t>({batch_size, ic, ih, iw})};
const DDim& filter_dim =
lite::DDim{std::vector<int64_t>({oc, ic, ksize, ksize})};
const DDim& out_dim =
lite::DDim{std::vector<int64_t>({batch_size, oc, ih, iw})};
// element wise bias
const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})};
param.x->Resize(input_dim);
param.filter->Resize(filter_dim);
param.output->Resize(out_dim);
if (bias_flag) {
param.bias->Resize(bias_dim);
}
kernel->SetParam(param);
size_t input_image_width = iw * ((ic + 3) / 4);
size_t input_image_height = ih * batch_size;
size_t out_image_width = ow * ((oc + 3) / 4);
size_t out_image_height = oh * batch_size;
size_t bias_image_width = ow * ((oc + 3) / 4);
size_t bias_image_height = oh * batch_size;
size_t filter_image_width = ksize * ((oc + 3) / 4);
size_t filter_image_height = ic * ksize;
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
std::vector<float> input_v(batch_size * ic * ih * iw);
std::vector<float> filter_v(oc * ic * ksize * ksize);
std::vector<float> output_v(batch_size * oc * ih * iw);
std::vector<float> bias_v(oc);
VLOG(4) << "gen input and filter ...";
for (auto& i : input_v) {
i = gen(engine);
}
for (auto& f : filter_v) {
f = gen(engine);
}
VLOG(4) << "after gen input and filter ...";
VLOG(4) << "input_v.size(): " << input_v.size();
VLOG(4) << "filter_v.size(): " << filter_v.size();
VLOG(4) << "output_v.size(): " << output_v.size();
VLOG(4) << "bias_v.size(): " << bias_v.size();
VLOG(4) << "input_dim.production(): " << input_dim.production();
VLOG(4) << "filter_dim.production(): "
<< filter_dim.production();
VLOG(4) << "out_dim.production(): " << out_dim.production();
VLOG(4) << "bias_dim.production(): " << bias_dim.production();
VLOG(4) << "4 * input_image_height * input_image_width: "
<< 4 * input_image_height * input_image_width;
VLOG(4) << "4 * filter_image_width * filter_image_height: "
<< 4 * filter_image_width * filter_image_height;
CHECK(input_dim.production() == input_v.size());
CHECK_LE(input_dim.production(),
4 * input_image_height * input_image_width);
CHECK(filter_dim.production() == filter_v.size());
CHECK_LE(filter_dim.production(),
4 * filter_image_width * filter_image_height);
paddle::lite::CLImageConverterDefault default_convertor;
VLOG(4) << "set mapped input ...";
std::vector<float> x_image_v(
input_image_width * input_image_height * 4); // 4 : RGBA
std::vector<float> filter_image_v(
filter_image_width * filter_image_height * 4); // 4 : RGBA
std::vector<float> bias_image_v(
bias_image_width * bias_image_height * 4); // 4 : RGBA
std::vector<float> out_image_v(
out_image_width * out_image_height * 4); // 4 : RGBA
default_convertor.NCHWToImage(
input_v.data(), x_image_v.data(), input_dim);
/* for (int j = 0; j < input_v.size(); j += 1) {
// VLOG(4) << "input_v
input[" << j << "]:
// " << input_v.data()[j];
std::cout << j << " " << input_v.data()[j] <<
std::endl;
}
std::cout << std::endl;
for (int j = 0; j < x_image_v.size(); j += 1) {
// VLOG(4) << "x_image_v
input[" << j <<
// "]: " <<
x_image_v.data()[j];
std::cout << j << " " << x_image_v.data()[j]
<< std::endl;
}*/
VLOG(4) << "set mapped filter ...";
paddle::lite::CLImageConverterNWBlock nw_convertor;
nw_convertor.NCHWToImage(
filter_v.data(), filter_image_v.data(), filter_dim);
auto* input_image2d = input.mutable_data<float, cl::Image2D>(
input_image_width, input_image_height, x_image_v.data());
auto* filter_image2d = filter.mutable_data<float, cl::Image2D>(
filter_image_width,
filter_image_height,
filter_image_v.data());
if (bias_flag) {
nw_convertor.NCHWToImage(
filter_v.data(), filter_image_v.data(), filter_dim);
for (int i = 0; i < bias_dim.production(); ++i) {
bias_v[i] = static_cast<int>(gen(engine));
}
CLImageConverterFolder folder_convertor;
folder_convertor.NCHWToImage(
bias_v.data(), bias_image_v.data(), bias_dim);
auto* bias_data = bias.mutable_data<float, cl::Image2D>(
bias_image_width, bias_image_height, bias_image_v.data());
}
VLOG(4) << "resize output ...";
output.Resize(out_dim);
// cpu conv basic calc
lite::Tensor out_ref;
out_ref.Resize(out_dim);
VLOG(4) << "prepare kernel ready";
VLOG(4) << "kernel launch ...";
kernel->Launch();
VLOG(4) << "mutable output ...";
auto* output_image2d = output.mutable_data<float, cl::Image2D>(
out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<float, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<float, cl::Image2D>(),
out_image_width,
out_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
DDim out_image_shape =
default_convertor.InitImageDimInfoWith(output.dims());
default_convertor.ImageToNCHW(out_image_v.data(),
output_v.data(),
out_image_shape,
output.dims());
VLOG(4) << "mutable_data out_ref_data: ";
// run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
VLOG(4) << " conv_basic beigin ..... ";
conv_basic<float, float>(input_v.data(),
out_ref_data,
batch_size,
oc,
oh,
ow,
ic,
ih,
iw,
filter_v.data(),
bias_v.data(), // mapped_bias,
group,
ksize,
ksize,
stride,
stride,
dilation,
dilation,
pad,
pad,
bias_flag,
relu_flag);
VLOG(4) << " conv_basic end ..... ";
VLOG(4) << " out_dim: " << out_dim;
const DDim& out_image_dims = lite::DDim{std::vector<int64_t>(
{static_cast<int64_t>(out_image_width),
static_cast<int64_t>(out_image_height)})};
for (int i = 0; i < out_dim.production(); i++) {
EXPECT_NEAR(output_v[i], out_ref_data[i], 1e-2);
if (abs(output_v[i] - out_ref_data[i]) > 1e-2) {
LOG(FATAL) << "error idx:" << i;
}
}
#ifdef LOOP_TEST
}
}
}
}
}
}
}
#else
// nothing to do.
#endif
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv2d_1x1, kOpenCL, kFloat, kImageDefault, image2d);
...@@ -362,6 +362,40 @@ void ConvImageCompute::PrepareForRun() { ...@@ -362,6 +362,40 @@ void ConvImageCompute::PrepareForRun() {
filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d1x1; impl_ = &ConvImageCompute::Conv2d1x1;
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] &&
kernel_h == 3 && kernel_w == 3 && groups > 1) {
// depth_conv2d_3x3s1, depth_conv2d_3x3
if (stride_h == 1 && dilations[0] == 1) {
kernel_func_names_.push_back("depth_conv2d_3x3s1");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3s1;
} else {
kernel_func_names_.push_back("depth_conv2d_3x3");
impl_ = &ConvImageCompute::DepthwiseConv2d3x3;
}
kernel_func_paths_.push_back("image/depthwise_conv2d_kernel.cl");
CLImageConverterDWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
} else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] &&
kernel_h != 3 && groups > 1) {
// depth_conv2d
kernel_func_names_.push_back("depth_conv2d");
kernel_func_paths_.push_back("image/depthwise_conv2d_basic_kernel.cl");
CLImageConverterDWBlock converter;
const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims);
std::vector<float> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<float, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::DepthwiseConv2d;
} else if (kernel_h == 3 && kernel_h == 3) { } else if (kernel_h == 3 && kernel_h == 3) {
// conv2d_3x3 // conv2d_3x3
kernel_func_names_.push_back("conv2d_3x3"); kernel_func_names_.push_back("conv2d_3x3");
...@@ -407,6 +441,8 @@ void ConvImageCompute::PrepareForRun() { ...@@ -407,6 +441,8 @@ void ConvImageCompute::PrepareForRun() {
} else { } else {
LOG(FATAL) << "conv image compute not support this condition yet! "; LOG(FATAL) << "conv image compute not support this condition yet! ";
} }
VLOG(1) << "kernel_func_names_[0]:" << kernel_func_names_[0]
<< " kernel_func_paths_[0]:" << kernel_func_paths_[0];
std::string build_options_single(" -DCL_DTYPE_float"); std::string build_options_single(" -DCL_DTYPE_float");
// relu options // relu options
...@@ -1064,6 +1100,326 @@ void ConvImageCompute::Conv2d7x7() { ...@@ -1064,6 +1100,326 @@ void ConvImageCompute::Conv2d7x7() {
context.cl_wait_list()->emplace(out_image, event_); context.cl_wait_list()->emplace(out_image, event_);
} }
void ConvImageCompute::DepthwiseConv2d3x3s1() {
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<float, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<float, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_.data<float, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<float, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
int w_blk_size = 2;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
auto global_work_size = cl::NDRange(c_block, w_blk, nh);
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w_blk));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[1]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_img, event_);
}
void ConvImageCompute::DepthwiseConv2d3x3() {
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
int offset = filter_dims[2] / 2 - paddings[0];
int input_c_block = (x_dims[1] + 3) / 4;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<float, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<float, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_.data<float, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<float, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
auto global_work_size = cl::NDRange(c_block, w, nh);
VLOG(4) << "setArg";
VLOG(4) << "c_block = " << c_block;
VLOG(4) << "w = " << w;
VLOG(4) << "nh = " << nh;
VLOG(4) << "strides = " << strides[0];
VLOG(4) << "offset = " << offset;
VLOG(4) << "dilations = " << dilations[0];
VLOG(4) << "input_c_block = " << input_c_block;
VLOG(4) << "x_dims[3] = " << x_dims[3];
VLOG(4) << "x_dims[2] = " << x_dims[2];
VLOG(4) << "output_dims[3] = " << output_dims[3];
VLOG(4) << "output_dims[2] = " << output_dims[2];
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(offset));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(input_c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_img, event_);
}
void ConvImageCompute::DepthwiseConv2d() {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ depthwise conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_names_[0] << build_options_[0];
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
void ConvImageCompute::Run() { (this->*impl_)(); } void ConvImageCompute::Run() { (this->*impl_)(); }
} // namespace opencl } // namespace opencl
...@@ -1071,19 +1427,37 @@ void ConvImageCompute::Run() { (this->*impl_)(); } ...@@ -1071,19 +1427,37 @@ void ConvImageCompute::Run() { (this->*impl_)(); }
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// REGISTER_LITE_KERNEL(conv2d,
// kOpenCL,
// kFloat,
// kNCHW,
// paddle::lite::kernels::opencl::ConvCompute,
// def)
// .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))})
// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kOpenCL))})
// .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL))})
// .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL))})
// .Finalize();
REGISTER_LITE_KERNEL(conv2d, REGISTER_LITE_KERNEL(conv2d,
kOpenCL, kOpenCL,
kFloat, kFloat,
kNCHW, kImageDefault,
paddle::lite::kernels::opencl::ConvCompute, paddle::lite::kernels::opencl::ConvImageCompute,
def) image2d)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))}) .BindInput("Input",
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kOpenCL))}) {LiteType::GetTensorTy(TARGET(kOpenCL),
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL))}) PRECISION(kFloat),
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL))}) DATALAYOUT(kImageDefault))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(conv2d, REGISTER_LITE_KERNEL(depthwise_conv2d,
kOpenCL, kOpenCL,
kFloat, kFloat,
kImageDefault, kImageDefault,
......
...@@ -74,6 +74,9 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -74,6 +74,9 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
void Conv2d3x3(); void Conv2d3x3();
void Conv2d5x5(); void Conv2d5x5();
void Conv2d7x7(); void Conv2d7x7();
void DepthwiseConv2d3x3s1();
void DepthwiseConv2d3x3();
void DepthwiseConv2d();
kernel_t impl_; kernel_t impl_;
std::vector<std::string> kernel_func_names_{}; std::vector<std::string> kernel_func_names_{};
......
...@@ -166,6 +166,8 @@ void PrintData(std::string name, ...@@ -166,6 +166,8 @@ void PrintData(std::string name,
} }
} }
// buffer
#if 0
// #define PRINT_RESULT // #define PRINT_RESULT
#define LOOP_TEST #define LOOP_TEST
TEST(conv2d, compute_conv2d_1x1) { TEST(conv2d, compute_conv2d_1x1) {
...@@ -623,8 +625,9 @@ TEST(conv2d, compute_conv2d_gemm) { ...@@ -623,8 +625,9 @@ TEST(conv2d, compute_conv2d_gemm) {
} // batch_size } // batch_size
#endif #endif
} }
#endif
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(conv2d, kOpenCL, kFloat, kNCHW, def); // USE_LITE_KERNEL(conv2d, kOpenCL, kFloat, kNCHW, def);
...@@ -559,9 +559,11 @@ TEST(conv2d, compute_image2d_3x3) { ...@@ -559,9 +559,11 @@ TEST(conv2d, compute_image2d_3x3) {
// element wise bias // element wise bias
const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})}; const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})};
LOG(INFO) << "input_dim:" << input_dim VLOG(2) << "input_dim:" << input_dim
<< " filter_dim:" << filter_dim << " filter_dim:" << filter_dim << " out_dim:" << out_dim
<< " out_dim:" << out_dim; << " bias_flag:" << bias_flag << " bias_dim:" << bias_dim
<< " group:" << group << " stride:" << stride
<< " pad:" << pad << " dilation:" << dilation;
param.x->Resize(input_dim); param.x->Resize(input_dim);
param.filter->Resize(filter_dim); param.filter->Resize(filter_dim);
...@@ -902,6 +904,12 @@ TEST(conv2d, compute_image2d_5x5) { ...@@ -902,6 +904,12 @@ TEST(conv2d, compute_image2d_5x5) {
// element wise bias // element wise bias
const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})}; const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})};
VLOG(2) << "input_dim:" << input_dim
<< " filter_dim:" << filter_dim << " out_dim:" << out_dim
<< " bias_flag:" << bias_flag << " bias_dim:" << bias_dim
<< " group:" << group << " stride:" << stride
<< " pad:" << pad << " dilation:" << dilation;
param.x->Resize(input_dim); param.x->Resize(input_dim);
param.filter->Resize(filter_dim); param.filter->Resize(filter_dim);
param.output->Resize(out_dim); param.output->Resize(out_dim);
......
...@@ -123,420 +123,6 @@ class DepthwiseConv2dCompute ...@@ -123,420 +123,6 @@ class DepthwiseConv2dCompute
std::shared_ptr<cl::Event> event_{new cl::Event}; std::shared_ptr<cl::Event> event_{new cl::Event};
}; };
class DepthwiseConv2dComputeFP16Image
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ConvParam;
std::string doc() const override {
return "DepthwiseConv2d using cl::Image2D/kImageDefault, kFP16";
}
void PrepareForRun() override {
const auto& param = *param_.get_mutable<param_t>();
if (param.fuse_relu) {
build_options_ += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_ += " -DRELU6";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/depthwise_conv2d_kernel.cl", build_options_);
}
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
int offset = filter_dims[2] / 2 - paddings[0];
int input_c_block = (x_dims[1] + 3) / 4;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<int16_t, cl::Image2D>();
auto* filter_img = param.filter->data<int16_t, cl::Image2D>();
auto* bias_img = param.bias == nullptr
? static_cast<cl::Image2D*>(nullptr)
: param.bias->data<int16_t, cl::Image2D>();
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<int16_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
auto global_work_size = cl::NDRange(c_block, w, nh);
VLOG(4) << "setArg";
VLOG(4) << "c_block = " << c_block;
VLOG(4) << "w = " << w;
VLOG(4) << "nh = " << nh;
VLOG(4) << "strides = " << strides[0];
VLOG(4) << "offset = " << offset;
VLOG(4) << "dilations = " << dilations[0];
VLOG(4) << "input_c_block = " << input_c_block;
VLOG(4) << "x_dims[3] = " << x_dims[3];
VLOG(4) << "x_dims[2] = " << x_dims[2];
VLOG(4) << "output_dims[3] = " << output_dims[3];
VLOG(4) << "output_dims[2] = " << output_dims[2];
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(offset));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(input_c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_img, event_);
}
private:
std::string kernel_func_name_{"depth_conv2d_3x3"};
std::string build_options_{"-DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class DepthwiseConv2d3x3s1ComputeFP16Image
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ConvParam;
std::string doc() const override {
return "DepthwiseConv2d3x3s1 using cl::Image2D/kImageDefault, kFP16";
}
void PrepareForRun() override {
const auto& param = *param_.get_mutable<param_t>();
if (param.fuse_relu) {
build_options_ += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_ += " -DRELU6";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(
kernel_func_name_, "image/depthwise_conv2d_kernel.cl", build_options_);
}
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
auto x_dims = param.x->dims();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto dilations = *param.dilations;
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto* input_img = param.x->data<int16_t, cl::Image2D>();
auto* filter_img = param.filter->data<int16_t, cl::Image2D>();
auto* bias_img = param.bias == nullptr
? static_cast<cl::Image2D*>(nullptr)
: param.bias->data<int16_t, cl::Image2D>();
auto image_shape = InitImageDimInfoWith(output_dims);
auto* output_img = param.output->mutable_data<int16_t, cl::Image2D>(
image_shape["width"], image_shape["height"]);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int c_block = (output_dims[1] + 3) / 4;
int w = output_dims[3];
int nh = output_dims[0] * output_dims[2];
int w_blk_size = 2;
int w_blk = (w + w_blk_size - 1) / w_blk_size;
auto global_work_size = cl::NDRange(c_block, w_blk, nh);
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, static_cast<const int>(c_block));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(w_blk));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(nh));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *output_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(dilations[0]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[1]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(x_dims[2]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[3]));
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_img, event_);
}
private:
std::string kernel_func_name_{"depth_conv2d_3x3s1"};
std::string build_options_{"-DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
class DepthwiseConv2dBasicComputeFP32Image
: public KernelLite<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::ConvParam;
std::string doc() const override {
return "DepthwiseConv2d basic using cl::Image2D/kImageDefault, kFloat32";
}
void PrepareForRun() override {
const auto& param = *param_.get_mutable<param_t>();
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
if (param.fuse_relu) {
build_options_ += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_ += " -DRELU6";
}
if (has_bias) {
build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
}
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/depthwise_conv2d_basic_kernel.cl",
build_options_);
}
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
auto input_dims = param.x->dims();
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<float, cl::Image2D>();
auto* filter_image = param.filter->data<float, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
int input_width = input_dims[3];
int input_height = input_dims[2];
int output_width = output_dims[3];
int output_height = output_dims[2];
int filter_width = filter_dims[3];
int filter_height = filter_dims[2];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<float, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
int offset = static_cast<int>(param.filter->dims()[2]) / 2 -
static_cast<int>(paddings[0]);
// calc input_c_block
auto input_image_shape = InitImageDimInfoWith(input_dims);
int input_c_block = input_image_shape["width"] / input_dims[3];
int input_c = input_dims[1];
auto dilations = *param.dilations;
const std::vector<size_t>& default_work_size =
DefaultWorkSize(output_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
VLOG(4) << "============ depthwise conv2d params ============";
VLOG(4) << "input_image_shape: " << input_image_shape["width"] << ","
<< input_image_shape["height"];
VLOG(4) << "input_c_block: " << input_c_block;
VLOG(4) << "input_c: " << input_c;
VLOG(4) << "input_image: " << input_image;
VLOG(4) << "filter_dims: " << filter_dims;
VLOG(4) << "filter_image: " << filter_image;
VLOG(4) << "output_dims: " << output_dims;
VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", "
<< out_image_shape["height"];
VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1];
VLOG(4) << "has bias: " << has_bias;
VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias;
VLOG(4) << "strides: " << strides[0] << "," << strides[1];
VLOG(4) << "offset: " << offset;
VLOG(4) << "dilations.size : " << dilations.size();
VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1];
VLOG(4) << "default work size{c_block, w, nh}: "
<< "{" << c_block << ", " << w << ", " << nh << ""
<< "}";
CHECK_GE(dilations.size(), 2);
CHECK(dilations[0] == dilations[1]);
CHECK_GE(input_dims.size(), 4);
CHECK_GE(paddings.size(), 2);
CHECK(paddings[0] == paddings[1]);
CHECK_GE(strides.size(), 2);
CHECK(strides[0] == strides[1]);
// handle bias use buffer for channel wise , use image for element wise
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = param.bias->data<float, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
VLOG(4) << "kernel_key: " << kernel_key.str();
VLOG(4) << "kernel ready ... " << kernel_key.str();
VLOG(4) << "w: " << w;
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, nh);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *input_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *filter_image);
CL_CHECK_FATAL(status);
if (has_bias) {
VLOG(4) << "set bias_image: ";
status = kernel.setArg(++arg_idx, *bias_image);
CL_CHECK_FATAL(status);
}
status = kernel.setArg(++arg_idx, *out_image);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, strides[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, offset);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_c_block);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dilations[0]);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, input_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_width);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, filter_height);
CL_CHECK_FATAL(status);
auto global_work_size =
cl::NDRange{static_cast<size_t>(default_work_size.data()[0]),
static_cast<size_t>(default_work_size.data()[1]),
static_cast<size_t>(default_work_size.data()[2])};
VLOG(4) << "out_image: " << out_image;
VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << ","
<< global_work_size[1] << "," << global_work_size[2] << "}";
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_image, event_);
}
private:
std::string kernel_func_name_{"depth_conv2d"};
std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl } // namespace opencl
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -553,52 +139,3 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, ...@@ -553,52 +139,3 @@ REGISTER_LITE_KERNEL(depthwise_conv2d,
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(
depthwise_conv2d,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::DepthwiseConv2dComputeFP16Image,
image2d)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageNW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
REGISTER_LITE_KERNEL(
depthwise_conv2d_basic,
kOpenCL,
kFloat,
kImageDefault,
paddle::lite::kernels::opencl::DepthwiseConv2dBasicComputeFP32Image,
image2d)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Bias",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault))})
.Finalize();
...@@ -177,135 +177,7 @@ TEST(depthwise_conv2d_buffer_fp32, compute) { ...@@ -177,135 +177,7 @@ TEST(depthwise_conv2d_buffer_fp32, compute) {
TargetWrapperCL::Unmap(input_data, mapped_input); TargetWrapperCL::Unmap(input_data, mapped_input);
} }
TEST(depthwise_conv2d_image2d_fp16, compute) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create("depthwise_conv2d",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
LOG(INFO) << "get kernel";
lite::Tensor input, filter, output;
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
std::vector<int> paddings = {0, 0};
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.strides = std::vector<int>{1, 1};
std::vector<int> dilations = {1, 1};
param.dilations = std::make_shared<std::vector<int>>(dilations);
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> dep_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(dep_context->As<OpenCLContext>()));
kernel->SetContext(std::move(dep_context));
LOG(INFO) << "kernel ready";
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
std::vector<float> input_v(1 * 32 * 112 * 112);
std::vector<float> filter_v(32 * 1 * 3 * 3);
for (auto& i : input_v) {
i = gen(engine);
}
for (auto& f : filter_v) {
f = gen(engine);
}
LOG(INFO) << "prepare input";
input.Resize({1, 32, 112, 112});
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim input_image_shape =
default_converter->InitImageDimInfoWith(input.dims());
LOG(INFO) << "input_image_shape = " << input_image_shape[0] << " "
<< input_image_shape[1];
std::vector<float> input_image_data(input_image_shape.production() *
4); // 4 : RGBA
default_converter->NCHWToImage(
input_v.data(), input_image_data.data(), input.dims());
auto* input_image = input.mutable_data<int16_t, cl::Image2D>(
input_image_shape[0], input_image_shape[1], input_image_data.data());
LOG(INFO) << "prepare kernel";
filter.Resize({32, 1, 3, 3});
CLImageConverterNWBlock* nw_converter = new CLImageConverterNWBlock();
DDim filter_image_shape = nw_converter->InitImageDimInfoWith(filter.dims());
LOG(INFO) << "filter_image_shape = " << filter_image_shape[0] << " "
<< filter_image_shape[1];
std::vector<float> filter_image_data(filter_image_shape.production() *
4); // 4 : RGBA
nw_converter->NCHWToImage(
filter_v.data(), filter_image_data.data(), filter.dims());
auto* filter_image = filter.mutable_data<int16_t, cl::Image2D>(
filter_image_shape[0], filter_image_shape[1], filter_image_data.data());
LOG(INFO) << "launch";
output.Resize({1, 32, 110, 110});
DDim output_image_shape =
default_converter->InitImageDimInfoWith(output.dims());
LOG(INFO) << "output_image_shape = " << output_image_shape[0] << " "
<< output_image_shape[1];
auto* output_image = output.mutable_data<int16_t, cl::Image2D>(
output_image_shape[0], output_image_shape[1]);
kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<int16_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
LOG(INFO) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
LOG(INFO) << "Could not find the sync event for the target cl tensor.";
}
lite::Tensor output_ref;
output_ref.Resize({1, 32, 110, 110});
auto* output_ref_data = output_ref.mutable_data<float>(TARGET(kARM));
depth_conv<float, 1, 1>(input_v.data(),
input.dims(),
filter_v.data(),
filter.dims(),
output_ref_data,
output_ref.dims());
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
float* output_image_data = new float[output_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(output_image_data,
output_image,
output_image_shape[0],
output_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
float* output_data = new float[output_image_shape.production() * 4];
default_converter->ImageToNCHW(
output_image_data, output_data, output_image_shape, output.dims());
LOG(INFO) << "output_data vs output_ref_data";
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4);
LOG(INFO) << output_data[i] << " " << output_ref_data[i];
}
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kNCHW, def); USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFP16, kImageDefault, image2d);
...@@ -142,7 +142,7 @@ TEST(depthwise_conv2d_basic, compute) { ...@@ -142,7 +142,7 @@ TEST(depthwise_conv2d_basic, compute) {
VLOG(4) << "to get kernel ..."; VLOG(4) << "to get kernel ...";
auto kernels = auto kernels =
KernelRegistry::Global().Create("depthwise_conv2d_basic", KernelRegistry::Global().Create("depthwise_conv2d",
TARGET(kOpenCL), TARGET(kOpenCL),
PRECISION(kFloat), PRECISION(kFloat),
DATALAYOUT(kImageDefault)); DATALAYOUT(kImageDefault));
...@@ -383,7 +383,133 @@ TEST(depthwise_conv2d_basic, compute) { ...@@ -383,7 +383,133 @@ TEST(depthwise_conv2d_basic, compute) {
#endif #endif
} }
TEST(depthwise_conv2d_image2d_fp16, compute) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create("depthwise_conv2d",
TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
LOG(INFO) << "get kernel";
lite::Tensor input, filter, output;
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
std::vector<int> paddings = {0, 0};
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.strides = std::vector<int>{1, 1};
std::vector<int> dilations = {1, 1};
param.dilations = std::make_shared<std::vector<int>>(dilations);
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> dep_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(dep_context->As<OpenCLContext>()));
kernel->SetContext(std::move(dep_context));
LOG(INFO) << "kernel ready";
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
std::vector<float> input_v(1 * 32 * 112 * 112);
std::vector<float> filter_v(32 * 1 * 3 * 3);
for (auto& i : input_v) {
i = gen(engine);
}
for (auto& f : filter_v) {
f = gen(engine);
}
LOG(INFO) << "prepare input";
input.Resize({1, 32, 112, 112});
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim input_image_shape =
default_converter->InitImageDimInfoWith(input.dims());
LOG(INFO) << "input_image_shape = " << input_image_shape[0] << " "
<< input_image_shape[1];
std::vector<float> input_image_data(input_image_shape.production() *
4); // 4 : RGBA
default_converter->NCHWToImage(
input_v.data(), input_image_data.data(), input.dims());
auto* input_image = input.mutable_data<int16_t, cl::Image2D>(
input_image_shape[0], input_image_shape[1], input_image_data.data());
LOG(INFO) << "prepare kernel";
filter.Resize({32, 1, 3, 3});
CLImageConverterNWBlock* nw_converter = new CLImageConverterNWBlock();
DDim filter_image_shape = nw_converter->InitImageDimInfoWith(filter.dims());
LOG(INFO) << "filter_image_shape = " << filter_image_shape[0] << " "
<< filter_image_shape[1];
std::vector<float> filter_image_data(filter_image_shape.production() *
4); // 4 : RGBA
nw_converter->NCHWToImage(
filter_v.data(), filter_image_data.data(), filter.dims());
auto* filter_image = filter.mutable_data<int16_t, cl::Image2D>(
filter_image_shape[0], filter_image_shape[1], filter_image_data.data());
LOG(INFO) << "launch";
output.Resize({1, 32, 110, 110});
DDim output_image_shape =
default_converter->InitImageDimInfoWith(output.dims());
LOG(INFO) << "output_image_shape = " << output_image_shape[0] << " "
<< output_image_shape[1];
auto* output_image = output.mutable_data<int16_t, cl::Image2D>(
output_image_shape[0], output_image_shape[1]);
kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<int16_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
LOG(INFO) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
LOG(INFO) << "Could not find the sync event for the target cl tensor.";
}
lite::Tensor output_ref;
output_ref.Resize({1, 32, 110, 110});
auto* output_ref_data = output_ref.mutable_data<float>(TARGET(kARM));
depth_conv<float, 1, 1>(input_v.data(),
input.dims(),
filter_v.data(),
filter.dims(),
output_ref_data,
output_ref.dims());
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
float* output_image_data = new float[output_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(output_image_data,
output_image,
output_image_shape[0],
output_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
float* output_data = new float[output_image_shape.production() * 4];
default_converter->ImageToNCHW(
output_image_data, output_data, output_image_shape, output.dims());
LOG(INFO) << "output_data vs output_ref_data";
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4);
LOG(INFO) << output_data[i] << " " << output_ref_data[i];
}
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL( USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kImageDefault, image2d);
depthwise_conv2d_basic, kOpenCL, kFloat, kImageDefault, image2d);
...@@ -66,6 +66,8 @@ void PrintData(std::string name, float* a, const int rows, const int cols) { ...@@ -66,6 +66,8 @@ void PrintData(std::string name, float* a, const int rows, const int cols) {
} }
} }
// buffer
#if 0 // fc_buffer
// #define PRINT_RESULT // #define PRINT_RESULT
#define LOOP_TEST #define LOOP_TEST
TEST(fc, compute) { TEST(fc, compute) {
...@@ -193,8 +195,9 @@ TEST(fc, compute) { ...@@ -193,8 +195,9 @@ TEST(fc, compute) {
} // m } // m
#endif #endif
} }
#endif // fc_buffer
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(fc, kOpenCL, kFloat, kNCHW, def); // USE_LITE_KERNEL(fc, kOpenCL, kFloat, kNCHW, def);
...@@ -229,15 +229,15 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -229,15 +229,15 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(pool2d, // REGISTER_LITE_KERNEL(pool2d,
kOpenCL, // kOpenCL,
kFloat, // kFloat,
kNCHW, // kNCHW,
paddle::lite::kernels::opencl::PoolCompute, // paddle::lite::kernels::opencl::PoolCompute,
def) // def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) // .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) // .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.Finalize(); // .Finalize();
REGISTER_LITE_KERNEL(pool2d, REGISTER_LITE_KERNEL(pool2d,
kOpenCL, kOpenCL,
......
...@@ -73,6 +73,8 @@ void pool_avg(const int padding_height, ...@@ -73,6 +73,8 @@ void pool_avg(const int padding_height,
} }
} }
// buffer
#if 0 // pool_buffer
TEST(pool2d_buffer_fp32, compute) { TEST(pool2d_buffer_fp32, compute) {
LOG(INFO) << "to get kernel ..."; LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create( auto kernels = KernelRegistry::Global().Create(
...@@ -141,6 +143,7 @@ TEST(pool2d_buffer_fp32, compute) { ...@@ -141,6 +143,7 @@ TEST(pool2d_buffer_fp32, compute) {
} }
TargetWrapperCL::Unmap(out_data, mapped_out); TargetWrapperCL::Unmap(out_data, mapped_out);
} }
#endif // pool_buffer
TEST(pool2d_image2d_fp32, compute) { TEST(pool2d_image2d_fp32, compute) {
LOG(INFO) << "to get kernel ..."; LOG(INFO) << "to get kernel ...";
...@@ -239,5 +242,5 @@ TEST(pool2d_image2d_fp32, compute) { ...@@ -239,5 +242,5 @@ TEST(pool2d_image2d_fp32, compute) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, def); // USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kImageDefault, image2d); USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kImageDefault, image2d);
...@@ -227,15 +227,15 @@ class SigmoidComputeFP16ImageDefault ...@@ -227,15 +227,15 @@ class SigmoidComputeFP16ImageDefault
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(sigmoid, // REGISTER_LITE_KERNEL(sigmoid,
kOpenCL, // kOpenCL,
kFloat, // kFloat,
kNCHW, // kNCHW,
paddle::lite::kernels::opencl::SigmoidCompute, // paddle::lite::kernels::opencl::SigmoidCompute,
def) // def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))}) // .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))}) // .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.Finalize(); // .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
sigmoid, sigmoid,
......
...@@ -32,7 +32,8 @@ void sigmoid_compute_ref(const dtype *x_data, ...@@ -32,7 +32,8 @@ void sigmoid_compute_ref(const dtype *x_data,
} }
} }
#if 1 // sigmoid_buffer // buffer
#if 0 // sigmoid_buffer
TEST(opencl_sigmoid_buffer, compute) { TEST(opencl_sigmoid_buffer, compute) {
// prepare data // prepare data
const DDim x_dim = DDim(std::vector<DDim::value_type>{3, 6, 10, 10}); const DDim x_dim = DDim(std::vector<DDim::value_type>{3, 6, 10, 10});
...@@ -414,7 +415,7 @@ TEST(sigmoid_image2d_fp16, compute) { ...@@ -414,7 +415,7 @@ TEST(sigmoid_image2d_fp16, compute) {
} // namespace paddle } // namespace paddle
// sigmoid buffer // sigmoid buffer
USE_LITE_KERNEL(sigmoid, kOpenCL, kFloat, kNCHW, def); // USE_LITE_KERNEL(sigmoid, kOpenCL, kFloat, kNCHW, def);
// sigmoid image2d fp32 // sigmoid image2d fp32
USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册