提交 e8824833 编写于 作者: Y yejianwu

update cpu batch norm to adapt locality, modify op to use template dtype

上级 99963c98
...@@ -28,8 +28,11 @@ struct BatchNormFunctor { ...@@ -28,8 +28,11 @@ struct BatchNormFunctor {
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val; // new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
const index_t ch_pixel_size = input->dim(0) * input->dim(1) * input->dim(2); const index_t batchs = input->dim(0);
const index_t channel = input->dim(3); const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t height_width = height * width;
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard scale_mapper(scale); Tensor::MappingGuard scale_mapper(scale);
...@@ -47,15 +50,24 @@ struct BatchNormFunctor { ...@@ -47,15 +50,24 @@ struct BatchNormFunctor {
const T *epsilon_ptr = epsilon->data<T>(); const T *epsilon_ptr = epsilon->data<T>();
T *output_ptr = output->mutable_data<T>(); T *output_ptr = output->mutable_data<T>();
vector<T> new_scale(channels);
vector<T> new_offset(channels);
#pragma omp parallel for #pragma omp parallel for
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channels; ++c) {
T new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr); new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr);
T new_offset = offset_ptr[c] - mean_ptr[c] * new_scale; new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
index_t pos = c; }
index_t pos = 0;
for (index_t i = 0; i < ch_pixel_size; ++i) { #pragma omp parallel for
output_ptr[pos] = new_scale * input_ptr[pos] + new_offset; for (index_t n = 0; n < batchs; ++n) {
pos += channel; for (index_t hb = 0; hb < height_width; ++hb) {
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
++pos;
}
} }
} }
} }
...@@ -71,15 +83,16 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()( ...@@ -71,15 +83,16 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *epsilon, const Tensor *epsilon,
Tensor *output); Tensor *output);
template <> template <typename T>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( struct BatchNormFunctor<DeviceType::OPENCL, T> {
const Tensor *input, void operator()(const Tensor *input,
const Tensor *scale, const Tensor *scale,
const Tensor *offset, const Tensor *offset,
const Tensor *mean, const Tensor *mean,
const Tensor *var, const Tensor *var,
const Tensor *epsilon, const Tensor *epsilon,
Tensor *output); Tensor *output);
};
} // namepsace kernels } // namepsace kernels
} // namespace mace } // namespace mace
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <> template <typename T>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input, const Tensor *input,
const Tensor *scale, const Tensor *scale,
const Tensor *offset, const Tensor *offset,
...@@ -27,7 +27,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -27,7 +27,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const index_t channels = input->dim(3); const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels); const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_blocks = RoundUpDiv4(width);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width), static_cast<uint32_t>(width),
...@@ -35,8 +34,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -35,8 +34,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options; std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype())); auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(input->dtype())); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options); auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
...@@ -83,5 +83,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -83,5 +83,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
func); func);
} }
template
struct BatchNormFunctor<DeviceType::OPENCL, float>;
template
struct BatchNormFunctor<DeviceType::OPENCL, half>;
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#include <common.h> #include <common.h>
// Supported data types: half/float // Supported data types: half/float
__kernel void batch_norm(__read_only image2d_t input, __kernel void batch_norm(__read_only image2d_t input,
__read_only image2d_t scale, __read_only image2d_t scale,
__read_only image2d_t offset, __read_only image2d_t offset,
__read_only image2d_t mean, __read_only image2d_t mean,
__read_only image2d_t var, __read_only image2d_t var,
global const DATA_TYPE *epsilon, __global const DATA_TYPE *epsilon,
__write_only image2d_t output) { __write_only image2d_t output) {
const int ch_blk = get_global_id(0); const int ch_blk = get_global_id(0);
const int w_blk = get_global_id(1); const int w_blk = get_global_id(1);
const int hb_blk = get_global_id(2); const int hb_blk = get_global_id(2);
......
...@@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm") ...@@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm")
.Build(), .Build(),
BatchNormOp<DeviceType::OPENCL, float>); BatchNormOp<DeviceType::OPENCL, float>);
} // namespace mace REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm")
\ No newline at end of file .TypeConstraint<half>("T")
.Build(),
BatchNormOp<DeviceType::OPENCL, half>);
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册