提交 297f5771 编写于 作者: L liuqi

Change the type of batch_norm's attribute epsilon to tensor.

上级 5b21653b
...@@ -13,16 +13,13 @@ namespace kernels { ...@@ -13,16 +13,13 @@ namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct BatchNormFunctor { struct BatchNormFunctor {
float variance_epsilon_;
BatchNormFunctor(const float variance_epsilon)
: variance_epsilon_(variance_epsilon) {}
void operator()(const T* input, void operator()(const T* input,
const T* scale, const T* scale,
const T* offset, const T* offset,
const T* mean, const T* mean,
const T* var, const T* var,
const float variance_epsilon,
const index_t n, const index_t n,
const index_t channel, const index_t channel,
const index_t sample_size, const index_t sample_size,
...@@ -37,7 +34,7 @@ struct BatchNormFunctor { ...@@ -37,7 +34,7 @@ struct BatchNormFunctor {
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
T new_scale, new_offset; T new_scale, new_offset;
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_); new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon);
new_offset = offset[c] - mean[c] * new_scale; new_offset = offset[c] - mean[c] * new_scale;
index_t pos = c * sample_size; index_t pos = c * sample_size;
...@@ -60,6 +57,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()( ...@@ -60,6 +57,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* offset, const float* offset,
const float* mean, const float* mean,
const float* var, const float* var,
const float variance_epsilon,
const index_t n, const index_t n,
const index_t channel, const index_t channel,
const index_t sample_size, const index_t sample_size,
......
...@@ -15,6 +15,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()( ...@@ -15,6 +15,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* offset, const float* offset,
const float* mean, const float* mean,
const float* var, const float* var,
const float variance_epsilon,
const index_t n, const index_t n,
const index_t channel, const index_t channel,
const index_t sample_size, const index_t sample_size,
...@@ -31,7 +32,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()( ...@@ -31,7 +32,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
index_t count = sample_size >> 2; index_t count = sample_size >> 2;
index_t remain_count = sample_size - (count << 2); index_t remain_count = sample_size - (count << 2);
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_); new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon);
new_offset = offset[c] - mean[c] * new_scale; new_offset = offset[c] - mean[c] * new_scale;
index_t pos = c * sample_size; index_t pos = c * sample_size;
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
// //
#include <arm_neon.h> #include <arm_neon.h>
#include <float.h>
#include <limits> #include <limits>
#include "mace/core/common.h" #include "mace/core/common.h"
......
...@@ -15,8 +15,7 @@ class BatchNormOp : public Operator<D, T> { ...@@ -15,8 +15,7 @@ class BatchNormOp : public Operator<D, T> {
public: public:
BatchNormOp(const OperatorDef& operator_def, Workspace* ws) BatchNormOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_( functor_() {}
OperatorBase::GetSingleArgument<float>("variance_epsilon", 1e-4)) {}
bool Run() override { bool Run() override {
const Tensor* input = this->Input(0); const Tensor* input = this->Input(0);
...@@ -24,6 +23,7 @@ class BatchNormOp : public Operator<D, T> { ...@@ -24,6 +23,7 @@ class BatchNormOp : public Operator<D, T> {
const Tensor* offset = this->Input(2); const Tensor* offset = this->Input(2);
const Tensor* mean = this->Input(3); const Tensor* mean = this->Input(3);
const Tensor* var = this->Input(4); const Tensor* var = this->Input(4);
const Tensor* epsilon = this->Input(5);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size()); input->dim_size());
...@@ -35,6 +35,8 @@ class BatchNormOp : public Operator<D, T> { ...@@ -35,6 +35,8 @@ class BatchNormOp : public Operator<D, T> {
mean->dim_size()); mean->dim_size());
MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ", MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ",
var->dim_size()); var->dim_size());
MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ",
epsilon->dim_size());
Tensor* output = this->Output(0); Tensor* output = this->Output(0);
output->ResizeLike(input); output->ResizeLike(input);
...@@ -48,9 +50,10 @@ class BatchNormOp : public Operator<D, T> { ...@@ -48,9 +50,10 @@ class BatchNormOp : public Operator<D, T> {
const T* offset_ptr = offset->data<T>(); const T* offset_ptr = offset->data<T>();
const T* mean_ptr = mean->data<T>(); const T* mean_ptr = mean->data<T>();
const T* var_ptr = var->data<T>(); const T* var_ptr = var->data<T>();
const T* epsilon_ptr = epsilon->data<T>();
T* output_ptr = output->mutable_data<T>(); T* output_ptr = output->mutable_data<T>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, n, channel, functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, *epsilon_ptr, n, channel,
sample_size, output_ptr); sample_size, output_ptr);
return true; return true;
} }
......
...@@ -19,6 +19,7 @@ static void BatchNorm( ...@@ -19,6 +19,7 @@ static void BatchNorm(
.Input("Offset") .Input("Offset")
.Input("Mean") .Input("Mean")
.Input("Var") .Input("Var")
.Input("Epsilon")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
...@@ -28,6 +29,7 @@ static void BatchNorm( ...@@ -28,6 +29,7 @@ static void BatchNorm(
net.AddRandomInput<T>("Offset", {channels}); net.AddRandomInput<T>("Offset", {channels});
net.AddRandomInput<T>("Mean", {channels}); net.AddRandomInput<T>("Mean", {channels});
net.AddRandomInput<T>("Var", {channels}, true); net.AddRandomInput<T>("Var", {channels}, true);
net.AddInputFromArray<float>("Epsilon", {}, {1e-3});
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
......
...@@ -18,6 +18,7 @@ TEST_F(BatchNormOpTest, SimpleCPU) { ...@@ -18,6 +18,7 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
.Input("Offset") .Input("Offset")
.Input("Mean") .Input("Mean")
.Input("Var") .Input("Var")
.Input("Epsilon")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
...@@ -28,6 +29,7 @@ TEST_F(BatchNormOpTest, SimpleCPU) { ...@@ -28,6 +29,7 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
net.AddInputFromArray<float>("Offset", {1}, {2.0}); net.AddInputFromArray<float>("Offset", {1}, {2.0});
net.AddInputFromArray<float>("Mean", {1}, {10}); net.AddInputFromArray<float>("Mean", {1}, {10});
net.AddInputFromArray<float>("Var", {1}, {11.67f}); net.AddInputFromArray<float>("Var", {1}, {11.67f});
net.AddInputFromArray<float>("Epsilon", {}, {1e-3});
// Run // Run
net.RunOp(); net.RunOp();
...@@ -46,8 +48,8 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -46,8 +48,8 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
// generate random input // generate random input
index_t batch = 1 + rand() % 10; index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50; index_t channels = 3 + rand() % 50;
index_t height = 10 + rand() % 50; index_t height = 103;
index_t width = 10 + rand() % 50; index_t width = 113;
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
...@@ -56,6 +58,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -56,6 +58,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
.Input("Offset") .Input("Offset")
.Input("Mean") .Input("Mean")
.Input("Var") .Input("Var")
.Input("Epsilon")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
...@@ -65,6 +68,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -65,6 +68,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
net.AddRandomInput<float>("Offset", {channels}); net.AddRandomInput<float>("Offset", {channels});
net.AddRandomInput<float>("Mean", {channels}); net.AddRandomInput<float>("Mean", {channels});
net.AddRandomInput<float>("Var", {channels}, true); net.AddRandomInput<float>("Var", {channels}, true);
net.AddInputFromArray<float>("Epsilon", {}, {1e-3});
// run cpu // run cpu
net.RunOp(); net.RunOp();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册