提交 e8824833 编写于 作者: Y yejianwu

update cpu batch norm to adapt locality, modify op to use template dtype

上级 99963c98
......@@ -28,8 +28,11 @@ struct BatchNormFunctor {
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t ch_pixel_size = input->dim(0) * input->dim(1) * input->dim(2);
const index_t channel = input->dim(3);
const index_t batchs = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t height_width = height * width;
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard scale_mapper(scale);
......@@ -47,15 +50,24 @@ struct BatchNormFunctor {
const T *epsilon_ptr = epsilon->data<T>();
T *output_ptr = output->mutable_data<T>();
vector<T> new_scale(channels);
vector<T> new_offset(channels);
#pragma omp parallel for
for (index_t c = 0; c < channel; ++c) {
T new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr);
T new_offset = offset_ptr[c] - mean_ptr[c] * new_scale;
index_t pos = c;
for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
index_t pos = 0;
for (index_t i = 0; i < ch_pixel_size; ++i) {
output_ptr[pos] = new_scale * input_ptr[pos] + new_offset;
pos += channel;
#pragma omp parallel for
for (index_t n = 0; n < batchs; ++n) {
for (index_t hb = 0; hb < height_width; ++hb) {
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
++pos;
}
}
}
}
......@@ -71,15 +83,16 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *epsilon,
Tensor *output);
template <>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input,
template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output);
};
} // namepsace kernels
} // namespace mace
......
......@@ -11,8 +11,8 @@
namespace mace {
namespace kernels {
template <>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
template <typename T>
void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
......@@ -27,7 +27,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_blocks = RoundUpDiv4(width);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
......@@ -35,8 +34,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype()));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(input->dtype()));
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel);
......@@ -83,5 +83,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
func);
}
template
struct BatchNormFunctor<DeviceType::OPENCL, float>;
template
struct BatchNormFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -5,7 +5,7 @@ __kernel void batch_norm(__read_only image2d_t input,
__read_only image2d_t offset,
__read_only image2d_t mean,
__read_only image2d_t var,
global const DATA_TYPE *epsilon,
__global const DATA_TYPE *epsilon,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w_blk = get_global_id(1);
......
......@@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm")
.Build(),
BatchNormOp<DeviceType::OPENCL, float>);
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm")
.TypeConstraint<half>("T")
.Build(),
BatchNormOp<DeviceType::OPENCL, half>);
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册