提交 75c8a0ba 编写于 作者: 李寅

Merge branch 'batch_norm_epsilon' into 'master'

Change the type of batch_norm's attribute epsilon to tensor.

See merge request !49
......@@ -13,16 +13,13 @@ namespace kernels {
template <DeviceType D, typename T>
struct BatchNormFunctor {
float variance_epsilon_;
BatchNormFunctor(const float variance_epsilon)
: variance_epsilon_(variance_epsilon) {}
void operator()(const T* input,
const T* scale,
const T* offset,
const T* mean,
const T* var,
const float variance_epsilon,
const index_t n,
const index_t channel,
const index_t sample_size,
......@@ -37,7 +34,7 @@ struct BatchNormFunctor {
// Y = new_scale * X + new_offset;
T new_scale, new_offset;
for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon);
new_offset = offset[c] - mean[c] * new_scale;
index_t pos = c * sample_size;
......@@ -60,6 +57,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* offset,
const float* mean,
const float* var,
const float variance_epsilon,
const index_t n,
const index_t channel,
const index_t sample_size,
......
......@@ -15,6 +15,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* offset,
const float* mean,
const float* var,
const float variance_epsilon,
const index_t n,
const index_t channel,
const index_t sample_size,
......@@ -31,7 +32,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
index_t count = sample_size >> 2;
index_t remain_count = sample_size - (count << 2);
for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon);
new_offset = offset[c] - mean[c] * new_scale;
index_t pos = c * sample_size;
......
......@@ -3,7 +3,6 @@
//
#include <arm_neon.h>
#include <float.h>
#include <limits>
#include "mace/core/common.h"
......
......@@ -15,8 +15,7 @@ class BatchNormOp : public Operator<D, T> {
public:
BatchNormOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<D, T>(operator_def, ws),
functor_(
OperatorBase::GetSingleArgument<float>("variance_epsilon", 1e-4)) {}
functor_() {}
bool Run() override {
const Tensor* input = this->Input(0);
......@@ -24,6 +23,7 @@ class BatchNormOp : public Operator<D, T> {
const Tensor* offset = this->Input(2);
const Tensor* mean = this->Input(3);
const Tensor* var = this->Input(4);
const Tensor* epsilon = this->Input(5);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
......@@ -35,6 +35,8 @@ class BatchNormOp : public Operator<D, T> {
mean->dim_size());
MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ",
var->dim_size());
MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ",
epsilon->dim_size());
Tensor* output = this->Output(0);
output->ResizeLike(input);
......@@ -48,9 +50,10 @@ class BatchNormOp : public Operator<D, T> {
const T* offset_ptr = offset->data<T>();
const T* mean_ptr = mean->data<T>();
const T* var_ptr = var->data<T>();
const T* epsilon_ptr = epsilon->data<T>();
T* output_ptr = output->mutable_data<T>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, n, channel,
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, *epsilon_ptr, n, channel,
sample_size, output_ptr);
return true;
}
......
......@@ -19,6 +19,7 @@ static void BatchNorm(
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.operator_def());
......@@ -28,6 +29,7 @@ static void BatchNorm(
net.AddRandomInput<T>("Offset", {channels});
net.AddRandomInput<T>("Mean", {channels});
net.AddRandomInput<T>("Var", {channels}, true);
net.AddInputFromArray<float>("Epsilon", {}, {1e-3});
// Warm-up
for (int i = 0; i < 5; ++i) {
......
......@@ -18,6 +18,7 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.operator_def());
......@@ -28,6 +29,7 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
net.AddInputFromArray<float>("Offset", {1}, {2.0});
net.AddInputFromArray<float>("Mean", {1}, {10});
net.AddInputFromArray<float>("Var", {1}, {11.67f});
net.AddInputFromArray<float>("Epsilon", {}, {1e-3});
// Run
net.RunOp();
......@@ -46,8 +48,8 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 10 + rand() % 50;
index_t width = 10 + rand() % 50;
index_t height = 103;
index_t width = 113;
// Construct graph
auto& net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest")
......@@ -56,6 +58,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.operator_def());
......@@ -65,6 +68,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
net.AddRandomInput<float>("Offset", {channels});
net.AddRandomInput<float>("Mean", {channels});
net.AddRandomInput<float>("Var", {channels}, true);
net.AddInputFromArray<float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册