提交 3e506ffc 编写于 作者: L Liangliang He

Merge branch 'master' into 'master'

Refactor ops and net

See merge request !2
...@@ -6,20 +6,15 @@ ...@@ -6,20 +6,15 @@
#ifndef MACE_CORE_ALLOCATOR_H_ #ifndef MACE_CORE_ALLOCATOR_H_
#define MACE_CORE_ALLOCATOR_H_ #define MACE_CORE_ALLOCATOR_H_
#include <unordered_map>
#include <functional>
#include <malloc.h> #include <malloc.h>
#include <cstring>
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/proto/mace.pb.h" #include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
// 16 bytes = 32 * 4 (Neon)
constexpr size_t kMaceAlignment = 16; constexpr size_t kMaceAlignment = 16;
using MemoryDeleter = std::function<void(void* ptr)>;
class Allocator { class Allocator {
public: public:
Allocator() {} Allocator() {}
...@@ -44,9 +39,9 @@ class CPUAllocator: public Allocator { ...@@ -44,9 +39,9 @@ class CPUAllocator: public Allocator {
void* New(size_t nbytes) override { void* New(size_t nbytes) override {
void* data = nullptr; void* data = nullptr;
#ifdef __ANDROID__ #ifdef __ANDROID__
data = memalign(gMaceAlignment, nbytes); data = memalign(kMaceAlignment, nbytes);
#elif defined(_MSC_VER) #elif defined(_MSC_VER)
data = _aligned_malloc(nbytes, gMaceAlignment); data = _aligned_malloc(nbytes, kMaceAlignment);
#else #else
CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0); CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0);
#endif #endif
...@@ -72,7 +67,7 @@ CPUAllocator* cpu_allocator(); ...@@ -72,7 +67,7 @@ CPUAllocator* cpu_allocator();
// ownership of the pointer. // ownership of the pointer.
void SetCPUAllocator(CPUAllocator* alloc); void SetCPUAllocator(CPUAllocator* alloc);
template <DeviceType DT> template <DeviceType D>
struct DeviceContext {}; struct DeviceContext {};
template <> template <>
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#define MACE_CORE_COMMON_H_ #define MACE_CORE_COMMON_H_
#include <set> #include <set>
#include <map>
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -15,6 +16,7 @@ ...@@ -15,6 +16,7 @@
#include "mace/core/logging.h" #include "mace/core/logging.h"
using std::set; using std::set;
using std::map;
using std::string; using std::string;
using std::unique_ptr; using std::unique_ptr;
using std::vector; using std::vector;
......
...@@ -8,8 +8,8 @@ namespace mace { ...@@ -8,8 +8,8 @@ namespace mace {
NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def, NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type) { DeviceType type)
: name_(net_def->name()) {
} }
......
...@@ -14,7 +14,9 @@ namespace mace { ...@@ -14,7 +14,9 @@ namespace mace {
class NetBase { class NetBase {
public: public:
NetBase(const std::shared_ptr<const NetDef> &net_def, Workspace* ws, DeviceType type); NetBase(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws,
DeviceType type);
virtual ~NetBase() noexcept {} virtual ~NetBase() noexcept {}
virtual bool Run() = 0; virtual bool Run() = 0;
...@@ -31,9 +33,11 @@ class NetBase { ...@@ -31,9 +33,11 @@ class NetBase {
class SimpleNet : public NetBase { class SimpleNet : public NetBase {
public: public:
SimpleNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws, DeviceType type); SimpleNet(const std::shared_ptr<const NetDef>& net_def,
Workspace* ws,
DeviceType type);
virtual bool Run() override; bool Run() override;
protected: protected:
vector<unique_ptr<OperatorBase> > operators_; vector<unique_ptr<OperatorBase> > operators_;
...@@ -41,7 +45,9 @@ class SimpleNet : public NetBase { ...@@ -41,7 +45,9 @@ class SimpleNet : public NetBase {
DISABLE_COPY_AND_ASSIGN(SimpleNet); DISABLE_COPY_AND_ASSIGN(SimpleNet);
}; };
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws, DeviceType type); unique_ptr<NetBase> CreateNet(const NetDef& net_def,
Workspace* ws,
DeviceType type);
unique_ptr<NetBase> CreateNet( unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const NetDef>& net_def, const std::shared_ptr<const NetDef>& net_def,
Workspace* ws, Workspace* ws,
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
namespace mace { namespace mace {
std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() { std::map<int32, OperatorRegistry*>* gDeviceTypeRegistry() {
static std::map<int32_t, OperatorRegistry*> g_device_type_registry; static std::map<int32, OperatorRegistry*> g_device_type_registry;
return &g_device_type_registry; return &g_device_type_registry;
} }
......
...@@ -58,10 +58,7 @@ class OperatorBase { ...@@ -58,10 +58,7 @@ class OperatorBase {
inline const vector<const Tensor *> &Inputs() const { return inputs_; } inline const vector<const Tensor *> &Inputs() const { return inputs_; }
inline const vector<Tensor *> &Outputs() { return outputs_; } inline const vector<Tensor *> &Outputs() { return outputs_; }
virtual bool Run() { virtual bool Run() = 0;
MACE_NOT_IMPLEMENTED;
return false;
}
inline const OperatorDef &debug_def() const { inline const OperatorDef &debug_def() const {
REQUIRE(has_debug_def(), "operator_def was null!"); REQUIRE(has_debug_def(), "operator_def was null!");
...@@ -108,10 +105,7 @@ class Operator : public OperatorBase { ...@@ -108,10 +105,7 @@ class Operator : public OperatorBase {
DataTypeToEnum<T>::v()))); DataTypeToEnum<T>::v())));
} }
} }
virtual bool Run() { virtual bool Run() = 0;
MACE_NOT_IMPLEMENTED;
return false;
}
~Operator() noexcept override {} ~Operator() noexcept override {}
}; };
......
...@@ -5,11 +5,7 @@ ...@@ -5,11 +5,7 @@
#ifndef MACE_CORE_REGISTRY_H_ #ifndef MACE_CORE_REGISTRY_H_
#define MACE_CORE_REGISTRY_H_ #define MACE_CORE_REGISTRY_H_
#include <memory>
#include <mutex> #include <mutex>
#include <string>
#include <map>
#include "mace/core/common.h"
namespace mace { namespace mace {
......
...@@ -53,7 +53,7 @@ class Tensor { ...@@ -53,7 +53,7 @@ class Tensor {
size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; size_(0), dtype_(DT_FLOAT), data_(nullptr) {};
Tensor(Allocator* a, DataType type) Tensor(Allocator* a, DataType type)
: alloc_(a), size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; : alloc_(a), size_(0), dtype_(type), data_(nullptr) {};
~Tensor() { ~Tensor() {
if (alloc_ && data_.get()) { if (alloc_ && data_.get()) {
...@@ -65,10 +65,6 @@ class Tensor { ...@@ -65,10 +65,6 @@ class Tensor {
inline const vector<TIndex>& shape() const { return shape_; } inline const vector<TIndex>& shape() const { return shape_; }
inline int64 NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64>());
}
inline TIndex dim_size() { return shape_.size(); } inline TIndex dim_size() { return shape_.size(); }
inline TIndex size() const { return size_; } inline TIndex size() const { return size_; }
...@@ -86,10 +82,6 @@ class Tensor { ...@@ -86,10 +82,6 @@ class Tensor {
return static_cast<T*>(data_.get()); return static_cast<T*>(data_.get());
} }
void Deleter(void* data) {
alloc_->Delete(data);
}
inline void* raw_mutable_data() { inline void* raw_mutable_data() {
if (data_.get() || size_ == 0) { if (data_.get() || size_ == 0) {
return data_.get(); return data_.get();
...@@ -113,7 +105,7 @@ class Tensor { ...@@ -113,7 +105,7 @@ class Tensor {
shape_ = shape; shape_ = shape;
TIndex size = NumElements(); TIndex size = NumElements();
if (size_ != size) { if (size_ != size) {
size_ = NumElements(); size_ = size;
data_.reset(); data_.reset();
} }
} }
...@@ -127,6 +119,10 @@ class Tensor { ...@@ -127,6 +119,10 @@ class Tensor {
} }
private: private:
inline int64 NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64>());
}
Allocator* alloc_; Allocator* alloc_;
TIndex size_; TIndex size_;
DataType dtype_; DataType dtype_;
......
...@@ -16,7 +16,7 @@ struct IsValidDataType; ...@@ -16,7 +16,7 @@ struct IsValidDataType;
template <class T> template <class T>
struct DataTypeToEnum { struct DataTypeToEnum {
static_assert(IsValidDataType<T>::value, "Specified Data Type not supported"); static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
}; // Specializations below };
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g. // EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
......
...@@ -15,7 +15,9 @@ vector<string> Workspace::Tensors() const { ...@@ -15,7 +15,9 @@ vector<string> Workspace::Tensors() const {
return names; return names;
} }
Tensor* Workspace::CreateTensor(const string& name, Allocator* alloc, DataType type) { Tensor* Workspace::CreateTensor(const string& name,
Allocator* alloc,
DataType type) {
if (HasTensor(name)) { if (HasTensor(name)) {
VLOG(1) << "Tensor " << name << " already exists. Skipping."; VLOG(1) << "Tensor " << name << " already exists. Skipping.";
} else { } else {
......
...@@ -14,7 +14,7 @@ namespace mace { ...@@ -14,7 +14,7 @@ namespace mace {
class Workspace { class Workspace {
public: public:
typedef std::map<string, unique_ptr<Tensor>> TensorMap; typedef map<string, unique_ptr<Tensor>> TensorMap;
Workspace() {} Workspace() {}
......
...@@ -22,4 +22,4 @@ def if_android_arm64(a): ...@@ -22,4 +22,4 @@ def if_android_arm64(a):
return select({ return select({
"//mace:android_arm64": a, "//mace:android_arm64": a,
"//conditions:default": [], "//conditions:default": [],
}) })
\ No newline at end of file
...@@ -9,7 +9,7 @@ package( ...@@ -9,7 +9,7 @@ package(
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
cc_library( cc_library(
name = "op", name = "ops",
srcs = ["relu.cc"], srcs = ["relu.cc"],
hdrs = glob(["*.h"]), hdrs = glob(["*.h"]),
deps = [ deps = [
...@@ -19,10 +19,11 @@ cc_library( ...@@ -19,10 +19,11 @@ cc_library(
) )
cc_test( cc_test(
name = "op_test", name = "relu_test",
srcs = ["relu_test.cc",], srcs = ["relu_test.cc",],
deps = [ deps = [
"@gtest//:gtest", "@gtest//:gtest",
":op", ":ops",
], ],
) )
\ No newline at end of file
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#ifndef MACE_OPERATORS_RELU_H_ #ifndef MACE_OPS_RELU_H_
#define MACE_OPERATORS_RELU_H_ #define MACE_OPS_RELU_H_
#include "mace/core/operator.h" #include "mace/core/operator.h"
...@@ -19,4 +19,4 @@ class ReluOp : public Operator<D, T> { ...@@ -19,4 +19,4 @@ class ReluOp : public Operator<D, T> {
} // namespace mace } // namespace mace
#endif // MACE_OPERATORS_RELU_H_ #endif // MACE_OPS_RELU_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册