提交 c93b928a 编写于 作者: L Liangliang He

Merge branch 'batch_norm_opt' into 'master'

Vectorization batch norm opencl kernel.

See merge request !83
...@@ -20,9 +20,12 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -20,9 +20,12 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *epsilon, const Tensor *epsilon,
Tensor *output) { Tensor *output) {
index_t pixel_size = input->dim(2) * input->dim(3);
index_t blocks = (pixel_size + 3) / 4;
const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)), const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)),
static_cast<uint32_t>(input->dim(1)), static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(input->dim(2) * input->dim(3))}; static_cast<uint32_t>(blocks)};
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
...@@ -39,10 +42,10 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -39,10 +42,10 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(mean->buffer()))); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(var->buffer()))); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(var->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(epsilon->buffer()))); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(epsilon->buffer())));
bm_kernel.setArg(idx++, gws[2]); bm_kernel.setArg(idx++, static_cast<uint32_t>(pixel_size));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer()))); bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr);
auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> { auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
return {{1, 1, 64}, return {{1, 1, 64},
......
...@@ -6,8 +6,8 @@ void kernel batch_norm(global const float *input, ...@@ -6,8 +6,8 @@ void kernel batch_norm(global const float *input,
global const float *epsilon, global const float *epsilon,
private const uint pixels, private const uint pixels,
global float *output, global float *output,
__local float *new_scale, __local float4 *new_scale,
__local float *new_offset) { __local float4 *new_offset) {
const int batch = get_global_id(0); const int batch = get_global_id(0);
const int channel = get_global_id(1); const int channel = get_global_id(1);
const int channels = get_global_size(1); const int channels = get_global_size(1);
...@@ -16,15 +16,26 @@ void kernel batch_norm(global const float *input, ...@@ -16,15 +16,26 @@ void kernel batch_norm(global const float *input,
const int local_pixel_idx = get_local_id(2); const int local_pixel_idx = get_local_id(2);
if(local_pixel_idx == 0) { if(local_pixel_idx == 0) {
new_scale[local_channel] = scale[channel] * rsqrt(var[channel] + *epsilon); new_scale[local_channel] = (float4)(scale[channel] * rsqrt(var[channel] + *epsilon));
new_offset[local_channel] = offset[channel] - mean[channel] * new_scale[local_channel]; new_offset[local_channel] = (float4)(offset[channel] - mean[channel] * new_scale[local_channel].x);
} }
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
const int sample_offset = (batch * channels + channel) * pixels + pixel_offset; const int image_offset = (batch * channels + channel) * pixels + pixel_offset*4;
const float *input_ptr = input + sample_offset; const float *input_ptr = input + image_offset;
float *output_ptr = output + sample_offset; float *output_ptr = output + image_offset;
*output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel]; const int end = (batch * channels + channel + 1) * pixels;
if ((image_offset+4) > end) {
for (int i = image_offset; i < end; ++i) {
*output_ptr = new_scale[local_channel].x * *input_ptr + new_offset[local_channel].x;
++input_ptr;
++output_ptr;
}
} else {
float4 values = vload4(0, input_ptr);
values = values * new_scale[local_channel] + new_offset[local_channel];
vstore4(values, 0, output_ptr);
}
} }
...@@ -39,6 +39,7 @@ cc_library( ...@@ -39,6 +39,7 @@ cc_library(
copts = ["-std=c++11"], copts = ["-std=c++11"],
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/core:opencl_runtime",
], ],
) )
......
...@@ -33,7 +33,7 @@ class Tuner { ...@@ -33,7 +33,7 @@ class Tuner {
const std::function<std::vector<std::vector<param_type>>()> &param_generator, const std::function<std::vector<std::vector<param_type>>()> &param_generator,
const std::function<RetType(const std::vector<param_type> &)> &func) { const std::function<RetType(const std::vector<param_type> &)> &func) {
if (IsTuning()) { if (IsTuning() && param_generator != nullptr) {
// tune // tune
std::vector<param_type> opt_param = default_param; std::vector<param_type> opt_param = default_param;
RetType res = Tune<RetType>(param_generator, func, opt_param); RetType res = Tune<RetType>(param_generator, func, opt_param);
...@@ -68,7 +68,7 @@ class Tuner { ...@@ -68,7 +68,7 @@ class Tuner {
} }
inline void WriteRunParameters() { inline void WriteRunParameters() {
VLOG(0) << path_; VLOG(1) << path_;
if (path_ != nullptr) { if (path_ != nullptr) {
std::ofstream ofs(path_, std::ios::binary | std::ios::out); std::ofstream ofs(path_, std::ios::binary | std::ios::out);
if (ofs.is_open()) { if (ofs.is_open()) {
...@@ -78,14 +78,14 @@ class Tuner { ...@@ -78,14 +78,14 @@ class Tuner {
int32_t key_size = kp.first.size(); int32_t key_size = kp.first.size();
ofs.write(reinterpret_cast<char *>(&key_size), sizeof(key_size)); ofs.write(reinterpret_cast<char *>(&key_size), sizeof(key_size));
ofs.write(kp.first.c_str(), key_size); ofs.write(kp.first.c_str(), key_size);
VLOG(0) << kp.first.c_str(); VLOG(1) << kp.first.c_str();
auto &params = kp.second; auto &params = kp.second;
int32_t params_size = params.size() * sizeof(param_type); int32_t params_size = params.size() * sizeof(param_type);
ofs.write(reinterpret_cast<char*>(&params_size), sizeof(params_size)); ofs.write(reinterpret_cast<char*>(&params_size), sizeof(params_size));
for (auto &param : params) { for (auto &param : params) {
ofs.write(reinterpret_cast<char *>(&param), sizeof(params_size)); ofs.write(reinterpret_cast<char *>(&param), sizeof(params_size));
VLOG(0) << param; VLOG(1) << param;
} }
} }
ofs.close(); ofs.close();
...@@ -144,7 +144,7 @@ class Tuner { ...@@ -144,7 +144,7 @@ class Tuner {
} }
template <typename RetType> template <typename RetType>
inline RetType Tune(std::function<std::vector<std::vector<param_type>>()> param_generator, inline RetType Tune(const std::function<std::vector<std::vector<param_type>>()> &param_generator,
const std::function<RetType(const std::vector<param_type> &)> &func, const std::function<RetType(const std::vector<param_type> &)> &func,
std::vector<param_type> &opt_params) { std::vector<param_type> &opt_params) {
RetType res; RetType res;
......
...@@ -13,7 +13,8 @@ class TunerTest: public ::testing::Test { ...@@ -13,7 +13,8 @@ class TunerTest: public ::testing::Test {
protected: protected:
virtual void SetUp() { virtual void SetUp() {
remove( "/data/local/tmp/mace.config" ); remove( "/data/local/tmp/mace.config" );
setenv("MACE_RUN_PARAMTER_PATH", "/data/local/tmp/mace.config", 1); setenv("MACE_RUN_PARAMETER_PATH", "/data/local/tmp/mace.config", 1);
setenv("MACE_TUNING", "1", 1);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册