提交 ed99a2d1 编写于 作者: L liuqi

Change the time statistic strategy for tunner.

上级 1bbf62ee
...@@ -44,33 +44,17 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -44,33 +44,17 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
std::function<std::vector<std::vector<uint32_t>>()> params_generator = nullptr; auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
std::function<cl_int(const std::vector<uint32_t>& params)> func;
if (Tuning()) {
params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
return {{1, 1, 64}, return {{1, 1, 64},
{1, 1, 128}, {1, 1, 128},
{1, kwg_size/16, 16},
{1, kwg_size/32, 32}, {1, kwg_size/32, 32},
{1, kwg_size/64, 64}, {1, kwg_size/64, 64},
{1, kwg_size/128, 128}, {1, kwg_size/128, 128},
{1, 1, kwg_size}, {1, 1, kwg_size},
{1, kwg_size, 1}}; {1, kwg_size, 1}};
}; };
func = [&](const std::vector<uint32_t>& params)->cl_int { auto func = [&](const std::vector<uint32_t>& params)->cl_int {
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr,
&event);
MACE_CHECK(error == CL_SUCCESS);
event.wait();
return error;
};
} else {
func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel( cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange, bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(gws[0], gws[1], gws[2]),
...@@ -79,7 +63,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -79,7 +63,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
return error; return error;
}; };
}
std::stringstream ss; std::stringstream ss;
ss << "batch_norm_opencl_kernel_" ss << "batch_norm_opencl_kernel_"
<< input->dim(0) << "_" << input->dim(0) << "_"
......
...@@ -10,18 +10,14 @@ ...@@ -10,18 +10,14 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <fstream> #include <fstream>
#include <thread>
#include <limits> #include <limits>
#include "mace/core/logging.h" #include "mace/core/logging.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
bool Tuning() {
const char *tuning = getenv("MACE_TUNING");
return tuning != nullptr && tuning[0] == '1';
}
template<typename param_type> template<typename param_type>
class Tuner { class Tuner {
...@@ -34,22 +30,22 @@ class Tuner { ...@@ -34,22 +30,22 @@ class Tuner {
template <typename RetType> template <typename RetType>
RetType TuneOrRun(const std::string param_key, RetType TuneOrRun(const std::string param_key,
const std::vector<param_type> &default_param, const std::vector<param_type> &default_param,
const std::function<std::vector<std::vector<param_type>>()> param_generator, const std::function<std::vector<std::vector<param_type>>()> &param_generator,
const std::function<RetType(const std::vector<param_type> &)>& func) { const std::function<RetType(const std::vector<param_type> &)> &func) {
if (param_generator == nullptr) { if (IsTuning()) {
// tune
std::vector<param_type> opt_param = default_param;
RetType res = Tune<RetType>(param_generator, func, opt_param);
param_table_[param_key] = opt_param;
return res;
} else {
// run // run
if (param_table_.find(param_key) != param_table_.end()) { if (param_table_.find(param_key) != param_table_.end()) {
return func(param_table_[param_key]); return func(param_table_[param_key]);
} else { } else {
return func(default_param); return func(default_param);
} }
} else {
// tune
std::vector<param_type> opt_param = default_param;
RetType res = Tune<RetType>(param_generator, func, opt_param);
param_table_[param_key] = opt_param;
return res;
} }
} }
...@@ -66,6 +62,11 @@ class Tuner { ...@@ -66,6 +62,11 @@ class Tuner {
Tuner(const Tuner&) = delete; Tuner(const Tuner&) = delete;
Tuner& operator=(const Tuner&) = delete; Tuner& operator=(const Tuner&) = delete;
inline bool IsTuning() {
const char *tuning = getenv("MACE_TUNING");
return tuning != nullptr && strlen(tuning) == 1 && tuning[0] == '1';
}
inline void WriteRunParameters() { inline void WriteRunParameters() {
VLOG(0) << path_; VLOG(0) << path_;
if (path_ != nullptr) { if (path_ != nullptr) {
...@@ -127,24 +128,18 @@ class Tuner { ...@@ -127,24 +128,18 @@ class Tuner {
inline RetType Run(const std::function<RetType(const std::vector<param_type> &)> &func, inline RetType Run(const std::function<RetType(const std::vector<param_type> &)> &func,
const std::vector<param_type> &params, const std::vector<param_type> &params,
int num_runs, int num_runs,
int64_t sleep_millisecond,
double &time_us) { double &time_us) {
RetType res; RetType res;
int64_t total_time_us = 0; int64_t total_time_us = 0;
int64_t actual_num_runs = 0;
bool util_max_time = (num_runs <= 0);
for (int i = 0; util_max_time || i < num_runs; ++i) {
const int64_t start_time = NowInMicroSec(); const int64_t start_time = NowInMicroSec();
for (int i = 0; i < num_runs; ++i) {
res = func(params); res = func(params);
}
OpenCLRuntime::Get()->command_queue().finish();
const int64_t end_time = NowInMicroSec(); const int64_t end_time = NowInMicroSec();
total_time_us += end_time - start_time; total_time_us += end_time - start_time;
++(actual_num_runs);
if (sleep_millisecond > 0) { time_us = total_time_us * 1.0 / num_runs;
std::this_thread::sleep_for(std::chrono::milliseconds(sleep_millisecond));
}
}
time_us = total_time_us * 1.0 / actual_num_runs;
return res; return res;
} }
...@@ -158,10 +153,10 @@ class Tuner { ...@@ -158,10 +153,10 @@ class Tuner {
for (const auto &param: params) { for (const auto &param: params) {
double tmp_time = 0.0; double tmp_time = 0.0;
// warm up // warm up
Run<RetType>(func, param, 2, 10, tmp_time); Run<RetType>(func, param, 2, tmp_time);
// run // run
RetType tmp_res = Run<RetType>(func, param, 10, 10, tmp_time); RetType tmp_res = Run<RetType>(func, param, 10, tmp_time);
// Check the execution time // Check the execution time
if (tmp_time < opt_time) { if (tmp_time < opt_time) {
......
// //
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include <thread>
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册