提交 aeb4c35e 编写于 作者: L liuqi

Use vector operation to optimize batch_norm opencl kernel.

上级 5e103649
......@@ -20,9 +20,12 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *epsilon,
Tensor *output) {
index_t pixel_size = input->dim(2) * input->dim(3);
index_t blocks = (pixel_size + 3) / 4;
const uint32_t gws[3] = {static_cast<uint32_t>(input->dim(0)),
static_cast<uint32_t>(input->dim(1)),
static_cast<uint32_t>(input->dim(2) * input->dim(3))};
static_cast<uint32_t>(blocks)};
auto runtime = OpenCLRuntime::Get();
......@@ -39,10 +42,10 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(var->buffer())));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(epsilon->buffer())));
bm_kernel.setArg(idx++, gws[2]);
bm_kernel.setArg(idx++, static_cast<uint32_t>(pixel_size));
bm_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr);
bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr);
auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
return {{1, 1, 64},
......
......@@ -6,8 +6,8 @@ void kernel batch_norm(global const float *input,
global const float *epsilon,
private const uint pixels,
global float *output,
__local float *new_scale,
__local float *new_offset) {
__local float4 *new_scale,
__local float4 *new_offset) {
const int batch = get_global_id(0);
const int channel = get_global_id(1);
const int channels = get_global_size(1);
......@@ -16,15 +16,26 @@ void kernel batch_norm(global const float *input,
const int local_pixel_idx = get_local_id(2);
if(local_pixel_idx == 0) {
new_scale[local_channel] = scale[channel] * rsqrt(var[channel] + *epsilon);
new_offset[local_channel] = offset[channel] - mean[channel] * new_scale[local_channel];
new_scale[local_channel] = (float4)(scale[channel] * rsqrt(var[channel] + *epsilon));
new_offset[local_channel] = (float4)(offset[channel] - mean[channel] * new_scale[local_channel].x);
}
barrier(CLK_LOCAL_MEM_FENCE);
const int sample_offset = (batch * channels + channel) * pixels + pixel_offset;
const int sample_offset = (batch * channels + channel) * pixels + pixel_offset*4;
const float *input_ptr = input + sample_offset;
float *output_ptr = output + sample_offset;
*output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel];
const int end = (batch * channels + channel + 1) * pixels;
if ((sample_offset+4) > end) {
for (int i = sample_offset; i < end; ++i) {
*output_ptr = new_scale[local_channel].x * *input_ptr + new_offset[local_channel].x;
++input_ptr;
++output_ptr;
}
} else {
float4 values = vload4(0, input_ptr);
values = values * new_scale[local_channel] + new_offset[local_channel];
vstore4(values, 0, output_ptr);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册