提交 1bbf62ee 编写于 作者: L liuqi

Add auto-tuning code.

上级 fd61adfd
......@@ -22,6 +22,7 @@ cc_library(
"//mace/core",
"//mace/core:opencl_runtime",
"//mace/utils",
"//mace/utils:tuner",
],
)
......
......@@ -5,6 +5,7 @@
#include "mace/kernels/batch_norm.h"
#include "mace/core/runtime/opencl/cl2.hpp"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
......@@ -29,7 +30,7 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
auto bm_kernel = cl::Kernel(program, "batch_norm");
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
const uint32_t lws[3] = {1, 1, kwg_size};
const std::vector<uint32_t> lws = {1, 1, kwg_size};
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
......@@ -43,12 +44,52 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
//TODO need to design the new way to tune.
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]));
MACE_CHECK(error == CL_SUCCESS);
std::function<std::vector<std::vector<uint32_t>>()> params_generator = nullptr;
std::function<cl_int(const std::vector<uint32_t>& params)> func;
if (Tuning()) {
params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
return {{1, 1, 64},
{1, 1, 128},
{1, kwg_size/32, 32},
{1, kwg_size/64, 64},
{1, kwg_size/128, 128},
{1, 1, kwg_size},
{1, kwg_size, 1}};
};
func = [&](const std::vector<uint32_t>& params)->cl_int {
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]),
nullptr,
&event);
MACE_CHECK(error == CL_SUCCESS);
event.wait();
return error;
};
} else {
func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
bm_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]));
MACE_CHECK(error == CL_SUCCESS);
return error;
};
}
std::stringstream ss;
ss << "batch_norm_opencl_kernel_"
<< input->dim(0) << "_"
<< input->dim(1) << "_"
<< input->dim(2) << "_"
<< input->dim(3);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func);
}
} // namespace kernels
......
......@@ -31,6 +31,11 @@ static void BatchNorm(
net.AddRandomInput<D, T>("Var", {channels}, true);
net.AddInputFromArray<D, float>("Epsilon", {}, {1e-3});
// tuning
setenv("MACE_TUNING", "1", 1);
net.RunOp(D);
unsetenv("MACE_TUNING");
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
......@@ -55,8 +60,6 @@ static void BatchNorm(
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BATCH_NORM(N, C, H, W, TYPE) \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, OPENCL);
BM_BATCH_NORM(1, 1, 512, 512, float);
......
......@@ -165,7 +165,12 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::OPENCL, float>("Epsilon", {}, {1e-3});
// Run NEON
// tuning
setenv("MACE_TUNING", "1", 1);
net.RunOp(DeviceType::OPENCL);
unsetenv("MACE_TUNING");
// Run on opencl
net.RunOp(DeviceType::OPENCL);
// Check
......@@ -206,7 +211,12 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::OPENCL, float>("Epsilon", {}, {1e-3});
// Run NEON
// tuning
setenv("MACE_TUNING", "1", 1);
net.RunOp(DeviceType::OPENCL);
unsetenv("MACE_TUNING");
// Run on opencl
net.RunOp(DeviceType::OPENCL);
net.Sync();
......
......@@ -7,6 +7,8 @@ package(
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android")
cc_library(
name = "command_line_flags",
srcs = [
......@@ -39,3 +41,19 @@ cc_library(
"//mace/core",
],
)
cc_test(
name = "tuner_test",
testonly = 1,
srcs = [
"tuner_test.cc",
],
copts = ["-std=c++11"],
linkopts = if_android(["-lm", "-ldl"]),
linkstatic = 1,
deps = [
":tuner",
"@gtest//:gtest",
"@gtest//:gtest_main",
],
)
......@@ -10,13 +10,19 @@
#include <string>
#include <unordered_map>
#include <fstream>
#include <chrono>
#include <thread>
#include <limits>
#include "mace/core/logging.h"
#include "mace/utils/utils.h"
namespace mace {
bool Tuning() {
const char *tuning = getenv("MACE_TUNING");
return tuning != nullptr && tuning[0] == '1';
}
template<typename param_type>
class Tuner {
public:
......@@ -24,29 +30,32 @@ class Tuner {
static Tuner tuner;
return &tuner;
}
void TuneOrRun(const std::string &param_key,
template <typename RetType>
RetType TuneOrRun(const std::string param_key,
const std::vector<param_type> &default_param,
std::function<std::vector<std::vector<param_type>>()> param_generator,
const std::function<void(const std::vector<param_type> &)> &func) {
const std::function<std::vector<std::vector<param_type>>()> param_generator,
const std::function<RetType(const std::vector<param_type> &)>& func) {
if (param_generator == nullptr) {
// run
if (param_table_.find(param_key) != param_table_.end()) {
func(param_table_[param_key]);
return func(param_table_[param_key]);
} else {
func(default_param);
return func(default_param);
}
} else {
// tune
std::vector<param_type> opt_param = default_param;
Tune(param_generator, func, opt_param);
RetType res = Tune<RetType>(param_generator, func, opt_param);
param_table_[param_key] = opt_param;
return res;
}
}
private:
Tuner() {
path_ = getenv("MACE_RUN_PARAMTER_PATH");
path_ = getenv("MACE_RUN_PARAMETER_PATH");
ReadRunParamters();
}
......@@ -58,19 +67,24 @@ class Tuner {
Tuner& operator=(const Tuner&) = delete;
inline void WriteRunParameters() {
VLOG(0) << path_;
if (path_ != nullptr) {
std::ofstream ofs(path_, std::ios::binary | std::ios::out);
if (ofs.is_open()) {
size_t num_pramas = param_table_.size();
ofs.write(reinterpret_cast<char *>(&num_pramas), sizeof(num_pramas));
for (auto &kp : param_table_) {
int32_t key_size = kp.first.size() + 1;
ofs.write(static_cast<char*>(&key_size), sizeof(key_size));
ofs.write(&kp.first.c_str(), key_size);
int32_t key_size = kp.first.size();
ofs.write(reinterpret_cast<char *>(&key_size), sizeof(key_size));
ofs.write(kp.first.c_str(), key_size);
VLOG(0) << kp.first.c_str();
auto &params = kp.second;
int32_t params_size = params.size() * sizeof(param_type);
ofs.write(static_cast<char*>(&params_size), sizeof(params_size));
ofs.write(reinterpret_cast<char*>(&params_size), sizeof(params_size));
for (auto &param : params) {
ofs.write(&param, sizeof(params_size));
ofs.write(reinterpret_cast<char *>(&param), sizeof(params_size));
VLOG(0) << param;
}
}
ofs.close();
......@@ -87,43 +101,76 @@ class Tuner {
int32_t key_size = 0;
int32_t params_size = 0;
int32_t params_count = 0;
while (!ifs.eof()) {
ifs.read(static_cast<char*>(&key_size), sizeof(key_size));
std::string key(key_size, '');
size_t num_pramas = 0;
ifs.read(reinterpret_cast<char *>(&num_pramas), sizeof(num_pramas));
while (num_pramas--) {
ifs.read(reinterpret_cast<char *>(&key_size), sizeof(key_size));
std::string key(key_size, ' ');
ifs.read(&key[0], key_size);
ifs.read(static_cast<char*>(&params_size), sizeof(params_size));
ifs.read(reinterpret_cast<char *>(&params_size), sizeof(params_size));
params_count = params_size / sizeof(param_type);
std::vector<param_type> params(params_count);
for (int i = 0; i < params_count; ++i) {
ifs.read(&params[i], sizeof(param_type));
ifs.read(reinterpret_cast<char*>(&params[i]), sizeof(param_type));
}
param_table_.emplace(key, params);
}
ifs.close();
} else {
LOG(WARNING) << "Write run parameter file failed.";
LOG(WARNING) << "Read run parameter file failed.";
}
}
}
inline void Tune(std::function<std::vector<std::vector<param_type>>()> param_generator,
const std::function<void(const std::vector<param_type> &)> &func,
template <typename RetType>
inline RetType Run(const std::function<RetType(const std::vector<param_type> &)> &func,
const std::vector<param_type> &params,
int num_runs,
int64_t sleep_millisecond,
double &time_us) {
RetType res;
int64_t total_time_us = 0;
int64_t actual_num_runs = 0;
bool util_max_time = (num_runs <= 0);
for (int i = 0; util_max_time || i < num_runs; ++i) {
const int64_t start_time = NowInMicroSec();
res = func(params);
const int64_t end_time = NowInMicroSec();
total_time_us += end_time - start_time;
++(actual_num_runs);
if (sleep_millisecond > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(sleep_millisecond));
}
}
time_us = total_time_us * 1.0 / actual_num_runs;
return res;
}
template <typename RetType>
inline RetType Tune(std::function<std::vector<std::vector<param_type>>()> param_generator,
const std::function<RetType(const std::vector<param_type> &)> &func,
std::vector<param_type> &opt_params) {
RetType res;
double opt_time = std::numeric_limits<double>::max();
auto params = param_generator();
for (const auto &param: params) {
auto start = std::chrono::high_resolution_clock::now();
func(param);
auto end = std::chrono::high_resolution_clock::now();
auto duration_time = end - start;
double tmp_time = 0.0;
// warm up
Run<RetType>(func, param, 2, 10, tmp_time);
// run
RetType tmp_res = Run<RetType>(func, param, 10, 10, tmp_time);
// Check the execution time
if (duration_time.count() < opt_time) {
opt_time = duration_time.count();
if (tmp_time < opt_time) {
opt_time = tmp_time;
opt_params = param;
res = tmp_res;
}
}
return res;
}
private:
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "gtest/gtest.h"
#include "mace/utils/tuner.h"
namespace mace {
class TunerTest: public ::testing::Test {
protected:
virtual void SetUp() {
remove( "/data/local/tmp/mace.config" );
setenv("MACE_RUN_PARAMTER_PATH", "/data/local/tmp/mace.config", 1);
}
};
TEST_F(TunerTest, SimpleRun) {
int expect = 1;
auto TunerFunc = [&](const std::vector<int>& params)->int {
if (params.front() == 1) {
return expect;
} else {
return expect + 1;
}
};
std::vector<int> default_params(1, 1);
int res = Tuner<int>::Get()->template TuneOrRun<int>("SimpleRun", default_params, nullptr, TunerFunc);
EXPECT_EQ(expect, res);
default_params[0] = 2;
res = Tuner<int>::Get()->template TuneOrRun<int>("SimpleRun", default_params, nullptr, TunerFunc);
EXPECT_EQ(expect+1, res);
}
TEST_F(TunerTest, SimpleTune) {
int expect = 3;
auto TunerFunc = [&](const std::vector<int>& params)->int {
if (params.front() == expect) {
return expect;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
return params.front();
}
};
std::vector<int> default_params(1, 1);
auto params_generator = []()->std::vector<std::vector<int>> {
return {{1}, {2}, {3}, {4}};
};
// tune
int res = Tuner<int>::Get()->template TuneOrRun<int>("SimpleRun", default_params, *params_generator, TunerFunc);
EXPECT_EQ(expect, res);
// run
res = Tuner<int>::Get()->template TuneOrRun<int>("SimpleRun", default_params, nullptr, TunerFunc);
EXPECT_EQ(expect, res);
}
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册