提交 ff3fac69 编写于 作者: 刘琦

Merge branch 'hide-tuning-params' into 'master'

Obfuscate tuning parameters

See merge request !204
......@@ -96,6 +96,7 @@ cc_library(
deps = [
":opencl_headers",
"//mace/codegen:generated_opencl_dev",
"//mace/utils:utils_hdrs",
],
)
......
......@@ -5,24 +5,11 @@
#include <vector>
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/utils/utils.h"
namespace mace {
namespace {
inline void DecryptOpenCLSource(const std::vector<unsigned char> &src,
std::vector<unsigned char> *dst) {
dst->reserve(src.size());
// Keep consistent with encrypt in python tool
const std::string decrypt_lookup_table = "Xiaomi-AI-Platform-Mace";
size_t lookup_table_size = decrypt_lookup_table.size();
for (int i = 0; i < src.size(); i++) {
dst->push_back(src[i] ^ decrypt_lookup_table[i % lookup_table_size]);
}
}
} // namespace
bool GetSourceOrBinaryProgram(const std::string &program_name,
const std::string &binary_file_name_prefix,
cl::Context &context,
......@@ -36,9 +23,8 @@ bool GetSourceOrBinaryProgram(const std::string &program_name,
return false;
}
cl::Program::Sources sources;
std::vector<unsigned char> decrypt_source;
DecryptOpenCLSource(it_source->second, &decrypt_source);
sources.push_back(std::string(decrypt_source.begin(), decrypt_source.end()));
std::string kernel_source(it_source->second.begin(), it_source->second.end());
sources.push_back(ObfuscateString(kernel_source));
*program = cl::Program(context, sources);
return true;
......
......@@ -124,6 +124,9 @@ cl::CommandQueue &OpenCLRuntime::command_queue() { return *command_queue_; }
std::string OpenCLRuntime::GenerateCLBinaryFilenamePrefix(
const std::string &filename_msg) {
#ifdef MACE_OBFUSCATE_LITERALS
return ObfuscateSymbolWithCollision(filename_msg);
#else
std::string filename_prefix = filename_msg;
for (auto it = filename_prefix.begin(); it != filename_prefix.end(); ++it) {
if (*it == ' ' || *it == '-' || *it == '=') {
......@@ -131,6 +134,7 @@ std::string OpenCLRuntime::GenerateCLBinaryFilenamePrefix(
}
}
return filename_prefix;
#endif
}
extern bool GetSourceOrBinaryProgram(const std::string &program_name,
......@@ -219,7 +223,7 @@ cl::Kernel OpenCLRuntime::BuildKernel(
program = built_program_it->second;
} else {
std::string binary_file_name_prefix =
GenerateCLBinaryFilenamePrefix(built_program_key);
GenerateCLBinaryFilenamePrefix(built_program_key);
this->BuildProgram(program_name, binary_file_name_prefix,
build_options_str, &program);
built_program_map_.emplace(built_program_key, program);
......
......@@ -31,10 +31,12 @@ static void AddN(const std::vector<const Tensor *> &input_tensors,
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_KERNRL_NAME("addn");
built_options.emplace("-Daddn=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace("-DINPUT_NUM=" + ToString(input_tensors.size()));
auto addn_kernel = runtime->BuildKernel("addn", "addn", built_options);
auto addn_kernel = runtime->BuildKernel("addn", kernel_name, built_options);
uint32_t idx = 0;
for (auto input : input_tensors) {
......
......@@ -34,6 +34,8 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_KERNRL_NAME("batch_norm");
built_options.emplace("-Dbatch_norm=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (folded_constant_) {
......@@ -42,7 +44,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
if (fused_relu_) {
built_options.emplace("-DFUSED_RELU");
}
auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options);
auto bm_kernel = runtime->BuildKernel("batch_norm", kernel_name, built_options);
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
......@@ -31,9 +31,11 @@ void BiasAddFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_KERNRL_NAME("bias_add");
built_options.emplace("-Dbias_add=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto bias_kernel = runtime->BuildKernel("bias_add", "bias_add", built_options);
auto bias_kernel = runtime->BuildKernel("bias_add", kernel_name, built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bias_kernel);
const std::vector<uint32_t> lws = {8, 16, 8};
......
......@@ -25,15 +25,6 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer,
buffer->Resize(image->shape());
}
std::set<std::string> built_options;
if (buffer->dtype() == image->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum<T>::value));
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
}
auto runtime = OpenCLRuntime::Global();
string kernel_name;
switch (type) {
case FILTER:
......@@ -46,8 +37,21 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer,
kernel_name = i2b_ ? "arg_image_to_buffer" : "arg_buffer_to_image";
break;
}
string obfuscated_kernel_name = MACE_KERNRL_NAME(kernel_name);
std::set<std::string> built_options;
std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
if (buffer->dtype() == image->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum<T>::value));
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
}
auto runtime = OpenCLRuntime::Global();
auto b2f_kernel = runtime->BuildKernel("buffer_to_image",
kernel_name,
obfuscated_kernel_name,
built_options);
uint32_t idx = 0;
......
......@@ -25,6 +25,8 @@ static void Concat2(const Tensor *input0,
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("concat_channel");
built_options.emplace("-Dconcat_channel=" + kernel_name);
if (input0->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
......@@ -35,7 +37,7 @@ static void Concat2(const Tensor *input0,
if (input0->dim(3) % 4 == 0) {
built_options.emplace("-DDIVISIBLE_FOUR");
}
auto concat_kernel = runtime->BuildKernel("concat", "concat_channel", built_options);
auto concat_kernel = runtime->BuildKernel("concat", kernel_name, built_options);
uint32_t idx = 0;
concat_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input0->buffer())));
......
......@@ -36,6 +36,8 @@ void Conv1x1(const Tensor *input,
MACE_CHECK(input_batch == batch);
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("conv_2d_1x1");
built_options.emplace("-Dconv_2d_1x1=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace("-DSTRIDE=" + ToString(stride));
......@@ -47,7 +49,7 @@ void Conv1x1(const Tensor *input,
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d_1x1", "conv_2d_1x1", built_options);
auto conv_2d_kernel = runtime->BuildKernel("conv_2d_1x1", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
......@@ -28,6 +28,8 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
const index_t width_blocks = RoundUpDiv<index_t, 5>(width);
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("conv_2d_3x3");
built_options.emplace("-Dconv_2d_3x3=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
......@@ -37,7 +39,7 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d_3x3", "conv_2d_3x3", built_options);
auto conv_2d_kernel = runtime->BuildKernel("conv_2d_3x3", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
......@@ -28,6 +28,8 @@ void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const index_t width_blocks = RoundUpDiv4(width);
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("conv_2d");
built_options.emplace("-Dconv_2d=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
......@@ -37,7 +39,7 @@ void Conv2dOpencl(const Tensor *input, const Tensor *filter,
}
auto runtime = OpenCLRuntime::Global();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d", "conv_2d", built_options);
auto conv_2d_kernel = runtime->BuildKernel("conv_2d", kernel_name, built_options);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
......@@ -33,10 +33,12 @@ static void InnerDepthwiseConvOpenclK3x3S12(const Tensor *input,
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("depthwise_conv_3x3");
built_options.emplace("-Ddepthwise_conv_3x3=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype()));
built_options.emplace(stride == 1 ? "-DSTRIDE_1" : "");
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
auto conv_kernel = runtime->BuildKernel("depthwise_conv_3x3", "depthwise_conv_3x3", built_options);
auto conv_kernel = runtime->BuildKernel("depthwise_conv_3x3", kernel_name, built_options);
uint32_t idx = 0;
conv_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
......
......@@ -5,6 +5,7 @@
#ifndef MACE_KERNELS_OPENCL_HELPER_H_
#define MACE_KERNELS_OPENCL_HELPER_H_
#include "mace/core/types.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
......
......@@ -28,6 +28,8 @@ static void Pooling(const Tensor *input,
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("pooling");
built_options.emplace("-Dpooling=" + kernel_name);
if (type == MAX && input->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
......@@ -39,7 +41,7 @@ static void Pooling(const Tensor *input,
if (type == AVG) {
built_options.emplace("-DPOOL_AVG");
}
auto pooling_kernel = runtime->BuildKernel("pooling", "pooling", built_options);
auto pooling_kernel = runtime->BuildKernel("pooling", kernel_name, built_options);
uint32_t idx = 0;
pooling_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
......@@ -32,13 +32,17 @@ void ReluFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
cl::Kernel relu_kernel;
if (max_limit_ < 0) {
relu_kernel = runtime->BuildKernel("relu", "relu", built_options);
std::string kernel_name = MACE_KERNRL_NAME("relu");
built_options.emplace("-Drelu=" + kernel_name);
relu_kernel = runtime->BuildKernel("relu", kernel_name, built_options);
uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
relu_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
} else {
relu_kernel = runtime->BuildKernel("relu", "relux", built_options);
std::string kernel_name = MACE_KERNRL_NAME("relux");
built_options.emplace("-Drelux=" + kernel_name);
relu_kernel = runtime->BuildKernel("relu", kernel_name, built_options);
uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
......@@ -40,10 +40,12 @@ void ResizeBilinearFunctor<DeviceType::OPENCL, T>::operator()(
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("resize_bilinear_nocache");
built_options.emplace("-Dresize_bilinear_nocache=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto rb_kernel = runtime->BuildKernel("resize_bilinear", "resize_bilinear_nocache", built_options);
auto rb_kernel = runtime->BuildKernel("resize_bilinear", kernel_name, built_options);
uint32_t idx = 0;
rb_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
......@@ -26,10 +26,12 @@ void SoftmaxFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *logits,
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_KERNRL_NAME("softmax");
built_options.emplace("-Dsoftmax=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
cl::Kernel softmax_kernel = runtime->BuildKernel("softmax", "softmax", built_options);
cl::Kernel softmax_kernel = runtime->BuildKernel("softmax", kernel_name, built_options);
uint32_t idx = 0;
softmax_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(logits->buffer())));
......
......@@ -30,8 +30,12 @@ void SpaceToBatchFunctor<DeviceType::OPENCL, T>::operator()(Tensor *space_tensor
batch_tensor->ResizeImage(output_shape, output_image_shape);
kernel_name = "space_to_batch";
}
std::string obfuscated_kernel_name = MACE_KERNRL_NAME(kernel_name);
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DataTypeToEnum<T>::value));
auto s2b_kernel = runtime->BuildKernel("space_to_batch", kernel_name, built_options);
......
......@@ -27,6 +27,7 @@ def generate_cpp_source():
key_size, = struct.unpack("i", binary_array[idx:idx+4])
idx += 4
key, = struct.unpack(str(key_size) + "s", binary_array[idx:idx+key_size])
key = ''.join([ "\\%03o" % ord(c) if not c.isalnum() else c for c in key])
idx += key_size
params_size, = struct.unpack("i", binary_array[idx:idx+4])
idx += 4
......
......@@ -82,3 +82,21 @@ cc_library(
":utils_hdrs",
],
)
cc_test(
name = "utils_test",
testonly = 1,
srcs = [
"utils_test.cc",
],
linkopts = if_android([
"-pie",
"-lm",
]),
linkstatic = 1,
deps = [
":utils_hdrs",
"@gtest//:gtest",
"@gtest//:gtest_main",
],
)
......@@ -15,6 +15,10 @@
#include "mace/utils/logging.h"
#include "mace/utils/timer.h"
#include "mace/utils/utils.h"
namespace {
} // namespace
namespace mace {
......@@ -42,22 +46,23 @@ class Tuner {
&param_generator,
const std::function<RetType(const std::vector<param_type> &)> &func,
Timer *timer) {
std::string obfucated_param_key = MACE_OBFUSCATE_SYMBOLS(param_key);
if (IsTuning() && param_generator != nullptr) {
// tune
std::vector<param_type> opt_param = default_param;
RetType res = Tune<RetType>(param_generator, func, timer, &opt_param);
VLOG(1) << "Tuning result. "
<< param_key << ": " << internal::MakeString(opt_param);
param_table_[param_key] = opt_param;
param_table_[obfucated_param_key] = opt_param;
return res;
} else {
// run
if (param_table_.find(param_key) != param_table_.end()) {
if (param_table_.find(obfucated_param_key) != param_table_.end()) {
VLOG(1) << param_key << ": "
<< internal::MakeString(param_table_[param_key]);
return func(param_table_[param_key]);
<< internal::MakeString(param_table_[obfucated_param_key]);
return func(param_table_[obfucated_param_key]);
} else {
LOG(WARNING) << "Fallback to default parameter: " << param_key;
LOG(WARNING) << "Fallback to default parameter: " << param_key << ", table size: " << param_table_.size();
return func(default_param);
}
}
......@@ -85,7 +90,8 @@ class Tuner {
int32_t key_size = kp.first.size();
ofs.write(reinterpret_cast<char *>(&key_size), sizeof(key_size));
ofs.write(kp.first.c_str(), key_size);
VLOG(1) << kp.first.c_str();
VLOG(1) << "Write tuning param: "
<< MACE_OBFUSCATE_SYMBOLS(kp.first.c_str());
auto &params = kp.second;
int32_t params_size = params.size() * sizeof(param_type);
......
......@@ -48,5 +48,44 @@ inline std::string ToString(T v) {
return ss.str();
}
inline std::string ObfuscateString(const std::string &src,
const std::string &lookup_table) {
std::string dest;
dest.resize(src.size());
for (int i = 0; i < src.size(); i++) {
dest[i] = src[i] ^ lookup_table[i % lookup_table.size()];
}
return std::move(dest);
}
// ObfuscateString(ObfuscateString(str)) ==> str
inline std::string ObfuscateString(const std::string &src) {
// Keep consistent with obfuscation in python tools
return ObfuscateString(src, "Xiaomi-AI-Platform-Mace");
}
// Obfuscate synbol or path string
inline std::string ObfuscateSymbolWithCollision(const std::string &src) {
std::string dest = ObfuscateString(src);
const std::string encode_dict =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
for (int i = 0; i < src.size(); i++) {
dest[i] = encode_dict[dest[i] % encode_dict.size()];
}
return std::move(dest);
}
#ifdef MACE_OBFUSCATE_LITERALS
#define MACE_OBFUSCATE_STRING(str) ObfuscateString(str)
// This table is delibratedly selected to avoid '\0' in genereated literal
#define MACE_OBFUSCATE_SYMBOLS(str) ObfuscateString(str, "!@#$%^&*()+?")
// OpenCL will report error if there is name collision
#define MACE_KERNRL_NAME(name) ObfuscateSymbolWithCollision(name)
#else
#define MACE_OBFUSCATE_STRING(str) (str)
#define MACE_OBFUSCATE_SYMBOLS(str) (str)
#define MACE_KERNRL_NAME(name) (name)
#endif
} // namespace mace
#endif // MACE_UTILS_UTILS_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <thread>
#include "gtest/gtest.h"
#include "mace/utils/tuner.h"
namespace mace {
class TunerTest : public ::testing::Test {
protected:
virtual void SetUp() {
remove("/data/local/tmp/mace.config");
setenv("MACE_RUN_PARAMETER_PATH", "/data/local/tmp/mace.config", 1);
setenv("MACE_TUNING", "1", 1);
}
};
TEST_F(TunerTest, SimpleRun) {
int expect = 1;
auto TunerFunc = [&](const std::vector<unsigned int> &params) -> int {
if (params.front() == 1) {
return expect;
} else {
return expect + 1;
}
};
WallClockTimer timer;
std::vector<unsigned int> default_params(1, 1);
int res = Tuner<unsigned int>::Get()->template TuneOrRun<unsigned int>("SimpleRun",
default_params,
nullptr,
TunerFunc,
&timer);
EXPECT_EQ(expect, res);
default_params[0] = 2;
res = Tuner<unsigned int>::Get()->template TuneOrRun<unsigned int>("SimpleRun",
default_params,
nullptr,
TunerFunc,
&timer);
EXPECT_EQ(expect + 1, res);
}
TEST_F(TunerTest, SimpleTune) {
int expect = 3;
auto TunerFunc = [&](const std::vector<unsigned int> &params) -> int {
if (params.front() == expect) {
return expect;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
return params.front();
}
};
std::vector<unsigned int> default_params(1, 1);
auto params_generator = []() -> std::vector<std::vector<unsigned int>> {
return {{1}, {2}, {3}, {4}};
};
// tune
WallClockTimer timer;
int res =
Tuner<unsigned int>::Get()->template TuneOrRun<unsigned int>("SimpleRun",
default_params,
*params_generator,
TunerFunc,
&timer);
EXPECT_EQ(expect, res);
// run
res = Tuner<unsigned int>::Get()->template TuneOrRun<unsigned int>("SimpleRun",
default_params,
nullptr,
TunerFunc,
&timer);
EXPECT_EQ(expect, res);
}
} // namespace mace
......@@ -45,7 +45,7 @@ build_and_run()
round=0 # only warm up
else
tuning_flag=0
round=100
round=2
fi
bazel build --verbose_failures -c opt --strip always mace/examples:mace_run \
......@@ -56,6 +56,7 @@ build_and_run()
--copt="-D_GLIBCXX_USE_C99_MATH_TR1" \
--copt="-Werror=return-type" \
--copt="-DMACE_MODEL_TAG=${MODEL_TAG}" \
--copt="-DMACE_OBFUSCATE_LITERALS" \
$PRODUCTION_MODE_BUILD_FLAGS \
$TUNING_MODE_BUILD_FLAGS || exit -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册