diff --git a/lite/api/mobilenetv1_test.cc b/lite/api/mobilenetv1_test.cc index a10908c644dc659c86134a6bbbccf3e58a47df55..9164129dcf4566fc02803c1c7dcffd9e97a830d6 100644 --- a/lite/api/mobilenetv1_test.cc +++ b/lite/api/mobilenetv1_test.cc @@ -53,9 +53,13 @@ void TestModel(const std::vector& valid_places, predictor.Run(); } - auto start = GetCurrentUS(); + double sum_duration = 0.0; // millisecond; for (int i = 0; i < FLAGS_repeats; ++i) { + auto start = GetCurrentUS(); predictor.Run(); + auto duration = (GetCurrentUS() - start) / 1000.0; + sum_duration += duration; + VLOG(1) << "run_idx:" << i << " " << duration << " ms"; } if (save_model) { @@ -68,8 +72,7 @@ void TestModel(const std::vector& valid_places, LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "Model: " << model_dir << ", threads num " << FLAGS_threads << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats - << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 - << " ms in average."; + << ", spend " << sum_duration / FLAGS_repeats << " ms in average."; std::vector> ref; ref.emplace_back(std::vector( @@ -115,13 +118,11 @@ void TestModel(const std::vector& valid_places, } // Get detailed result - auto* pred = &predictor; - size_t output_tensor_num = pred->GetOutputNames().size(); + size_t output_tensor_num = predictor.GetOutputNames().size(); VLOG(1) << "output tesnor num:" << output_tensor_num; for (size_t tidx = 0; tidx < output_tensor_num; ++tidx) { - std::unique_ptr output_tensor( - std::move(pred->GetOutput(tidx))); + auto* output_tensor = predictor.GetOutput(tidx); VLOG(1) << "============= output tensor " << tidx << " =============\n"; auto out_dims = output_tensor->dims(); VLOG(1) << "out_dims:" << out_dims; diff --git a/lite/api/mobilenetv2_test.cc b/lite/api/mobilenetv2_test.cc index dfb7045d853de9e0c468420bb844ed200f30f70f..26b9dc93da73e8f637c01fca8f7ea99a8e5e9af0 100644 --- a/lite/api/mobilenetv2_test.cc +++ b/lite/api/mobilenetv2_test.cc @@ -54,9 +54,13 @@ void TestModel(const std::vector& valid_places, predictor.Run(); } - auto start = GetCurrentUS(); + double sum_duration = 0.0; // millisecond; for (int i = 0; i < FLAGS_repeats; ++i) { + auto start = GetCurrentUS(); predictor.Run(); + auto duration = (GetCurrentUS() - start) / 1000.0; + sum_duration += duration; + VLOG(1) << "run_idx:" << i << " " << duration << " ms"; } if (save_model) { @@ -69,8 +73,7 @@ void TestModel(const std::vector& valid_places, LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "Model: " << model_dir << ", threads num " << FLAGS_threads << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats - << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 - << " ms in average."; + << ", spend " << sum_duration / FLAGS_repeats << " ms in average."; std::vector> ref; // i = 1 @@ -117,13 +120,11 @@ void TestModel(const std::vector& valid_places, } // Get detailed result - auto* pred = &predictor; - size_t output_tensor_num = pred->GetOutputNames().size(); + size_t output_tensor_num = predictor.GetOutputNames().size(); VLOG(1) << "output tesnor num:" << output_tensor_num; for (size_t tidx = 0; tidx < output_tensor_num; ++tidx) { - std::unique_ptr output_tensor( - std::move(pred->GetOutput(tidx))); + auto* output_tensor = predictor.GetOutput(tidx); VLOG(1) << "============= output tensor " << tidx << " =============\n"; auto out_dims = output_tensor->dims(); VLOG(1) << "out_dims:" << out_dims; diff --git a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl b/lite/backends/opencl/cl_kernel/image/layout_kernel.cl similarity index 100% rename from lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl rename to lite/backends/opencl/cl_kernel/image/layout_kernel.cl diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index f7d3fae8785e4e0fc85476f6a13cba1cc99d7c4b..c1a2afdabb2b28ada4c11e49ef0bd8e27008b184 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -18,7 +18,7 @@ add_kernel(pool_opencl OPENCL basic SRCS pool_image_compute.cc DEPS ${cl_kernel_ add_kernel(activation_opencl OPENCL basic SRCS activation_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(reshape_opencl OPENCL basic SRCS reshape_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_image_compute.cc DEPS ${cl_kernel_deps}) -add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(layout_opencl OPENCL basic SRCS layout_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(concat_opencl OPENCL basic SRCS concat_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(nearest_interp_opencl OPENCL basic SRCS nearest_interp_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(scale_opencl OPENCL basic SRCS scale_image_compute.cc DEPS ${cl_kernel_deps}) @@ -68,7 +68,7 @@ lite_cc_test(test_elementwise_mul_image_opencl SRCS elementwise_mul_image_comput DEPS elementwise_mul_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc +lite_cc_test(test_layout_image_opencl SRCS layout_image_compute_test.cc DEPS layout_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/kernels/opencl/activation_image_compute.cc b/lite/kernels/opencl/activation_image_compute.cc index eecbd56afbd24fe138f175a398154c0b1f4d876e..e1b863e17f7225e71ebd6c68b326c7cae475ad1a 100644 --- a/lite/kernels/opencl/activation_image_compute.cc +++ b/lite/kernels/opencl/activation_image_compute.cc @@ -44,9 +44,9 @@ class ReluComputeImageDefault : public KernelLite(); const auto& x_dims = param.X->dims(); - auto* x_buf = param.X->data(); + auto* x_img = param.X->data(); auto image_shape = InitImageDimInfoWith(x_dims); - auto* out_buf = param.Out->mutable_data( + auto* out_img = param.Out->mutable_data( image_shape["width"], image_shape["height"]); const auto& y_dims = param.Out->dims(); // useless: check dim only @@ -57,9 +57,9 @@ class ReluComputeImageDefault : public KernelLiteGetKernel(kernel_key.str()); int arg_idx = 0; - cl_int status = kernel.setArg(arg_idx, *x_buf); + cl_int status = kernel.setArg(arg_idx, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_buf); + status = kernel.setArg(++arg_idx, *out_img); CL_CHECK_FATAL(status); VLOG(4) << TargetToStr(param.X->target()); @@ -82,9 +82,7 @@ class ReluComputeImageDefault : public KernelLitehost) jammed if emplace to `cl_wait_list` - // context.cl_wait_list()->emplace(out_buf, event_); - context.cl_context()->GetCommandQueue().finish(); + context.cl_wait_list()->emplace(out_img, event_); } private: @@ -112,9 +110,9 @@ class Relu6ComputeImageDefault : public KernelLite(); const auto& x_dims = param.X->dims(); - auto* x_buf = param.X->data(); + auto* x_img = param.X->data(); auto image_shape = InitImageDimInfoWith(x_dims); - auto* out_buf = param.Out->mutable_data( + auto* out_img = param.Out->mutable_data( image_shape["width"], image_shape["height"]); const auto& y_dims = param.Out->dims(); // useless: check dim only auto threshold = param.Relu_clipped_coef; @@ -126,9 +124,9 @@ class Relu6ComputeImageDefault : public KernelLiteGetKernel(kernel_key.str()); int arg_idx = 0; - cl_int status = kernel.setArg(arg_idx, *x_buf); + cl_int status = kernel.setArg(arg_idx, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_buf); + status = kernel.setArg(++arg_idx, *out_img); CL_CHECK_FATAL(status); status = kernel.setArg(++arg_idx, threshold); CL_CHECK_FATAL(status); @@ -154,9 +152,7 @@ class Relu6ComputeImageDefault : public KernelLitehost) jammed if emplace to `cl_wait_list` - // context.cl_wait_list()->emplace(out_buf, event_); - context.cl_context()->GetCommandQueue().finish(); + context.cl_wait_list()->emplace(out_img, event_); } private: @@ -185,11 +181,11 @@ class SigmoidComputeImageDefault void Run() override { auto& param = *param_.get_mutable(); const auto& x_dims = param.X->dims(); - auto* x_buf = + auto* x_img = param.X->data(); // use half_t represents half float auto image_shape = InitImageDimInfoWith(x_dims); - auto* out_buf = param.Out->mutable_data( // use half_t + auto* out_img = param.Out->mutable_data( // use half_t // represents half float image_shape["width"], image_shape["height"]); @@ -202,9 +198,9 @@ class SigmoidComputeImageDefault auto kernel = context.cl_context()->GetKernel(kernel_key.str()); int arg_idx = 0; - cl_int status = kernel.setArg(arg_idx, *x_buf); + cl_int status = kernel.setArg(arg_idx, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_buf); + status = kernel.setArg(++arg_idx, *out_img); CL_CHECK_FATAL(status); VLOG(4) << TargetToStr(param.X->target()); @@ -227,9 +223,7 @@ class SigmoidComputeImageDefault nullptr, event_.get()); CL_CHECK_FATAL(status); - // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` - // context.cl_wait_list()->emplace(out_buf, event_); - context.cl_context()->GetCommandQueue().finish(); + context.cl_wait_list()->emplace(out_img, event_); } private: diff --git a/lite/kernels/opencl/activation_image_compute_test.cc b/lite/kernels/opencl/activation_image_compute_test.cc index 09f48eb86f484ecb7073d9fe23dffd770f6cf687..e3f1084d216d6abe8a66bef9cf090c3e63c16487 100644 --- a/lite/kernels/opencl/activation_image_compute_test.cc +++ b/lite/kernels/opencl/activation_image_compute_test.cc @@ -18,6 +18,9 @@ #include "lite/core/op_registry.h" #include "lite/core/tensor.h" #include "lite/kernels/opencl/image_helper.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (1e0) namespace paddle { namespace lite { @@ -58,8 +61,8 @@ TEST(relu_image2d_fp16, compute) { "-> host"; #ifdef RELU_FP16_LOOP_TEST - for (int n = 1; n <= 100; n += 33) { - for (auto c : {1, 3}) { + for (int n = 1; n <= 2; n += 1) { + for (auto c : {1}) { for (int h = 12; h <= 100; h += 13) { for (int w = 12; w <= 100; w += 25) { #else @@ -169,6 +172,21 @@ TEST(relu_image2d_fp16, compute) { LOG(INFO) << "run kernel: img_to_buf_kernel"; img_to_buf_kernel->Launch(); + // wait for opencl + auto *wait_list = context->As().cl_wait_list(); + auto *out_ptr = ImageToBufferParam.y->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto &event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + // compute ref cpu relu_compute_ref(mapped_x, x_dim, y_data_ref); // result @@ -176,18 +194,24 @@ TEST(relu_image2d_fp16, compute) { LOG(INFO) << "---- print kernel result (input -> output) ----"; for (int eidx = 0; eidx < x_dim.production(); ++eidx) { std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx] - << std::endl; + << ", ref: " << y_data_ref[eidx] << std::endl; } #endif // RELU_FP16_PRINT_RESULT // check result: compare kernel output and cpu output(y_data_ref) - for (int eidx = 0; eidx < x_dim.production(); eidx++) { - EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-6); - if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-6) { - LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx - << " / " << x_dim.production() << ", y_data_ref[" - << eidx << "]:" << y_data_ref[eidx] << ", mapped_y[" - << eidx << "]:" << mapped_y[eidx]; + for (int eidx = 0; eidx < x_dim.production(); ++eidx) { + auto abs_diff = COMPUTE_ABS_DIFF(y_data_ref[eidx], mapped_y[eidx]); + auto relative_diff = + COMPUTE_RELATIVE_DIFF(y_data_ref[eidx], mapped_y[eidx]); + EXPECT_EQ( + (relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << eidx << ", y_data_ref[" << eidx + << "]:" << y_data_ref[eidx] << ", mapped_y[" << eidx + << "]:" << mapped_y[eidx] << " abs_diff:" << abs_diff + << " relative_diff:" << relative_diff + << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; break; } } @@ -206,7 +230,7 @@ TEST(relu_image2d_fp16, compute) { #endif } -// #define RELU6_FP16_LOOP_TEST +// #define RELU6_FP16_LOOP_TEST // #define RELU6_FP16_PRINT_RESULT TEST(relu6_image2d_fp16, compute) { LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu6(img) -> " @@ -287,7 +311,7 @@ TEST(relu6_image2d_fp16, compute) { auto *mapped_y = static_cast(TargetWrapperCL::Map( y_data, 0, sizeof(float) * x_dim.production())); for (int i = 0; i < x_dim.production(); ++i) { - mapped_x[i] = static_cast(i) - x_dim.production() / 2; + mapped_x[i] = static_cast(i) - x_dim.production() / 2 * 0.1; mapped_y[i] = static_cast(0); } auto *relu_in_data = relu_in.mutable_data( @@ -326,6 +350,21 @@ TEST(relu6_image2d_fp16, compute) { LOG(INFO) << "run kernel: img_to_buf_kernel"; img_to_buf_kernel->Launch(); + // wait for opencl + auto *wait_list = context->As().cl_wait_list(); + auto *out_ptr = ImageToBufferParam.y->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto &event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + // compute ref cpu relu_compute_ref(mapped_x, x_dim, y_data_ref, 6.f); // result @@ -333,14 +372,14 @@ TEST(relu6_image2d_fp16, compute) { LOG(INFO) << "---- print kernel result (input -> output) ----"; for (int eidx = 0; eidx < x_dim.production(); ++eidx) { std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx] - << std::endl; + << ", ref: " << y_data_ref[eidx] << std::endl; } #endif // RELU6_FP16_PRINT_RESULT // check result: compare kernel output and cpu output(y_data_ref) for (int eidx = 0; eidx < x_dim.production(); eidx++) { - EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-6); - if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-6) { + EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], FP16_MAX_DIFF); + if (abs(y_data_ref[eidx] - mapped_y[eidx]) > FP16_MAX_DIFF) { LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / " << x_dim.production() << ", y_data_ref[" << eidx << "]:" << y_data_ref[eidx] << ", mapped_y[" @@ -485,6 +524,21 @@ TEST(sigmoid_image2d_fp16, compute) { LOG(INFO) << "run kernel: img_to_buf_kernel"; img_to_buf_kernel->Launch(); + // wait for opencl + auto *wait_list = context->As().cl_wait_list(); + auto *out_ptr = ImageToBufferParam.y->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto &event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + // compute ref cpu sigmoid_compute_ref(mapped_x, x_dim, y_data_ref); // result @@ -492,14 +546,14 @@ TEST(sigmoid_image2d_fp16, compute) { LOG(INFO) << "---- print kernel result (input -> output) ----"; for (int eidx = 0; eidx < x_dim.production(); ++eidx) { std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx] - << std::endl; + << ", ref:" << y_data_ref[eidx] << std::endl; } #endif // SIGMOID_FP16_PRINT_RESULT // check result: compare kernel output and cpu output(y_data_ref) for (int eidx = 0; eidx < x_dim.production(); eidx++) { - EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-3); - if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-3) { + EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], FP16_MAX_DIFF); + if (abs(y_data_ref[eidx] - mapped_y[eidx]) > FP16_MAX_DIFF) { LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / " << x_dim.production() << ", y_data_ref[" << eidx << "]: " << y_data_ref[eidx] << ", mapped_y[" diff --git a/lite/kernels/opencl/concat_image_compute.cc b/lite/kernels/opencl/concat_image_compute.cc index ce1a5e6ceab677ceaf0a3cd8e86a608507b54d54..a9b66e37ebc0d497ed65ea1dfe7aaa63b4c2aabe 100644 --- a/lite/kernels/opencl/concat_image_compute.cc +++ b/lite/kernels/opencl/concat_image_compute.cc @@ -109,25 +109,28 @@ class ConcatComputeImage : public KernelLitedims()[inputs[0]->dims().size() - 1]; - LOG(INFO) << "concat 输入尺寸: "; + VLOG(4) << "concat 输入尺寸: "; for (size_t i = 0; i < inputs.size(); i++) { - LOG(INFO) << "inputs [" << i << "]" - << "[" << inputs[i]->dims().size() << "D]:" - << " dims:" << inputs[i]->dims()[0] << " " - << inputs[i]->dims()[1] << " " << inputs[i]->dims()[2] << " " - << inputs[i]->dims()[3]; + VLOG(4) << "inputs [" << i << "]" + << "[" << inputs[i]->dims().size() << "D]:" + << " dims:" << inputs[i]->dims()[0] << " " + << inputs[i]->dims()[1] << " " << inputs[i]->dims()[2] << " " + << inputs[i]->dims()[3]; } - LOG(INFO) << "concat 输出尺寸: "; - LOG(INFO) << " out dims: " - << "[" << x_dims.size() << "D]:" << x_dims[0] << " " << x_dims[1] - << " " << x_dims[2] << " " << x_dims[3]; - LOG(INFO) << "axis_: " << axis_; - LOG(INFO) << "flag_: " << flag_; + + VLOG(4) << "concat 输出尺寸: "; + VLOG(4) << " out dims: " + << "[" << x_dims.size() << "D]:" << x_dims[0] << " " << x_dims[1] + << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "axis_: " << axis_; + VLOG(4) << "flag_: " << flag_; + auto global_work_size = cl::NDRange{static_cast(x_dims[x_dims.size() - 1]), static_cast(image_shape["width"] / x_dims[x_dims.size() - 1]), static_cast(image_shape["height"])}; + VLOG(4) << TargetToStr(param.output->target()); VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " << image_shape["height"]; @@ -136,16 +139,17 @@ class ConcatComputeImage : public KernelLiteGetKernel(kernel_key.str()); int out_w = x_dims[x_dims.size() - 1]; int out_c = x_dims[1]; if (inputs.size() == 2) { - auto* x_buf0 = inputs[0]->data(); - auto* x_buf1 = inputs[1]->data(); + auto* x_buf0 = inputs[0]->data(); + auto* x_buf1 = inputs[1]->data(); cl_int status = kernel.setArg(arg_idx, *x_buf0); CL_CHECK_FATAL(status); status = kernel.setArg(++arg_idx, *x_buf1); @@ -171,14 +175,14 @@ class ConcatComputeImage : public KernelLiteGetCommandQueue().finish(); + context.cl_wait_list()->emplace(out_buf, event_); } else { auto start = 0; for (int i = 0; i < inputs.size(); i++) { arg_idx = 0; auto in_dims = inputs[i]->dims(); image_shape = InitImageDimInfoWith(in_dims); - auto* x_buf = inputs[i]->data(); + auto* x_buf = inputs[i]->data(); int in_w = in_dims[in_dims.size() - 1]; VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " << image_shape["height"]; @@ -212,7 +216,7 @@ class ConcatComputeImage : public KernelLiteGetCommandQueue().finish(); + context.cl_wait_list()->emplace(out_buf, event_); start += inputs[i]->dims()[axis_]; } } diff --git a/lite/kernels/opencl/concat_image_compute_test.cc b/lite/kernels/opencl/concat_image_compute_test.cc index 43b37d9b74aa6b37ca10a140b8485345bff46571..38958acfbccecdf1d8e96a2d571e0804e172d049 100644 --- a/lite/kernels/opencl/concat_image_compute_test.cc +++ b/lite/kernels/opencl/concat_image_compute_test.cc @@ -245,6 +245,21 @@ TEST(concat_image2d, compute) { LOG(INFO) << "run kernel: img_to_buf_kernel"; img_to_buf_kernel->Launch(); + // wait for opencl + auto *wait_list = context->As().cl_wait_list(); + auto *out_ptr = ImageToBufferParam.y->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto &event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + // compute ref cp_u std::vector ins_ptr; std::vector in_dim; diff --git a/lite/kernels/opencl/conv_image_compute_test.cc b/lite/kernels/opencl/conv_image_compute_test.cc index 94c0f96ab958ba053da3a922a91e414c5c2f6203..0d76ef11eef0f7f784354d841c116e0adb19d306 100644 --- a/lite/kernels/opencl/conv_image_compute_test.cc +++ b/lite/kernels/opencl/conv_image_compute_test.cc @@ -471,7 +471,7 @@ TEST(conv2d, compute_image2d_1x1) { for (int i = 0; i < out_dim.production(); i++) { auto relative_diff = COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]); - auto abs_diff = COMPTUE_ABS_DIFF(output_v[i], out_ref_data[i]); + auto abs_diff = COMPUTE_ABS_DIFF(output_v[i], out_ref_data[i]); // EXPECT_LT(relative_diff, FP16_MAX_DIFF); EXPECT_FALSE(relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF); @@ -1191,7 +1191,7 @@ TEST(conv2d, compute_image2d_5x5) { for (int i = 0; i < out_dim.production(); i++) { auto relative_diff = COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]); - auto abs_diff = COMPTUE_ABS_DIFF(output_v[i], out_ref_data[i]); + auto abs_diff = COMPUTE_ABS_DIFF(output_v[i], out_ref_data[i]); EXPECT_FALSE(relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF); if (relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF) { @@ -1540,7 +1540,7 @@ TEST(conv2d, compute_image2d_7x7) { for (int i = 0; i < out_dim.production(); i++) { auto relative_diff = COMPUTE_RELATIVE_DIFF(output_v[i], out_ref_data[i]); - auto abs_diff = COMPTUE_ABS_DIFF(output_v[i], out_ref_data[i]); + auto abs_diff = COMPUTE_ABS_DIFF(output_v[i], out_ref_data[i]); EXPECT_FALSE(relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF); if (relative_diff > FP16_MAX_DIFF && abs_diff > FP16_ABS_DIFF) { diff --git a/lite/kernels/opencl/fc_buffer_compute.cc b/lite/kernels/opencl/fc_buffer_compute.cc index 1f8ba6ae2f603ba02e4025e63158249a49dbc815..642ae145dfe415eba50d0a1d7e2e5683d5308e2b 100644 --- a/lite/kernels/opencl/fc_buffer_compute.cc +++ b/lite/kernels/opencl/fc_buffer_compute.cc @@ -57,6 +57,10 @@ class FcCompute global_work_size_ = cl::NDRange{static_cast((m_ + 3) / 4), static_cast((n_ + 3) / 4)}; } + + if (param.activation_type == "relu") { + build_options_ += "-DRELU"; + } auto& context = ctx_->As(); context.cl_context()->AddKernel( kernel_func_name_, "buffer/fc_kernel.cl", build_options_); @@ -107,7 +111,7 @@ class FcCompute private: int m_, n_, k_; std::string kernel_func_name_{}; - std::string build_options_{"-DCL_DTYPE=float"}; + std::string build_options_{"-DCL_DTYPE_float "}; cl::NDRange global_work_size_; std::shared_ptr event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/io_copy_buffer_compute.cc b/lite/kernels/opencl/io_copy_buffer_compute.cc index 3387a0887d3422636e39e742149f84672e8e75d4..0148e6b143ebae1d4c2ca60e6f40f8a0228a2979 100644 --- a/lite/kernels/opencl/io_copy_buffer_compute.cc +++ b/lite/kernels/opencl/io_copy_buffer_compute.cc @@ -103,9 +103,6 @@ class IoCopykOpenCLToHostCompute auto* wait_list = context.cl_wait_list(); auto* x_ptr = param.x->data(); - /* TODO(ysh329): io_copy(device->host) jammed if `it` emplaced to - `cl_wait_list` - in kernel and `wait_list` enabled auto it = wait_list->find(x_ptr); if (it != wait_list->end()) { VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; @@ -114,7 +111,6 @@ class IoCopykOpenCLToHostCompute } else { LOG(FATAL) << "Could not find the sync event for the target cl tensor."; } - */ CopyToHostSync(data, param.x->raw_data(), mem_size); } diff --git a/lite/kernels/opencl/layout_compute.cc b/lite/kernels/opencl/layout_image_compute.cc similarity index 94% rename from lite/kernels/opencl/layout_compute.cc rename to lite/kernels/opencl/layout_image_compute.cc index 2a82aec526eba13515986c90ace790d6efed21ad..fad37aa709441f95f2aadc535e7eb5db895765f4 100644 --- a/lite/kernels/opencl/layout_compute.cc +++ b/lite/kernels/opencl/layout_image_compute.cc @@ -44,7 +44,7 @@ class LayoutComputeBufferChwToImageDefault } auto& context = ctx_->As(); context.cl_context()->AddKernel( - kernel_func_name_, "buffer/layout_kernel.cl", build_options_); + kernel_func_name_, "image/layout_kernel.cl", build_options_); } void Run() override { @@ -126,9 +126,7 @@ class LayoutComputeBufferChwToImageDefault nullptr, event_.get()); CL_CHECK_FATAL(status); - // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` - // context.cl_wait_list()->emplace(y_data, event_); - context.cl_context()->GetCommandQueue().finish(); + context.cl_wait_list()->emplace(y_data, event_); } std::string doc() const override { @@ -155,7 +153,7 @@ class LayoutComputeImageDefaultToBufferChw } auto& context = ctx_->As(); context.cl_context()->AddKernel( - kernel_func_name_, "buffer/layout_kernel.cl", build_options_); + kernel_func_name_, "image/layout_kernel.cl", build_options_); } void Run() override { @@ -229,9 +227,7 @@ class LayoutComputeImageDefaultToBufferChw nullptr, event_.get()); CL_CHECK_FATAL(status); - // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` - // context.cl_wait_list()->emplace(y_data, event_); - context.cl_context()->GetCommandQueue().finish(); + context.cl_wait_list()->emplace(y_data, event_); } std::string doc() const override { @@ -325,10 +321,7 @@ class LayoutComputeBufferChwToImage2DNw nullptr, event_.get()); CL_CHECK_FATAL(status); - // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` - // context.cl_wait_list()->emplace(y_data, event_); - context.cl_context()->GetCommandQueue().finish(); - // auto image_shape = InitImageDimInfoWith(x_dims); + context.cl_wait_list()->emplace(y_data, event_); } std::string doc() const override { diff --git a/lite/kernels/opencl/layout_compute_test.cc b/lite/kernels/opencl/layout_image_compute_test.cc similarity index 90% rename from lite/kernels/opencl/layout_compute_test.cc rename to lite/kernels/opencl/layout_image_compute_test.cc index a523c896fa893a32090c7c153879d00b2d4e5513..9cdfbe0a1d64176db9dfd2698ab3ab0631a4b118 100644 --- a/lite/kernels/opencl/layout_compute_test.cc +++ b/lite/kernels/opencl/layout_image_compute_test.cc @@ -18,6 +18,9 @@ #include "lite/core/op_registry.h" #include "lite/core/tensor.h" #include "lite/kernels/opencl/image_helper.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (1e0) namespace paddle { namespace lite { @@ -86,7 +89,7 @@ TEST(layout_ImageDefault, compute) { auto* mapped_y = static_cast(TargetWrapperCL::Map( y_data, 0, sizeof(float) * x_dim.production())); for (int i = 0; i < x_dim.production(); ++i) { - mapped_x[i] = static_cast(i) * 2; + mapped_x[i] = static_cast(i) * 0.01; } // set context and kernel args @@ -122,14 +125,19 @@ TEST(layout_ImageDefault, compute) { #endif // PRINT_RESULT // check result: compare input and output - float MAX_PASS_DIFF = 1e-4; - for (int eidx = 0; eidx < x_dim.production(); eidx++) { - EXPECT_NEAR(mapped_x[eidx], mapped_y[eidx], MAX_PASS_DIFF); - if (abs(mapped_x[eidx] - mapped_y[eidx]) > MAX_PASS_DIFF) { - LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx - << " / " << x_dim.production() << ", mapped_x[" << eidx - << "]:" << mapped_x[eidx] << ", mapped_y[" << eidx - << "]:" << mapped_y[eidx]; + for (int i = 0; i < x_dim.production(); i++) { + auto abs_diff = COMPUTE_ABS_DIFF(mapped_x[i], mapped_y[i]); + auto relative_diff = + COMPUTE_RELATIVE_DIFF(mapped_x[i], mapped_y[i]); + EXPECT_EQ( + (relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << " mapped_x[" << i + << "]:" << mapped_x[i] << " mapped_y[" << i + << "]:" << mapped_y[i] << " abs_diff:" << abs_diff + << " relative_diff:" << relative_diff + << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; break; } } @@ -238,12 +246,27 @@ TEST(layout_ImageDefault_With_Pre_Post, compute) { LOG(INFO) << "run kernel: image2d_to_buffer_with_post255"; img_to_buf_kernel->Launch(); + // wait for opencl + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = ImageToBufferParam.y->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + // result #ifdef PRINT_RESULT LOG(INFO) << "---- print result ----"; for (int eidx = 0; eidx < x_dim.production(); ++eidx) { - std::cout << mapped_x[eidx] << " -> " - << static_cast(mapped_y[eidx]) << std::endl; + std::cout << +mapped_x[eidx] << " -> " + << +static_cast(mapped_y[eidx]) << std::endl; } #endif // PRINT_RESULT diff --git a/lite/kernels/opencl/nearest_interp_image_compute.cc b/lite/kernels/opencl/nearest_interp_image_compute.cc index c22e38a8c2c57f158afc9a1e68524b269cd4cc6e..5f9aa252c5b5ba4fd7730b28e5c295711a2e7f64 100644 --- a/lite/kernels/opencl/nearest_interp_image_compute.cc +++ b/lite/kernels/opencl/nearest_interp_image_compute.cc @@ -46,11 +46,11 @@ class NearestInterpComputeImageDefault auto& param = *param_.get_mutable(); const auto& x_dims = param.X->dims(); const auto& y_dims = param.Out->dims(); - auto* x_buf = + auto* x_img = param.X->data(); // use half_t represents half float auto out_image_shape = InitImageDimInfoWith(y_dims); - auto* out_buf = param.Out->mutable_data( // use half_t + auto* out_img = param.Out->mutable_data( // use half_t // represents half float out_image_shape["width"], out_image_shape["height"]); @@ -69,9 +69,9 @@ class NearestInterpComputeImageDefault auto kernel = context.cl_context()->GetKernel(kernel_key.str()); int arg_idx = 0; - cl_int status = kernel.setArg(arg_idx, *x_buf); + cl_int status = kernel.setArg(arg_idx, *x_img); CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *out_buf); + status = kernel.setArg(++arg_idx, *out_img); CL_CHECK_FATAL(status); status = kernel.setArg(++arg_idx, static_cast(scale_h)); CL_CHECK_FATAL(status); @@ -112,9 +112,7 @@ class NearestInterpComputeImageDefault nullptr, event_.get()); CL_CHECK_FATAL(status); - // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` - // context.cl_wait_list()->emplace(out_buf, event_); - context.cl_context()->GetCommandQueue().finish(); + context.cl_wait_list()->emplace(out_img, event_); } private: diff --git a/lite/kernels/opencl/nearest_interp_image_compute_test.cc b/lite/kernels/opencl/nearest_interp_image_compute_test.cc index 37389d7a3d72fa7ab79c9e63d034c1ad06c80a51..a91e853a865f6abb2536606be6628e860cf7d6b9 100644 --- a/lite/kernels/opencl/nearest_interp_image_compute_test.cc +++ b/lite/kernels/opencl/nearest_interp_image_compute_test.cc @@ -208,6 +208,21 @@ TEST(nearest_interp_image2d, compute) { LOG(INFO) << "run kernel: img_to_buf_kernel"; img_to_buf_kernel->Launch(); + // wait for opencl + auto *wait_list = context->As().cl_wait_list(); + auto *out_ptr = ImageToBufferParam.y->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto &event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + // compute ref cpu for (int nid = 0; nid < x_dim[0]; ++nid) { for (int cid = 0; cid < x_dim[1]; ++cid) { diff --git a/lite/kernels/opencl/pool_image_compute.cc b/lite/kernels/opencl/pool_image_compute.cc index adfa57f15b35a8999bff5ae0c3f938f32d0709e8..f2c35b8ddb057a77d4a056d2e856ad6d86f112aa 100644 --- a/lite/kernels/opencl/pool_image_compute.cc +++ b/lite/kernels/opencl/pool_image_compute.cc @@ -69,14 +69,14 @@ class PoolComputeImage2D : public KernelLitedata(); - LOG(INFO) << "x_image" << x_img; + VLOG(4) << "x_image" << x_img; auto out_image_shape = InitImageDimInfoWith(out_dims); - LOG(INFO) << "out_image_shape = " << out_image_shape["width"] << " " - << out_image_shape["height"]; + VLOG(4) << "out_image_shape = " << out_image_shape["width"] << " " + << out_image_shape["height"]; auto* out_img = param.output->mutable_data( out_image_shape["width"], out_image_shape["height"]); - LOG(INFO) << "out_image" << out_img; + VLOG(4) << "out_image" << out_img; STL::stringstream kernel_key; kernel_key << kernel_func_name_ << build_options_; diff --git a/lite/kernels/opencl/reshape_image_compute.cc b/lite/kernels/opencl/reshape_image_compute.cc index 4b50cfd0501b58c87989ad30dc4579d6d4d99272..6bf1cfb2f2a78c3fa9d825c33e4d4eeeb804a1e4 100644 --- a/lite/kernels/opencl/reshape_image_compute.cc +++ b/lite/kernels/opencl/reshape_image_compute.cc @@ -63,7 +63,7 @@ class ReshapeComputeFloatImage : public KernelLitemutable_data( out_image_shape.at("width"), out_image_shape.at("height")); - LOG(INFO) << "out_dims= " << out_dims; + VLOG(4) << "out_dims= " << out_dims; const std::vector& default_work_size = DefaultWorkSize( out_dims, diff --git a/lite/kernels/opencl/test_helper.h b/lite/kernels/opencl/test_helper.h index af8a1c1e4651d5add6021b41cfe2f2c7caf26f68..a1b875688e1ade3aa3fb441506d2a11c5a06ab19 100644 --- a/lite/kernels/opencl/test_helper.h +++ b/lite/kernels/opencl/test_helper.h @@ -14,7 +14,7 @@ #pragma once -#define COMPTUE_ABS_DIFF(res0, res1) abs(res0 - res1) +#define COMPUTE_ABS_DIFF(res0, res1) abs(res0 - res1) #define COMPUTE_RELATIVE_DIFF(res0, res1) abs(abs(res0 - res1) / (res1 + 1e-5))