提交 e2f43f41 编写于 作者: L liuqi

Fix batch norm bugs and Add batch norm unit test and benchmark.

上级 8f246d81
bazel-*
.idea/
cmake-build-debug/
*.sh
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/batch_norm.h"
namespace mace {
template <DeviceType D, typename T>
static void BatchNorm(int iters, int batch, int channels, int height, int width) {
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
TIndex input_size = batch * channels * height * width;
std::vector<T> input(input_size, 0.0);
std::vector<T> scale(channels, 0.0);
std::vector<T> offset(channels, 0.0);
std::vector<T> mean(channels, 0.0);
std::vector<T> var(channels, 0.0);
for (int i = 0; i < input_size; ++i) {
input[i] = nd(gen);
}
for (int i = 0; i < channels; ++i) {
scale[i] = nd(gen);
offset[i] = nd(gen);
mean[i] = nd(gen);
var[i] = std::abs(nd(gen));
}
// declare output
std::unique_ptr<T[]> output(new T[input_size]);
auto functor = kernels::BatchNormFunctor<D, T>(1e-5);
while(iters--) {
functor(input.data(),
scale.data(),
offset.data(),
mean.data(),
var.data(),
batch,
channels,
height * width,
output.get());
}
}
#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64 tot = static_cast<int64>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE)));\
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BATCH_NORM(N, C, H, W, TYPE) \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON);
BM_BATCH_NORM(1, 1, 128, 128, float);
BM_BATCH_NORM(1, 1, 512, 512, float);
BM_BATCH_NORM(1, 1, 1024, 1024, float);
BM_BATCH_NORM(16, 1, 256, 256, float);
BM_BATCH_NORM(32, 1, 256, 256, float);
BM_BATCH_NORM(64, 1, 256, 256, float);
BM_BATCH_NORM(1, 3, 128, 128, float);
BM_BATCH_NORM(1, 3, 512, 512, float);
BM_BATCH_NORM(1, 3, 1024, 1024, float);
BM_BATCH_NORM(16, 3, 256, 256, float);
BM_BATCH_NORM(32, 3, 256, 256, float);
BM_BATCH_NORM(64, 3, 256, 256, float);
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <random>
#include "gtest/gtest.h"
#include "mace/kernels/batch_norm.h"
namespace mace {
TEST(BatchNormNeonTest, Simple) {
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
srand(time(NULL));
// generate random input
TIndex batch = 1 + rand() % 128;
TIndex channels = 3;
TIndex height = 2 + rand() % 100;
TIndex width = 2 + rand() % 100;
TIndex input_size = batch * channels * height * width;
std::vector<float> input(input_size, 0.0);
std::vector<float> scale(channels, 0.0);
std::vector<float> offset(channels, 0.0);
std::vector<float> mean(channels, 0.0);
std::vector<float> var(channels, 0.0);
for (int i = 0; i < input_size; ++i) {
input[i] = nd(gen);
}
for (int i = 0; i < channels; ++i) {
scale[i] = nd(gen);
offset[i] = nd(gen);
mean[i] = nd(gen);
var[i] = std::abs(nd(gen));
}
// declare output
std::unique_ptr<float[]> output(new float[input_size]);
std::unique_ptr<float[]> output_neon(new float[input_size]);
kernels::BatchNormFunctor<DeviceType::CPU, float>(1e-5)(
input.data(),
scale.data(),
offset.data(),
mean.data(),
var.data(),
batch,
channels,
height * width,
output.get()
);
kernels::BatchNormFunctor<DeviceType::NEON, float>(1e-5)(
input.data(),
scale.data(),
offset.data(),
mean.data(),
var.data(),
batch,
channels,
height * width,
output_neon.get()
);
for (TIndex i = 0; i < input_size; ++i) {
EXPECT_FLOAT_EQ(output[i], output_neon[i]);
}
}
} // namespace mace
\ No newline at end of file
......@@ -37,12 +37,12 @@ class BatchNormOp : public Operator<D, T> {
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
const float* input_ptr = input->data<float>();
const float* scale_ptr = scale->data<float>();
const float* offset_ptr = offset->data<float>();
const float* mean_ptr = mean->data<float>();
const float* var_ptr = var->data<float>();
float* output_ptr = output->mutable_data<float>();
const T* input_ptr = input->data<T>();
const T* scale_ptr = scale->data<T>();
const T* offset_ptr = offset->data<T>();
const T* mean_ptr = mean->data<T>();
const T* var_ptr = var->data<T>();
T* output_ptr = output->mutable_data<T>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr,
n, channel, sample_size,
......
......@@ -23,14 +23,10 @@ TEST_F(BatchNormOpTest, Simple) {
// Add input data
AddInputFromArray<float>("Input", {1, 1, 6, 2},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
AddInputFromArray<float>("Scale", {2},
{4.0f, 4.0f});
AddInputFromArray<float>("Offset", {2},
{2.0, 2.0});
AddInputFromArray<float>("Mean", {2},
{10, 10});
AddInputFromArray<float>("Var", {2},
{11.67f, 11.67f});
AddInputFromArray<float>("Scale", {1}, {4.0f});
AddInputFromArray<float>("Offset", {1}, {2.0});
AddInputFromArray<float>("Mean", {1}, {10});
AddInputFromArray<float>("Var", {1}, {11.67f});
// Run
RunOp();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册