From 72ef73a560cf90647c690fec11833bf5e3121d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Tue, 29 Aug 2017 20:40:31 +0800 Subject: [PATCH] Refactor ops and net --- mace/core/allocator.h | 13 ++++--------- mace/core/common.h | 2 ++ mace/core/net.cc | 4 ++-- mace/core/net.h | 14 ++++++++++---- mace/core/operator.cc | 4 ++-- mace/core/operator.h | 10 ++-------- mace/core/registry.h | 4 ---- mace/core/tensor.h | 16 ++++++---------- mace/core/types.h | 2 +- mace/core/workspace.cc | 4 +++- mace/core/workspace.h | 2 +- mace/mace.bzl | 2 +- mace/ops/BUILD | 9 +++++---- mace/ops/relu.h | 6 +++--- 14 files changed, 42 insertions(+), 50 deletions(-) diff --git a/mace/core/allocator.h b/mace/core/allocator.h index e4482f6b..fa4f1889 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -6,20 +6,15 @@ #ifndef MACE_CORE_ALLOCATOR_H_ #define MACE_CORE_ALLOCATOR_H_ -#include -#include #include -#include - #include "mace/core/common.h" #include "mace/proto/mace.pb.h" namespace mace { +// 16 bytes = 32 * 4 (Neon) constexpr size_t kMaceAlignment = 16; -using MemoryDeleter = std::function; - class Allocator { public: Allocator() {} @@ -44,9 +39,9 @@ class CPUAllocator: public Allocator { void* New(size_t nbytes) override { void* data = nullptr; #ifdef __ANDROID__ - data = memalign(gMaceAlignment, nbytes); + data = memalign(kMaceAlignment, nbytes); #elif defined(_MSC_VER) - data = _aligned_malloc(nbytes, gMaceAlignment); + data = _aligned_malloc(nbytes, kMaceAlignment); #else CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0); #endif @@ -72,7 +67,7 @@ CPUAllocator* cpu_allocator(); // ownership of the pointer. void SetCPUAllocator(CPUAllocator* alloc); -template +template struct DeviceContext {}; template <> diff --git a/mace/core/common.h b/mace/core/common.h index 5c24503e..ae295d3d 100644 --- a/mace/core/common.h +++ b/mace/core/common.h @@ -6,6 +6,7 @@ #define MACE_CORE_COMMON_H_ #include +#include #include #include #include @@ -15,6 +16,7 @@ #include "mace/core/logging.h" using std::set; +using std::map; using std::string; using std::unique_ptr; using std::vector; diff --git a/mace/core/net.cc b/mace/core/net.cc index 96ee4656..2933df8c 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -8,8 +8,8 @@ namespace mace { NetBase::NetBase(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type) { - + DeviceType type) + : name_(net_def->name()) { } diff --git a/mace/core/net.h b/mace/core/net.h index 37df69de..93ce98ce 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -14,7 +14,9 @@ namespace mace { class NetBase { public: - NetBase(const std::shared_ptr &net_def, Workspace* ws, DeviceType type); + NetBase(const std::shared_ptr &net_def, + Workspace* ws, + DeviceType type); virtual ~NetBase() noexcept {} virtual bool Run() = 0; @@ -31,9 +33,11 @@ class NetBase { class SimpleNet : public NetBase { public: - SimpleNet(const std::shared_ptr& net_def, Workspace* ws, DeviceType type); + SimpleNet(const std::shared_ptr& net_def, + Workspace* ws, + DeviceType type); - virtual bool Run() override; + bool Run() override; protected: vector > operators_; @@ -41,7 +45,9 @@ class SimpleNet : public NetBase { DISABLE_COPY_AND_ASSIGN(SimpleNet); }; -unique_ptr CreateNet(const NetDef& net_def, Workspace* ws, DeviceType type); +unique_ptr CreateNet(const NetDef& net_def, + Workspace* ws, + DeviceType type); unique_ptr CreateNet( const std::shared_ptr& net_def, Workspace* ws, diff --git a/mace/core/operator.cc b/mace/core/operator.cc index a233b273..0072b58a 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -6,8 +6,8 @@ namespace mace { -std::map* gDeviceTypeRegistry() { - static std::map g_device_type_registry; +std::map* gDeviceTypeRegistry() { + static std::map g_device_type_registry; return &g_device_type_registry; } diff --git a/mace/core/operator.h b/mace/core/operator.h index b079bdac..4b755526 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -58,10 +58,7 @@ class OperatorBase { inline const vector &Inputs() const { return inputs_; } inline const vector &Outputs() { return outputs_; } - virtual bool Run() { - MACE_NOT_IMPLEMENTED; - return false; - } + virtual bool Run() = 0; inline const OperatorDef &debug_def() const { REQUIRE(has_debug_def(), "operator_def was null!"); @@ -108,10 +105,7 @@ class Operator : public OperatorBase { DataTypeToEnum::v()))); } } - virtual bool Run() { - MACE_NOT_IMPLEMENTED; - return false; - } + virtual bool Run() = 0; ~Operator() noexcept override {} }; diff --git a/mace/core/registry.h b/mace/core/registry.h index 4064e7d5..f0989335 100644 --- a/mace/core/registry.h +++ b/mace/core/registry.h @@ -5,11 +5,7 @@ #ifndef MACE_CORE_REGISTRY_H_ #define MACE_CORE_REGISTRY_H_ -#include #include -#include -#include -#include "mace/core/common.h" namespace mace { diff --git a/mace/core/tensor.h b/mace/core/tensor.h index d1059971..1e15b425 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -53,7 +53,7 @@ class Tensor { size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; Tensor(Allocator* a, DataType type) - : alloc_(a), size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; + : alloc_(a), size_(0), dtype_(type), data_(nullptr) {}; ~Tensor() { if (alloc_ && data_.get()) { @@ -65,10 +65,6 @@ class Tensor { inline const vector& shape() const { return shape_; } - inline int64 NumElements() const { - return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); - } - inline TIndex dim_size() { return shape_.size(); } inline TIndex size() const { return size_; } @@ -86,10 +82,6 @@ class Tensor { return static_cast(data_.get()); } - void Deleter(void* data) { - alloc_->Delete(data); - } - inline void* raw_mutable_data() { if (data_.get() || size_ == 0) { return data_.get(); @@ -113,7 +105,7 @@ class Tensor { shape_ = shape; TIndex size = NumElements(); if (size_ != size) { - size_ = NumElements(); + size_ = size; data_.reset(); } } @@ -127,6 +119,10 @@ class Tensor { } private: + inline int64 NumElements() const { + return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); + } + Allocator* alloc_; TIndex size_; DataType dtype_; diff --git a/mace/core/types.h b/mace/core/types.h index 476fa54d..2e33a568 100644 --- a/mace/core/types.h +++ b/mace/core/types.h @@ -16,7 +16,7 @@ struct IsValidDataType; template struct DataTypeToEnum { static_assert(IsValidDataType::value, "Specified Data Type not supported"); -}; // Specializations below +}; // EnumToDataType::Type is the type for DataType constant VALUE, e.g. diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 20ff5727..14431bc6 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -15,7 +15,9 @@ vector Workspace::Tensors() const { return names; } -Tensor* Workspace::CreateTensor(const string& name, Allocator* alloc, DataType type) { +Tensor* Workspace::CreateTensor(const string& name, + Allocator* alloc, + DataType type) { if (HasTensor(name)) { VLOG(1) << "Tensor " << name << " already exists. Skipping."; } else { diff --git a/mace/core/workspace.h b/mace/core/workspace.h index 3f16077f..93043744 100644 --- a/mace/core/workspace.h +++ b/mace/core/workspace.h @@ -14,7 +14,7 @@ namespace mace { class Workspace { public: - typedef std::map> TensorMap; + typedef map> TensorMap; Workspace() {} diff --git a/mace/mace.bzl b/mace/mace.bzl index 3b57e724..f9e7b6af 100644 --- a/mace/mace.bzl +++ b/mace/mace.bzl @@ -22,4 +22,4 @@ def if_android_arm64(a): return select({ "//mace:android_arm64": a, "//conditions:default": [], - }) + }) \ No newline at end of file diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 1f06a345..9443f7a3 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -9,7 +9,7 @@ package( licenses(["notice"]) # Apache 2.0 cc_library( - name = "op", + name = "ops", srcs = ["relu.cc"], hdrs = glob(["*.h"]), deps = [ @@ -19,10 +19,11 @@ cc_library( ) cc_test( - name = "op_test", + name = "relu_test", srcs = ["relu_test.cc",], deps = [ "@gtest//:gtest", - ":op", + ":ops", ], -) \ No newline at end of file +) + diff --git a/mace/ops/relu.h b/mace/ops/relu.h index b6580506..8a0ea34d 100644 --- a/mace/ops/relu.h +++ b/mace/ops/relu.h @@ -2,8 +2,8 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#ifndef MACE_OPERATORS_RELU_H_ -#define MACE_OPERATORS_RELU_H_ +#ifndef MACE_OPS_RELU_H_ +#define MACE_OPS_RELU_H_ #include "mace/core/operator.h" @@ -19,4 +19,4 @@ class ReluOp : public Operator { } // namespace mace -#endif // MACE_OPERATORS_RELU_H_ +#endif // MACE_OPS_RELU_H_ -- GitLab