diff --git a/mace/core/net.cc b/mace/core/net.cc index 45db357cbcf43fc7423f51be1d80d1801b9d36e0..e255614fba3caeb9b744103d74b4b53ebec9cccf 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -38,7 +38,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { VLOG(1) << "Running operator " << op->debug_def().name() << "(" << op->debug_def().type() << ")."; OperatorStats *op_stats = nullptr; - if (run_metadata) { + if (run_metadata && device_type_ != DeviceType::OPENCL) { op_stats = run_metadata->add_op_stats(); op_stats->set_operator_name(op->debug_def().name()); op_stats->set_type(op->debug_def().type()); @@ -50,14 +50,32 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); return false; } - if (op_stats) { + + if (run_metadata) { if (device_type_ == DeviceType::OPENCL) { OpenCLRuntime::Get()->command_queue().finish(); + op_stats = run_metadata->add_op_stats(); + op_stats->set_operator_name(op->debug_def().name()); + op_stats->set_type(op->debug_def().type()); + + op_stats->set_all_start_micros( + OpenCLRuntime::Get()->GetEventProfilingStartInfo() / 1000); + op_stats->set_op_start_rel_micros( + OpenCLRuntime::Get()->GetEventProfilingStartInfo() / 1000 - + op_stats->all_start_micros()); + + op_stats->set_op_end_rel_micros( + OpenCLRuntime::Get()->GetEventProfilingEndInfo() / 1000 - + op_stats->all_start_micros()); + op_stats->set_all_end_rel_micros( + OpenCLRuntime::Get()->GetEventProfilingEndInfo() / 1000 - + op_stats->all_start_micros()); + } else { + op_stats->set_op_end_rel_micros(NowInMicroSec() - + op_stats->all_start_micros()); + op_stats->set_all_end_rel_micros(NowInMicroSec() - + op_stats->all_start_micros()); } - op_stats->set_op_end_rel_micros(NowInMicroSec() - - op_stats->all_start_micros()); - op_stats->set_all_end_rel_micros(NowInMicroSec() - - op_stats->all_start_micros()); } VLOG(1) << "Op " << op->debug_def().name() << " has shape: " << internal::MakeString(op->Output(0)->shape()); diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index 585a151e72c0c950ef68e704e63e56798c9efeb9..e31db894e0a8fd445df2f7a3011809b03c7fcfb0 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -32,6 +32,8 @@ bool ReadSourceFile(const std::string &filename, std::string *content) { } // namespace +bool OpenCLRuntime::enable_profiling_ = false; +cl::Event* OpenCLRuntime::profiling_ev_ = NULL; OpenCLRuntime *OpenCLRuntime::Get() { static std::once_flag init_once; @@ -80,13 +82,35 @@ OpenCLRuntime *OpenCLRuntime::Get() { // a context is like a "runtime link" to the device and platform; // i.e. communication is possible cl::Context context({gpu_device}); - cl::CommandQueue command_queue(context, gpu_device); + cl::CommandQueue command_queue(context, gpu_device, + enable_profiling_ ? CL_QUEUE_PROFILING_ENABLE : 0); instance = new OpenCLRuntime(context, gpu_device, command_queue); }); return instance; } +void OpenCLRuntime::EnableProfiling() { + if (!enable_profiling_) { + enable_profiling_ = true; + profiling_ev_ = new cl::Event(); + } +} + +cl::Event* OpenCLRuntime::GetDefaultEvent() { + return profiling_ev_; +} + +cl_ulong OpenCLRuntime::GetEventProfilingStartInfo() { + MACE_CHECK(enable_profiling_, "should enable profiling first."); + return profiling_ev_->getProfilingInfo(); +} + +cl_ulong OpenCLRuntime::GetEventProfilingEndInfo() { + MACE_CHECK(enable_profiling_, "should enable profiling first."); + return profiling_ev_->getProfilingInfo(); +} + OpenCLRuntime::OpenCLRuntime(cl::Context context, cl::Device device, cl::CommandQueue command_queue) @@ -95,7 +119,10 @@ OpenCLRuntime::OpenCLRuntime(cl::Context context, kernel_path_ = std::string(kernel_path == nullptr ? "" : kernel_path) + "/"; } -OpenCLRuntime::~OpenCLRuntime() {} +OpenCLRuntime::~OpenCLRuntime() { + if (profiling_ev_) + delete profiling_ev_; +} cl::Context &OpenCLRuntime::context() { return context_; } diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index b4bda7d6f39ac97c493285c4139e32e07da34bd2..88086998d0779d8c7688de28e77d908611ba36ba 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -18,6 +18,13 @@ class OpenCLRuntime { public: static OpenCLRuntime *Get(); + static void EnableProfiling(); + cl::Event *GetDefaultEvent(); + + cl_ulong GetEventProfilingStartInfo(); + cl_ulong GetEventProfilingEndInfo(); + + cl::Context &context(); cl::Device &device(); cl::CommandQueue &command_queue(); @@ -41,6 +48,9 @@ class OpenCLRuntime { cl::Program *program); private: + static bool enable_profiling_; + static cl::Event* profiling_ev_; + cl::Context context_; cl::Device device_; cl::CommandQueue command_queue_; diff --git a/mace/core/runtime/opencl/opencl_wrapper.cc b/mace/core/runtime/opencl/opencl_wrapper.cc index e7ae8f9991c1cbf8cb22f6d8f7ca11591a2991cd..afd2b1737a103f20bc8a86b1ed2e091e4fb62c35 100644 --- a/mace/core/runtime/opencl/opencl_wrapper.cc +++ b/mace/core/runtime/opencl/opencl_wrapper.cc @@ -160,6 +160,11 @@ class OpenCLLibraryImpl final { size_t, void *, size_t *); + using clGetEventProfilingInfoFunc = cl_int (*)(cl_event event, + cl_profiling_info param_name, + size_t param_value_size, + void *param_value, + size_t *param_value_size_ret); using clGetImageInfoFunc = cl_int (*)(cl_mem, cl_image_info, size_t, @@ -209,6 +214,7 @@ class OpenCLLibraryImpl final { DEFINE_FUNC_PTR(clReleaseDevice); DEFINE_FUNC_PTR(clRetainEvent); DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo); + DEFINE_FUNC_PTR(clGetEventProfilingInfo); DEFINE_FUNC_PTR(clGetImageInfo); #undef DEFINE_FUNC_PTR @@ -333,6 +339,7 @@ void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) { ASSIGN_FROM_DLSYM(clReleaseDevice); ASSIGN_FROM_DLSYM(clRetainEvent); ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo); + ASSIGN_FROM_DLSYM(clGetEventProfilingInfo); ASSIGN_FROM_DLSYM(clGetImageInfo); #undef ASSIGN_FROM_DLSYM @@ -879,6 +886,20 @@ cl_int clGetKernelWorkGroupInfo(cl_kernel kernel, } } +cl_int clGetEventProfilingInfo(cl_event event, + cl_profiling_info param_name, + size_t param_value_size, + void *param_value, + size_t *param_value_size_ret) { + auto func = mace::OpenCLLibraryImpl::Get().clGetEventProfilingInfo; + if (func != nullptr) { + return func(event, param_name, param_value_size, param_value, + param_value_size_ret); + } else { + return CL_OUT_OF_RESOURCES; + } +} + cl_int clGetImageInfo(cl_mem image, cl_image_info param_name, size_t param_value_size, @@ -892,3 +913,4 @@ cl_int clGetImageInfo(cl_mem image, return CL_OUT_OF_RESOURCES; } } + diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc index 26ba3cc4827dfd2a4e19497d5c0c913b778d858f..6c5106db25c0f12cb625b6e5e0c80c0497541804 100644 --- a/mace/kernels/opencl/addn.cc +++ b/mace/kernels/opencl/addn.cc @@ -31,7 +31,8 @@ static void Add2(const Tensor *input0, const Tensor *input1, Tensor *output) { cl_int error = runtime->command_queue().enqueueNDRangeKernel( addn_kernel, cl::NullRange, cl::NDRange(gws), - cl::NDRange(lws)); + cl::NDRange(lws), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 07a284887e9bff3e26b8a25aee0050f3b5c16815..c7cd37e3ec7e6c1e0dbe31cf335bb105869e35c2 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -62,7 +62,8 @@ void BatchNormFunctor::operator()( cl_int error = runtime->command_queue().enqueueNDRangeKernel( bm_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(params[0], params[1], params[2])); + cl::NDRange(params[0], params[1], params[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; return error; diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index aa8ee24fd642eab1cf03756149b33c27629985a8..9a112cdc4abb11275bb1494a55d6898c3af548cb 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -59,7 +59,8 @@ void Conv1x1V2(const Tensor *input, conv_2d_kernel, cl::NullRange, cl::NDRange(static_cast(batch), static_cast(channel_blocks), static_cast(pixel_blocks)), - cl::NDRange(1, 2, kwg_size / 2)); + cl::NDRange(1, 2, kwg_size / 2), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS, error); } @@ -104,7 +105,8 @@ void Conv1x1V3(const Tensor *input, conv_2d_kernel, cl::NullRange, cl::NDRange(static_cast(channel_blocks), static_cast(height), static_cast(height * batch)), - cl::NDRange(4, 15, 8)); + cl::NDRange(4, 15, 8), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS, error); } diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index 02c67f1672bbd274b473ab7fe92c207422eff3fe..1adb80b85c6af1f93b4d69baf808a77278330d3d 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -52,7 +52,8 @@ static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter, cl_int error = runtime->command_queue().enqueueNDRangeKernel( conv_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2])); + cl::NDRange(lws[0], lws[1], lws[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } diff --git a/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc b/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc index eeb6ae5dc2c1573948bae9e0494129bfa943db83..60ce2a829a78a0a0439dd1e287c61f2dee4b490b 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl_3x3.cc @@ -60,7 +60,8 @@ static void InnerDepthwiseConvOpenclK3x3S12(const Tensor *input, cl_int error = runtime->command_queue().enqueueNDRangeKernel( conv_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2])); + cl::NDRange(lws[0], lws[1], lws[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index 8bd73f50673766c7b0dd8d5dd706a1c2df8a6231..0aaa89ae2c649583dddafaffbcce428d4ffc94fd 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -52,7 +52,8 @@ static void Pooling3(const Tensor *input, cl_int error = runtime->command_queue().enqueueNDRangeKernel( pooling_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2])); + cl::NDRange(lws[0], lws[1], lws[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } @@ -100,7 +101,8 @@ static void PoolingN(const Tensor *input, cl_int error = runtime->command_queue().enqueueNDRangeKernel( pooling_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2])); + cl::NDRange(lws[0], lws[1], lws[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } diff --git a/mace/kernels/opencl/relu_opencl.cc b/mace/kernels/opencl/relu_opencl.cc index dc1ade1aefbcaa66e080113d7d3de5903668eddf..1149b965a2fc91c5394c97b7028d872b827dc125 100644 --- a/mace/kernels/opencl/relu_opencl.cc +++ b/mace/kernels/opencl/relu_opencl.cc @@ -36,7 +36,8 @@ void ReluFunctor::operator()(const Tensor *input, cl_int error = runtime->command_queue().enqueueNDRangeKernel( relu_kernel, cl::NullRange, cl::NDRange(gws), - cl::NDRange(lws)); + cl::NDRange(lws), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } else { auto relu_kernel = runtime->BuildKernel("relu", "relux", built_options); @@ -52,7 +53,8 @@ void ReluFunctor::operator()(const Tensor *input, cl_int error = runtime->command_queue().enqueueNDRangeKernel( relu_kernel, cl::NullRange, cl::NDRange(gws), - cl::NDRange(lws)); + cl::NDRange(lws), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } } diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 91dbb05252659da6e9e62ad671389eb62d0209f1..7b77afea0fdd3aed146b22d736cacc5c6c165e79 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -50,7 +50,8 @@ void ResizeBilinearFunctor::operator()( cl::NDRange(static_cast(batch * channels), static_cast(out_height), static_cast(out_width)), // TODO (heliangliang) tuning and fix when kwg_size < devisor - cl::NDRange(1, 16, kwg_size / 16)); + cl::NDRange(1, 16, kwg_size / 16), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS, error); } diff --git a/mace/kernels/opencl/space_to_batch_opecl.cc b/mace/kernels/opencl/space_to_batch_opecl.cc index 716e3d76696aec62f0a3f58d8e839b06ece9bb47..2716501c880fcd4fb2232e292b9396e27cfff2f3 100644 --- a/mace/kernels/opencl/space_to_batch_opecl.cc +++ b/mace/kernels/opencl/space_to_batch_opecl.cc @@ -45,7 +45,8 @@ void SpaceToBatchFunctor::operator()(Tensor *space_te cl_int error = runtime->command_queue().enqueueNDRangeKernel( s2b_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2])); + cl::NDRange(lws[0], lws[1], lws[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS); } diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 499af6f29c5f1918f8233ef1e11ba155e35cc869..e0d56173d20e89799e7c2f1a9df33a90dbca47bd 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -3,6 +3,7 @@ // #include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/testing/test_benchmark.h" #include "mace/ops/ops_test_util.h" @@ -12,6 +13,9 @@ static void BatchNorm( int iters, int batch, int channels, int height, int width) { mace::testing::StopTiming(); + if ( D == OPENCL ) + OpenCLRuntime::EnableProfiling(); + OpsTestNet net; OpDefBuilder("BatchNorm", "BatchNormBM") .Input("Input") @@ -77,4 +81,4 @@ BM_BATCH_NORM(1, 512, 14, 14, float); BM_BATCH_NORM(1, 1024, 7, 7, float); BM_BATCH_NORM(32, 1, 256, 256, float); BM_BATCH_NORM(32, 3, 256, 256, float); -} // namespace mace \ No newline at end of file +} // namespace mace diff --git a/mace/tools/benchmark/benchmark_model.cc b/mace/tools/benchmark/benchmark_model.cc index d4ae7b5d3bf754bca7d3d369b966ccea22ffdab4..09ac6fd62c947021a98fd328d55cf5663c9dabb3 100644 --- a/mace/tools/benchmark/benchmark_model.cc +++ b/mace/tools/benchmark/benchmark_model.cc @@ -3,6 +3,7 @@ // #include "mace/core/net.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/tools/benchmark/stat_summarizer.h" #include "mace/utils/command_line_flags.h" #include "mace/utils/utils.h" @@ -149,7 +150,7 @@ int Main(int argc, char **argv) { std::vector flag_list = { Flag("model_file", &model_file, "graph file name"), - Flag("device", &device, "CPU/NEON"), + Flag("device", &device, "CPU/NEON/OPENCL"), Flag("input_layer", &input_layer_string, "input layer names"), Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), Flag("input_layer_type", &input_layer_type_string, "input layer type"), @@ -259,6 +260,9 @@ int Main(int argc, char **argv) { DeviceType_Parse(device, &device_type); VLOG(0) << device_type; + if (device_type == DeviceType::OPENCL) + OpenCLRuntime::EnableProfiling(); + // load model std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary); if (!model_file_stream.is_open()) { diff --git a/mace/utils/tuner.h b/mace/utils/tuner.h index 1d36f7f5b170fc109bc7596bb556b0e8e3ed6959..38c29a8fe7e81a4ffc72bf048780d306ed1dd578 100644 --- a/mace/utils/tuner.h +++ b/mace/utils/tuner.h @@ -131,13 +131,14 @@ class Tuner { double &time_us) { RetType res; int64_t total_time_us = 0; - const int64_t start_time = NowInMicroSec(); for (int i = 0; i < num_runs; ++i) { res = func(params); + OpenCLRuntime::Get()->command_queue().finish(); + + double start_time = OpenCLRuntime::Get()->GetEventProfilingStartInfo() / 1000.0; + double end_time = OpenCLRuntime::Get()->GetEventProfilingEndInfo() / 1000.0; + total_time_us += end_time - start_time; } - OpenCLRuntime::Get()->command_queue().finish(); - const int64_t end_time = NowInMicroSec(); - total_time_us += end_time - start_time; time_us = total_time_us * 1.0 / num_runs; return res;