From e882483363642c518d7fb3c7dab1c9921095d7c6 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Mon, 4 Dec 2017 12:46:09 +0800 Subject: [PATCH] update cpu batch norm to adapt locality, modify op to use template dtype --- mace/kernels/batch_norm.h | 49 +++++++++++++++--------- mace/kernels/opencl/batch_norm_opencl.cc | 14 ++++--- mace/kernels/opencl/cl/batch_norm.cl | 12 +++--- mace/ops/batch_norm.cc | 7 +++- 4 files changed, 52 insertions(+), 30 deletions(-) diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 1340f26a..5f00d747 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -28,8 +28,11 @@ struct BatchNormFunctor { // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_offset = \offset - mean * common_val; // Y = new_scale * X + new_offset; - const index_t ch_pixel_size = input->dim(0) * input->dim(1) * input->dim(2); - const index_t channel = input->dim(3); + const index_t batchs = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t height_width = height * width; + const index_t channels = input->dim(3); Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard scale_mapper(scale); @@ -47,15 +50,24 @@ struct BatchNormFunctor { const T *epsilon_ptr = epsilon->data(); T *output_ptr = output->mutable_data(); + vector new_scale(channels); + vector new_offset(channels); + #pragma omp parallel for - for (index_t c = 0; c < channel; ++c) { - T new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr); - T new_offset = offset_ptr[c] - mean_ptr[c] * new_scale; - index_t pos = c; + for (index_t c = 0; c < channels; ++c) { + new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr); + new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c]; + } + + index_t pos = 0; - for (index_t i = 0; i < ch_pixel_size; ++i) { - output_ptr[pos] = new_scale * input_ptr[pos] + new_offset; - pos += channel; +#pragma omp parallel for + for (index_t n = 0; n < batchs; ++n) { + for (index_t hb = 0; hb < height_width; ++hb) { + for (index_t c = 0; c < channels; ++c) { + output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c]; + ++pos; + } } } } @@ -71,15 +83,16 @@ void BatchNormFunctor::operator()( const Tensor *epsilon, Tensor *output); -template <> -void BatchNormFunctor::operator()( - const Tensor *input, - const Tensor *scale, - const Tensor *offset, - const Tensor *mean, - const Tensor *var, - const Tensor *epsilon, - Tensor *output); +template +struct BatchNormFunctor { + void operator()(const Tensor *input, + const Tensor *scale, + const Tensor *offset, + const Tensor *mean, + const Tensor *var, + const Tensor *epsilon, + Tensor *output); +}; } // namepsace kernels } // namespace mace diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index e9dc00b9..c810d929 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -11,8 +11,8 @@ namespace mace { namespace kernels { -template <> -void BatchNormFunctor::operator()( +template +void BatchNormFunctor::operator()( const Tensor *input, const Tensor *scale, const Tensor *offset, @@ -27,7 +27,6 @@ void BatchNormFunctor::operator()( const index_t channels = input->dim(3); const index_t channel_blocks = RoundUpDiv4(channels); - const index_t width_blocks = RoundUpDiv4(width); const uint32_t gws[3] = {static_cast(channel_blocks), static_cast(width), @@ -35,8 +34,9 @@ void BatchNormFunctor::operator()( auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(input->dtype())); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); @@ -83,5 +83,9 @@ void BatchNormFunctor::operator()( func); } +template +struct BatchNormFunctor; +template +struct BatchNormFunctor; } // namespace kernels } // namespace mace diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index 8294d6df..e3d5ae5e 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -1,12 +1,12 @@ #include // Supported data types: half/float __kernel void batch_norm(__read_only image2d_t input, - __read_only image2d_t scale, - __read_only image2d_t offset, - __read_only image2d_t mean, - __read_only image2d_t var, - global const DATA_TYPE *epsilon, - __write_only image2d_t output) { + __read_only image2d_t scale, + __read_only image2d_t offset, + __read_only image2d_t mean, + __read_only image2d_t var, + __global const DATA_TYPE *epsilon, + __write_only image2d_t output) { const int ch_blk = get_global_id(0); const int w_blk = get_global_id(1); const int hb_blk = get_global_id(2); diff --git a/mace/ops/batch_norm.cc b/mace/ops/batch_norm.cc index 34ba41a6..76723b2d 100644 --- a/mace/ops/batch_norm.cc +++ b/mace/ops/batch_norm.cc @@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm") .Build(), BatchNormOp); -} // namespace mace \ No newline at end of file +REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm") + .TypeConstraint("T") + .Build(), + BatchNormOp); + +} // namespace mace -- GitLab