提交 6ed08429 编写于 作者: Y yejianwu

update variable name and local group size in batch norm

上级 ffdae79f
......@@ -28,10 +28,9 @@ struct BatchNormFunctor {
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t batchs = input->dim(0);
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t height_width = height * width;
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input);
......@@ -62,11 +61,13 @@ struct BatchNormFunctor {
index_t pos = 0;
#pragma omp parallel for
for (index_t n = 0; n < batchs; ++n) {
for (index_t hb = 0; hb < height_width; ++hb) {
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
++pos;
for (index_t n = 0; n < batch; ++n) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
++pos;
}
}
}
}
......
......@@ -21,7 +21,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *epsilon,
Tensor *output) {
const index_t batchs = input->dim(0);
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
......@@ -30,7 +30,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batchs)};
static_cast<uint32_t>(height * batch)};
auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options;
......@@ -40,7 +40,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
const std::vector<uint32_t> lws = {1, 1, kwg_size};
const std::vector<uint32_t> lws = {1, kwg_size, 1};
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......@@ -52,7 +52,8 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
return {{1, 1, 64},
return {{8, 128, 1}, //SNPE size
{1, 1, 64},
{1, 1, 128},
{1, kwg_size/16, 16},
{1, kwg_size/32, 32},
......
......@@ -8,24 +8,21 @@ __kernel void batch_norm(__read_only image2d_t input,
__global const DATA_TYPE *epsilon,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w_blk = get_global_id(1);
const int hb_blk = get_global_id(2);
const int w = get_global_id(1);
const int hb = get_global_id(2);
const int width = get_global_size(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
DATA_TYPE4 scale_value = READ_IMAGET(scale, sampler, (int2)(ch_blk, 0));
DATA_TYPE4 offset_value = READ_IMAGET(offset, sampler, (int2)(ch_blk, 0));
DATA_TYPE4 mean_value = READ_IMAGET(mean, sampler, (int2)(ch_blk, 0));
DATA_TYPE4 var_value = READ_IMAGET(var, sampler, (int2)(ch_blk, 0));
DATA_TYPE4 scale_value = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 offset_value = READ_IMAGET(offset, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 mean_value = READ_IMAGET(mean, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 var_value = READ_IMAGET(var, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 new_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)(*epsilon));
DATA_TYPE4 new_offset = offset_value - mean_value * new_scale;
const int pos = ch_blk * width + w_blk;
const int pos = ch_blk * width + w;
DATA_TYPE4 in = READ_IMAGET(input, sampler, (int2)(pos, hb_blk));
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = in * new_scale + new_offset;
WRITE_IMAGET(output, (int2)(pos, hb_blk), out);
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
......@@ -5,8 +5,6 @@
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
class BatchNormOpTest : public OpsTestBase {};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册