提交 63c8675c 编写于 作者: C chenjiaoAngel

pull code

Merge branch 'opencl' of https://github.com/chenjiaoAngel/Paddle-Lite into int8
...@@ -33,6 +33,7 @@ void add_bias_rowwise(Tensor* input, ...@@ -33,6 +33,7 @@ void add_bias_rowwise(Tensor* input,
for (int w = start_w; w < w_adds; ++w) { for (int w = start_w; w < w_adds; ++w) {
i_data[w] += b_data[w]; i_data[w] += b_data[w];
} }
i_data += width;
} }
} }
void vector_dot( void vector_dot(
...@@ -67,15 +68,8 @@ void vector_dot( ...@@ -67,15 +68,8 @@ void vector_dot(
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
if (!v2) { if (!v2) {
out_ptr[i] = in_ptr[i] * v1_ptr[i]; out_ptr[i] = in_ptr[i] * v1_ptr[i];
++out_ptr;
++in_ptr;
++v1_ptr;
} else { } else {
out_ptr[i] = in_ptr[i] + v1_ptr[i] * v2_ptr[i]; out_ptr[i] = in_ptr[i] + v1_ptr[i] * v2_ptr[i];
++out_ptr;
++in_ptr;
++v1_ptr;
++v2_ptr;
} }
} }
} }
......
...@@ -28,6 +28,7 @@ namespace lite { ...@@ -28,6 +28,7 @@ namespace lite {
class CLContext { class CLContext {
public: public:
~CLContext() { ~CLContext() {
GetCommandQueue().finish();
for (size_t kidx = 0; kidx < kernels_.size(); ++kidx) { for (size_t kidx = 0; kidx < kernels_.size(); ++kidx) {
// Note(ysh329): Don't need `clReleaseKernel` // Note(ysh329): Don't need `clReleaseKernel`
kernels_[kidx].reset(); kernels_[kidx].reset();
......
...@@ -100,16 +100,18 @@ TEST(cl_test, kernel_test) { ...@@ -100,16 +100,18 @@ TEST(cl_test, kernel_test) {
size_t width = in_image.ImageWidth(); size_t width = in_image.ImageWidth();
size_t height = in_image.ImageHeight(); size_t height = in_image.ImageHeight();
auto global_work_size = cl::NDRange{width, height}; auto global_work_size = cl::NDRange{width, height};
cl::Event event;
status = context->GetCommandQueue().enqueueNDRangeKernel( status = context->GetCommandQueue().enqueueNDRangeKernel(
kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, &event); kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = context->GetCommandQueue().finish(); status = context->GetCommandQueue().finish();
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
#if 0
double start_nanos = event.getProfilingInfo<CL_PROFILING_COMMAND_START>(); double start_nanos = event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
double stop_nanos = event.getProfilingInfo<CL_PROFILING_COMMAND_END>(); double stop_nanos = event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
double elapsed_micros = (stop_nanos - start_nanos) / 1000.0; double elapsed_micros = (stop_nanos - start_nanos) / 1000.0;
LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros << " us."; LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros << " us.";
#endif
LOG(INFO) << out_image; LOG(INFO) << out_image;
} }
......
...@@ -73,7 +73,7 @@ void CLImageConverterDefault::NCHWToImage(float *nchw, ...@@ -73,7 +73,7 @@ void CLImageConverterDefault::NCHWToImage(float *nchw,
i2 += 4; i2 += 4;
p++; p++;
} else { } else {
image[i2] = 0.0; image[i2] = Float2Half(0.f);
i2 += 4; i2 += 4;
} }
} }
...@@ -261,7 +261,7 @@ void CLImageConverterNWBlock::NCHWToImage(float *tensor, ...@@ -261,7 +261,7 @@ void CLImageConverterNWBlock::NCHWToImage(float *tensor,
image[index] = Float2Half(*p); image[index] = Float2Half(*p);
p++; p++;
} else { } else {
image[index] = 0.0; image[index] = Float2Half(0.f);
} }
if (index >= (width * height * 4)) { if (index >= (width * height * 4)) {
LOG(INFO) << " index out of range "; LOG(INFO) << " index out of range ";
......
...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
///////////////////////////////// /////////////////////////////////
...@@ -108,7 +107,8 @@ inline CL_DTYPE4 activation_type4(CL_DTYPE4 in ...@@ -108,7 +107,8 @@ inline CL_DTYPE4 activation_type4(CL_DTYPE4 in
#endif #endif
#ifdef RELU6 #ifdef RELU6
output = clamp(in, (CL_DTYPE4)0, (CL_DTYPE4)6); in = fmax((CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), in);
output = fmin((CL_DTYPE4)(6.0f, 6.0f, 6.0f, 6.0f), in);
#endif #endif
return output; return output;
} }
...@@ -14,36 +14,30 @@ limitations under the License. */ ...@@ -14,36 +14,30 @@ limitations under the License. */
#include <cl_common.h> #include <cl_common.h>
__kernel void relu(__read_only image2d_t input, __kernel void relu(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const float threshold, __private const float threshold,
__private const float scale) { __private const float scale) {
const int x = get_global_id(0); // image_width
const int y = get_global_id(1); // image_height
const int x = get_global_id(0); // image_width const sampler_t sampler =
const int y = get_global_id(1); // image_height CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
in = max((CL_DTYPE4)(0.0f), in); in = max((CL_DTYPE4)(0.0f), in);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
} }
__kernel void relu6(__read_only image2d_t input, __kernel void relu6(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const float threshold, __private const float threshold,
__private const float scale){ __private const float scale) {
const int x = get_global_id(0); const int x = get_global_id(0);
const int y = get_global_id(1); const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler =
CLK_ADDRESS_CLAMP | CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
in = max((CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), in); in = max((CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), in);
...@@ -51,7 +45,6 @@ __kernel void relu6(__read_only image2d_t input, ...@@ -51,7 +45,6 @@ __kernel void relu6(__read_only image2d_t input,
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
} }
__kernel void sigmoid(__read_only image2d_t input, __kernel void sigmoid(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const float threshold, __private const float threshold,
...@@ -64,70 +57,66 @@ __kernel void sigmoid(__read_only image2d_t input, ...@@ -64,70 +57,66 @@ __kernel void sigmoid(__read_only image2d_t input,
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
CL_DTYPE4 out; CL_DTYPE4 out;
out.x = 1.0 / (1.0 + pow(2.71828182, -1.0 * (float)(in.x)));
out.y = 1.0 / (1.0 + pow(2.71828182, -1.0 * (float)(in.y))); out.x = (CL_DTYPE)(1.0f / (1.0f + pow(2.71828182f, -1.0f * (float)(in.x))));
out.z = 1.0 / (1.0 + pow(2.71828182, -1.0 * (float)(in.z))); out.y = (CL_DTYPE)(1.0f / (1.0f + pow(2.71828182f, -1.0f * (float)(in.y))));
out.w = 1.0 / (1.0 + pow(2.71828182, -1.0 * (float)(in.w))); out.z = (CL_DTYPE)(1.0f / (1.0f + pow(2.71828182f, -1.0f * (float)(in.z))));
out.w = (CL_DTYPE)(1.0f / (1.0f + pow(2.71828182f, -1.0f * (float)(in.w))));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out);
} }
__kernel void leaky_relu(__read_only image2d_t input, __kernel void leaky_relu(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const float threshold, __private const float threshold,
__private const float scale) { __private const float scale) {
const int x = get_global_id(0); const int x = get_global_id(0);
const int y = get_global_id(1); const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler =
CLK_ADDRESS_CLAMP | CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
CL_DTYPE4 s_val = CONVERT_TYPE_TO(scale, CL_DTYPE) * in; CL_DTYPE4 s_val = CONVERT_TYPE_TO(scale, CL_DTYPE) * in;
if (in.x < 0.0f){ if (in.x < 0.0f) {
in.x = s_val.x; in.x = s_val.x;
} }
if (in.y < 0.0f){ if (in.y < 0.0f) {
in.y = s_val.y; in.y = s_val.y;
} }
if (in.z < 0.0f){ if (in.z < 0.0f) {
in.z = s_val.z; in.z = s_val.z;
} }
if (in.w < 0.0f){ if (in.w < 0.0f) {
in.w = s_val.w; in.w = s_val.w;
} }
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in);
} }
__kernel void tanh_act(__read_only image2d_t input, __kernel void tanh_act(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const float threshold, __private const float threshold,
__private const float scale) { __private const float scale) {
const int x = get_global_id(0); // image_width
const int x = get_global_id(0); // image_width const int y = get_global_id(1); // image_height
const int y = get_global_id(1); // image_height
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler =
CLK_ADDRESS_CLAMP | CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
CL_DTYPE4 out= (exp(in) - exp(-in))/ (exp(in) + exp(-in)); CL_DTYPE4 out = (exp(in) - exp(-in)) / (exp(in) + exp(-in));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out);
} }
__kernel void exp_act(__read_only image2d_t input, __kernel void exp_act(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const float threshold, __private const float threshold,
__private const float scale) { __private const float scale) {
const int x = get_global_id(0); // image_width
const int x = get_global_id(0); // image_width const int y = get_global_id(1); // image_height
const int y = get_global_id(1); // image_height
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler =
CLK_ADDRESS_CLAMP | CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
CL_DTYPE4 out = exp(in); CL_DTYPE4 out = exp(in);
...@@ -135,19 +124,16 @@ __kernel void exp_act(__read_only image2d_t input, ...@@ -135,19 +124,16 @@ __kernel void exp_act(__read_only image2d_t input,
} }
__kernel void swish(__read_only image2d_t input, __kernel void swish(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const float threshold, __private const float threshold,
__private const float scale) { __private const float scale) {
const int x = get_global_id(0); // image_width
const int x = get_global_id(0); // image_width const int y = get_global_id(1); // image_height
const int y = get_global_id(1); // image_height
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler =
CLK_ADDRESS_CLAMP | CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CLK_FILTER_NEAREST;
CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y));
CL_DTYPE4 out = in / (1 + exp(-(CL_DTYPE)scale * in)); CL_DTYPE4 out = in / (1 + exp(-(CL_DTYPE)scale * in));
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out);
} }
#include <cl_common.h> #include <cl_common.h>
__kernel void conv2d_1x1_opt(__private const int global_size_dim0, __kernel void conv2d_1x1_opt(
__private const int global_size_dim1, __private const int global_size_dim0,
__private const int global_size_dim2, __private const int global_size_dim1,
__read_only image2d_t input_image, __private const int global_size_dim2,
__read_only image2d_t filter, __read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE) #if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias, __read_only image2d_t bias,
#endif #endif
#ifdef BATCH_NORM #ifdef BATCH_NORM
__read_only image2d_t new_scale, __read_only image2d_t new_scale,
__read_only image2d_t new_biase, __read_only image2d_t new_biase,
#endif #endif
__write_only image2d_t output_image, __write_only image2d_t output_image,
__private const int stride, __private const int stride,
__private const int offset, __private const int offset,
__private const int input_c_block, __private const int input_c_block,
__private const int input_c_origin, __private const int input_c_origin,
__private const int dilation, __private const int dilation,
__private const int input_width, /* of one block */ __private const int input_width, /* of one block */
__private const int input_height, /* of one block */ __private const int input_height, /* of one block */
__private const int output_width, __private const int output_width,
__private const int output_height, __private const int output_height,
__private const int old_w) { __private const int old_w) {
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
...@@ -287,7 +288,7 @@ __kernel void conv2d_1x1_simple( ...@@ -287,7 +288,7 @@ __kernel void conv2d_1x1_simple(
__read_only image2d_t bias, __read_only image2d_t bias,
#endif #endif
#ifdef BATCH_NORM #ifdef BATCH_NORM
__read_only image2d_t new_scale, __read_only image2d_t new_scale,
__read_only image2d_t new_biase, __read_only image2d_t new_biase,
#endif #endif
__write_only image2d_t output_image, __write_only image2d_t output_image,
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// buffer -> image2d // buffer -> image2d
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
__kernel void buffer_to_image2d(__global CL_DTYPE *in, __kernel void buffer_to_image2d(__global CL_DTYPE* in,
__write_only image2d_t output_image, __write_only image2d_t output_image,
__private const int out_H, __private const int out_H,
__private const int out_W, __private const int out_W,
...@@ -26,7 +26,6 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in, ...@@ -26,7 +26,6 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in,
__private const int Stride0, __private const int Stride0,
__private const int Stride1, __private const int Stride1,
__private const int Stride2) { __private const int Stride2) {
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
...@@ -66,16 +65,25 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in, ...@@ -66,16 +65,25 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in,
#ifdef DEBUG #ifdef DEBUG
if (out_w > 2045) { if (out_w > 2045) {
printf("out_w:%d, out_C - 4 * out_c:%d, input[pos0~pos3]:%.2f %.2f %.2f %.2f\n", printf(
out_w, "out_w:%d, out_C - 4 * out_c:%d, input[pos0~pos3]:%.2f %.2f %.2f "
out_C - 4 * out_c, "%.2f\n",
(float)(in[input_pos0]), out_w,
(float)(in[input_pos1]), out_C - 4 * out_c,
(float)(in[input_pos2]), (float)(in[input_pos0]),
(float)(in[input_pos3])); (float)(in[input_pos1]),
printf("buffer2image ===> %d,%d,%d, out(%d,%d): %.2f %.2f %.2f %.2f \n", out_c, out_w, out_nh, (float)(in[input_pos2]),
output_pos.x, output_pos.y, (float)(in[input_pos3]));
(float)(output.x), (float)(output.y), (float)(output.z), (float)(output.w)); printf("buffer2image ===> %d,%d,%d, out(%d,%d): %.2f %.2f %.2f %.2f \n",
out_c,
out_w,
out_nh,
output_pos.x,
output_pos.y,
(float)(output.x),
(float)(output.y),
(float)(output.z),
(float)(output.w));
} }
#endif #endif
...@@ -101,34 +109,42 @@ __kernel void image2d_to_buffer(__read_only image2d_t input, ...@@ -101,34 +109,42 @@ __kernel void image2d_to_buffer(__read_only image2d_t input,
const int in_h = in_nh % in_height; const int in_h = in_nh % in_height;
const sampler_t sampler = const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const int pos_x = mad24(in_c, in_width, in_w); const int pos_x = mad24(in_c, in_width, in_w);
CL_COMPUTE_DTYPE4 in = READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh)); CL_COMPUTE_DTYPE4 in = READ_IMG_TYPE(
CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh));
#ifdef DEBUG #ifdef DEBUG
if (in_w > 2045) { if (in_w > 2045) {
printf("image2buffer ===> %d,%d,%d, in(%d,%d): %.2f %.2f %.2f %.2f \n", in_c, in_w, in_nh, printf("image2buffer ===> %d,%d,%d, in(%d,%d): %.2f %.2f %.2f %.2f \n",
pos_x, in_nh, in_c,
(float)(in.x), (float)(in.y), (float)(in.z), (float)(in.w)); in_w,
in_nh,
pos_x,
in_nh,
(float)(in.x),
(float)(in.y),
(float)(in.z),
(float)(in.w));
} }
#endif #endif
const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; const int index =
in_n * size_batch + in_c * size_block + in_h * in_width + in_w;
out[index] = CONVERT_TYPE_TO(in.x, CL_DTYPE); out[index] = CONVERT_TYPE_TO(in.x, CL_DTYPE);
if (C - 4 * in_c >= 2) { if (C - 4 * in_c >= 2) {
out[index + size_ch] = CONVERT_TYPE_TO(in.y, CL_DTYPE); out[index + size_ch] = CONVERT_TYPE_TO(in.y, CL_DTYPE);
} }
if(C - 4 * in_c >= 3) { if (C - 4 * in_c >= 3) {
out[index + size_ch * 2] = CONVERT_TYPE_TO(in.z, CL_DTYPE); out[index + size_ch * 2] = CONVERT_TYPE_TO(in.z, CL_DTYPE);
} }
if(C - 4 * in_c >= 4) { if (C - 4 * in_c >= 4) {
out[index + size_ch * 3] = CONVERT_TYPE_TO(in.w, CL_DTYPE); out[index + size_ch * 3] = CONVERT_TYPE_TO(in.w, CL_DTYPE);
} }
} }
#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile
#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// buffer -> image2d_nw // buffer -> image2d_nw
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
...@@ -182,8 +198,7 @@ __kernel void buffer_to_image2d_nw(__global CL_DTYPE* in, ...@@ -182,8 +198,7 @@ __kernel void buffer_to_image2d_nw(__global CL_DTYPE* in,
} }
#endif #endif
#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile
#if 0 // NOTE(ysh329): keep, un-used from paddle-mobile
// image2d -> buffer // image2d -> buffer
__kernel void image2d_to_buffer_2d(__private const int in_height, __kernel void image2d_to_buffer_2d(__private const int in_height,
__private const int in_width, __private const int in_width,
...@@ -208,15 +223,14 @@ __kernel void image2d_to_buffer_2d(__private const int in_height, ...@@ -208,15 +223,14 @@ __kernel void image2d_to_buffer_2d(__private const int in_height,
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// buffer -> image2d (divide by 255 to normalize) // buffer -> image2d (divide by 255 to normalize)
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
__kernel void buffer_to_image2d_with_pre255(__global uchar *in, __kernel void buffer_to_image2d_with_pre255(__global uchar* in,
__write_only image2d_t output_image, __write_only image2d_t output_image,
__private const int out_H, __private const int out_H,
__private const int out_W, __private const int out_W,
__private const int out_C, __private const int out_C,
__private const int Stride0, __private const int Stride0,
__private const int Stride1, __private const int Stride1,
__private const int Stride2){ __private const int Stride2) {
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
...@@ -231,7 +245,6 @@ __kernel void buffer_to_image2d_with_pre255(__global uchar *in, ...@@ -231,7 +245,6 @@ __kernel void buffer_to_image2d_with_pre255(__global uchar *in,
const int in_h = out_h; const int in_h = out_h;
const int in_w = out_w; const int in_w = out_w;
int input_pos0 = in_n * Stride2 + in_c0 * Stride1 + in_h * Stride0 + in_w; int input_pos0 = in_n * Stride2 + in_c0 * Stride1 + in_h * Stride0 + in_w;
int input_pos1 = in_n * Stride2 + in_c1 * Stride1 + in_h * Stride0 + in_w; int input_pos1 = in_n * Stride2 + in_c1 * Stride1 + in_h * Stride0 + in_w;
int input_pos2 = in_n * Stride2 + in_c2 * Stride1 + in_h * Stride0 + in_w; int input_pos2 = in_n * Stride2 + in_c2 * Stride1 + in_h * Stride0 + in_w;
...@@ -243,30 +256,29 @@ __kernel void buffer_to_image2d_with_pre255(__global uchar *in, ...@@ -243,30 +256,29 @@ __kernel void buffer_to_image2d_with_pre255(__global uchar *in,
CL_COMPUTE_DTYPE4 output = (CL_COMPUTE_DTYPE4)0.0f; CL_COMPUTE_DTYPE4 output = (CL_COMPUTE_DTYPE4)0.0f;
output.x = CONVERT_TYPE_TO(in[input_pos0], CL_COMPUTE_DTYPE) / 255; output.x = CONVERT_TYPE_TO(in[input_pos0], CL_COMPUTE_DTYPE) / 255;
if(out_C - 4 * out_c>=2){ if (out_C - 4 * out_c >= 2) {
output.y = CONVERT_TYPE_TO(in[input_pos1], CL_COMPUTE_DTYPE) / 255; output.y = CONVERT_TYPE_TO(in[input_pos1], CL_COMPUTE_DTYPE) / 255;
} }
if(out_C - 4 * out_c>=3){ if (out_C - 4 * out_c >= 3) {
output.z = CONVERT_TYPE_TO(in[input_pos2], CL_COMPUTE_DTYPE) / 255; output.z = CONVERT_TYPE_TO(in[input_pos2], CL_COMPUTE_DTYPE) / 255;
} }
if(out_C - 4 * out_c>=4){ if (out_C - 4 * out_c >= 4) {
output.w = CONVERT_TYPE_TO(in[input_pos3], CL_COMPUTE_DTYPE) / 255; output.w = CONVERT_TYPE_TO(in[input_pos3], CL_COMPUTE_DTYPE) / 255;
} }
WRITE_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, output_image, output_pos, output); WRITE_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, output_image, output_pos, output);
} }
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// image2d -> buffer (multiply by 255 to de-normalize) // image2d -> buffer (multiply by 255 to de-normalize)
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
__kernel void image2d_to_buffer_with_post255(__read_only image2d_t input, __kernel void image2d_to_buffer_with_post255(__read_only image2d_t input,
__private const int in_width, __private const int in_width,
__private const int in_height, __private const int in_height,
__global uchar* out, __global uchar* out,
__private const int size_ch, __private const int size_ch,
__private const int size_block, __private const int size_block,
__private const int size_batch, __private const int size_batch,
__private const int C) { __private const int C) {
const int in_c = get_global_id(0); const int in_c = get_global_id(0);
const int in_w = get_global_id(1); const int in_w = get_global_id(1);
const int in_nh = get_global_id(2); const int in_nh = get_global_id(2);
...@@ -277,22 +289,34 @@ __kernel void image2d_to_buffer_with_post255(__read_only image2d_t input, ...@@ -277,22 +289,34 @@ __kernel void image2d_to_buffer_with_post255(__read_only image2d_t input,
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const int pos_x = mad24(in_c, in_width, in_w); const int pos_x = mad24(in_c, in_width, in_w);
CL_COMPUTE_DTYPE4 in = READ_IMG_TYPE(CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh)) * 255; CL_COMPUTE_DTYPE4 in =
READ_IMG_TYPE(
CL_COMPUTE_DTYPE_CHAR, input, sampler, (int2)(pos_x, in_nh)) *
255;
#ifdef DEBUG #ifdef DEBUG
printf("in_c:%d, in_w:%d, in_nh:%d ===> in(%d,%d): %.2f %.2f %.2f %.2f\n", printf("in_c:%d, in_w:%d, in_nh:%d ===> in(%d,%d): %.2f %.2f %.2f %.2f\n",
in_c, in_w, in_nh, pos_x, in_nh, in.x, in.y, in.z, in.w); in_c,
in_w,
in_nh,
pos_x,
in_nh,
in.x,
in.y,
in.z,
in.w);
#endif #endif
const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; const int index =
in_n * size_batch + in_c * size_block + in_h * in_width + in_w;
out[index] = convert_uchar_sat(in.x); out[index] = convert_uchar_sat(in.x);
if(C - 4 * in_c>=2){ if (C - 4 * in_c >= 2) {
out[index + size_ch] = convert_uchar_sat(in.y); out[index + size_ch] = convert_uchar_sat(in.y);
} }
if(C - 4 * in_c>=3){ if (C - 4 * in_c >= 3) {
out[index + size_ch * 2] = convert_uchar_sat(in.z); out[index + size_ch * 2] = convert_uchar_sat(in.z);
} }
if(C - 4 * in_c>=4){ if (C - 4 * in_c >= 4) {
out[index + size_ch * 3] = convert_uchar_sat(in.w); out[index + size_ch * 3] = convert_uchar_sat(in.w);
} }
} }
...@@ -45,6 +45,9 @@ bool CLRuntime::Init() { ...@@ -45,6 +45,9 @@ bool CLRuntime::Init() {
bool is_device_init = InitializeDevice(); bool is_device_init = InitializeDevice();
is_init_success_ = is_platform_init && is_device_init; is_init_success_ = is_platform_init && is_device_init;
initialized_ = true; initialized_ = true;
context_ = CreateContext();
command_queue_ = CreateCommandQueue(context());
return initialized_; return initialized_;
} }
...@@ -55,7 +58,7 @@ cl::Platform& CLRuntime::platform() { ...@@ -55,7 +58,7 @@ cl::Platform& CLRuntime::platform() {
cl::Context& CLRuntime::context() { cl::Context& CLRuntime::context() {
if (context_ == nullptr) { if (context_ == nullptr) {
context_ = CreateContext(); LOG(FATAL) << "context_ create failed. ";
} }
return *context_; return *context_;
} }
...@@ -67,7 +70,7 @@ cl::Device& CLRuntime::device() { ...@@ -67,7 +70,7 @@ cl::Device& CLRuntime::device() {
cl::CommandQueue& CLRuntime::command_queue() { cl::CommandQueue& CLRuntime::command_queue() {
if (command_queue_ == nullptr) { if (command_queue_ == nullptr) {
command_queue_ = CreateCommandQueue(context()); LOG(FATAL) << "command_queue_ create failed. ";
} }
return *command_queue_; return *command_queue_;
} }
...@@ -96,7 +99,7 @@ std::unique_ptr<cl::UserEvent> CLRuntime::CreateEvent( ...@@ -96,7 +99,7 @@ std::unique_ptr<cl::UserEvent> CLRuntime::CreateEvent(
bool CLRuntime::BuildProgram(cl::Program* program, const std::string& options) { bool CLRuntime::BuildProgram(cl::Program* program, const std::string& options) {
/* -I +CLRuntime::Global()->cl_path() + "/cl_kernel"*/ /* -I +CLRuntime::Global()->cl_path() + "/cl_kernel"*/
std::string build_option = options + " -cl-fast-relaxed-math "; std::string build_option = options + " -cl-fast-relaxed-math -cl-mad-enable";
VLOG(4) << "OpenCL build_option: " << build_option; VLOG(4) << "OpenCL build_option: " << build_option;
status_ = program->build({*device_}, build_option.c_str()); status_ = program->build({*device_}, build_option.c_str());
CL_CHECK_ERROR(status_); CL_CHECK_ERROR(status_);
......
...@@ -66,7 +66,8 @@ void *TargetWrapperCL::MallocImage<float>(const size_t cl_image2d_width, ...@@ -66,7 +66,8 @@ void *TargetWrapperCL::MallocImage<float>(const size_t cl_image2d_width,
cl_int status; cl_int status;
cl::Image2D *cl_image = cl::Image2D *cl_image =
new cl::Image2D(CLRuntime::Global()->context(), new cl::Image2D(CLRuntime::Global()->context(),
CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR
: CL_MEM_ALLOC_HOST_PTR),
img_format, img_format,
cl_image2d_width, cl_image2d_width,
cl_image2d_height, cl_image2d_height,
...@@ -89,7 +90,8 @@ void *TargetWrapperCL::MallocImage<uint16_t>(const size_t cl_image2d_width, ...@@ -89,7 +90,8 @@ void *TargetWrapperCL::MallocImage<uint16_t>(const size_t cl_image2d_width,
cl_int status; cl_int status;
cl::Image2D *cl_image = cl::Image2D *cl_image =
new cl::Image2D(CLRuntime::Global()->context(), new cl::Image2D(CLRuntime::Global()->context(),
CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR
: CL_MEM_ALLOC_HOST_PTR),
img_format, img_format,
cl_image2d_width, cl_image2d_width,
cl_image2d_height, cl_image2d_height,
...@@ -112,7 +114,8 @@ void *TargetWrapperCL::MallocImage<int32_t>(const size_t cl_image2d_width, ...@@ -112,7 +114,8 @@ void *TargetWrapperCL::MallocImage<int32_t>(const size_t cl_image2d_width,
cl_int status; cl_int status;
cl::Image2D *cl_image = cl::Image2D *cl_image =
new cl::Image2D(CLRuntime::Global()->context(), new cl::Image2D(CLRuntime::Global()->context(),
CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR : 0), CL_MEM_READ_WRITE | (host_ptr ? CL_MEM_COPY_HOST_PTR
: CL_MEM_ALLOC_HOST_PTR),
img_format, img_format,
cl_image2d_width, cl_image2d_width,
cl_image2d_height, cl_image2d_height,
...@@ -192,7 +195,6 @@ void TargetWrapperCL::MemcpySync(void *dst, ...@@ -192,7 +195,6 @@ void TargetWrapperCL::MemcpySync(void *dst,
size_t size, size_t size,
IoDirection dir) { IoDirection dir) {
cl_int status; cl_int status;
cl::Event event;
auto stream = CLRuntime::Global()->command_queue(); auto stream = CLRuntime::Global()->command_queue();
switch (dir) { switch (dir) {
case IoDirection::DtoD: case IoDirection::DtoD:
...@@ -202,9 +204,9 @@ void TargetWrapperCL::MemcpySync(void *dst, ...@@ -202,9 +204,9 @@ void TargetWrapperCL::MemcpySync(void *dst,
0, 0,
size, size,
nullptr, nullptr,
&event); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event.wait(); CLRuntime::Global()->command_queue().finish();
break; break;
case IoDirection::HtoD: case IoDirection::HtoD:
status = stream.enqueueWriteBuffer(*static_cast<cl::Buffer *>(dst), status = stream.enqueueWriteBuffer(*static_cast<cl::Buffer *>(dst),
...@@ -283,7 +285,6 @@ void TargetWrapperCL::ImgcpySync(void *dst, ...@@ -283,7 +285,6 @@ void TargetWrapperCL::ImgcpySync(void *dst,
cl::array<size_t, 3> origin = {0, 0, 0}; cl::array<size_t, 3> origin = {0, 0, 0};
cl::array<size_t, 3> region = {cl_image2d_width, cl_image2d_height, 1}; cl::array<size_t, 3> region = {cl_image2d_width, cl_image2d_height, 1};
cl_int status; cl_int status;
cl::Event event;
auto stream = CLRuntime::Global()->command_queue(); auto stream = CLRuntime::Global()->command_queue();
switch (dir) { switch (dir) {
case IoDirection::DtoD: case IoDirection::DtoD:
...@@ -293,9 +294,9 @@ void TargetWrapperCL::ImgcpySync(void *dst, ...@@ -293,9 +294,9 @@ void TargetWrapperCL::ImgcpySync(void *dst,
origin, origin,
region, region,
nullptr, nullptr,
&event); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event.wait(); CLRuntime::Global()->command_queue().finish();
break; break;
case IoDirection::HtoD: case IoDirection::HtoD:
status = stream.enqueueWriteImage(*static_cast<cl::Image2D *>(dst), status = stream.enqueueWriteImage(*static_cast<cl::Image2D *>(dst),
......
...@@ -340,27 +340,17 @@ class Context<TargetType::kX86> { ...@@ -340,27 +340,17 @@ class Context<TargetType::kX86> {
template <> template <>
class Context<TargetType::kOpenCL> { class Context<TargetType::kOpenCL> {
std::shared_ptr<CLContext> cl_context_; std::shared_ptr<CLContext> cl_context_;
using WaitListType =
std::unordered_map<decltype(static_cast<const void*>(nullptr)),
std::shared_ptr<cl::Event>>;
std::shared_ptr<WaitListType> cl_wait_list_;
public: public:
CLContext* cl_context() { return cl_context_.get(); } CLContext* cl_context() { return cl_context_.get(); }
WaitListType* cl_wait_list() { return cl_wait_list_.get(); }
void InitOnce() { void InitOnce() {
// Init cl runtime. // Init cl runtime.
CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed"; CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed";
cl_context_ = std::make_shared<CLContext>(); cl_context_ = std::make_shared<CLContext>();
cl_wait_list_ = std::make_shared<WaitListType>();
} }
void CopySharedTo(OpenCLContext* ctx) { void CopySharedTo(OpenCLContext* ctx) { ctx->cl_context_ = cl_context_; }
ctx->cl_context_ = cl_context_;
ctx->cl_wait_list_ = cl_wait_list_;
}
}; };
#endif #endif
......
...@@ -25,16 +25,16 @@ namespace lite { ...@@ -25,16 +25,16 @@ namespace lite {
bool OpLite::InferShape() { bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied. // InferShapeByMemoryInternal will be applied.
if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) { if (op_param_ && op_param_->input_tensor_ptrs() &&
op_param_->output_tensor_ptrs()) {
return this->InferShapeWithCache(); return this->InferShapeWithCache();
} else { } else {
// otherwise, InferShapeImpl is applied directly.
return this->InferShapeImpl(); return this->InferShapeImpl();
} }
} }
bool OpLite::InferShapeWithCache() { bool OpLite::InferShapeWithCache() {
// 1. Get vector of current input tensors // 1. Get vector of current input tensors
auto *current_inputs = param_.input_tensor_ptrs(); auto *current_inputs = op_param_->input_tensor_ptrs();
// 2. Get hash value of current inputs shape and lod // 2. Get hash value of current inputs shape and lod
size_t new_hash = 0; size_t new_hash = 0;
for (auto iter = current_inputs->begin(); iter != current_inputs->end(); for (auto iter = current_inputs->begin(); iter != current_inputs->end();
...@@ -59,7 +59,7 @@ bool OpLite::InferShapeWithCache() { ...@@ -59,7 +59,7 @@ bool OpLite::InferShapeWithCache() {
if (new_hash == io_shape_lod_hash_ && new_hash != 0) { if (new_hash == io_shape_lod_hash_ && new_hash != 0) {
// if current hash value is consistent with io_shape_lod_hash_, // if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused. // previous outputs shape and lod are reused.
auto *current_outputs = param_.output_tensor_ptrs(); auto *current_outputs = op_param_->output_tensor_ptrs();
for (size_t i = 0; i < current_outputs->size(); i++) { for (size_t i = 0; i < current_outputs->size(); i++) {
current_outputs->at(i)->Resize(last_output_shapes[i]); current_outputs->at(i)->Resize(last_output_shapes[i]);
current_outputs->at(i)->set_lod(last_output_lods[i]); current_outputs->at(i)->set_lod(last_output_lods[i]);
...@@ -68,10 +68,12 @@ bool OpLite::InferShapeWithCache() { ...@@ -68,10 +68,12 @@ bool OpLite::InferShapeWithCache() {
// otherwise, current hash value is changed, InferShapeImpl will apply. // otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_ = new_hash; io_shape_lod_hash_ = new_hash;
this->InferShapeImpl(); this->InferShapeImpl();
auto *current_outputs = param_.output_tensor_ptrs(); auto *current_outputs = op_param_->output_tensor_ptrs();
last_output_shapes.clear();
last_output_lods.clear();
for (size_t i = 0; i < current_outputs->size(); i++) { for (size_t i = 0; i < current_outputs->size(); i++) {
last_output_shapes[i] = current_outputs->at(i)->dims(); last_output_shapes.push_back(current_outputs->at(i)->dims());
last_output_lods[i] = current_outputs->at(i)->lod(); last_output_lods.push_back(current_outputs->at(i)->lod());
} }
} }
return true; return true;
......
...@@ -77,6 +77,11 @@ class OpLite : public Registry { ...@@ -77,6 +77,11 @@ class OpLite : public Registry {
// Link the external execution environ to internal context. // Link the external execution environ to internal context.
bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope); bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope);
template <typename T>
inline void AttachParam(T *param) {
op_param_ = static_cast<T *>(param);
}
const OpInfo *op_info() const { return op_info_.get(); } const OpInfo *op_info() const { return op_info_.get(); }
OpInfo *mutable_op_info() { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); }
...@@ -167,11 +172,10 @@ class OpLite : public Registry { ...@@ -167,11 +172,10 @@ class OpLite : public Registry {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_; std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_output_shapes{}; std::vector<DDimLite> last_output_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{}; std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{};
size_t io_shape_lod_hash_{}; size_t io_shape_lod_hash_{};
mutable operators::ParamBase param_; mutable operators::ParamBase *op_param_{nullptr};
private: private:
// Infer Shape according to memory, if current input shapes are consistent // Infer Shape according to memory, if current input shapes are consistent
......
...@@ -56,7 +56,6 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k ...@@ -56,7 +56,6 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k
add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM extra SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
## 3. extra kernels ## 3. extra kernels
add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -92,8 +91,6 @@ add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute. ...@@ -92,8 +91,6 @@ add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.
add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
...@@ -114,14 +114,14 @@ struct BeamSearchDecoder { ...@@ -114,14 +114,14 @@ struct BeamSearchDecoder {
lod.push_back(source_level_lod); lod.push_back(source_level_lod);
lod.push_back(sentence_level_lod); lod.push_back(sentence_level_lod);
*(id_tensor->mutable_lod()) = lod; id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())}); id_tensor->Resize({static_cast<int64_t>(id_data.size())});
auto id_ptr = id_tensor->mutable_data<int64_t>(); auto id_ptr = id_tensor->mutable_data<int64_t>();
TargetCopy( TargetCopy(
TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t)); TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t));
*(score_tensor->mutable_lod()) = lod; score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())}); score_tensor->Resize({static_cast<int64_t>(score_data.size())});
auto score_ptr = score_tensor->mutable_data<T>(); auto score_ptr = score_tensor->mutable_data<T>();
TargetCopy(TARGET(kARM), TargetCopy(TARGET(kARM),
......
...@@ -109,6 +109,8 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -109,6 +109,8 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
int pw = paddings[2]; int pw = paddings[2];
int sh = param.strides[1]; int sh = param.strides[1];
int sw = param.strides[0]; int sw = param.strides[0];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]);
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
...@@ -116,13 +118,12 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -116,13 +118,12 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2));
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && (sw == 1 || sw == 2)); bool flag_dw_5x5 = pads_all_equal && (kw == 5 && (sw == 1 || sw == 2));
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
no_dilation && flag_dw) { no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) { ic * oc < 4 * hin * win && kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else { } else {
...@@ -154,6 +155,8 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -154,6 +155,8 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
int pw = paddings[2]; int pw = paddings[2];
int sh = param.strides[1]; int sh = param.strides[1];
int sw = param.strides[0]; int sw = param.strides[0];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]);
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
...@@ -167,7 +170,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -167,7 +170,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) { ic * oc < 4 * hin * win && kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else { } else {
......
...@@ -10,3 +10,6 @@ add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEP ...@@ -10,3 +10,6 @@ add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEP
add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps}) add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps})
add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_kernel_deps}) add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_kernel_deps})
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps}) add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
...@@ -12,29 +12,42 @@ ...@@ -12,29 +12,42 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/assign_compute.h" #include "lite/kernels/host/assign_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void AssignCompute::Run() { void AssignCompute::Run() {
auto& param = Param<param_t>(); auto& param = Param<param_t>();
param.Out->CopyDataFrom(*param.X); if (param.X != nullptr) {
param.Out->CopyDataFrom(*param.X);
} else if (param.X_array != nullptr) {
auto x_array = param.X_array;
auto out_array = param.Out_array;
out_array->resize(x_array->size());
for (size_t i = 0; i < x_array->size(); i++) {
out_array->at(i).CopyDataFrom(x_array->at(i));
}
} else {
LOG(FATAL) << "x or x_array of assign must be set.";
}
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
assign, kARM, kAny, kNCHW, paddle::lite::kernels::arm::AssignCompute, def) assign, kHost, kAny, kAny, paddle::lite::kernels::host::AssignCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X",
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -15,14 +15,15 @@ ...@@ -15,14 +15,15 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/operators/assign_op.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class AssignCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::AssignParam; using param_t = operators::AssignParam;
...@@ -31,7 +32,7 @@ class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { ...@@ -31,7 +32,7 @@ class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
virtual ~AssignCompute() = default; virtual ~AssignCompute() = default;
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -12,17 +12,15 @@ ...@@ -12,17 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/read_from_array_compute.h" #include "lite/kernels/host/read_from_array_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void ReadFromArrayCompute::Run() { void ReadFromArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& param = this->Param<operators::ReadFromArrayParam>();
auto& param = this->Param<param_t>();
CHECK_EQ(param.I->numel(), 1) << "I should have only one element"; CHECK_EQ(param.I->numel(), 1) << "I should have only one element";
int id = param.I->data<int64_t>()[0]; int id = param.I->data<int64_t>()[0];
...@@ -33,18 +31,27 @@ void ReadFromArrayCompute::Run() { ...@@ -33,18 +31,27 @@ void ReadFromArrayCompute::Run() {
param.Out->CopyDataFrom((*param.X)[id]); param.Out->CopyDataFrom((*param.X)[id]);
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(read_from_array, REGISTER_LITE_KERNEL(read_from_array,
kARM, kHost,
kAny, kAny,
kNCHW, kAny,
paddle::lite::kernels::arm::ReadFromArrayCompute, paddle::lite::kernels::host::ReadFromArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X",
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) {LiteType::GetTensorListTy(TARGET(kHost),
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("I",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -13,20 +13,17 @@ ...@@ -13,20 +13,17 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class ReadFromArrayCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::ReadFromArrayParam;
void Run() override; void Run() override;
~ReadFromArrayCompute() {} ~ReadFromArrayCompute() {}
...@@ -34,7 +31,7 @@ class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { ...@@ -34,7 +31,7 @@ class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
private: private:
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/write_to_array_compute.h" #include "lite/kernels/host/write_to_array_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void WriteToArrayCompute::Run() { void WriteToArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->template Param<operators::WriteToArrayParam>(); auto& param = this->template Param<operators::WriteToArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
...@@ -32,19 +30,27 @@ void WriteToArrayCompute::Run() { ...@@ -32,19 +30,27 @@ void WriteToArrayCompute::Run() {
param.Out->at(id).CopyDataFrom(*param.X); param.Out->at(id).CopyDataFrom(*param.X);
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(write_to_array, REGISTER_LITE_KERNEL(write_to_array,
kARM, kHost,
kAny, kAny,
kNCHW, kAny,
paddle::lite::kernels::arm::WriteToArrayCompute, paddle::lite::kernels::host::WriteToArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X",
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("I",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) {LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -13,17 +13,16 @@ ...@@ -13,17 +13,16 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class WriteToArrayCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
void Run() override; void Run() override;
...@@ -32,7 +31,7 @@ class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { ...@@ -32,7 +31,7 @@ class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
private: private:
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -62,23 +62,21 @@ class ReluCompute ...@@ -62,23 +62,21 @@ class ReluCompute
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{count}; auto global_work_size = cl::NDRange{count};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} }
private: private:
std::string kernel_func_name_{"relu"}; std::string kernel_func_name_{"relu"};
std::string build_options_{"-DCL_DTYPE_float -DRELU"}; std::string build_options_{"-DCL_DTYPE_float -DRELU"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
class SigmoidCompute class SigmoidCompute
...@@ -121,23 +119,21 @@ class SigmoidCompute ...@@ -121,23 +119,21 @@ class SigmoidCompute
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{count}; auto global_work_size = cl::NDRange{count};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} }
private: private:
std::string kernel_func_name_{"sigmoid"}; std::string kernel_func_name_{"sigmoid"};
std::string build_options_{"-DCL_DTYPE_float -DSIGMOID"}; std::string build_options_{"-DCL_DTYPE_float -DSIGMOID"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -85,16 +85,9 @@ TEST(opencl_relu_buffer, compute) { ...@@ -85,16 +85,9 @@ TEST(opencl_relu_buffer, compute) {
kernel->Launch(); kernel->Launch();
auto *wait_list = context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr = param.Out->data<float, cl::Buffer>(); auto *out_ptr = param.Out->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) { CLRuntime::Global()->command_queue().finish();
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
// run compute ref and check // run compute ref and check
std::unique_ptr<float[]> out_ref(new float[x_dim.production()]); std::unique_ptr<float[]> out_ref(new float[x_dim.production()]);
...@@ -145,16 +138,9 @@ TEST(opencl_sigmoid_buffer, compute) { ...@@ -145,16 +138,9 @@ TEST(opencl_sigmoid_buffer, compute) {
kernel->Launch(); kernel->Launch();
auto *wait_list = context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr = param.Out->data<float, cl::Buffer>(); auto *out_ptr = param.Out->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) { CLRuntime::Global()->command_queue().finish();
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
// run compute ref and check // run compute ref and check
std::unique_ptr<float[]> out_ref(new float[x_dim.production()]); std::unique_ptr<float[]> out_ref(new float[x_dim.production()]);
......
...@@ -147,16 +147,15 @@ class ActivationComputeImageDefault ...@@ -147,16 +147,15 @@ class ActivationComputeImageDefault
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
} }
private: private:
...@@ -175,7 +174,6 @@ class ActivationComputeImageDefault ...@@ -175,7 +174,6 @@ class ActivationComputeImageDefault
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
} // namespace kernels } // namespace kernels
......
...@@ -234,19 +234,9 @@ TEST(act_image2d_fp16, compute) { ...@@ -234,19 +234,9 @@ TEST(act_image2d_fp16, compute) {
img_to_buf_kernel->Launch(); img_to_buf_kernel->Launch();
// wait for opencl // wait for opencl
auto *wait_list = context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr = ImageToBufferParam.y->data<float, cl::Buffer>(); auto *out_ptr = ImageToBufferParam.y->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) { CLRuntime::Global()->command_queue().finish();
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
// compute ref cpu // compute ref cpu
act_compute_ref<float>( act_compute_ref<float>(
......
...@@ -142,16 +142,14 @@ class BilinearInterpImageCompute ...@@ -142,16 +142,14 @@ class BilinearInterpImageCompute
static_cast<cl::size_type>(default_work_size[1]), static_cast<cl::size_type>(default_work_size[1]),
static_cast<cl::size_type>(default_work_size[2])}; static_cast<cl::size_type>(default_work_size[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
<< global_work_size[1] << " " << global_work_size[2]; << global_work_size[1] << " " << global_work_size[2];
...@@ -163,7 +161,6 @@ class BilinearInterpImageCompute ...@@ -163,7 +161,6 @@ class BilinearInterpImageCompute
std::string kernel_func_name_{"bilinear_interp"}; std::string kernel_func_name_{"bilinear_interp"};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -187,18 +187,7 @@ TEST(bilinear_interp_image2d, compute) { ...@@ -187,18 +187,7 @@ TEST(bilinear_interp_image2d, compute) {
// LOG(INFO) << "out_image:" << out_image; // LOG(INFO) << "out_image:" << out_image;
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.Out->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the "
"target cl tensor.";
}
std::unique_ptr<float[]> out_ref( std::unique_ptr<float[]> out_ref(
new float[out_dim.production()]); new float[out_dim.production()]);
......
...@@ -47,8 +47,10 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -47,8 +47,10 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
} }
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel( context.cl_context()->AddKernel(kernel_func_name_,
kernel_func_name_, "image/box_coder_kernel.cl", build_options_); "image/box_coder_kernel.cl",
build_options_,
time_stamp_);
} }
void Run() override { void Run() override {
...@@ -81,7 +83,7 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -81,7 +83,7 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
auto default_work_size = auto default_work_size =
...@@ -120,16 +122,14 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -120,16 +122,14 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
cl::NDRange{static_cast<cl::size_type>(default_work_size[0]), cl::NDRange{static_cast<cl::size_type>(default_work_size[0]),
static_cast<cl::size_type>(default_work_size[2])}; static_cast<cl::size_type>(default_work_size[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
...@@ -142,7 +142,7 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -142,7 +142,7 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
param_t* boxcoder_param_{nullptr}; param_t* boxcoder_param_{nullptr};
std::string kernel_func_name_{}; std::string kernel_func_name_{};
std::string build_options_{" -DCL_DTYPE_half"}; std::string build_options_{" -DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{nullptr}; std::string time_stamp_{GetTimeStamp()};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -216,18 +216,7 @@ TEST(box_coder_image2d, compute) { ...@@ -216,18 +216,7 @@ TEST(box_coder_image2d, compute) {
out_image_shape[0], out_image_shape[1]); out_image_shape[0], out_image_shape[1]);
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.proposals->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the "
"target cl tensor.";
}
lite::Tensor out_ref_tensor; lite::Tensor out_ref_tensor;
out_ref_tensor.Resize(out_dim); out_ref_tensor.Resize(out_dim);
......
...@@ -123,16 +123,15 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL), ...@@ -123,16 +123,15 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total1); status = kernel.setArg(++arg_idx, total1);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} else { } else {
auto start = 0; auto start = 0;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
...@@ -157,16 +156,15 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL), ...@@ -157,16 +156,15 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total0); status = kernel.setArg(++arg_idx, total0);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
start += size; start += size;
} }
} }
...@@ -182,7 +180,6 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL), ...@@ -182,7 +180,6 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
std::string kernel_func_name_{}; std::string kernel_func_name_{};
std::string build_options_{"-DCL_DTYPE_float"}; std::string build_options_{"-DCL_DTYPE_float"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -142,16 +142,7 @@ TEST(opencl_concat_buffer, compute) { ...@@ -142,16 +142,7 @@ TEST(opencl_concat_buffer, compute) {
kernel->SetContext(std::move(concat_context)); kernel->SetContext(std::move(concat_context));
kernel->Launch(); kernel->Launch();
auto *wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto *out_ptr = param.output->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
// run compute ref and check // run compute ref and check
auto *out_ref_data = out_ref.mutable_data<float>(TARGET(kARM)); auto *out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
......
...@@ -187,16 +187,15 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -187,16 +187,15 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_); status = kernel.setArg(++arg_idx, width_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} else { } else {
auto start = 0; auto start = 0;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
...@@ -231,16 +230,15 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -231,16 +230,15 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(++arg_idx, width_); status = kernel.setArg(++arg_idx, width_);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
start += inputs[i]->dims()[axis_]; start += inputs[i]->dims()[axis_];
} }
} }
...@@ -256,7 +254,6 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL), ...@@ -256,7 +254,6 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
std::string kernel_func_name_{}; std::string kernel_func_name_{};
std::string build_options_{" -DCL_DTYPE_half"}; std::string build_options_{" -DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -245,20 +245,7 @@ TEST(concat_image2d, compute) { ...@@ -245,20 +245,7 @@ TEST(concat_image2d, compute) {
LOG(INFO) << "run kernel: img_to_buf_kernel"; LOG(INFO) << "run kernel: img_to_buf_kernel";
img_to_buf_kernel->Launch(); img_to_buf_kernel->Launch();
// wait for opencl CLRuntime::Global()->command_queue().finish();
auto *wait_list = context->As<OpenCLContext>().cl_wait_list();
auto *out_ptr = ImageToBufferParam.y->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
// compute ref cp_u // compute ref cp_u
std::vector<const float *> ins_ptr; std::vector<const float *> ins_ptr;
......
...@@ -205,7 +205,7 @@ void ConvCompute::GemmlikeConv2d() { ...@@ -205,7 +205,7 @@ void ConvCompute::GemmlikeConv2d() {
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{static_cast<size_t>(out_stride)}; auto global_work_size = cl::NDRange{static_cast<size_t>(out_stride)};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
img2col_kernel, img2col_kernel,
cl::NullRange, cl::NullRange,
...@@ -301,17 +301,14 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel, ...@@ -301,17 +301,14 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel,
status = kernel.setArg(++arg_idx, batch_size); status = kernel.setArg(++arg_idx, batch_size);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
local_work_size, local_work_size,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_d, event_);
} }
void ConvCompute::Run() { (this->*impl_)(); } void ConvCompute::Run() { (this->*impl_)(); }
......
...@@ -57,7 +57,6 @@ class ConvCompute ...@@ -57,7 +57,6 @@ class ConvCompute
std::vector<std::string> kernel_func_paths_{}; std::vector<std::string> kernel_func_paths_{};
std::vector<std::string> build_options_{}; std::vector<std::string> build_options_{};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -304,25 +304,14 @@ TEST(conv2d, compute_conv2d_1x1) { ...@@ -304,25 +304,14 @@ TEST(conv2d, compute_conv2d_1x1) {
// run opencl kernel // run opencl kernel
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<float, cl::Buffer>(); // double start_nanos =
auto it = wait_list->find(out_ptr); // event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
if (it != wait_list->end()) { // double stop_nanos =
VLOG(4) << "--- Find the sync event for the target cl " // event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
"tensor. ---"; // double elapsed_micros = (stop_nanos - start_nanos) / 1000.0;
auto& event = *(it->second); // LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros
event.wait(); // << " us.";
double start_nanos =
event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
double stop_nanos =
event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
double elapsed_micros = (stop_nanos - start_nanos) / 1000.0;
LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros
<< " us.";
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
// run cpu ref // run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM)); auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
...@@ -536,25 +525,15 @@ TEST(conv2d, compute_conv2d_gemm) { ...@@ -536,25 +525,15 @@ TEST(conv2d, compute_conv2d_gemm) {
// run opencl kernel // run opencl kernel
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<float, cl::Buffer>(); // double start_nanos =
auto it = wait_list->find(out_ptr); // event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
if (it != wait_list->end()) { // double stop_nanos =
VLOG(4) << "--- Find the sync event for the target cl " // event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
"tensor. ---"; // double elapsed_micros = (stop_nanos - start_nanos) /
auto& event = *(it->second); // 1000.0;
event.wait(); // LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros
double start_nanos = // << " us.";
event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
double stop_nanos =
event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
double elapsed_micros = (stop_nanos - start_nanos) / 1000.0;
LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros
<< " us.";
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
// run cpu ref // run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM)); auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
......
...@@ -58,9 +58,11 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -58,9 +58,11 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
std::vector<std::string> kernel_func_paths_{}; std::vector<std::string> kernel_func_paths_{};
std::vector<std::string> build_options_{}; std::vector<std::string> build_options_{};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
std::unique_ptr<Tensor> filter_gpu_image_{nullptr}; std::unique_ptr<Tensor> filter_gpu_image_{nullptr};
std::unique_ptr<Tensor> bias_gpu_image_{nullptr}; std::unique_ptr<Tensor> bias_gpu_image_{nullptr};
std::unique_ptr<Tensor> tensor_hold_filter_image_{nullptr};
std::unique_ptr<Tensor> tensor_hold_bias_image_{nullptr};
cl::NDRange global_work_size_ = cl::NDRange{ cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
int c_blk_ = 1; int c_blk_ = 1;
......
...@@ -395,19 +395,7 @@ TEST(conv2d, compute_image2d_1x1) { ...@@ -395,19 +395,7 @@ TEST(conv2d, compute_image2d_1x1) {
auto* output_image2d = output.mutable_data<half_t, cl::Image2D>( auto* output_image2d = output.mutable_data<half_t, cl::Image2D>(
out_image_width, out_image_height); out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
SHADOW_LOG << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target"
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(), TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<half_t, cl::Image2D>(), output.data<half_t, cl::Image2D>(),
...@@ -530,11 +518,11 @@ const int stride = 2; ...@@ -530,11 +518,11 @@ const int stride = 2;
const int iw = 3; const int iw = 3;
const int oc = 2; const int oc = 2;
#else // big scale with group #else // big scale with group
const int stride = 1; const int stride = 2;
const int group = 32 / 1; const int group = 1;
const int batch_size = 2; const int batch_size = 1;
const int ic = 32 / 1; const int ic = 3 / 1;
const int ih = 112 / 1; const int ih = 224 / 1;
const int iw = 112 / 1; const int iw = 112 / 1;
const int oc = 32 / 1; const int oc = 32 / 1;
#endif #endif
...@@ -652,10 +640,10 @@ const int stride = 2; ...@@ -652,10 +640,10 @@ const int stride = 2;
SHADOW_LOG << "gen input and filter ..."; SHADOW_LOG << "gen input and filter ...";
for (int i = 0; i < input_v.size(); ++i) { for (int i = 0; i < input_v.size(); ++i) {
input_v[i] = i * 0.001; // gen(engine); input_v[i] = gen(engine);
} }
for (int i = 0; i < filter_v.size(); ++i) { for (int i = 0; i < filter_v.size(); ++i) {
filter_v[i] = 1 * 0.001; // gen(engine); filter_v[i] = gen(engine);
} }
SHADOW_LOG << "after gen input and filter ..."; SHADOW_LOG << "after gen input and filter ...";
...@@ -763,20 +751,7 @@ const int stride = 2; ...@@ -763,20 +751,7 @@ const int stride = 2;
auto* output_image2d = output.mutable_data<half_t, cl::Image2D>( auto* output_image2d = output.mutable_data<half_t, cl::Image2D>(
out_image_width, out_image_height); out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
SHADOW_LOG << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(), TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<half_t, cl::Image2D>(), output.data<half_t, cl::Image2D>(),
out_image_width, out_image_width,
...@@ -848,8 +823,13 @@ const int stride = 2; ...@@ -848,8 +823,13 @@ const int stride = 2;
for (int i = 0; i < out_dim.production(); i++) { for (int i = 0; i < out_dim.production(); i++) {
auto relative_diff = auto relative_diff =
COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]); COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]);
EXPECT_LT(relative_diff, FP16_MAX_DIFF); auto abs_diff = COMPUTE_ABS_DIFF(output_v[i], out_ref_data[i]);
if (relative_diff > FP16_MAX_DIFF) { // EXPECT_LT(relative_diff, FP16_MAX_DIFF);
// EXPECT_LT(abs_diff, FP16_ABS_DIFF);
EXPECT_FALSE(relative_diff > FP16_MAX_DIFF &&
abs_diff > FP16_ABS_DIFF);
if (relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF) {
LOG(FATAL) << "error idx:" << i << "output_v[" << i LOG(FATAL) << "error idx:" << i << "output_v[" << i
<< "]:" << output_v[i] << " " << "]:" << output_v[i] << " "
"out_ref_data[" "out_ref_data["
...@@ -1115,19 +1095,7 @@ TEST(conv2d, compute_image2d_5x5) { ...@@ -1115,19 +1095,7 @@ TEST(conv2d, compute_image2d_5x5) {
auto* output_image2d = output.mutable_data<half_t, cl::Image2D>( auto* output_image2d = output.mutable_data<half_t, cl::Image2D>(
out_image_width, out_image_height); out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
SHADOW_LOG << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(), TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<half_t, cl::Image2D>(), output.data<half_t, cl::Image2D>(),
...@@ -1468,19 +1436,7 @@ TEST(conv2d, compute_image2d_7x7) { ...@@ -1468,19 +1436,7 @@ TEST(conv2d, compute_image2d_7x7) {
auto* output_image2d = output.mutable_data<half_t, cl::Image2D>( auto* output_image2d = output.mutable_data<half_t, cl::Image2D>(
out_image_width, out_image_height); out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
SHADOW_LOG << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(), TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<half_t, cl::Image2D>(), output.data<half_t, cl::Image2D>(),
......
...@@ -108,23 +108,21 @@ class DepthwiseConv2dCompute ...@@ -108,23 +108,21 @@ class DepthwiseConv2dCompute
status = kernel.setArg(++arg_idx, *bias_buf); status = kernel.setArg(++arg_idx, *bias_buf);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange(static_cast<size_t>(numel)); auto global_work_size = cl::NDRange(static_cast<size_t>(numel));
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(output_buf, event_);
} }
private: private:
std::string kernel_func_name_{"depthwise_conv2d"}; std::string kernel_func_name_{"depthwise_conv2d"};
std::string build_options_{"-DCL_DTYPE_float"}; std::string build_options_{"-DCL_DTYPE_float"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -137,16 +137,7 @@ TEST(depthwise_conv2d_buffer_fp32, compute) { ...@@ -137,16 +137,7 @@ TEST(depthwise_conv2d_buffer_fp32, compute) {
output.Resize({4, 32, 110, 110}); output.Resize({4, 32, 110, 110});
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
lite::Tensor output_ref; lite::Tensor output_ref;
output_ref.Resize({4, 32, 110, 110}); output_ref.Resize({4, 32, 110, 110});
......
...@@ -312,19 +312,7 @@ TEST(depthwise_conv2d, compute_basic) { ...@@ -312,19 +312,7 @@ TEST(depthwise_conv2d, compute_basic) {
auto* output_image2d = output.mutable_data<half_t, cl::Image2D>( auto* output_image2d = output.mutable_data<half_t, cl::Image2D>(
out_image_width, out_image_height); out_image_width, out_image_height);
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
TargetWrapperCL::ImgcpySync(out_image_v.data(), TargetWrapperCL::ImgcpySync(out_image_v.data(),
output.data<half_t, cl::Image2D>(), output.data<half_t, cl::Image2D>(),
...@@ -503,20 +491,7 @@ TEST(depthwise_conv2d, compute_image2d_3x3) { ...@@ -503,20 +491,7 @@ TEST(depthwise_conv2d, compute_image2d_3x3) {
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
LOG(INFO) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL)
<< "Could not find the sync event for the target cl tensor.";
LOG(INFO)
<< "Could not find the sync event for the target cl tensor.";
}
lite::Tensor out_ref; lite::Tensor out_ref;
out_ref.Resize(output_dim); out_ref.Resize(output_dim);
......
...@@ -89,23 +89,20 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL), ...@@ -89,23 +89,20 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size.data()[1]), static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])}; static_cast<cl::size_type>(default_work_size.data()[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
} }
private: private:
std::string kernel_func_name_{"dropout"}; std::string kernel_func_name_{"dropout"};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -86,16 +86,7 @@ TEST(dropout_image2d_fp16, compute) { ...@@ -86,16 +86,7 @@ TEST(dropout_image2d_fp16, compute) {
LOG(INFO) << "out_image:" << out_image; LOG(INFO) << "out_image:" << out_image;
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
std::unique_ptr<float[]> out_ref(new float[out_dim.production()]); std::unique_ptr<float[]> out_ref(new float[out_dim.production()]);
dropout(input_v.data(), in_dim, out_ref.get(), 0.6); dropout(input_v.data(), in_dim, out_ref.get(), 0.6);
......
...@@ -63,16 +63,10 @@ void ElementwiseAddCompute::Run() { ...@@ -63,16 +63,10 @@ void ElementwiseAddCompute::Run() {
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{channels_, batch_}; auto global_work_size = cl::NDRange{channels_, batch_};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr);
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} }
void ElementwiseAddCompute::UpdateParams() { void ElementwiseAddCompute::UpdateParams() {
......
...@@ -48,7 +48,6 @@ class ElementwiseAddCompute ...@@ -48,7 +48,6 @@ class ElementwiseAddCompute
std::string kernel_func_name_{"elementwise_add"}; std::string kernel_func_name_{"elementwise_add"};
std::string build_options_{"-DCL_DTYPE_float"}; std::string build_options_{"-DCL_DTYPE_float"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -144,16 +144,7 @@ TEST(elementwise_add_buffer, compute) { ...@@ -144,16 +144,7 @@ TEST(elementwise_add_buffer, compute) {
kernel->Launch(); kernel->Launch();
auto *wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto *out_ptr = param.Out->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
std::unique_ptr<float[]> out_ref(new float[out_dim.production()]); std::unique_ptr<float[]> out_ref(new float[out_dim.production()]);
elementwise_compute_ref<float>( elementwise_compute_ref<float>(
...@@ -225,16 +216,7 @@ TEST(fusion_elementwise_add_activation_buffer, compute) { ...@@ -225,16 +216,7 @@ TEST(fusion_elementwise_add_activation_buffer, compute) {
kernel->Launch(); kernel->Launch();
auto *wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto *out_ptr = param.Out->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto &event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
std::unique_ptr<float[]> out_ref(new float[out_dim.production()]); std::unique_ptr<float[]> out_ref(new float[out_dim.production()]);
elementwise_compute_ref<float>( elementwise_compute_ref<float>(
......
...@@ -153,16 +153,15 @@ void ElementwiseAddImageCompute::Run() { ...@@ -153,16 +153,15 @@ void ElementwiseAddImageCompute::Run() {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
} }
} // namespace opencl } // namespace opencl
......
...@@ -63,7 +63,6 @@ class ElementwiseAddImageCompute ...@@ -63,7 +63,6 @@ class ElementwiseAddImageCompute
cl::Kernel kernel_; cl::Kernel kernel_;
cl::NDRange global_work_size_ = cl::NDRange{ cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -50,8 +50,10 @@ void ElementwiseMulFloatImageCompute::PrepareForRun() { ...@@ -50,8 +50,10 @@ void ElementwiseMulFloatImageCompute::PrepareForRun() {
VLOG(4) << "y_dims.size():" << y_dims.size(); VLOG(4) << "y_dims.size():" << y_dims.size();
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(kernel_func_name_,
kernel_func_name_, "image/elementwise_mul_kernel.cl", build_options_); "image/elementwise_mul_kernel.cl",
build_options_,
time_stamp_);
} }
void ElementwiseMulFloatImageCompute::Run() { void ElementwiseMulFloatImageCompute::Run() {
...@@ -88,7 +90,7 @@ void ElementwiseMulFloatImageCompute::Run() { ...@@ -88,7 +90,7 @@ void ElementwiseMulFloatImageCompute::Run() {
<< out_img_shape[1]; << out_img_shape[1];
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
...@@ -150,16 +152,16 @@ void ElementwiseMulFloatImageCompute::Run() { ...@@ -150,16 +152,16 @@ void ElementwiseMulFloatImageCompute::Run() {
auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width), auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)}; static_cast<cl::size_type>(x_img_height)};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_); std::string time_stamp_{GetTimeStamp()};
VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height; VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
} }
......
...@@ -185,16 +185,15 @@ class ElementwiseMulImageCompute ...@@ -185,16 +185,15 @@ class ElementwiseMulImageCompute
auto global_work_size = auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(x_img_width), cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)}; static_cast<cl::size_type>(x_img_height)};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height; VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
#endif #endif
...@@ -205,7 +204,6 @@ class ElementwiseMulImageCompute ...@@ -205,7 +204,6 @@ class ElementwiseMulImageCompute
std::string kernel_func_name_{"elementwise_mul"}; std::string kernel_func_name_{"elementwise_mul"};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -138,16 +138,9 @@ void ElementwiseSubImageCompute::Run() { ...@@ -138,16 +138,9 @@ void ElementwiseSubImageCompute::Run() {
VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height; VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
#endif #endif
event_ = std::shared_ptr<cl::Event>(new cl::Event);
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr);
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
} }
} // namespace opencl } // namespace opencl
......
...@@ -46,7 +46,6 @@ class ElementwiseSubImageCompute ...@@ -46,7 +46,6 @@ class ElementwiseSubImageCompute
std::string kernel_func_name_{"elementwise_sub"}; std::string kernel_func_name_{"elementwise_sub"};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -123,16 +123,15 @@ class FcCompute ...@@ -123,16 +123,15 @@ class FcCompute
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} }
private: private:
...@@ -145,7 +144,6 @@ class FcCompute ...@@ -145,7 +144,6 @@ class FcCompute
DDim last_x_dims_; DDim last_x_dims_;
cl::NDRange global_work_size_; cl::NDRange global_work_size_;
cl::Kernel kernel_; cl::Kernel kernel_;
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -162,17 +162,8 @@ TEST(fc, compute) { ...@@ -162,17 +162,8 @@ TEST(fc, compute) {
// run opencl kernel // run opencl kernel
kernel->Launch(); kernel->Launch();
// kernel->Launch();
CLRuntime::Global()->command_queue().finish();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
auto command_queue = CLRuntime::Global()->command_queue();
command_queue.finish();
#if 0 #if 0
double start_nanos = double start_nanos =
event.getProfilingInfo<CL_PROFILING_COMMAND_START>(); event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
...@@ -181,10 +172,6 @@ TEST(fc, compute) { ...@@ -181,10 +172,6 @@ TEST(fc, compute) {
double elapsed_micros = (stop_nanos - start_nanos) / 1000.0; double elapsed_micros = (stop_nanos - start_nanos) / 1000.0;
LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros << " us."; LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros << " us.";
#endif #endif
} else {
LOG(FATAL)
<< "Could not find the sync event for the target cl tensor.";
}
std::vector<float> out_data_from_gpu(out_dim.production()); std::vector<float> out_data_from_gpu(out_dim.production());
TargetWrapperCL::MemcpySync( TargetWrapperCL::MemcpySync(
......
...@@ -130,16 +130,15 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -130,16 +130,15 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size_, global_work_size_,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
} }
protected: protected:
...@@ -154,7 +153,6 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -154,7 +153,6 @@ class GridSamplerImageCompute : public KernelLite<TARGET(kOpenCL),
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)}; static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -191,17 +191,7 @@ TEST(grid_samler_image2d, compute) { ...@@ -191,17 +191,7 @@ TEST(grid_samler_image2d, compute) {
// LOG(INFO) << "out_image:" << out_image; // LOG(INFO) << "out_image:" << out_image;
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.out->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL)
<< "Could not find the sync event for the target cl tensor.";
}
std::unique_ptr<float[]> out_ref(new float[out_dim.production()]); std::unique_ptr<float[]> out_ref(new float[out_dim.production()]);
gird_sampler_ref( gird_sampler_ref(
......
...@@ -137,16 +137,14 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -137,16 +137,14 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(7, *out_img); status = kernel.setArg(7, *out_img);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
local_work_size, local_work_size,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
} }
#else // paddle version #else // paddle version
...@@ -260,16 +258,14 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -260,16 +258,14 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(arg_idx++, in_w); status = kernel.setArg(arg_idx++, in_w);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
local_work_size, local_work_size,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
} }
#endif #endif
...@@ -278,7 +274,7 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -278,7 +274,7 @@ class InstanceNormImageCompute : public KernelLite<TARGET(kOpenCL),
std::string kernel_func_name_{"instance_norm_onnx"}; std::string kernel_func_name_{"instance_norm_onnx"};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
Tensor scale_image_; Tensor scale_image_;
Tensor bias_image_; Tensor bias_image_;
}; };
......
...@@ -105,20 +105,11 @@ class IoCopykOpenCLToHostCompute ...@@ -105,20 +105,11 @@ class IoCopykOpenCLToHostCompute
} }
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
auto* wait_list = context.cl_wait_list();
auto it = wait_list->find(x_ptr);
if (it != wait_list->end()) {
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(2) << "--- Find the sync event for the target cl tensor. ---"; VLOG(2) << "--- Find the sync event for the target cl tensor. ---";
#endif #endif
auto& event = *(it->second); CLRuntime::Global()->command_queue().finish();
event.wait();
auto command_queue = CLRuntime::Global()->command_queue();
command_queue.finish();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
CopyToHostSync(data, param.x->raw_data(), mem_size); CopyToHostSync(data, param.x->raw_data(), mem_size);
} }
......
...@@ -65,10 +65,7 @@ TEST(io_copy, compute) { ...@@ -65,10 +65,7 @@ TEST(io_copy, compute) {
h2d_kernel->Launch(); h2d_kernel->Launch();
auto* event_key = d_y.data<float, cl::Buffer>(); auto* event_key = d_y.data<float, cl::Buffer>();
std::shared_ptr<cl::Event> event(new cl::Event);
context->As<OpenCLContext>().cl_wait_list()->emplace(event_key, event);
d2h_kernel->Launch(); d2h_kernel->Launch();
auto* h_y_data = h_y.data<float>(); auto* h_y_data = h_y.data<float>();
for (int i = 0; i < 3 * 9 * 28 * 28; i++) { for (int i = 0; i < 3 * 9 * 28 * 28; i++) {
......
...@@ -44,8 +44,10 @@ class LayoutComputeBufferChwToImageDefault ...@@ -44,8 +44,10 @@ class LayoutComputeBufferChwToImageDefault
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(kernel_func_name_,
kernel_func_name_, "image/layout_kernel.cl", build_options_); "image/layout_kernel.cl",
build_options_,
time_stamp_);
} }
void Run() override { void Run() override {
...@@ -95,7 +97,7 @@ class LayoutComputeBufferChwToImageDefault ...@@ -95,7 +97,7 @@ class LayoutComputeBufferChwToImageDefault
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
...@@ -122,16 +124,15 @@ class LayoutComputeBufferChwToImageDefault ...@@ -122,16 +124,15 @@ class LayoutComputeBufferChwToImageDefault
cl::NDRange{static_cast<cl::size_type>((new_dims[1] + 3) / 4), cl::NDRange{static_cast<cl::size_type>((new_dims[1] + 3) / 4),
static_cast<cl::size_type>(new_dims[3]), static_cast<cl::size_type>(new_dims[3]),
static_cast<cl::size_type>(new_dims[0] * new_dims[2])}; static_cast<cl::size_type>(new_dims[0] * new_dims[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(y_data, event_);
} }
std::string doc() const override { std::string doc() const override {
...@@ -140,9 +141,9 @@ class LayoutComputeBufferChwToImageDefault ...@@ -140,9 +141,9 @@ class LayoutComputeBufferChwToImageDefault
} }
private: private:
std::string time_stamp_{GetTimeStamp()};
std::string kernel_func_name_{"buffer_to_image2d"}; std::string kernel_func_name_{"buffer_to_image2d"};
std::string build_options_{"-DCL_DTYPE_float"}; std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
// [ImageDefault] -> [NCHW] // [ImageDefault] -> [NCHW]
...@@ -158,8 +159,10 @@ class LayoutComputeImageDefaultToBufferChw ...@@ -158,8 +159,10 @@ class LayoutComputeImageDefaultToBufferChw
} }
VLOG(1) << "kernel_func_name_:" << kernel_func_name_; VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(kernel_func_name_,
kernel_func_name_, "image/layout_kernel.cl", build_options_); "image/layout_kernel.cl",
build_options_,
time_stamp_);
} }
void Run() override { void Run() override {
...@@ -202,7 +205,7 @@ class LayoutComputeImageDefaultToBufferChw ...@@ -202,7 +205,7 @@ class LayoutComputeImageDefaultToBufferChw
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
...@@ -230,16 +233,15 @@ class LayoutComputeImageDefaultToBufferChw ...@@ -230,16 +233,15 @@ class LayoutComputeImageDefaultToBufferChw
cl::NDRange{static_cast<cl::size_type>((new_dims[1] + 3) / 4), cl::NDRange{static_cast<cl::size_type>((new_dims[1] + 3) / 4),
static_cast<cl::size_type>(new_dims[3]), static_cast<cl::size_type>(new_dims[3]),
static_cast<cl::size_type>(new_dims[0] * new_dims[2])}; static_cast<cl::size_type>(new_dims[0] * new_dims[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(y_data, event_);
} }
std::string doc() const override { std::string doc() const override {
...@@ -248,9 +250,9 @@ class LayoutComputeImageDefaultToBufferChw ...@@ -248,9 +250,9 @@ class LayoutComputeImageDefaultToBufferChw
} }
private: private:
std::string time_stamp_{GetTimeStamp()};
std::string kernel_func_name_{"image2d_to_buffer"}; std::string kernel_func_name_{"image2d_to_buffer"};
std::string build_options_{"-DCL_DTYPE_float"}; std::string build_options_{"-DCL_DTYPE_float"};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
// [NCHW] -> [ImageDW] // [NCHW] -> [ImageDW]
...@@ -263,8 +265,10 @@ class LayoutComputeBufferChwToImage2DNw ...@@ -263,8 +265,10 @@ class LayoutComputeBufferChwToImage2DNw
void PrepareForRun() override { void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel( context.cl_context()->AddKernel(kernel_func_name_,
kernel_func_name_, "buffer/layout_kernel.cl", build_options_); "buffer/layout_kernel.cl",
build_options_,
time_stamp_);
} }
void Run() override { void Run() override {
...@@ -298,7 +302,7 @@ class LayoutComputeBufferChwToImage2DNw ...@@ -298,7 +302,7 @@ class LayoutComputeBufferChwToImage2DNw
auto& context = ctx_->As<OpenCLContext>(); auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr); CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key; STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_; kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str()); auto kernel = context.cl_context()->GetKernel(kernel_key.str());
int arg_idx = 0; int arg_idx = 0;
...@@ -325,16 +329,15 @@ class LayoutComputeBufferChwToImage2DNw ...@@ -325,16 +329,15 @@ class LayoutComputeBufferChwToImage2DNw
cl::NDRange{static_cast<cl::size_type>((out_N + 3) / 4), // N blocks cl::NDRange{static_cast<cl::size_type>((out_N + 3) / 4), // N blocks
static_cast<cl::size_type>(out_W), // w static_cast<cl::size_type>(out_W), // w
static_cast<cl::size_type>(out_C * out_H)}; // ch static_cast<cl::size_type>(out_C * out_H)}; // ch
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(y_data, event_);
} }
std::string doc() const override { std::string doc() const override {
...@@ -342,9 +345,10 @@ class LayoutComputeBufferChwToImage2DNw ...@@ -342,9 +345,10 @@ class LayoutComputeBufferChwToImage2DNw
} }
private: private:
std::string time_stamp_{GetTimeStamp()};
std::string kernel_func_name_{"buffer_to_image2d_nw"}; std::string kernel_func_name_{"buffer_to_image2d_nw"};
std::string build_options_{"-DCL_DTYPE_float "}; std::string build_options_{"-DCL_DTYPE_float "};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -246,20 +246,7 @@ TEST(layout_ImageDefault_With_Pre_Post, compute) { ...@@ -246,20 +246,7 @@ TEST(layout_ImageDefault_With_Pre_Post, compute) {
LOG(INFO) << "run kernel: image2d_to_buffer_with_post255"; LOG(INFO) << "run kernel: image2d_to_buffer_with_post255";
img_to_buf_kernel->Launch(); img_to_buf_kernel->Launch();
// wait for opencl CLRuntime::Global()->command_queue().finish();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = ImageToBufferParam.y->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target "
"cl tensor.";
}
// result // result
#ifdef PRINT_RESULT #ifdef PRINT_RESULT
......
...@@ -128,16 +128,14 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -128,16 +128,14 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size[1]), static_cast<cl::size_type>(default_work_size[1]),
static_cast<cl::size_type>(default_work_size[2])}; static_cast<cl::size_type>(default_work_size[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
#ifndef LITE_SHUTDOWN_LOG #ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " "
<< global_work_size[1] << " " << global_work_size[2]; << global_work_size[1] << " " << global_work_size[2];
...@@ -154,7 +152,6 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL), ...@@ -154,7 +152,6 @@ class LrnImageCompute : public KernelLite<TARGET(kOpenCL),
std::string kernel_func_name_{"lrn"}; std::string kernel_func_name_{"lrn"};
std::string build_options_{"-DCL_DTYPE_half"}; std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -181,19 +181,7 @@ TEST(lrn_image2d, compute) { ...@@ -181,19 +181,7 @@ TEST(lrn_image2d, compute) {
// LOG(INFO) << "out_image:" << out_image; // LOG(INFO) << "out_image:" << out_image;
kernel->Launch(); kernel->Launch();
auto* wait_list = CLRuntime::Global()->command_queue().finish();
context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.Out->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl "
"tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the "
"target cl tensor.";
}
std::unique_ptr<float[]> out_ref( std::unique_ptr<float[]> out_ref(
new float[out_dim.production()]); new float[out_dim.production()]);
......
...@@ -91,16 +91,15 @@ class MulCompute ...@@ -91,16 +91,15 @@ class MulCompute
auto global_work_size = cl::NDRange{static_cast<size_t>((m_ + 3) / 4), auto global_work_size = cl::NDRange{static_cast<size_t>((m_ + 3) / 4),
static_cast<size_t>((n_ + 3) / 4)}; static_cast<size_t>((n_ + 3) / 4)};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel, kernel,
cl::NullRange, cl::NullRange,
global_work_size, global_work_size,
cl::NullRange, cl::NullRange,
nullptr, nullptr,
event_.get()); nullptr);
CL_CHECK_FATAL(status); CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_buf, event_);
} }
private: private:
...@@ -108,7 +107,6 @@ class MulCompute ...@@ -108,7 +107,6 @@ class MulCompute
std::string kernel_func_name_{"mat_mul"}; std::string kernel_func_name_{"mat_mul"};
std::string build_options_{"-DCL_DTYPE_float"}; std::string build_options_{"-DCL_DTYPE_float"};
std::string time_stamp_{GetTimeStamp()}; std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{nullptr};
}; };
} // namespace opencl } // namespace opencl
......
...@@ -123,17 +123,7 @@ TEST(mul, compute) { ...@@ -123,17 +123,7 @@ TEST(mul, compute) {
// run opencl kernel // run opencl kernel
kernel->Launch(); kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list(); CLRuntime::Global()->command_queue().finish();
auto* out_ptr = param.output->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL)
<< "Could not find the sync event for the target cl tensor.";
}
// run cpu ref // run cpu ref
auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM)); auto* out_ref_data = out_ref.mutable_data<float>(TARGET(kARM));
......
此差异已折叠。
...@@ -73,6 +73,7 @@ bool BatchNormOp::InferShapeImpl() const { ...@@ -73,6 +73,7 @@ bool BatchNormOp::InferShapeImpl() const {
} }
bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>(); param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.bias = param_.bias =
scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>(); scope->FindVar(op_desc.Input("Bias").front())->GetMutable<Tensor>();
......
...@@ -66,6 +66,7 @@ bool ConcatOpLite::InferShapeImpl() const { ...@@ -66,6 +66,7 @@ bool ConcatOpLite::InferShapeImpl() const {
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
AttachParam(&param_);
auto inputs = op_desc.Input("X"); auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册