diff --git a/mace/core/allocator.h b/mace/core/allocator.h index e4482f6bc0a33521ef95a37c0c7d4e6f63ee27cb..fa4f188983801d005da46c64c672d4aa3dd9e910 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -6,20 +6,15 @@ #ifndef MACE_CORE_ALLOCATOR_H_ #define MACE_CORE_ALLOCATOR_H_ -#include -#include #include -#include - #include "mace/core/common.h" #include "mace/proto/mace.pb.h" namespace mace { +// 16 bytes = 32 * 4 (Neon) constexpr size_t kMaceAlignment = 16; -using MemoryDeleter = std::function; - class Allocator { public: Allocator() {} @@ -44,9 +39,9 @@ class CPUAllocator: public Allocator { void* New(size_t nbytes) override { void* data = nullptr; #ifdef __ANDROID__ - data = memalign(gMaceAlignment, nbytes); + data = memalign(kMaceAlignment, nbytes); #elif defined(_MSC_VER) - data = _aligned_malloc(nbytes, gMaceAlignment); + data = _aligned_malloc(nbytes, kMaceAlignment); #else CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0); #endif @@ -72,7 +67,7 @@ CPUAllocator* cpu_allocator(); // ownership of the pointer. void SetCPUAllocator(CPUAllocator* alloc); -template +template struct DeviceContext {}; template <> diff --git a/mace/core/common.h b/mace/core/common.h index 5c24503e549a6503f87165d0a6b9c4fe849ca06f..ae295d3d0aacdabaf60503d9686c7b1a0344b6cc 100644 --- a/mace/core/common.h +++ b/mace/core/common.h @@ -6,6 +6,7 @@ #define MACE_CORE_COMMON_H_ #include +#include #include #include #include @@ -15,6 +16,7 @@ #include "mace/core/logging.h" using std::set; +using std::map; using std::string; using std::unique_ptr; using std::vector; diff --git a/mace/core/net.cc b/mace/core/net.cc index 96ee4656f0e574e6bdfbb2eeac68b6ef066a54f4..2933df8c4aec358afc5f15637747ecb92ae4a1ca 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -8,8 +8,8 @@ namespace mace { NetBase::NetBase(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type) { - + DeviceType type) + : name_(net_def->name()) { } diff --git a/mace/core/net.h b/mace/core/net.h index 37df69de1649d3df2f6c1ed9ea0b6622c5c9eb8a..93ce98ce78ee54b0d92c64e4b237d9d185f4c67a 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -14,7 +14,9 @@ namespace mace { class NetBase { public: - NetBase(const std::shared_ptr &net_def, Workspace* ws, DeviceType type); + NetBase(const std::shared_ptr &net_def, + Workspace* ws, + DeviceType type); virtual ~NetBase() noexcept {} virtual bool Run() = 0; @@ -31,9 +33,11 @@ class NetBase { class SimpleNet : public NetBase { public: - SimpleNet(const std::shared_ptr& net_def, Workspace* ws, DeviceType type); + SimpleNet(const std::shared_ptr& net_def, + Workspace* ws, + DeviceType type); - virtual bool Run() override; + bool Run() override; protected: vector > operators_; @@ -41,7 +45,9 @@ class SimpleNet : public NetBase { DISABLE_COPY_AND_ASSIGN(SimpleNet); }; -unique_ptr CreateNet(const NetDef& net_def, Workspace* ws, DeviceType type); +unique_ptr CreateNet(const NetDef& net_def, + Workspace* ws, + DeviceType type); unique_ptr CreateNet( const std::shared_ptr& net_def, Workspace* ws, diff --git a/mace/core/operator.cc b/mace/core/operator.cc index a233b273859d8fa90a493e14c804281418d6ee2c..0072b58add999b1b60c7a1ea0a3bef0172931909 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -6,8 +6,8 @@ namespace mace { -std::map* gDeviceTypeRegistry() { - static std::map g_device_type_registry; +std::map* gDeviceTypeRegistry() { + static std::map g_device_type_registry; return &g_device_type_registry; } diff --git a/mace/core/operator.h b/mace/core/operator.h index b079bdac4f6fb88fcf702af640dc6d4e9110a142..4b7555262c115b6a44215bb9034f1b80718d78f9 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -58,10 +58,7 @@ class OperatorBase { inline const vector &Inputs() const { return inputs_; } inline const vector &Outputs() { return outputs_; } - virtual bool Run() { - MACE_NOT_IMPLEMENTED; - return false; - } + virtual bool Run() = 0; inline const OperatorDef &debug_def() const { REQUIRE(has_debug_def(), "operator_def was null!"); @@ -108,10 +105,7 @@ class Operator : public OperatorBase { DataTypeToEnum::v()))); } } - virtual bool Run() { - MACE_NOT_IMPLEMENTED; - return false; - } + virtual bool Run() = 0; ~Operator() noexcept override {} }; diff --git a/mace/core/registry.h b/mace/core/registry.h index 4064e7d501bb49b0f755adf848fbdf14b6eb7af7..f098933582d3a09f322a822c43f0ea0b195c4053 100644 --- a/mace/core/registry.h +++ b/mace/core/registry.h @@ -5,11 +5,7 @@ #ifndef MACE_CORE_REGISTRY_H_ #define MACE_CORE_REGISTRY_H_ -#include #include -#include -#include -#include "mace/core/common.h" namespace mace { diff --git a/mace/core/tensor.h b/mace/core/tensor.h index d1059971a249dc9194a899665ac5d5f25ff94a92..1e15b425c9475491bd5af29fc08110b4adefac53 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -53,7 +53,7 @@ class Tensor { size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; Tensor(Allocator* a, DataType type) - : alloc_(a), size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; + : alloc_(a), size_(0), dtype_(type), data_(nullptr) {}; ~Tensor() { if (alloc_ && data_.get()) { @@ -65,10 +65,6 @@ class Tensor { inline const vector& shape() const { return shape_; } - inline int64 NumElements() const { - return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); - } - inline TIndex dim_size() { return shape_.size(); } inline TIndex size() const { return size_; } @@ -86,10 +82,6 @@ class Tensor { return static_cast(data_.get()); } - void Deleter(void* data) { - alloc_->Delete(data); - } - inline void* raw_mutable_data() { if (data_.get() || size_ == 0) { return data_.get(); @@ -113,7 +105,7 @@ class Tensor { shape_ = shape; TIndex size = NumElements(); if (size_ != size) { - size_ = NumElements(); + size_ = size; data_.reset(); } } @@ -127,6 +119,10 @@ class Tensor { } private: + inline int64 NumElements() const { + return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); + } + Allocator* alloc_; TIndex size_; DataType dtype_; diff --git a/mace/core/types.h b/mace/core/types.h index 476fa54dafb3885d565b753ccd7bd520bbeccb9e..2e33a5687aeaefc0bbb00bd45d0f8241937d339e 100644 --- a/mace/core/types.h +++ b/mace/core/types.h @@ -16,7 +16,7 @@ struct IsValidDataType; template struct DataTypeToEnum { static_assert(IsValidDataType::value, "Specified Data Type not supported"); -}; // Specializations below +}; // EnumToDataType::Type is the type for DataType constant VALUE, e.g. diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 20ff5727d937c44566874f9d7b8b3df718153fa4..14431bc66229c72f9600c182c9081a7b7af7dbe3 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -15,7 +15,9 @@ vector Workspace::Tensors() const { return names; } -Tensor* Workspace::CreateTensor(const string& name, Allocator* alloc, DataType type) { +Tensor* Workspace::CreateTensor(const string& name, + Allocator* alloc, + DataType type) { if (HasTensor(name)) { VLOG(1) << "Tensor " << name << " already exists. Skipping."; } else { diff --git a/mace/core/workspace.h b/mace/core/workspace.h index 3f16077fbd8a7940030d17a6f0670f9030e56004..93043744fc275b19430f0457ea918646c2dbf9fc 100644 --- a/mace/core/workspace.h +++ b/mace/core/workspace.h @@ -14,7 +14,7 @@ namespace mace { class Workspace { public: - typedef std::map> TensorMap; + typedef map> TensorMap; Workspace() {} diff --git a/mace/mace.bzl b/mace/mace.bzl index 3b57e72421beaecbdc19deb9a1cfb0174fd11831..f9e7b6afc50d2908eef34292f522a0f3c4946c75 100644 --- a/mace/mace.bzl +++ b/mace/mace.bzl @@ -22,4 +22,4 @@ def if_android_arm64(a): return select({ "//mace:android_arm64": a, "//conditions:default": [], - }) + }) \ No newline at end of file diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 1f06a345ee83fb4a196eb6c1cd987a6ba7b38069..9443f7a39c71473f0e9e6a37980703a0bab4d755 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -9,7 +9,7 @@ package( licenses(["notice"]) # Apache 2.0 cc_library( - name = "op", + name = "ops", srcs = ["relu.cc"], hdrs = glob(["*.h"]), deps = [ @@ -19,10 +19,11 @@ cc_library( ) cc_test( - name = "op_test", + name = "relu_test", srcs = ["relu_test.cc",], deps = [ "@gtest//:gtest", - ":op", + ":ops", ], -) \ No newline at end of file +) + diff --git a/mace/ops/relu.h b/mace/ops/relu.h index b658050652f4a04d743ba80a078f14ebb2219f9c..8a0ea34df62509f103172fceadab473ed7cdda45 100644 --- a/mace/ops/relu.h +++ b/mace/ops/relu.h @@ -2,8 +2,8 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#ifndef MACE_OPERATORS_RELU_H_ -#define MACE_OPERATORS_RELU_H_ +#ifndef MACE_OPS_RELU_H_ +#define MACE_OPS_RELU_H_ #include "mace/core/operator.h" @@ -19,4 +19,4 @@ class ReluOp : public Operator { } // namespace mace -#endif // MACE_OPERATORS_RELU_H_ +#endif // MACE_OPS_RELU_H_