提交 b4c5fdb8 编写于 作者: X xiebaiyuan 提交者: GitHub

[LITE][OPENCL][Image] conv 1x1 5x5 7x7 suite (#2998)

* [LITE][OPENCL][Image] conv 1x1 5x5 7x7 suite

# Conflicts:
#	lite/kernels/opencl/conv_image_compute_test.cc

* [LITE][OPENCL][Image] conv 1x1 5x5 7x7 suite,test=develop

* [LITE][OPENCL][Image] conv 1x1 5x5 7x7 suite,test=develop

* [LITE][OPENCL][Image] conv 1x1 5x5 7x7 suite,test=develop

* [LITE][OPENCL][Image] conv 1x1 5x5 7x7 suite,rm1x1 old,test=develop
上级 2c229275
......@@ -15,7 +15,7 @@ __kernel void conv2d_1x1(__private const int global_size_dim0,
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int input_c_block,
__private const int input_c_origin,
__private const int dilation,
__private const int input_width, /* of one block */
......@@ -79,14 +79,14 @@ __kernel void conv2d_1x1(__private const int global_size_dim0,
CL_DTYPE4 output3 = 0.0f;
#endif
int max_w_bound = input_c * input_width;
int burndary_index = input_c * 4 - input_c_origin;
int max_w_bound = input_c_block * input_width;
int burndary_index = input_c_block * 4 - input_c_origin;
bool burndary_index_w =
burndary_index == 1 || burndary_index == 2 || burndary_index == 3;
bool burndary_index_z = burndary_index == 2 || burndary_index == 3;
bool burndary_index_y = burndary_index == 3;
for (int i = 0; i < input_c; ++i) {
for (int i = 0; i < input_c_block; ++i) {
// ------------0---------------
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x,
in_pos_in_one_block0.y);
......@@ -107,11 +107,81 @@ __kernel void conv2d_1x1(__private const int global_size_dim0,
input0.w = select(input0.w, zero, outof_bound && burndary_index_w);
input0.z = select(input0.z, zero, outof_bound && burndary_index_z);
input0.y = select(input0.y, zero, outof_bound && burndary_index_y);
#ifdef DEBUG
if (output_pos0.x == 0 && output_pos0.y == 0) {
printf("i ={ %d, }\n", i);
printf("in={ %f , %f , %f , %f } \n",
convert_float(input0.x),
convert_float(input0.y),
convert_float(input0.z),
convert_float(input0.w));
printf("filter0={ %f , %f , %f , %f } \n",
convert_float(weight0.x),
convert_float(weight0.y),
convert_float(weight0.z),
convert_float(weight0.w));
printf("filter1={ %f , %f , %f , %f } \n",
convert_float(weight1.x),
convert_float(weight1.y),
convert_float(weight1.z),
convert_float(weight1.w));
printf("filter2={ %f , %f , %f , %f } \n",
convert_float(weight2.x),
convert_float(weight2.y),
convert_float(weight2.z),
convert_float(weight2.w));
printf("filter3={ %f , %f , %f , %f } \n",
convert_float(weight3.x),
convert_float(weight3.y),
convert_float(weight3.z),
convert_float(weight3.w));
printf("000---- output={ %f , %f , %f , %f } \n",
convert_float(output0.x),
convert_float(output0.y),
convert_float(output0.z),
convert_float(output0.w));
}
#endif
output0 = mad(input0.x, weight0, output0);
#ifdef DEBUG
if (output_pos0.x == 0 && output_pos0.y == 0) {
printf("111---- output={ %f , %f , %f , %f } \n",
convert_float(output0.x),
convert_float(output0.y),
convert_float(output0.z),
convert_float(output0.w));
}
#endif
output0 = mad(input0.y, weight1, output0);
#ifdef DEBUG
if (output_pos0.x == 0 && output_pos0.y == 0) {
printf("222---- output={ %f , %f , %f , %f } \n",
convert_float(output0.x),
convert_float(output0.y),
convert_float(output0.z),
convert_float(output0.w));
}
#endif
output0 = mad(input0.z, weight2, output0);
#ifdef DEBUG
if (output_pos0.x == 0 && output_pos0.y == 0) {
printf("333---- output={ %f , %f , %f , %f } \n",
convert_float(output0.x),
convert_float(output0.y),
convert_float(output0.z),
convert_float(output0.w));
}
#endif
output0 = mad(input0.w, weight3, output0);
#ifdef DEBUG
if (output_pos0.x == 0 && output_pos0.y == 0) {
printf("444---- output={ %f , %f , %f , %f } \n",
convert_float(output0.x),
convert_float(output0.y),
convert_float(output0.z),
convert_float(output0.w));
}
#endif
// -------------1--------------
pos_in = (int2)(i * input_width + in_pos_in_one_block1.x,
in_pos_in_one_block1.y);
......@@ -171,6 +241,43 @@ __kernel void conv2d_1x1(__private const int global_size_dim0,
output3 = mad(input3.y, weight1, output3);
output3 = mad(input3.z, weight2, output3);
output3 = mad(input3.w, weight3, output3);
#ifdef DEBUG
if (output_pos0.x == 0 && output_pos0.y == 0) {
// printf("i,j,k ={ %d, %d , %d }\n", i,j,k);
printf("i ={ %d, }\n", i);
printf("in={ %f , %f , %f , %f } \n",
convert_float(input0.x),
convert_float(input0.y),
convert_float(input0.z),
convert_float(input0.w));
printf("filter0={ %f , %f , %f , %f } \n",
convert_float(weight0.x),
convert_float(weight0.y),
convert_float(weight0.z),
convert_float(weight0.w));
printf("filter1={ %f , %f , %f , %f } \n",
convert_float(weight1.x),
convert_float(weight1.y),
convert_float(weight1.z),
convert_float(weight1.w));
printf("filter2={ %f , %f , %f , %f } \n",
convert_float(weight2.x),
convert_float(weight2.y),
convert_float(weight2.z),
convert_float(weight2.w));
printf("filter3={ %f , %f , %f , %f } \n",
convert_float(weight3.x),
convert_float(weight3.y),
convert_float(weight3.z),
convert_float(weight3.w));
printf("output={ %f , %f , %f , %f } \n",
convert_float(output0.x),
convert_float(output0.y),
convert_float(output0.z),
convert_float(output0.w));
}
#endif
}
#ifdef BATCH_NORM
......@@ -195,7 +302,6 @@ __kernel void conv2d_1x1(__private const int global_size_dim0,
output1 = activation_type4(output1);
output2 = activation_type4(output2);
output3 = activation_type4(output3);
if (out_w0 < old_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos0, output0);
}
......@@ -213,29 +319,30 @@ __kernel void conv2d_1x1(__private const int global_size_dim0,
}
}
__kernel void conv2d_1x1_simple(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
__kernel void conv2d_1x1_simple(
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int input_c_origin,
__private const int dilation,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__private const int old_w) {
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int input_c_origin,
__private const int dilation,
__private const int input_width, /* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height,
__private const int old_w) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
......@@ -358,13 +465,11 @@ __read_only image2d_t new_scale,
READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0));
#endif
output0 = activation_type4(output0);
output1 = activation_type4(output1);
output2 = activation_type4(output2);
output3 = activation_type4(output3);
if (out_w0 < old_w) {
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos0, output0);
}
......
......@@ -36,10 +36,10 @@ __kernel void conv2d_7x7(__private const int global_size_dim0,
const int batch_index = out_nh / output_height;
const int out_nh_in_one_batch = out_nh % output_height;
const filter_n0 = 4 * out_c + 0;
const filter_n1 = 4 * out_c + 1;
const filter_n2 = 4 * out_c + 2;
const filter_n3 = 4 * out_c + 3;
const int filter_n0 = 4 * out_c + 0;
const int filter_n1 = 4 * out_c + 1;
const int filter_n2 = 4 * out_c + 2;
const int filter_n3 = 4 * out_c + 3;
int2 stride_xy;
stride_xy.x = stride;
......
......@@ -15,6 +15,7 @@
#include <gtest/gtest.h>
#include <random>
#include "lite/backends/opencl/cl_image_converter.h"
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
......@@ -26,7 +27,12 @@ namespace lite {
// #define SHADOW_LOG LOG(INFO)
#define SHADOW_LOG VLOG(4)
#define FP16_MAX_DIFF (1e0)
#define FP16_ABS_DIFF (1e-1)
// #define TEST_CONV_IMAGE_ALL_1
#define TEST_CONV_IMAGE_1x1
#define TEST_CONV_IMAGE_3x3
#define TEST_CONV_IMAGE_5x5
#define TEST_CONV_IMAGE_7x7
template <typename Dtype1, typename Dtype2>
static void conv_basic(const Dtype1* din,
Dtype2* dout,
......@@ -127,6 +133,8 @@ int ConvOutputSize(int input_size,
return output_size;
}
#ifdef TEST_CONV_IMAGE_1x1
// #define PRINT_RESULT
// #define LOOP_TEST
TEST(conv2d, compute_image2d_1x1) {
......@@ -140,304 +148,341 @@ TEST(conv2d, compute_image2d_1x1) {
#ifdef LOOP_TEST
for (int batch_size = 1; batch_size < 4; ++batch_size) {
for (int oc = 4; oc < 10; oc += 1) { // oc
for (int ih = 4; ih < 9; ih += 1) { // ih
for (int oc = 2; oc < 10; oc += 1) { // oc
for (int ih = 2; ih < 9; ih += 1) { // ih
int iw = ih;
for (int iw = 4; iw < 10; iw += 1) { // iw
for (int ic = 4; ic < 10; ic += 1) { // ic
for (bool bias_flag : {true, false}) {
for (std::string relu_flag : {"relu"}) {
for (int ic = 2; ic < 10; ic += 1) { // ic
for (bool bias_flag : {true, false}) {
for (std::string relu_flag : {""}) {
#else
const int batch_size = 1;
const int oc = 4;
const int ih = 8;
const int iw = 8;
const int ic = 4;
const bool bias_flag = true;
const std::string relu_flag = "relu";
const int oc = 2;
const int ih = 3;
const int iw = 3;
const int ic = 2;
const bool bias_flag = false;
const std::string relu_flag = "";
#endif
const int oh = ih;
const int ow = iw;
SHADOW_LOG << "to get kernel ...";
auto kernels =
KernelRegistry::Global().Create("conv2d",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
SHADOW_LOG << "created conv2d_1x1 kernel";
SHADOW_LOG << "prepare kernel ------";
lite::Tensor input, filter, bias, output;
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
if (bias_flag) {
param.bias = &bias;
}
if (relu_flag == "relu") {
param.fuse_relu = true;
} else if (relu_flag == "None") {
param.fuse_relu = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
}
LOG(INFO) << "---------------------------- "
"conv1x1----------------------- "
"run---------------";
const int oh = ih;
const int ow = iw;
LOG(INFO) << "batch_size: " << batch_size;
LOG(INFO) << "ic: " << ic;
LOG(INFO) << "ih: " << ih;
LOG(INFO) << "iw: " << iw;
LOG(INFO) << "oc: " << oc;
LOG(INFO) << "bias_flag: " << bias_flag;
LOG(INFO) << "relu_flag: " << relu_flag;
std::vector<int> paddings = {pad, pad, pad, pad};
std::vector<int> dilations = {dilation, dilation};
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.strides = std::vector<int>{stride, stride};
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
std::unique_ptr<KernelContext> conv_1x1_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(conv_1x1_context->As<OpenCLContext>()));
kernel->SetContext(std::move(conv_1x1_context));
const DDim& input_dim =
lite::DDim{std::vector<int64_t>({batch_size, ic, ih, iw})};
const DDim& filter_dim =
lite::DDim{std::vector<int64_t>({oc, ic, ksize, ksize})};
const DDim& out_dim =
lite::DDim{std::vector<int64_t>({batch_size, oc, ih, iw})};
// element wise bias
const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})};
param.x->Resize(input_dim);
param.filter->Resize(filter_dim);
param.output->Resize(out_dim);
if (bias_flag) {
param.bias->Resize(bias_dim);
}
SHADOW_LOG << "to get kernel ...";
auto kernels =
KernelRegistry::Global().Create("conv2d",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
SHADOW_LOG << "created conv2d_1x1 kernel";
kernel->SetParam(param);
SHADOW_LOG << "prepare kernel ------";
lite::Tensor input, filter, bias, output;
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
if (bias_flag) {
param.bias = &bias;
}
if (relu_flag == "relu") {
param.fuse_relu = true;
} else if (relu_flag == "None") {
param.fuse_relu = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
}
size_t input_image_width = iw * ((ic + 3) / 4);
size_t input_image_height = ih * batch_size;
std::vector<int> paddings = {pad, pad, pad, pad};
std::vector<int> dilations = {dilation, dilation};
size_t out_image_width = ow * ((oc + 3) / 4);
size_t out_image_height = oh * batch_size;
param.paddings = std::make_shared<std::vector<int>>(paddings);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.strides = std::vector<int>{stride, stride};
size_t bias_image_width = ow * ((oc + 3) / 4);
size_t bias_image_height = oh * batch_size;
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
size_t filter_image_width = ksize * ((oc + 3) / 4);
size_t filter_image_height = ic * ksize;
std::unique_ptr<KernelContext> conv_1x1_context(
new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(conv_1x1_context->As<OpenCLContext>()));
kernel->SetContext(std::move(conv_1x1_context));
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
const DDim& input_dim =
lite::DDim{std::vector<int64_t>({batch_size, ic, ih, iw})};
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
const DDim& filter_dim =
lite::DDim{std::vector<int64_t>({oc, ic, ksize, ksize})};
const DDim& out_dim =
lite::DDim{std::vector<int64_t>({batch_size, oc, ih, iw})};
// element wise bias
const DDim& bias_dim = lite::DDim{std::vector<int64_t>({oc})};
std::vector<float> input_v(batch_size * ic * ih * iw);
std::vector<float> filter_v(oc * ic * ksize * ksize);
std::vector<float> output_v(batch_size * oc * ih * iw);
std::vector<float> bias_v(oc);
param.x->Resize(input_dim);
param.filter->Resize(filter_dim);
param.output->Resize(out_dim);
if (bias_flag) {
param.bias->Resize(bias_dim);
}
SHADOW_LOG << "gen input and filter ...";
kernel->SetParam(param);
for (auto& i : input_v) {
i = gen(engine);
}
for (auto& f : filter_v) {
f = gen(engine);
}
size_t input_image_width = iw * ((ic + 3) / 4);
size_t input_image_height = ih * batch_size;
SHADOW_LOG << "after gen input and filter ...";
SHADOW_LOG << "input_v.size(): " << input_v.size();
SHADOW_LOG << "filter_v.size(): " << filter_v.size();
SHADOW_LOG << "output_v.size(): " << output_v.size();
SHADOW_LOG << "bias_v.size(): " << bias_v.size();
SHADOW_LOG << "input_dim.production(): "
<< input_dim.production();
SHADOW_LOG << "filter_dim.production(): "
<< filter_dim.production();
SHADOW_LOG << "out_dim.production(): " << out_dim.production();
SHADOW_LOG << "bias_dim.production(): "
<< bias_dim.production();
SHADOW_LOG << "4 * input_image_height * input_image_width: "
<< 4 * input_image_height * input_image_width;
SHADOW_LOG << "4 * filter_image_width * filter_image_height: "
<< 4 * filter_image_width * filter_image_height;
CHECK(input_dim.production() == input_v.size());
CHECK_LE(input_dim.production(),
4 * input_image_height * input_image_width);
CHECK(filter_dim.production() == filter_v.size());
CHECK_LE(filter_dim.production(),
4 * filter_image_width * filter_image_height);
paddle::lite::CLImageConverterDefault default_convertor;
SHADOW_LOG << "set mapped input ...";
std::vector<half_t> x_image_v(
input_image_width * input_image_height * 4); // 4 : RGBA
std::vector<half_t> filter_image_v(
filter_image_width * filter_image_height * 4); // 4 :RGBA
std::vector<half_t> bias_image_v(
bias_image_width * bias_image_height * 4); // 4 : RGBA
std::vector<half_t> out_image_v(
out_image_width * out_image_height * 4); // 4 : RGBA
default_convertor.NCHWToImage(
input_v.data(), x_image_v.data(), input_dim);
SHADOW_LOG << "set mapped filter ...";
paddle::lite::CLImageConverterNWBlock nw_convertor;
nw_convertor.NCHWToImage(
filter_v.data(), filter_image_v.data(), filter_dim);
auto* input_image2d = input.mutable_data<half_t, cl::Image2D>(
input_image_width, input_image_height, x_image_v.data());
// assign filter as target arm
filter.Assign<float, lite::DDim, TARGET(kARM)>(filter_v.data(),
filter_dim);
// auto* filter_image2d =
// filter.mutable_data<half_t, cl::Image2D>(
// filter_image_width,
// filter_image_height,
// filter_image_v.data());
SHADOW_LOG << "卷积核: ---- ";
for (int i = 0; i < filter_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << filter_v[i];
}
size_t out_image_width = ow * ((oc + 3) / 4);
size_t out_image_height = oh * batch_size;
SHADOW_LOG << "卷积核1: ---- ";
const float* filter_p = filter.data<float>();
for (int i = 0; i < filter_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << *filter_p;
filter_p++;
}
SHADOW_LOG << "卷积核2: ---- ";
const float* filter_p2 = filter.mutable_data<float>();
for (int i = 0; i < filter_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << *filter_p2;
filter_p2++;
}
if (bias_flag) {
for (int i = 0; i < bias_dim.production(); ++i) {
bias_v[i] = static_cast<int>(gen(engine));
}
bias.Assign<float, lite::DDim, TARGET(kARM)>(bias_v.data(),
bias_dim);
// CLImageConverterFolder folder_convertor;
// folder_convertor.NCHWToImage(
// bias_v.data(), bias_image_v.data(),
// bias_dim);
//
// auto* bias_data = bias.mutable_data<float,
// cl::Image2D>(
// bias_image_width, bias_image_height,
// bias_image_v.data());
}
size_t bias_image_width = ow * ((oc + 3) / 4);
size_t bias_image_height = oh * batch_size;
size_t filter_image_width = ksize * ((oc + 3) / 4);
size_t filter_image_height = ic * ksize;
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
std::default_random_engine engine;
std::uniform_real_distribution<float> gen(-5, 5);
std::vector<float> input_v(batch_size * ic * ih * iw);
std::vector<float> filter_v(oc * ic * ksize * ksize);
std::vector<float> output_v(batch_size * oc * ih * iw);
std::vector<float> bias_v(oc);
SHADOW_LOG << "gen input and filter ...";
for (auto& i : input_v) {
i = gen(engine);
#ifdef TEST_CONV_IMAGE_ALL_1
i = 0.01;
#endif
}
for (auto& f : filter_v) {
f = gen(engine);
#ifdef TEST_CONV_IMAGE_ALL_1
f = 0.01;
#endif
}
SHADOW_LOG << "after gen input and filter ...";
SHADOW_LOG << "input_v.size(): " << input_v.size();
SHADOW_LOG << "filter_v.size(): " << filter_v.size();
SHADOW_LOG << "output_v.size(): " << output_v.size();
SHADOW_LOG << "bias_v.size(): " << bias_v.size();
SHADOW_LOG << "input_dim.production(): "
<< input_dim.production();
SHADOW_LOG << "filter_dim.production(): "
<< filter_dim.production();
SHADOW_LOG << "out_dim.production(): " << out_dim.production();
SHADOW_LOG << "bias_dim.production(): " << bias_dim.production();
SHADOW_LOG << "4 * input_image_height * input_image_width: "
<< 4 * input_image_height * input_image_width;
SHADOW_LOG << "4 * filter_image_width * filter_image_height: "
<< 4 * filter_image_width * filter_image_height;
CHECK(input_dim.production() == input_v.size());
CHECK_LE(input_dim.production(),
4 * input_image_height * input_image_width);
CHECK(filter_dim.production() == filter_v.size());
CHECK_LE(filter_dim.production(),
4 * filter_image_width * filter_image_height);
paddle::lite::CLImageConverterDefault default_convertor;
SHADOW_LOG << "set mapped input ...";
std::vector<half_t> x_image_v(
input_image_width * input_image_height * 4); // 4 : RGBA
std::vector<half_t> filter_image_v(
filter_image_width * filter_image_height * 4); // 4 :RGBA
std::vector<half_t> bias_image_v(
bias_image_width * bias_image_height * 4); // 4 : RGBA
std::vector<half_t> out_image_v(
out_image_width * out_image_height * 4); // 4 : RGBA
default_convertor.NCHWToImage(
input_v.data(), x_image_v.data(), input_dim);
SHADOW_LOG << "set mapped filter ...";
paddle::lite::CLImageConverterNWBlock nw_convertor;
nw_convertor.NCHWToImage(
filter_v.data(), filter_image_v.data(), filter_dim);
auto* input_image2d = input.mutable_data<half_t, cl::Image2D>(
input_image_width, input_image_height, x_image_v.data());
// assign filter as target arm
filter.Assign<float, lite::DDim, TARGET(kARM)>(filter_v.data(),
filter_dim);
SHADOW_LOG << " lite输入 input_v ..... ";
for (int i = 0; i < input_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << input_v[i];
}
SHADOW_LOG << " lite输入 input_image2d ..... ";
for (int i = 0; i < x_image_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << Half2Float(x_image_v[i]);
}
// auto* filter_image2d =
// filter.mutable_data<uint16_t, cl::Image2D>(
// filter_image_width,
// filter_image_height,
// filter_image_v.data());
SHADOW_LOG << "卷积核 : ---- ";
for (int i = 0; i < filter_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << filter_v[i];
}
SHADOW_LOG << "卷积核1: ---- ";
const float* filter_p = filter.data<float>();
for (int i = 0; i < filter_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << *filter_p;
filter_p++;
}
SHADOW_LOG << "卷积核2: ---- ";
const float* filter_p2 = filter.mutable_data<float>();
for (int i = 0; i < filter_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << *filter_p2;
filter_p2++;
}
SHADOW_LOG << "resize output ...";
output.Resize(out_dim);
// cpu conv basic calc
lite::Tensor out_ref;
out_ref.Resize(out_dim);
SHADOW_LOG << "prepare kernel ready";
SHADOW_LOG << "kernel launch ...";
kernel->Launch();
SHADOW_LOG << "mutable output ...";
auto* output_image2d = output.mutable_data<half_t, cl::Image2D>(
out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
SHADOW_LOG << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target"
"cl tensor.";
SHADOW_LOG << "卷积核 image : ---- ";
for (int i = 0; i < filter_image_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << Half2Float(filter_image_v[i]);
}
if (bias_flag) {
for (int i = 0; i < bias_dim.production(); ++i) {
bias_v[i] = static_cast<int>(gen(engine));
}
bias.Assign<float, lite::DDim, TARGET(kARM)>(bias_v.data(),
bias_dim);
// CLImageConverterFolder folder_convertor;
// folder_convertor.NCHWToImage(
// bias_v.data(), bias_image_v.data(),
// bias_dim);
//
// auto* bias_data = bias.mutable_data<float,
// cl::Image2D>(
// bias_image_width, bias_image_height,
// bias_image_v.data());
}
SHADOW_LOG << "resize output ...";
output.Resize(out_dim);
// cpu conv basic calc
lite::Tensor out_ref;
out_ref.Resize(out_dim);
SHADOW_LOG << "prepare kernel ready";
SHADOW_LOG << "kernel launch ...";
kernel->Launch();
SHADOW_LOG << "mutable output ...";
auto* output_image2d = output.mutable_data<half_t, cl::Image2D>(
out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
SHADOW_LOG << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target"
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<half_t, cl::Image2D>(),
out_image_width,
out_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
DDim out_image_shape =
default_convertor.InitImageDimInfoWith(output.dims());
default_convertor.ImageToNCHW(out_image_v.data(),
output_v.data(),
out_image_shape,
output.dims());
SHADOW_LOG << " lite输出 out_image_v ..... ";
for (int i = 0; i < out_image_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << Half2Float(out_image_v[i]);
}
SHADOW_LOG << " lite输出 output_v ..... ";
for (int i = 0; i < output_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << output_v[i];
}
SHADOW_LOG << "mutable_data out_ref_data: ";
// run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
SHADOW_LOG << " conv_basic beigin ..... ";
TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<half_t, cl::Image2D>(),
out_image_width,
out_image_height,
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
DDim out_image_shape =
default_convertor.InitImageDimInfoWith(output.dims());
default_convertor.ImageToNCHW(out_image_v.data(),
output_v.data(),
out_image_shape,
output.dims());
SHADOW_LOG << "mutable_data out_ref_data: ";
// run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
SHADOW_LOG << " conv_basic beigin ..... ";
conv_basic<float, float>(input_v.data(),
out_ref_data,
batch_size,
oc,
oh,
ow,
ic,
ih,
iw,
filter_v.data(),
bias_v.data(), // mapped_bias,
group,
ksize,
ksize,
stride,
stride,
dilation,
dilation,
pad,
pad,
bias_flag,
relu_flag);
SHADOW_LOG << " conv_basic end ..... ";
SHADOW_LOG << " out_dim: " << out_dim;
const DDim& out_image_dims = lite::DDim{std::vector<int64_t>(
{static_cast<int64_t>(out_image_width),
static_cast<int64_t>(out_image_height)})};
for (int i = 0; i < out_dim.production(); i++) {
auto relative_diff =
COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]);
EXPECT_LT(relative_diff, FP16_MAX_DIFF);
if (relative_diff > FP16_MAX_DIFF) {
LOG(FATAL) << "error idx:" << i << "output_v[" << i
<< "]:" << output_v[i] << " "
"out_ref_data["
<< i << "]:" << out_ref_data[i];
}
conv_basic<float, float>(input_v.data(),
out_ref_data,
batch_size,
oc,
oh,
ow,
ic,
ih,
iw,
filter_v.data(),
bias_v.data(), // mapped_bias,
group,
ksize,
ksize,
stride,
stride,
dilation,
dilation,
pad,
pad,
bias_flag,
relu_flag);
SHADOW_LOG << " conv_basic end ..... ";
SHADOW_LOG << " out_dim: " << out_dim;
const DDim& out_image_dims = lite::DDim{std::vector<int64_t>(
{static_cast<int64_t>(out_image_width),
static_cast<int64_t>(out_image_height)})};
for (int i = 0; i < out_dim.production(); i++) {
auto relative_diff =
COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]);
auto abs_diff = COMPTUE_ABS_DIFF(output_v[i], out_ref_data[i]);
// EXPECT_LT(relative_diff, FP16_MAX_DIFF);
EXPECT_FALSE(relative_diff > FP16_MAX_DIFF &&
abs_diff > FP16_ABS_DIFF);
if (relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF) {
LOG(FATAL) << "error idx:" << i << "output_v[" << i
<< "]:" << output_v[i] << " "
"out_ref_data["
<< i << "]:" << out_ref_data[i];
}
#ifdef LOOP_TEST
}
#ifdef LOOP_TEST
}
}
}
......@@ -450,7 +495,9 @@ TEST(conv2d, compute_image2d_1x1) {
}
#undef LOOP_TEST
#undef PRINT_RESULT
#endif
#ifdef TEST_CONV_IMAGE_3x3
// #define PRINT_RESULT
// #define LOOP_TEST
TEST(conv2d, compute_image2d_3x3) {
......@@ -471,11 +518,11 @@ TEST(conv2d, compute_image2d_3x3) {
for (bool bias_flag : {true, false}) {
for (std::string relu_flag : {/*true,*/ "relu"}) {
#else
const int pad = 1;
const int dilation = 1;
const int pad = 1;
const int dilation = 1;
#if 0 // small scale with group, but result of cpu reference is wrong
const int stride = 2;
const int stride = 2;
const int group = 2;
const int batch_size = 1;
const int ic = 1;
......@@ -483,17 +530,17 @@ TEST(conv2d, compute_image2d_3x3) {
const int iw = 3;
const int oc = 2;
#else // big scale with group
const int stride = 1;
const int group = 32 / 1;
const int batch_size = 1;
const int ic = 32 / 1;
const int ih = 112 / 1;
const int iw = 112 / 1;
const int oc = 32 / 1;
const int stride = 1;
const int group = 32 / 1;
const int batch_size = 1;
const int ic = 32 / 1;
const int ih = 112 / 1;
const int iw = 112 / 1;
const int oc = 32 / 1;
#endif
const bool bias_flag = false;
const std::string relu_flag = "relu";
const bool bias_flag = false;
const std::string relu_flag = "relu";
#endif
int filter_channel = ic;
if (group > 1) {
......@@ -823,6 +870,10 @@ TEST(conv2d, compute_image2d_3x3) {
#undef LOOP_TEST
#undef PRINT_RESULT
#endif
#ifdef TEST_CONV_IMAGE_5x5
// #define PRINT_RESULT
// #define LOOP_TEST
TEST(conv2d, compute_image2d_5x5) {
......@@ -839,17 +890,18 @@ TEST(conv2d, compute_image2d_5x5) {
for (int oc = 1; oc < 10; oc += 1) { // oc
for (int ih = 5; ih < 9; ih += 1) { // ih
int iw = ih;
for (int ic = 1; ic < 10; ic += 1) { // ic
for (int ic = 2; ic < 10; ic += 1) { // ic
for (bool bias_flag : {true, false}) {
for (std::string relu_flag : {/*true,*/ "relu"}) {
#else
const int batch_size = 2;
const int oc = 1;
const int ih = 5;
const int iw = 5;
const int ic = 1;
const bool bias_flag = true;
const std::string relu_flag = "relu";
const int batch_size = 2;
const int oc = 1;
const int ih = 5;
const int iw = 5;
// ic = 1 会进入depthwise的路由 .
const int ic = 2;
const bool bias_flag = true;
const std::string relu_flag = "relu";
#endif
const int oh =
......@@ -1139,8 +1191,10 @@ TEST(conv2d, compute_image2d_5x5) {
for (int i = 0; i < out_dim.production(); i++) {
auto relative_diff =
COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]);
EXPECT_LT(relative_diff, FP16_MAX_DIFF);
if (relative_diff > FP16_MAX_DIFF) {
auto abs_diff = COMPTUE_ABS_DIFF(output_v[i], out_ref_data[i]);
EXPECT_FALSE(relative_diff > FP16_MAX_DIFF &&
abs_diff > FP16_ABS_DIFF);
if (relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF) {
LOG(FATAL) << "error idx:" << i << "output_v[" << i
<< "]:" << output_v[i] << " "
"out_ref_data["
......@@ -1161,13 +1215,16 @@ TEST(conv2d, compute_image2d_5x5) {
}
#undef LOOP_TEST
#undef PRINT_RESULT
#endif
#ifdef TEST_CONV_IMAGE_7x7
#undef FP16_ABS_DIFF
#define FP16_ABS_DIFF (1e0)
// #define LOOP_TEST
TEST(conv2d, compute_image2d_7x7) {
// conv infos
const int ksize = 7;
const int stride = 1;
const int pad = 2;
const int pad = 3;
const int group = 1;
const int dilation = 1;
// int loop_cnt = 0;
......@@ -1177,17 +1234,18 @@ TEST(conv2d, compute_image2d_7x7) {
for (int oc = 1; oc < 10; oc += 1) { // oc
for (int ih = 7; ih < 15; ih += 1) { // ih
int iw = ih;
for (int ic = 1; ic < 10; ic += 1) { // ic
for (int ic = 2; ic < 10; ic += 1) { // ic
for (bool bias_flag : {true, false}) {
for (std::string relu_flag : {"relu"}) {
#else
const int batch_size = 2;
const int oc = 1;
const int ih = 7;
const int iw = 7;
const int ic = 1;
const bool bias_flag = false;
const std::string relu_flag = "";
const int batch_size = 2;
const int oc = 1;
const int ih = 7;
const int iw = 7;
// ic = 1会进入 depthwise路由
const int ic = 2;
const bool bias_flag = false;
const std::string relu_flag = "";
#endif
const int oh =
......@@ -1286,11 +1344,15 @@ TEST(conv2d, compute_image2d_7x7) {
SHADOW_LOG << "gen input and filter ...";
for (auto& i : input_v) {
i = gen(engine);
// i = 1;
#ifdef TEST_CONV_IMAGE_ALL_1
i = 1;
#endif
}
for (auto& f : filter_v) {
f = gen(engine);
// f = 1;
#ifdef TEST_CONV_IMAGE_ALL_1
f = 1;
#endif
}
LOG(INFO) << "bias: " << bias_flag;
LOG(INFO) << "relu: " << relu_flag;
......@@ -1340,7 +1402,7 @@ TEST(conv2d, compute_image2d_7x7) {
}
SHADOW_LOG << "输入image : ---- ";
for (int i = 0; i < x_image_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << x_image_v[i];
SHADOW_LOG << "(" << i << ")" << Half2Float(x_image_v[i]);
}
SHADOW_LOG << "set mapped filter ...";
CLImageConverterFolder folder_convertor;
......@@ -1353,7 +1415,7 @@ TEST(conv2d, compute_image2d_7x7) {
}
SHADOW_LOG << "卷积核image: ---- ";
for (int i = 0; i < filter_image_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << filter_image_v[i];
SHADOW_LOG << "(" << i << ")" << Half2Float(filter_image_v[i]);
}
auto* input_image2d = input.mutable_data<half_t, cl::Image2D>(
input_image_width, input_image_height, x_image_v.data());
......@@ -1437,7 +1499,7 @@ TEST(conv2d, compute_image2d_7x7) {
SHADOW_LOG << "输出image: ---- ";
for (int i = 0; i < out_image_v.size(); i++) {
SHADOW_LOG << "(" << i << ")" << out_image_v[i];
SHADOW_LOG << "(" << i << ")" << Half2Float(out_image_v[i]);
}
SHADOW_LOG << "mutable_data out_ref_data: ";
......@@ -1478,8 +1540,10 @@ TEST(conv2d, compute_image2d_7x7) {
for (int i = 0; i < out_dim.production(); i++) {
auto relative_diff =
COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]);
EXPECT_LT(relative_diff, FP16_MAX_DIFF);
if (relative_diff > FP16_MAX_DIFF) {
auto abs_diff = COMPTUE_ABS_DIFF(output_v[i], out_ref_data[i]);
EXPECT_FALSE(relative_diff > FP16_MAX_DIFF &&
abs_diff > FP16_ABS_DIFF);
if (relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF) {
LOG(FATAL) << "error idx:" << i << "output_v[" << i
<< "]:" << output_v[i] << " "
"out_ref_data["
......@@ -1500,7 +1564,14 @@ TEST(conv2d, compute_image2d_7x7) {
}
#undef LOOP_TEST
#undef PRINT_RESULT
#endif
#undef SHADOW_LOG
#undef TEST_CONV_IMAGE_1x1
#undef TEST_CONV_IMAGE_3x3
#undef TEST_CONV_IMAGE_5x5
#undef TEST_CONV_IMAGE_7x7
#undef TEST_CONV_IMAGE_ALL_1
} // namespace lite
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册