提交 4dd645b8 编写于 作者: L liuqi

Vectorization batch norm opencl kernel.

上级 aeb4c35e
...@@ -22,12 +22,12 @@ void kernel batch_norm(global const float *input, ...@@ -22,12 +22,12 @@ void kernel batch_norm(global const float *input,
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
const int sample_offset = (batch * channels + channel) * pixels + pixel_offset*4; const int image_offset = (batch * channels + channel) * pixels + pixel_offset*4;
const float *input_ptr = input + sample_offset; const float *input_ptr = input + image_offset;
float *output_ptr = output + sample_offset; float *output_ptr = output + image_offset;
const int end = (batch * channels + channel + 1) * pixels; const int end = (batch * channels + channel + 1) * pixels;
if ((sample_offset+4) > end) { if ((image_offset+4) > end) {
for (int i = sample_offset; i < end; ++i) { for (int i = image_offset; i < end; ++i) {
*output_ptr = new_scale[local_channel].x * *input_ptr + new_offset[local_channel].x; *output_ptr = new_scale[local_channel].x * *input_ptr + new_offset[local_channel].x;
++input_ptr; ++input_ptr;
++output_ptr; ++output_ptr;
......
...@@ -39,6 +39,7 @@ cc_library( ...@@ -39,6 +39,7 @@ cc_library(
copts = ["-std=c++11"], copts = ["-std=c++11"],
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/core:opencl_runtime",
], ],
) )
......
...@@ -33,7 +33,7 @@ class Tuner { ...@@ -33,7 +33,7 @@ class Tuner {
const std::function<std::vector<std::vector<param_type>>()> &param_generator, const std::function<std::vector<std::vector<param_type>>()> &param_generator,
const std::function<RetType(const std::vector<param_type> &)> &func) { const std::function<RetType(const std::vector<param_type> &)> &func) {
if (IsTuning()) { if (IsTuning() && param_generator != nullptr) {
// tune // tune
std::vector<param_type> opt_param = default_param; std::vector<param_type> opt_param = default_param;
RetType res = Tune<RetType>(param_generator, func, opt_param); RetType res = Tune<RetType>(param_generator, func, opt_param);
...@@ -68,7 +68,7 @@ class Tuner { ...@@ -68,7 +68,7 @@ class Tuner {
} }
inline void WriteRunParameters() { inline void WriteRunParameters() {
VLOG(0) << path_; VLOG(1) << path_;
if (path_ != nullptr) { if (path_ != nullptr) {
std::ofstream ofs(path_, std::ios::binary | std::ios::out); std::ofstream ofs(path_, std::ios::binary | std::ios::out);
if (ofs.is_open()) { if (ofs.is_open()) {
...@@ -78,14 +78,14 @@ class Tuner { ...@@ -78,14 +78,14 @@ class Tuner {
int32_t key_size = kp.first.size(); int32_t key_size = kp.first.size();
ofs.write(reinterpret_cast<char *>(&key_size), sizeof(key_size)); ofs.write(reinterpret_cast<char *>(&key_size), sizeof(key_size));
ofs.write(kp.first.c_str(), key_size); ofs.write(kp.first.c_str(), key_size);
VLOG(0) << kp.first.c_str(); VLOG(1) << kp.first.c_str();
auto &params = kp.second; auto &params = kp.second;
int32_t params_size = params.size() * sizeof(param_type); int32_t params_size = params.size() * sizeof(param_type);
ofs.write(reinterpret_cast<char*>(&params_size), sizeof(params_size)); ofs.write(reinterpret_cast<char*>(&params_size), sizeof(params_size));
for (auto &param : params) { for (auto &param : params) {
ofs.write(reinterpret_cast<char *>(&param), sizeof(params_size)); ofs.write(reinterpret_cast<char *>(&param), sizeof(params_size));
VLOG(0) << param; VLOG(1) << param;
} }
} }
ofs.close(); ofs.close();
...@@ -144,7 +144,7 @@ class Tuner { ...@@ -144,7 +144,7 @@ class Tuner {
} }
template <typename RetType> template <typename RetType>
inline RetType Tune(std::function<std::vector<std::vector<param_type>>()> param_generator, inline RetType Tune(const std::function<std::vector<std::vector<param_type>>()> &param_generator,
const std::function<RetType(const std::vector<param_type> &)> &func, const std::function<RetType(const std::vector<param_type> &)> &func,
std::vector<param_type> &opt_params) { std::vector<param_type> &opt_params) {
RetType res; RetType res;
......
...@@ -13,7 +13,8 @@ class TunerTest: public ::testing::Test { ...@@ -13,7 +13,8 @@ class TunerTest: public ::testing::Test {
protected: protected:
virtual void SetUp() { virtual void SetUp() {
remove( "/data/local/tmp/mace.config" ); remove( "/data/local/tmp/mace.config" );
setenv("MACE_RUN_PARAMTER_PATH", "/data/local/tmp/mace.config", 1); setenv("MACE_RUN_PARAMETER_PATH", "/data/local/tmp/mace.config", 1);
setenv("MACE_TUNING", "1", 1);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册