提交 6ed08429 编写于 作者: Y yejianwu

update variable name and local group size in batch norm

上级 ffdae79f
...@@ -28,10 +28,9 @@ struct BatchNormFunctor { ...@@ -28,10 +28,9 @@ struct BatchNormFunctor {
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val; // new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
const index_t batchs = input->dim(0); const index_t batch = input->dim(0);
const index_t height = input->dim(1); const index_t height = input->dim(1);
const index_t width = input->dim(2); const index_t width = input->dim(2);
const index_t height_width = height * width;
const index_t channels = input->dim(3); const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard input_mapper(input);
...@@ -62,11 +61,13 @@ struct BatchNormFunctor { ...@@ -62,11 +61,13 @@ struct BatchNormFunctor {
index_t pos = 0; index_t pos = 0;
#pragma omp parallel for #pragma omp parallel for
for (index_t n = 0; n < batchs; ++n) { for (index_t n = 0; n < batch; ++n) {
for (index_t hb = 0; hb < height_width; ++hb) { for (index_t h = 0; h < height; ++h) {
for (index_t c = 0; c < channels; ++c) { for (index_t w = 0; w < width; ++w) {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c]; for (index_t c = 0; c < channels; ++c) {
++pos; output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
++pos;
}
} }
} }
} }
......
...@@ -21,7 +21,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -21,7 +21,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *epsilon, const Tensor *epsilon,
Tensor *output) { Tensor *output) {
const index_t batchs = input->dim(0); const index_t batch = input->dim(0);
const index_t height = input->dim(1); const index_t height = input->dim(1);
const index_t width = input->dim(2); const index_t width = input->dim(2);
const index_t channels = input->dim(3); const index_t channels = input->dim(3);
...@@ -30,7 +30,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -30,7 +30,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width), static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batchs)}; static_cast<uint32_t>(height * batch)};
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options; std::set<std::string> built_options;
...@@ -40,7 +40,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -40,7 +40,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options); auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
const std::vector<uint32_t> lws = {1, 1, kwg_size}; const std::vector<uint32_t> lws = {1, kwg_size, 1};
uint32_t idx = 0; uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer()))); bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
...@@ -52,7 +52,8 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -52,7 +52,8 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer()))); bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> { auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
return {{1, 1, 64}, return {{8, 128, 1}, //SNPE size
{1, 1, 64},
{1, 1, 128}, {1, 1, 128},
{1, kwg_size/16, 16}, {1, kwg_size/16, 16},
{1, kwg_size/32, 32}, {1, kwg_size/32, 32},
......
...@@ -8,24 +8,21 @@ __kernel void batch_norm(__read_only image2d_t input, ...@@ -8,24 +8,21 @@ __kernel void batch_norm(__read_only image2d_t input,
__global const DATA_TYPE *epsilon, __global const DATA_TYPE *epsilon,
__write_only image2d_t output) { __write_only image2d_t output) {
const int ch_blk = get_global_id(0); const int ch_blk = get_global_id(0);
const int w_blk = get_global_id(1); const int w = get_global_id(1);
const int hb_blk = get_global_id(2); const int hb = get_global_id(2);
const int width = get_global_size(1); const int width = get_global_size(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; DATA_TYPE4 scale_value = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 offset_value = READ_IMAGET(offset, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 mean_value = READ_IMAGET(mean, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 scale_value = READ_IMAGET(scale, sampler, (int2)(ch_blk, 0)); DATA_TYPE4 var_value = READ_IMAGET(var, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 offset_value = READ_IMAGET(offset, sampler, (int2)(ch_blk, 0));
DATA_TYPE4 mean_value = READ_IMAGET(mean, sampler, (int2)(ch_blk, 0));
DATA_TYPE4 var_value = READ_IMAGET(var, sampler, (int2)(ch_blk, 0));
DATA_TYPE4 new_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)(*epsilon)); DATA_TYPE4 new_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)(*epsilon));
DATA_TYPE4 new_offset = offset_value - mean_value * new_scale; DATA_TYPE4 new_offset = offset_value - mean_value * new_scale;
const int pos = ch_blk * width + w_blk; const int pos = ch_blk * width + w;
DATA_TYPE4 in = READ_IMAGET(input, sampler, (int2)(pos, hb_blk)); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = in * new_scale + new_offset; DATA_TYPE4 out = in * new_scale + new_offset;
WRITE_IMAGET(output, (int2)(pos, hb_blk), out); WRITE_IMAGET(output, (int2)(pos, hb), out);
} }
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
class BatchNormOpTest : public OpsTestBase {}; class BatchNormOpTest : public OpsTestBase {};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册