From 196fd847c7964abd6139f9689905ed69ab6fbe28 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 6 Sep 2017 15:51:26 +0800 Subject: [PATCH] Add dim function to tensor and fix some compare warning. --- mace/core/operator.h | 2 +- mace/core/tensor.h | 21 +++- mace/kernels/batch_norm.h | 68 +++++++++++ mace/kernels/neon/batch_norm_neon.cc | 66 +++++++++++ mace/ops/batch_norm.cc | 16 +++ mace/ops/batch_norm.h | 58 +++++++++ mace/ops/batch_norm_test.cc | 46 ++++++++ mace/ops/ops_test_util.h | 170 +++++++++++++++++++++++++++ 8 files changed, 442 insertions(+), 5 deletions(-) create mode 100644 mace/kernels/batch_norm.h create mode 100644 mace/kernels/neon/batch_norm_neon.cc create mode 100644 mace/ops/batch_norm.cc create mode 100644 mace/ops/batch_norm.h create mode 100644 mace/ops/batch_norm_test.cc create mode 100644 mace/ops/ops_test_util.h diff --git a/mace/core/operator.h b/mace/core/operator.h index fc883855..ddf8fb2e 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -45,7 +45,7 @@ class OperatorBase { } inline const Tensor *Input(int idx) { - MACE_CHECK(idx < inputs_.size()); + MACE_CHECK(static_cast(idx) < inputs_.size()); return inputs_[idx]; } diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 7aea8b15..2c45255d 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -66,7 +66,20 @@ class Tensor { inline const vector& shape() const { return shape_; } - inline TIndex dim_size() { return shape_.size(); } + inline TIndex dim_size() const { return shape_.size(); } + + inline int dim32(int index) const { + MACE_CHECK(static_cast(index) < shape_.size(), "Exceeding ndim limit"); + MACE_CHECK(index >= 0, "Cannot have negative dimension index"); + MACE_CHECK(shape_[index], std::numeric_limits::max()); + return static_cast(shape_[index]); + } + + inline TIndex dim(int index) const { + MACE_CHECK(static_cast(index) < shape_.size(), "Exceeding ndim limit"); + MACE_CHECK(index >= 0, "Cannot have negative dimension index"); + return shape_[index]; + } inline TIndex size() const { return size_; } @@ -121,15 +134,15 @@ class Tensor { template inline void Copy(const T* src, size_t size) { - MACE_CHECK(size == size_, "copy src and dst with different size."); + MACE_CHECK(static_cast(size) == size_, "copy src and dst with different size."); CopyBytes(static_cast(src), sizeof(T) * size); } template inline void CopyWithCast(const SrcType* src, size_t size) { - MACE_CHECK(size == size_, "copy src and dst with different size."); + MACE_CHECK(static_cast(size) == size_, "copy src and dst with different size."); unique_ptr buffer(new DstType[size]); - for (int i = 0; i < size; ++i) { + for (size_t i = 0; i < size; ++i) { buffer[i] = static_cast(src[i]); } CopyBytes(static_cast(buffer.get()), sizeof(DstType) * size); diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h new file mode 100644 index 00000000..fd405fa5 --- /dev/null +++ b/mace/kernels/batch_norm.h @@ -0,0 +1,68 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_BATCH_NORM_H_ +#define MACE_KERNELS_BATCH_NORM_H_ + +#include "mace/core/tensor.h" +#include "mace/proto/mace.pb.h" + +namespace mace { +namespace kernels { + + +template +struct BatchNormFunctor { + void operator()(const float* input, + const float* scale, + const float* offset, + const float* mean, + const float* var, + const int n, + const int channel, + const int sample_size, + const float variance_epsilon, + float* output) ; +}; + +template<> +struct BatchNormFunctor { + void operator()(const float* input, + const float* scale, + const float* offset, + const float* mean, + const float* var, + const int n, + const int channel, + const int sample_size, + const float variance_epsilon, + float* output) { + // Batch normalization in the paper https://arxiv.org/abs/1502.03167 . + // The calculation formula for inference is + // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + + // ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} } + // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } + // new_offset = \offset - mean * common_val; + // Y = new_scale * X + new_offset; + float new_scale, new_offset; + for (int c = 0; c < channel; ++c) { + new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); + new_offset = offset[c] - mean[c] * new_scale; + + for (int i = 0; i < n; ++i) { + int pos = i * channel * sample_size + c * sample_size; + const float* input_sample_ptr = input + pos; + float* output_sample_ptr = output + pos; + for (int j = 0; j < sample_size; ++j) { + output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset; + } + } + } + } +}; + +} // namepsace kernels +} // namespace mace + +#endif // MACE_KERNELS_BATCH_NORM_H_ diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc new file mode 100644 index 00000000..2fbf6ece --- /dev/null +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -0,0 +1,66 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#if __ARM_NEON +#include +#include "mace/kernels/batch_norm.h" + +namespace mace { +namespace kernels { + +template<> +struct BatchNormFunctor { + void operator()(const float* input, + const float* scale, + const float* offset, + const float* mean, + const float* var, + const int n, + const int channel, + const int sample_size, + const float variance_epsilon, + float* output) { + + // Batch normalization in the paper https://arxiv.org/abs/1502.03167 . + // The calculation formula for inference is + // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + + // ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} } + // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } + // new_offset = \offset - mean * common_val; + // Y = new_scale * X + new_offset; + float new_scale, new_offset; + int count = sample_size >> 2; + int remain_count = sample_size - count; + for (int c = 0; c < channel; ++c) { + new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); + new_offset = offset[c] - mean[c] * new_scale; + + float32x4_t new_scale_f = vdupq_n_f32(new_scale); + float32x4_t new_offset_f = vdupq_n_f32(new_offset); + for (int i = 0; i < n; ++i) { + int pos = i * channel * sample_size + c * sample_size; + const float* input_sample_ptr = input + pos; + float* output_sample_ptr = output + pos; + + for(int j = 0; j < count; ++j) { + float32x4_t input_f = vld1q_f32(input_sample_ptr); + float32x4_t output_f = new_offset_f; + output_f = vfmaq_f32(output_f, input_f, new_scale_f); + vst1q_f32(output_sample_ptr, output_f); + input_sample_ptr += 4; + output_sample_ptr += 4; + } + for(int j = 0; j < remain_count; ++j) { + *output_sample_ptr = new_scale * *input_sample_ptr + new_offset; + ++output_sample_ptr; + ++input_sample_ptr; + } + } + } + } +}; + +} // namespace kernels +} // namespace mace +#endif // __ARM_NEON diff --git a/mace/ops/batch_norm.cc b/mace/ops/batch_norm.cc new file mode 100644 index 00000000..09d0e300 --- /dev/null +++ b/mace/ops/batch_norm.cc @@ -0,0 +1,16 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/batch_norm.h" +#include "mace/proto/mace.pb.h" + +namespace mace { + +REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp); + +#if __ARM_NEON +REGISTER_NEON_OPERATOR(BatchNorm, BatchNormOp); +#endif // __ARM_NEON + +} // namespace mace \ No newline at end of file diff --git a/mace/ops/batch_norm.h b/mace/ops/batch_norm.h new file mode 100644 index 00000000..a2e175a7 --- /dev/null +++ b/mace/ops/batch_norm.h @@ -0,0 +1,58 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_BATCH_NORM_H_ +#define MACE_BATCH_NORM_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/batch_norm.h" + +namespace mace { + +template +class BatchNormOp : public Operator { + public: + BatchNormOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) {} + + bool Run() override { + const Tensor* input = this->Input(0); + const Tensor* scale = this->Input(1); + const Tensor* offset = this->Input(2); + const Tensor* mean = this->Input(3); + const Tensor* var = this->Input(4); + + const float variance_epsilon = this->template GetSingleArgument("variance_epsilon", 1e-4); + + REQUIRE(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size()); + REQUIRE(scale->dim_size() == 1, "scale must be 1-dimensional. ", scale->dim_size()); + REQUIRE(offset->dim_size() == 1, "offset must be 1-dimensional. ", offset->dim_size()); + REQUIRE(mean->dim_size() == 1, "mean must be 1-dimensional. ", mean->dim_size()); + REQUIRE(var->dim_size() == 1, "var must be 1-dimensional. ", var->dim_size()); + + Tensor* output = this->Output(0); + output->ResizeLike(input); + + const int n = input->dim32(0); + const int channel = input->dim32(1); + const int sample_size = input->dim32(2) * input->dim32(3); + + const float* input_ptr = input->data(); + const float* scale_ptr = scale->data(); + const float* offset_ptr = offset->data(); + const float* mean_ptr = mean->data(); + const float* var_ptr = var->data(); + float* output_ptr = output->mutable_data(); + + kernels::BatchNormFunctor()(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, + n, channel, sample_size, + variance_epsilon, output_ptr); + return true; + } + +}; + +} // namespace mace + +#endif // MACE_BATCH_NORM_H_ diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc new file mode 100644 index 00000000..5b52d059 --- /dev/null +++ b/mace/ops/batch_norm_test.cc @@ -0,0 +1,46 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { + +class BatchNormOpTest : public OpsTestBase {}; + +TEST_F(BatchNormOpTest, Simple) { + // Construct graph + OpDefBuilder("BatchNorm", "BatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Output("Output") + .Finalize(operator_def()); + + // Add input data + AddInputFromArray("Input", {1, 1, 6, 2}, + {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); + AddInputFromArray("Scale", {2}, + {4.0f, 4.0f}); + AddInputFromArray("Offset", {2}, + {2.0, 2.0}); + AddInputFromArray("Mean", {2}, + {10, 10}); + AddInputFromArray("Var", {2}, + {11.67f, 11.67f}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 1, 6, 2}, + {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, + 3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.01); +} + +} diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h new file mode 100644 index 00000000..82b7063d --- /dev/null +++ b/mace/ops/ops_test_util.h @@ -0,0 +1,170 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_TEST_UTIL_H_ +#define MACE_OPS_TEST_UTIL_H_ + +#include "gtest/gtest.h" +#include "mace/core/common.h" +#include "mace/core/tensor.h" +#include "mace/core/net.h" + +namespace mace { + +class OpDefBuilder { + public: + OpDefBuilder(const char* type, const char* name) { + op_def_.set_type(type); + op_def_.set_name(name); + } + OpDefBuilder& Input(const char* input_name) { + op_def_.add_input(input_name); + return *this; + } + OpDefBuilder& Output(const char* output_name) { + op_def_.add_output(output_name); + return *this; + } + void Finalize(OperatorDef* op_def) const { + REQUIRE(op_def != NULL, "input should not be null."); + *op_def = op_def_; + } + OperatorDef op_def_; +}; + +class OpsTestBase : public ::testing::Test { + protected: + virtual void TearDown() { + auto tensor_names = ws_.Tensors(); + for (auto& name : tensor_names) { + ws_.RemoveTensor(name); + } + } + public: + template + void AddInputFromArray(const char* name, const std::vector& shape, const std::vector& data) { + Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum::v()); + input->Resize(shape); + float* input_data = input->mutable_data(); + memcpy(input_data, data.data(), data.size() * sizeof(T)); + } + + OperatorDef* operator_def() { return &op_def_; } + + bool RunOp() { + NetDef net_def; + net_def.add_op()->CopyFrom(op_def_); + VLOG(0) << net_def.DebugString(); + auto net = CreateNet(net_def, &ws_, DeviceType::CPU); + return net->Run(); + } + + Tensor* GetOutput(const char* output_name) { + return ws_.GetTensor(output_name); + } + + private: + Workspace ws_; + OperatorDef op_def_; +}; + +template +Tensor CreateTensor(const std::vector& shape, const std::vector& data) { + Tensor res(cpu_allocator(), DataTypeToEnum::v()); + res.Resize(shape); + float* input_data = res.mutable_data(); + memcpy(input_data, data.data(), data.size() * sizeof(T)); + return res; +} + +inline bool IsSameSize(const Tensor& x, const Tensor& y) { + if (x.dim_size() != y.dim_size()) return false; + for (int d = 0; d < x.dim_size(); ++d) { + if (x.dim(d) != y.dim(d)) return false; + } + return true; +} + +inline std::string ShapeToString(const Tensor& x) { + std::stringstream stream; + for (int i = 0; i < x.dim_size(); i++) { + if (i > 0) stream<<","; + int64 dim = x.dim(i); + if (dim < 0) { + stream<<"?"; + } else { + stream< +struct is_floating_point_type { + static const bool value = std::is_same::value || + std::is_same::value; +}; + +template +inline void ExpectEqual(const T& a, const T& b) { + EXPECT_EQ(a, b); +} + +template <> +inline void ExpectEqual(const float& a, const float& b) { + EXPECT_FLOAT_EQ(a, b); +} + +template <> +inline void ExpectEqual(const double& a, const double& b) { + EXPECT_DOUBLE_EQ(a, b); +} + +inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), y.dtype()); + ASSERT_TRUE(IsSameSize(x, y)) + << "x.shape [" << ShapeToString(x) << "] vs " + << "y.shape [ " << ShapeToString(y) << "]"; +} + +template ::value> +struct Expector; +// Partial specialization for float and double. +template +struct Expector { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + auto a = x.data(); + auto b = y.data(); + for (int i = 0; i < x.size(); ++i) { + ExpectEqual(a(i), b(i)); + } + } + + static void Near(const Tensor& x, const Tensor& y, const double abs_err) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + auto a = x.data(); + auto b = y.data(); + for (int i = 0; i < x.size(); ++i) { + EXPECT_NEAR(a[i], b[i], abs_err) + << "a = " << a << " b = " << b << " index = " << i; + } + } +}; + +template +void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { + static_assert(is_floating_point_type::value, "T is not a floating point type"); + Expector::Near(x, y ,abs_err); +} + +} // namespace mace + +#endif // MACE_OPS_TEST_UTIL_H_ -- GitLab