提交 68a335f1 编写于 作者: L Liangliang He

Merge branch 'tensor-warning' into 'master'

Finish batch norm op

See merge request !14
......@@ -37,6 +37,7 @@ bool SimpleNet::Run() {
return false;
}
}
return true;
}
unique_ptr<NetBase> CreateNet(const NetDef& net_def,
......
......@@ -44,7 +44,7 @@ class OperatorBase {
*operator_def_, name, default_value);
}
inline const Tensor *Input(int idx) {
inline const Tensor *Input(TIndex idx) {
MACE_CHECK(idx < inputs_.size());
return inputs_[idx];
}
......
......@@ -66,7 +66,13 @@ class Tensor {
inline const vector<TIndex>& shape() const { return shape_; }
inline TIndex dim_size() { return shape_.size(); }
inline TIndex dim_size() const { return shape_.size(); }
inline TIndex dim(TIndex index) const {
MACE_CHECK(index < shape_.size(), "Exceeding ndim limit");
MACE_CHECK(index >= 0, "Cannot have negative dimension index");
return shape_[index];
}
inline TIndex size() const { return size_; }
......@@ -120,16 +126,16 @@ class Tensor {
}
template <typename T>
inline void Copy(const T* src, size_t size) {
inline void Copy(const T* src, TIndex size) {
MACE_CHECK(size == size_, "copy src and dst with different size.");
CopyBytes(static_cast<const void*>(src), sizeof(T) * size);
}
template <typename SrcType, typename DstType>
inline void CopyWithCast(const SrcType* src, size_t size) {
MACE_CHECK(size == size_, "copy src and dst with different size.");
MACE_CHECK(static_cast<TIndex>(size) == size_, "copy src and dst with different size.");
unique_ptr<DstType[]> buffer(new DstType[size]);
for (int i = 0; i < size; ++i) {
for (size_t i = 0; i < size; ++i) {
buffer[i] = static_cast<DstType>(src[i]);
}
CopyBytes(static_cast<const void*>(buffer.get()), sizeof(DstType) * size);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_BATCH_NORM_H_
#define MACE_KERNELS_BATCH_NORM_H_
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct BatchNormFunctorBase {
BatchNormFunctorBase(const float variance_epsilon)
:variance_epsilon_(variance_epsilon){}
float variance_epsilon_;
};
template<DeviceType D, typename T>
struct BatchNormFunctor : public BatchNormFunctorBase<D, T> {
BatchNormFunctor(const float variance_epsilon)
:BatchNormFunctorBase<D, T>(variance_epsilon){}
void operator()(const T* input,
const T* scale,
const T* offset,
const T* mean,
const T* var,
const TIndex n,
const TIndex channel,
const TIndex sample_size,
T* output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
T new_scale, new_offset;
for (TIndex c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_offset = offset[c] - mean[c] * new_scale;
TIndex pos = c * sample_size;
for (TIndex i = 0; i < n; ++i) {
const T* input_sample_ptr = input + pos;
T* output_sample_ptr = output + pos;
for (TIndex j = 0; j < sample_size; ++j) {
output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset;
}
pos += channel * sample_size;
}
}
}
};
} // namepsace kernels
} // namespace mace
#endif // MACE_KERNELS_BATCH_NORM_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#if __ARM_NEON
#include <arm_neon.h>
#include "mace/kernels/batch_norm.h"
namespace mace {
namespace kernels {
template <typename T>
struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<DeviceType::NEON, T> {
BatchNormFunctor(const float variance_epsilon)
:BatchNormFunctorBase<DeviceType::NEON, T>(variance_epsilon){}
void operator()(const T* input,
const T* scale,
const T* offset,
const T* mean,
const T* var,
const int n,
const int channel,
const int sample_size,
T* output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
T new_scale, new_offset;
int count = sample_size >> 2;
int remain_count = sample_size - count;
for (TIndex c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_offset = offset[c] - mean[c] * new_scale;
TIndex pos = c * sample_size;
float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (TIndex i = 0; i < n; ++i) {
const float* input_sample_ptr = input + pos;
float* output_sample_ptr = output + pos;
for(TIndex j = 0; j < count; ++j) {
float32x4_t input_f = vld1q_f32(input_sample_ptr);
float32x4_t output_f = new_offset_f;
output_f = vfmaq_f32(output_f, input_f, new_scale_f);
vst1q_f32(output_sample_ptr, output_f);
input_sample_ptr += 4;
output_sample_ptr += 4;
}
for(TIndex j = 0; j < remain_count; ++j) {
*output_sample_ptr = new_scale * *input_sample_ptr + new_offset;
++output_sample_ptr;
++input_sample_ptr;
}
pos += channel * sample_size;
}
}
}
};
} // namespace kernels
} // namespace mace
#endif // __ARM_NEON
......@@ -5,20 +5,49 @@ package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android")
cc_library(
name = "test",
testonly = 1,
hdrs = [
"ops_test_util.h",
],
deps = [
"//mace/core",
"@gtest//:gtest",
],
)
cc_library(
name = "ops",
srcs = glob(["*.cc"]),
hdrs = glob(["*.h"]),
srcs = glob(
["*.cc"],
exclude = ["*_test.cc"],
),
hdrs = glob(
["*.h"],
exclude = ["ops_test_util.h"],
),
copts = ["-std=c++11"],
deps = [
"//mace/core",
"//mace/kernels",
"//mace/proto:cc_proto",
"//mace/core:core",
"//mace/kernels:kernels",
],
copts = ['-std=c++11'],
alwayslink = 1,
)
cc_test(
name = "batch_norm_test",
srcs = ["batch_norm_test.cc"],
copts = ["-std=c++11"],
linkstatic = 1,
deps = [
":ops",
":test",
"@gtest//:gtest_main",
],
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/batch_norm.h"
namespace mace {
REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp<DeviceType::CPU, float>);
#if __ARM_NEON
REGISTER_NEON_OPERATOR(BatchNorm, BatchNormOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_BATCH_NORM_H_
#define MACE_BATCH_NORM_H_
#include "mace/core/operator.h"
#include "mace/kernels/batch_norm.h"
namespace mace {
template<DeviceType D, class T>
class BatchNormOp : public Operator<D, T> {
public:
BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<float>("variance_epsilon", 1e-4)){}
bool Run() override {
const Tensor* input = this->Input(0);
const Tensor* scale = this->Input(1);
const Tensor* offset = this->Input(2);
const Tensor* mean = this->Input(3);
const Tensor* var = this->Input(4);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size());
MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ", scale->dim_size());
MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ", offset->dim_size());
MACE_CHECK(mean->dim_size() == 1, "mean must be 1-dimensional. ", mean->dim_size());
MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ", var->dim_size());
Tensor* output = this->Output(0);
output->ResizeLike(input);
const TIndex n = input->dim(0);
const TIndex channel = input->dim(1);
const TIndex sample_size = input->dim(2) * input->dim(3);
const float* input_ptr = input->data<float>();
const float* scale_ptr = scale->data<float>();
const float* offset_ptr = offset->data<float>();
const float* mean_ptr = mean->data<float>();
const float* var_ptr = var->data<float>();
float* output_ptr = output->mutable_data<float>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr,
n, channel, sample_size,
output_ptr);
return true;
}
private:
kernels::BatchNormFunctor<D, T> functor_;
};
} // namespace mace
#endif // MACE_BATCH_NORM_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
class BatchNormOpTest : public OpsTestBase {};
TEST_F(BatchNormOpTest, Simple) {
// Construct graph
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Output("Output")
.Finalize(operator_def());
// Add input data
AddInputFromArray<float>("Input", {1, 1, 6, 2},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
AddInputFromArray<float>("Scale", {2},
{4.0f, 4.0f});
AddInputFromArray<float>("Offset", {2},
{2.0, 2.0});
AddInputFromArray<float>("Mean", {2},
{10, 10});
AddInputFromArray<float>("Var", {2},
{11.67f, 11.67f});
// Run
RunOp();
// Check
Tensor expected = CreateTensor<float>({1, 1, 6, 2},
{-3.86, -3.86, -1.51, -1.51, 0.83, 0.83,
3.17, 3.17, 5.51, 5.51, 7.86, 7.86});
ExpectTensorNear<float>(expected, *GetOutput("Output"), 0.01);
}
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_TEST_UTIL_H_
#define MACE_OPS_TEST_UTIL_H_
#include "gtest/gtest.h"
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/core/net.h"
namespace mace {
class OpDefBuilder {
public:
OpDefBuilder(const char* type, const char* name) {
op_def_.set_type(type);
op_def_.set_name(name);
}
OpDefBuilder& Input(const char* input_name) {
op_def_.add_input(input_name);
return *this;
}
OpDefBuilder& Output(const char* output_name) {
op_def_.add_output(output_name);
return *this;
}
void Finalize(OperatorDef* op_def) const {
MACE_CHECK(op_def != NULL, "input should not be null.");
*op_def = op_def_;
}
OperatorDef op_def_;
};
class OpsTestBase : public ::testing::Test {
protected:
virtual void TearDown() {
auto tensor_names = ws_.Tensors();
for (auto& name : tensor_names) {
ws_.RemoveTensor(name);
}
}
public:
template <typename T>
void AddInputFromArray(const char* name, const std::vector<TIndex>& shape, const std::vector<T>& data) {
Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
float* input_data = input->mutable_data<float>();
memcpy(input_data, data.data(), data.size() * sizeof(T));
}
OperatorDef* operator_def() { return &op_def_; }
bool RunOp() {
NetDef net_def;
net_def.add_op()->CopyFrom(op_def_);
VLOG(0) << net_def.DebugString();
auto net = CreateNet(net_def, &ws_, DeviceType::CPU);
return net->Run();
}
Tensor* GetOutput(const char* output_name) {
return ws_.GetTensor(output_name);
}
private:
Workspace ws_;
OperatorDef op_def_;
};
template <typename T>
Tensor CreateTensor(const std::vector<TIndex>& shape, const std::vector<T>& data) {
Tensor res(cpu_allocator(), DataTypeToEnum<T>::v());
res.Resize(shape);
float* input_data = res.mutable_data<float>();
memcpy(input_data, data.data(), data.size() * sizeof(T));
return res;
}
inline bool IsSameSize(const Tensor& x, const Tensor& y) {
if (x.dim_size() != y.dim_size()) return false;
for (int d = 0; d < x.dim_size(); ++d) {
if (x.dim(d) != y.dim(d)) return false;
}
return true;
}
inline std::string ShapeToString(const Tensor& x) {
std::stringstream stream;
for (int i = 0; i < x.dim_size(); i++) {
if (i > 0) stream<<",";
int64 dim = x.dim(i);
if (dim < 0) {
stream<<"?";
} else {
stream<<dim;
}
}
stream<<"]";
return std::string(stream.str());
}
template <typename T>
struct is_floating_point_type {
static const bool value = std::is_same<T, float>::value ||
std::is_same<T, double>::value;
};
template <typename T>
inline void ExpectEqual(const T& a, const T& b) {
EXPECT_EQ(a, b);
}
template <>
inline void ExpectEqual<float>(const float& a, const float& b) {
EXPECT_FLOAT_EQ(a, b);
}
template <>
inline void ExpectEqual<double>(const double& a, const double& b) {
EXPECT_DOUBLE_EQ(a, b);
}
inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
ASSERT_EQ(x.dtype(), y.dtype());
ASSERT_TRUE(IsSameSize(x, y))
<< "x.shape [" << ShapeToString(x) << "] vs "
<< "y.shape [ " << ShapeToString(y) << "]";
}
template <typename T, bool is_fp = is_floating_point_type<T>::value>
struct Expector;
// Partial specialization for float and double.
template <typename T>
struct Expector<T, true> {
static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
static void Equal(const Tensor& x, const Tensor& y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
AssertSameTypeDims(x, y);
auto a = x.data<T>();
auto b = y.data<T>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a(i), b(i));
}
}
static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
AssertSameTypeDims(x, y);
auto a = x.data<T>();
auto b = y.data<T>();
for (int i = 0; i < x.size(); ++i) {
EXPECT_NEAR(a[i], b[i], abs_err)
<< "a = " << a << " b = " << b << " index = " << i;
}
}
};
template <typename T>
void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
static_assert(is_floating_point_type<T>::value, "T is not a floating point type");
Expector<T>::Near(x, y ,abs_err);
}
} // namespace mace
#endif // MACE_OPS_TEST_UTIL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册