提交 f12366af 编写于 作者: 刘琦

Merge branch 'master' into 'master'

update opencl profiling time from opencl api

See merge request !111
...@@ -38,7 +38,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { ...@@ -38,7 +38,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
VLOG(1) << "Running operator " << op->debug_def().name() << "(" VLOG(1) << "Running operator " << op->debug_def().name() << "("
<< op->debug_def().type() << ")."; << op->debug_def().type() << ").";
OperatorStats *op_stats = nullptr; OperatorStats *op_stats = nullptr;
if (run_metadata) { if (run_metadata && device_type_ != DeviceType::OPENCL) {
op_stats = run_metadata->add_op_stats(); op_stats = run_metadata->add_op_stats();
op_stats->set_operator_name(op->debug_def().name()); op_stats->set_operator_name(op->debug_def().name());
op_stats->set_type(op->debug_def().type()); op_stats->set_type(op->debug_def().type());
...@@ -50,14 +50,32 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { ...@@ -50,14 +50,32 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false; return false;
} }
if (op_stats) {
if (run_metadata) {
if (device_type_ == DeviceType::OPENCL) { if (device_type_ == DeviceType::OPENCL) {
OpenCLRuntime::Get()->command_queue().finish(); OpenCLRuntime::Get()->command_queue().finish();
op_stats = run_metadata->add_op_stats();
op_stats->set_operator_name(op->debug_def().name());
op_stats->set_type(op->debug_def().type());
op_stats->set_all_start_micros(
OpenCLRuntime::Get()->GetEventProfilingStartInfo() / 1000);
op_stats->set_op_start_rel_micros(
OpenCLRuntime::Get()->GetEventProfilingStartInfo() / 1000 -
op_stats->all_start_micros());
op_stats->set_op_end_rel_micros(
OpenCLRuntime::Get()->GetEventProfilingEndInfo() / 1000 -
op_stats->all_start_micros());
op_stats->set_all_end_rel_micros(
OpenCLRuntime::Get()->GetEventProfilingEndInfo() / 1000 -
op_stats->all_start_micros());
} else {
op_stats->set_op_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
op_stats->set_all_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
} }
op_stats->set_op_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
op_stats->set_all_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
} }
VLOG(1) << "Op " << op->debug_def().name() VLOG(1) << "Op " << op->debug_def().name()
<< " has shape: " << internal::MakeString(op->Output(0)->shape()); << " has shape: " << internal::MakeString(op->Output(0)->shape());
......
...@@ -32,6 +32,8 @@ bool ReadSourceFile(const std::string &filename, std::string *content) { ...@@ -32,6 +32,8 @@ bool ReadSourceFile(const std::string &filename, std::string *content) {
} // namespace } // namespace
bool OpenCLRuntime::enable_profiling_ = false;
cl::Event* OpenCLRuntime::profiling_ev_ = NULL;
OpenCLRuntime *OpenCLRuntime::Get() { OpenCLRuntime *OpenCLRuntime::Get() {
static std::once_flag init_once; static std::once_flag init_once;
...@@ -80,13 +82,35 @@ OpenCLRuntime *OpenCLRuntime::Get() { ...@@ -80,13 +82,35 @@ OpenCLRuntime *OpenCLRuntime::Get() {
// a context is like a "runtime link" to the device and platform; // a context is like a "runtime link" to the device and platform;
// i.e. communication is possible // i.e. communication is possible
cl::Context context({gpu_device}); cl::Context context({gpu_device});
cl::CommandQueue command_queue(context, gpu_device); cl::CommandQueue command_queue(context, gpu_device,
enable_profiling_ ? CL_QUEUE_PROFILING_ENABLE : 0);
instance = new OpenCLRuntime(context, gpu_device, command_queue); instance = new OpenCLRuntime(context, gpu_device, command_queue);
}); });
return instance; return instance;
} }
void OpenCLRuntime::EnableProfiling() {
if (!enable_profiling_) {
enable_profiling_ = true;
profiling_ev_ = new cl::Event();
}
}
cl::Event* OpenCLRuntime::GetDefaultEvent() {
return profiling_ev_;
}
cl_ulong OpenCLRuntime::GetEventProfilingStartInfo() {
MACE_CHECK(enable_profiling_, "should enable profiling first.");
return profiling_ev_->getProfilingInfo<CL_PROFILING_COMMAND_START>();
}
cl_ulong OpenCLRuntime::GetEventProfilingEndInfo() {
MACE_CHECK(enable_profiling_, "should enable profiling first.");
return profiling_ev_->getProfilingInfo<CL_PROFILING_COMMAND_END>();
}
OpenCLRuntime::OpenCLRuntime(cl::Context context, OpenCLRuntime::OpenCLRuntime(cl::Context context,
cl::Device device, cl::Device device,
cl::CommandQueue command_queue) cl::CommandQueue command_queue)
...@@ -95,7 +119,10 @@ OpenCLRuntime::OpenCLRuntime(cl::Context context, ...@@ -95,7 +119,10 @@ OpenCLRuntime::OpenCLRuntime(cl::Context context,
kernel_path_ = std::string(kernel_path == nullptr ? "" : kernel_path) + "/"; kernel_path_ = std::string(kernel_path == nullptr ? "" : kernel_path) + "/";
} }
OpenCLRuntime::~OpenCLRuntime() {} OpenCLRuntime::~OpenCLRuntime() {
if (profiling_ev_)
delete profiling_ev_;
}
cl::Context &OpenCLRuntime::context() { return context_; } cl::Context &OpenCLRuntime::context() { return context_; }
......
...@@ -18,6 +18,13 @@ class OpenCLRuntime { ...@@ -18,6 +18,13 @@ class OpenCLRuntime {
public: public:
static OpenCLRuntime *Get(); static OpenCLRuntime *Get();
static void EnableProfiling();
cl::Event *GetDefaultEvent();
cl_ulong GetEventProfilingStartInfo();
cl_ulong GetEventProfilingEndInfo();
cl::Context &context(); cl::Context &context();
cl::Device &device(); cl::Device &device();
cl::CommandQueue &command_queue(); cl::CommandQueue &command_queue();
...@@ -41,6 +48,9 @@ class OpenCLRuntime { ...@@ -41,6 +48,9 @@ class OpenCLRuntime {
cl::Program *program); cl::Program *program);
private: private:
static bool enable_profiling_;
static cl::Event* profiling_ev_;
cl::Context context_; cl::Context context_;
cl::Device device_; cl::Device device_;
cl::CommandQueue command_queue_; cl::CommandQueue command_queue_;
......
...@@ -160,6 +160,11 @@ class OpenCLLibraryImpl final { ...@@ -160,6 +160,11 @@ class OpenCLLibraryImpl final {
size_t, size_t,
void *, void *,
size_t *); size_t *);
using clGetEventProfilingInfoFunc = cl_int (*)(cl_event event,
cl_profiling_info param_name,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret);
using clGetImageInfoFunc = cl_int (*)(cl_mem, using clGetImageInfoFunc = cl_int (*)(cl_mem,
cl_image_info, cl_image_info,
size_t, size_t,
...@@ -209,6 +214,7 @@ class OpenCLLibraryImpl final { ...@@ -209,6 +214,7 @@ class OpenCLLibraryImpl final {
DEFINE_FUNC_PTR(clReleaseDevice); DEFINE_FUNC_PTR(clReleaseDevice);
DEFINE_FUNC_PTR(clRetainEvent); DEFINE_FUNC_PTR(clRetainEvent);
DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo); DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo);
DEFINE_FUNC_PTR(clGetEventProfilingInfo);
DEFINE_FUNC_PTR(clGetImageInfo); DEFINE_FUNC_PTR(clGetImageInfo);
#undef DEFINE_FUNC_PTR #undef DEFINE_FUNC_PTR
...@@ -333,6 +339,7 @@ void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) { ...@@ -333,6 +339,7 @@ void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) {
ASSIGN_FROM_DLSYM(clReleaseDevice); ASSIGN_FROM_DLSYM(clReleaseDevice);
ASSIGN_FROM_DLSYM(clRetainEvent); ASSIGN_FROM_DLSYM(clRetainEvent);
ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo); ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo);
ASSIGN_FROM_DLSYM(clGetEventProfilingInfo);
ASSIGN_FROM_DLSYM(clGetImageInfo); ASSIGN_FROM_DLSYM(clGetImageInfo);
#undef ASSIGN_FROM_DLSYM #undef ASSIGN_FROM_DLSYM
...@@ -879,6 +886,20 @@ cl_int clGetKernelWorkGroupInfo(cl_kernel kernel, ...@@ -879,6 +886,20 @@ cl_int clGetKernelWorkGroupInfo(cl_kernel kernel,
} }
} }
cl_int clGetEventProfilingInfo(cl_event event,
cl_profiling_info param_name,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLLibraryImpl::Get().clGetEventProfilingInfo;
if (func != nullptr) {
return func(event, param_name, param_value_size, param_value,
param_value_size_ret);
} else {
return CL_OUT_OF_RESOURCES;
}
}
cl_int clGetImageInfo(cl_mem image, cl_int clGetImageInfo(cl_mem image,
cl_image_info param_name, cl_image_info param_name,
size_t param_value_size, size_t param_value_size,
...@@ -892,3 +913,4 @@ cl_int clGetImageInfo(cl_mem image, ...@@ -892,3 +913,4 @@ cl_int clGetImageInfo(cl_mem image,
return CL_OUT_OF_RESOURCES; return CL_OUT_OF_RESOURCES;
} }
} }
...@@ -31,7 +31,8 @@ static void Add2(const Tensor *input0, const Tensor *input1, Tensor *output) { ...@@ -31,7 +31,8 @@ static void Add2(const Tensor *input0, const Tensor *input1, Tensor *output) {
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
addn_kernel, cl::NullRange, addn_kernel, cl::NullRange,
cl::NDRange(gws), cl::NDRange(gws),
cl::NDRange(lws)); cl::NDRange(lws),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
......
...@@ -62,7 +62,8 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -62,7 +62,8 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange, bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2])); cl::NDRange(params[0], params[1], params[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error; return error;
......
...@@ -59,7 +59,8 @@ void Conv1x1V2(const Tensor *input, ...@@ -59,7 +59,8 @@ void Conv1x1V2(const Tensor *input,
conv_2d_kernel, cl::NullRange, conv_2d_kernel, cl::NullRange,
cl::NDRange(static_cast<int>(batch), static_cast<int>(channel_blocks), cl::NDRange(static_cast<int>(batch), static_cast<int>(channel_blocks),
static_cast<int>(pixel_blocks)), static_cast<int>(pixel_blocks)),
cl::NDRange(1, 2, kwg_size / 2)); cl::NDRange(1, 2, kwg_size / 2),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS, error); MACE_CHECK(error == CL_SUCCESS, error);
} }
...@@ -104,7 +105,8 @@ void Conv1x1V3(const Tensor *input, ...@@ -104,7 +105,8 @@ void Conv1x1V3(const Tensor *input,
conv_2d_kernel, cl::NullRange, conv_2d_kernel, cl::NullRange,
cl::NDRange(static_cast<uint32_t>(channel_blocks), static_cast<uint32_t>(height), cl::NDRange(static_cast<uint32_t>(channel_blocks), static_cast<uint32_t>(height),
static_cast<uint32_t>(height * batch)), static_cast<uint32_t>(height * batch)),
cl::NDRange(4, 15, 8)); cl::NDRange(4, 15, 8),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS, error); MACE_CHECK(error == CL_SUCCESS, error);
} }
......
...@@ -52,7 +52,8 @@ static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter, ...@@ -52,7 +52,8 @@ static void InnerConv2dK3x3S12(const Tensor *input, const Tensor *filter,
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
conv_kernel, cl::NullRange, conv_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2])); cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
......
...@@ -60,7 +60,8 @@ static void InnerDepthwiseConvOpenclK3x3S12(const Tensor *input, ...@@ -60,7 +60,8 @@ static void InnerDepthwiseConvOpenclK3x3S12(const Tensor *input,
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
conv_kernel, cl::NullRange, conv_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2])); cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
......
...@@ -52,7 +52,8 @@ static void Pooling3(const Tensor *input, ...@@ -52,7 +52,8 @@ static void Pooling3(const Tensor *input,
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
pooling_kernel, cl::NullRange, pooling_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2])); cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
...@@ -100,7 +101,8 @@ static void PoolingN(const Tensor *input, ...@@ -100,7 +101,8 @@ static void PoolingN(const Tensor *input,
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
pooling_kernel, cl::NullRange, pooling_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2])); cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
......
...@@ -36,7 +36,8 @@ void ReluFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, ...@@ -36,7 +36,8 @@ void ReluFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
relu_kernel, cl::NullRange, relu_kernel, cl::NullRange,
cl::NDRange(gws), cl::NDRange(gws),
cl::NDRange(lws)); cl::NDRange(lws),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} else { } else {
auto relu_kernel = runtime->BuildKernel("relu", "relux", built_options); auto relu_kernel = runtime->BuildKernel("relu", "relux", built_options);
...@@ -52,7 +53,8 @@ void ReluFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, ...@@ -52,7 +53,8 @@ void ReluFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
relu_kernel, cl::NullRange, relu_kernel, cl::NullRange,
cl::NDRange(gws), cl::NDRange(gws),
cl::NDRange(lws)); cl::NDRange(lws),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
} }
......
...@@ -50,7 +50,8 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -50,7 +50,8 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, float>::operator()(
cl::NDRange(static_cast<int>(batch * channels), cl::NDRange(static_cast<int>(batch * channels),
static_cast<int>(out_height), static_cast<int>(out_width)), static_cast<int>(out_height), static_cast<int>(out_width)),
// TODO (heliangliang) tuning and fix when kwg_size < devisor // TODO (heliangliang) tuning and fix when kwg_size < devisor
cl::NDRange(1, 16, kwg_size / 16)); cl::NDRange(1, 16, kwg_size / 16),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS, error); MACE_CHECK(error == CL_SUCCESS, error);
} }
......
...@@ -45,7 +45,8 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, float>::operator()(Tensor *space_te ...@@ -45,7 +45,8 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, float>::operator()(Tensor *space_te
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
s2b_kernel, cl::NullRange, s2b_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2])); cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// //
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h" #include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
...@@ -12,6 +13,9 @@ static void BatchNorm( ...@@ -12,6 +13,9 @@ static void BatchNorm(
int iters, int batch, int channels, int height, int width) { int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming(); mace::testing::StopTiming();
if ( D == OPENCL )
OpenCLRuntime::EnableProfiling();
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormBM") OpDefBuilder("BatchNorm", "BatchNormBM")
.Input("Input") .Input("Input")
...@@ -77,4 +81,4 @@ BM_BATCH_NORM(1, 512, 14, 14, float); ...@@ -77,4 +81,4 @@ BM_BATCH_NORM(1, 512, 14, 14, float);
BM_BATCH_NORM(1, 1024, 7, 7, float); BM_BATCH_NORM(1, 1024, 7, 7, float);
BM_BATCH_NORM(32, 1, 256, 256, float); BM_BATCH_NORM(32, 1, 256, 256, float);
BM_BATCH_NORM(32, 3, 256, 256, float); BM_BATCH_NORM(32, 3, 256, 256, float);
} // namespace mace } // namespace mace
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// //
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/tools/benchmark/stat_summarizer.h" #include "mace/tools/benchmark/stat_summarizer.h"
#include "mace/utils/command_line_flags.h" #include "mace/utils/command_line_flags.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
...@@ -149,7 +150,7 @@ int Main(int argc, char **argv) { ...@@ -149,7 +150,7 @@ int Main(int argc, char **argv) {
std::vector<Flag> flag_list = { std::vector<Flag> flag_list = {
Flag("model_file", &model_file, "graph file name"), Flag("model_file", &model_file, "graph file name"),
Flag("device", &device, "CPU/NEON"), Flag("device", &device, "CPU/NEON/OPENCL"),
Flag("input_layer", &input_layer_string, "input layer names"), Flag("input_layer", &input_layer_string, "input layer names"),
Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"),
Flag("input_layer_type", &input_layer_type_string, "input layer type"), Flag("input_layer_type", &input_layer_type_string, "input layer type"),
...@@ -259,6 +260,9 @@ int Main(int argc, char **argv) { ...@@ -259,6 +260,9 @@ int Main(int argc, char **argv) {
DeviceType_Parse(device, &device_type); DeviceType_Parse(device, &device_type);
VLOG(0) << device_type; VLOG(0) << device_type;
if (device_type == DeviceType::OPENCL)
OpenCLRuntime::EnableProfiling();
// load model // load model
std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary); std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary);
if (!model_file_stream.is_open()) { if (!model_file_stream.is_open()) {
......
...@@ -131,13 +131,14 @@ class Tuner { ...@@ -131,13 +131,14 @@ class Tuner {
double &time_us) { double &time_us) {
RetType res; RetType res;
int64_t total_time_us = 0; int64_t total_time_us = 0;
const int64_t start_time = NowInMicroSec();
for (int i = 0; i < num_runs; ++i) { for (int i = 0; i < num_runs; ++i) {
res = func(params); res = func(params);
OpenCLRuntime::Get()->command_queue().finish();
double start_time = OpenCLRuntime::Get()->GetEventProfilingStartInfo() / 1000.0;
double end_time = OpenCLRuntime::Get()->GetEventProfilingEndInfo() / 1000.0;
total_time_us += end_time - start_time;
} }
OpenCLRuntime::Get()->command_queue().finish();
const int64_t end_time = NowInMicroSec();
total_time_us += end_time - start_time;
time_us = total_time_us * 1.0 / num_runs; time_us = total_time_us * 1.0 / num_runs;
return res; return res;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册