diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 4d0771f15819e71496100e441799801f35b191f2..fbb0d7a4043a4467d4ce7771db7ae4e4cb9d3e51 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -29,7 +29,7 @@ void BatchNormFunctor::operator()( auto bm_kernel = cl::Kernel(program, "batch_norm"); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); - const uint32_t lws[3] = {1, kwg_size/128, 128}; + const uint32_t lws[3] = {1, 1, kwg_size}; uint32_t idx = 0; bm_kernel.setArg(idx++, *(static_cast(input->buffer()))); @@ -43,6 +43,7 @@ void BatchNormFunctor::operator()( bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); + //TODO need to design the new way to tune. cl_int error = runtime->command_queue().enqueueNDRangeKernel( bm_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 678f855fc7b98b5c7ec66f13e07e75ab121e067e..f8bb94c2d5fee156774222fff5d6a04e4cd4ca6b 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -137,13 +137,10 @@ class OpsTestNet { Workspace *ws() { return &ws_; } bool RunOp(DeviceType device) { - if (!net_ || device_ != device) { - NetDef net_def; - net_def.add_op()->CopyFrom(op_def_); - VLOG(3) << net_def.DebugString(); - net_ = CreateNet(net_def, &ws_, device); - device_ = device; - } + NetDef net_def; + net_def.add_op()->CopyFrom(op_def_); + VLOG(3) << net_def.DebugString(); + net_ = CreateNet(net_def, &ws_, device); return net_->Run(); } @@ -163,7 +160,6 @@ class OpsTestNet { Workspace ws_; OperatorDef op_def_; std::unique_ptr net_; - DeviceType device_; }; class OpsTestBase : public ::testing::Test { diff --git a/mace/utils/BUILD b/mace/utils/BUILD index 06e2ccc490aef5f1a75920bd7cb0afeb9172f64c..e193887aaba89278242721644b341f02290c4b2b 100644 --- a/mace/utils/BUILD +++ b/mace/utils/BUILD @@ -28,3 +28,14 @@ cc_library( ], copts = ["-std=c++11"], ) + +cc_library( + name = "tuner", + hdrs = [ + "tuner.h", + ], + copts = ["-std=c++11"], + deps = [ + "//mace/core", + ], +) diff --git a/mace/utils/tuner.h b/mace/utils/tuner.h new file mode 100644 index 0000000000000000000000000000000000000000..4ad87ccc4aa8ac859962f45944c90e42a010d20a --- /dev/null +++ b/mace/utils/tuner.h @@ -0,0 +1,135 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_UTILS_TUNER_H_ +#define MACE_UTILS_TUNER_H_ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mace/core/logging.h" + +namespace mace { + +template +class Tuner { + public: + static Tuner* Get() { + static Tuner tuner; + return &tuner; + } + void TuneOrRun(const std::string ¶m_key, + const std::vector &default_param, + std::function>()> param_generator, + const std::function &)> &func) { + + if (param_generator == nullptr) { + // run + if (param_table_.find(param_key) != param_table_.end()) { + func(param_table_[param_key]); + } else { + func(default_param); + } + } else { + // tune + std::vector opt_param = default_param; + Tune(param_generator, func, opt_param); + param_table_[param_key] = opt_param; + } + } + + private: + Tuner() { + path_ = getenv("MACE_RUN_PARAMTER_PATH"); + ReadRunParamters(); + } + + ~Tuner() { + WriteRunParameters(); + } + + Tuner(const Tuner&) = delete; + Tuner& operator=(const Tuner&) = delete; + + inline void WriteRunParameters() { + if (path_ != nullptr) { + std::ofstream ofs(path_, std::ios::binary | std::ios::out); + if (ofs.is_open()) { + for (auto &kp : param_table_) { + int32_t key_size = kp.first.size() + 1; + ofs.write(static_cast(&key_size), sizeof(key_size)); + ofs.write(&kp.first.c_str(), key_size); + + auto ¶ms = kp.second; + int32_t params_size = params.size() * sizeof(param_type); + ofs.write(static_cast(¶ms_size), sizeof(params_size)); + for (auto ¶m : params) { + ofs.write(¶m, sizeof(params_size)); + } + } + ofs.close(); + } else { + LOG(WARNING) << "Write run parameter file failed."; + } + } + } + + inline void ReadRunParamters() { + if (path_ != nullptr) { + std::ifstream ifs(path_, std::ios::binary | std::ios::in); + if (ifs.is_open()) { + int32_t key_size = 0; + int32_t params_size = 0; + int32_t params_count = 0; + while (!ifs.eof()) { + ifs.read(static_cast(&key_size), sizeof(key_size)); + std::string key(key_size, ''); + ifs.read(&key[0], key_size); + + ifs.read(static_cast(¶ms_size), sizeof(params_size)); + params_count = params_size / sizeof(param_type); + std::vector params(params_count); + for (int i = 0; i < params_count; ++i) { + ifs.read(¶ms[i], sizeof(param_type)); + } + param_table_.emplace(key, params); + } + ifs.close(); + } else { + LOG(WARNING) << "Write run parameter file failed."; + } + } + } + + inline void Tune(std::function>()> param_generator, + const std::function &)> &func, + std::vector &opt_params) { + double opt_time = std::numeric_limits::max(); + auto params = param_generator(); + for (const auto ¶m: params) { + auto start = std::chrono::high_resolution_clock::now(); + func(param); + auto end = std::chrono::high_resolution_clock::now(); + auto duration_time = end - start; + + // Check the execution time + if (duration_time.count() < opt_time) { + opt_time = duration_time.count(); + opt_params = param; + } + } + } + + private: + const char* path_; + std::unordered_map> param_table_; +}; + +} // namespace mace +#endif // MACE_UTILS_TUNER_H_