diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 1340f26a46229da91d581bf14a1cff5b4da28dec..5f00d747967695bfd379f331c03033eeaebd20f9 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -28,8 +28,11 @@ struct BatchNormFunctor { // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_offset = \offset - mean * common_val; // Y = new_scale * X + new_offset; - const index_t ch_pixel_size = input->dim(0) * input->dim(1) * input->dim(2); - const index_t channel = input->dim(3); + const index_t batchs = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t height_width = height * width; + const index_t channels = input->dim(3); Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard scale_mapper(scale); @@ -47,15 +50,24 @@ struct BatchNormFunctor { const T *epsilon_ptr = epsilon->data(); T *output_ptr = output->mutable_data(); + vector new_scale(channels); + vector new_offset(channels); + #pragma omp parallel for - for (index_t c = 0; c < channel; ++c) { - T new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr); - T new_offset = offset_ptr[c] - mean_ptr[c] * new_scale; - index_t pos = c; + for (index_t c = 0; c < channels; ++c) { + new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr); + new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c]; + } + + index_t pos = 0; - for (index_t i = 0; i < ch_pixel_size; ++i) { - output_ptr[pos] = new_scale * input_ptr[pos] + new_offset; - pos += channel; +#pragma omp parallel for + for (index_t n = 0; n < batchs; ++n) { + for (index_t hb = 0; hb < height_width; ++hb) { + for (index_t c = 0; c < channels; ++c) { + output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c]; + ++pos; + } } } } @@ -71,15 +83,16 @@ void BatchNormFunctor::operator()( const Tensor *epsilon, Tensor *output); -template <> -void BatchNormFunctor::operator()( - const Tensor *input, - const Tensor *scale, - const Tensor *offset, - const Tensor *mean, - const Tensor *var, - const Tensor *epsilon, - Tensor *output); +template +struct BatchNormFunctor { + void operator()(const Tensor *input, + const Tensor *scale, + const Tensor *offset, + const Tensor *mean, + const Tensor *var, + const Tensor *epsilon, + Tensor *output); +}; } // namepsace kernels } // namespace mace diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index e9dc00b9d5d93fe4792ea15cc25e40791afff6e2..c810d9290259912660a35b522dac5661d2fa9c11 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -11,8 +11,8 @@ namespace mace { namespace kernels { -template <> -void BatchNormFunctor::operator()( +template +void BatchNormFunctor::operator()( const Tensor *input, const Tensor *scale, const Tensor *offset, @@ -27,7 +27,6 @@ void BatchNormFunctor::operator()( const index_t channels = input->dim(3); const index_t channel_blocks = RoundUpDiv4(channels); - const index_t width_blocks = RoundUpDiv4(width); const uint32_t gws[3] = {static_cast(channel_blocks), static_cast(width), @@ -35,8 +34,9 @@ void BatchNormFunctor::operator()( auto runtime = OpenCLRuntime::Get(); std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(input->dtype())); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); @@ -83,5 +83,9 @@ void BatchNormFunctor::operator()( func); } +template +struct BatchNormFunctor; +template +struct BatchNormFunctor; } // namespace kernels } // namespace mace diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index 8294d6dfae313688bbd2d5061e81fe4b1276b7db..e3d5ae5ec01c19679391c86d7e04015df6ac23b2 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -1,12 +1,12 @@ #include // Supported data types: half/float __kernel void batch_norm(__read_only image2d_t input, - __read_only image2d_t scale, - __read_only image2d_t offset, - __read_only image2d_t mean, - __read_only image2d_t var, - global const DATA_TYPE *epsilon, - __write_only image2d_t output) { + __read_only image2d_t scale, + __read_only image2d_t offset, + __read_only image2d_t mean, + __read_only image2d_t var, + __global const DATA_TYPE *epsilon, + __write_only image2d_t output) { const int ch_blk = get_global_id(0); const int w_blk = get_global_id(1); const int hb_blk = get_global_id(2); diff --git a/mace/ops/batch_norm.cc b/mace/ops/batch_norm.cc index 34ba41a6fbab4dff60e711efb852793b6509f6ee..76723b2dc2c369257b79fb66b8c472752253700d 100644 --- a/mace/ops/batch_norm.cc +++ b/mace/ops/batch_norm.cc @@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm") .Build(), BatchNormOp); -} // namespace mace \ No newline at end of file +REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm") + .TypeConstraint("T") + .Build(), + BatchNormOp); + +} // namespace mace