提交 72ef73a5 编写于 作者: 李寅

Refactor ops and net

上级 4656b708
......@@ -6,20 +6,15 @@
#ifndef MACE_CORE_ALLOCATOR_H_
#define MACE_CORE_ALLOCATOR_H_
#include <unordered_map>
#include <functional>
#include <malloc.h>
#include <cstring>
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
namespace mace {
// 16 bytes = 32 * 4 (Neon)
constexpr size_t kMaceAlignment = 16;
using MemoryDeleter = std::function<void(void* ptr)>;
class Allocator {
public:
Allocator() {}
......@@ -44,9 +39,9 @@ class CPUAllocator: public Allocator {
void* New(size_t nbytes) override {
void* data = nullptr;
#ifdef __ANDROID__
data = memalign(gMaceAlignment, nbytes);
data = memalign(kMaceAlignment, nbytes);
#elif defined(_MSC_VER)
data = _aligned_malloc(nbytes, gMaceAlignment);
data = _aligned_malloc(nbytes, kMaceAlignment);
#else
CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0);
#endif
......@@ -72,7 +67,7 @@ CPUAllocator* cpu_allocator();
// ownership of the pointer.
void SetCPUAllocator(CPUAllocator* alloc);
template <DeviceType DT>
template <DeviceType D>
struct DeviceContext {};
template <>
......
......@@ -6,6 +6,7 @@
#define MACE_CORE_COMMON_H_
#include <set>
#include <map>
#include <string>
#include <memory>
#include <vector>
......@@ -15,6 +16,7 @@
#include "mace/core/logging.h"
using std::set;
using std::map;
using std::string;
using std::unique_ptr;
using std::vector;
......
......@@ -8,8 +8,8 @@ namespace mace {
NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type) {
DeviceType type)
: name_(net_def->name()) {
}
......
......@@ -14,7 +14,9 @@ namespace mace {
class NetBase {
public:
NetBase(const std::shared_ptr<const NetDef> &net_def, Workspace* ws, DeviceType type);
NetBase(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws,
DeviceType type);
virtual ~NetBase() noexcept {}
virtual bool Run() = 0;
......@@ -31,9 +33,11 @@ class NetBase {
class SimpleNet : public NetBase {
public:
SimpleNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws, DeviceType type);
SimpleNet(const std::shared_ptr<const NetDef>& net_def,
Workspace* ws,
DeviceType type);
virtual bool Run() override;
bool Run() override;
protected:
vector<unique_ptr<OperatorBase> > operators_;
......@@ -41,7 +45,9 @@ class SimpleNet : public NetBase {
DISABLE_COPY_AND_ASSIGN(SimpleNet);
};
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws, DeviceType type);
unique_ptr<NetBase> CreateNet(const NetDef& net_def,
Workspace* ws,
DeviceType type);
unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const NetDef>& net_def,
Workspace* ws,
......
......@@ -6,8 +6,8 @@
namespace mace {
std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
static std::map<int32_t, OperatorRegistry*> g_device_type_registry;
std::map<int32, OperatorRegistry*>* gDeviceTypeRegistry() {
static std::map<int32, OperatorRegistry*> g_device_type_registry;
return &g_device_type_registry;
}
......
......@@ -58,10 +58,7 @@ class OperatorBase {
inline const vector<const Tensor *> &Inputs() const { return inputs_; }
inline const vector<Tensor *> &Outputs() { return outputs_; }
virtual bool Run() {
MACE_NOT_IMPLEMENTED;
return false;
}
virtual bool Run() = 0;
inline const OperatorDef &debug_def() const {
REQUIRE(has_debug_def(), "operator_def was null!");
......@@ -108,10 +105,7 @@ class Operator : public OperatorBase {
DataTypeToEnum<T>::v())));
}
}
virtual bool Run() {
MACE_NOT_IMPLEMENTED;
return false;
}
virtual bool Run() = 0;
~Operator() noexcept override {}
};
......
......@@ -5,11 +5,7 @@
#ifndef MACE_CORE_REGISTRY_H_
#define MACE_CORE_REGISTRY_H_
#include <memory>
#include <mutex>
#include <string>
#include <map>
#include "mace/core/common.h"
namespace mace {
......
......@@ -53,7 +53,7 @@ class Tensor {
size_(0), dtype_(DT_FLOAT), data_(nullptr) {};
Tensor(Allocator* a, DataType type)
: alloc_(a), size_(0), dtype_(DT_FLOAT), data_(nullptr) {};
: alloc_(a), size_(0), dtype_(type), data_(nullptr) {};
~Tensor() {
if (alloc_ && data_.get()) {
......@@ -65,10 +65,6 @@ class Tensor {
inline const vector<TIndex>& shape() const { return shape_; }
inline int64 NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64>());
}
inline TIndex dim_size() { return shape_.size(); }
inline TIndex size() const { return size_; }
......@@ -86,10 +82,6 @@ class Tensor {
return static_cast<T*>(data_.get());
}
void Deleter(void* data) {
alloc_->Delete(data);
}
inline void* raw_mutable_data() {
if (data_.get() || size_ == 0) {
return data_.get();
......@@ -113,7 +105,7 @@ class Tensor {
shape_ = shape;
TIndex size = NumElements();
if (size_ != size) {
size_ = NumElements();
size_ = size;
data_.reset();
}
}
......@@ -127,6 +119,10 @@ class Tensor {
}
private:
inline int64 NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64>());
}
Allocator* alloc_;
TIndex size_;
DataType dtype_;
......
......@@ -16,7 +16,7 @@ struct IsValidDataType;
template <class T>
struct DataTypeToEnum {
static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
}; // Specializations below
};
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
......
......@@ -15,7 +15,9 @@ vector<string> Workspace::Tensors() const {
return names;
}
Tensor* Workspace::CreateTensor(const string& name, Allocator* alloc, DataType type) {
Tensor* Workspace::CreateTensor(const string& name,
Allocator* alloc,
DataType type) {
if (HasTensor(name)) {
VLOG(1) << "Tensor " << name << " already exists. Skipping.";
} else {
......
......@@ -14,7 +14,7 @@ namespace mace {
class Workspace {
public:
typedef std::map<string, unique_ptr<Tensor>> TensorMap;
typedef map<string, unique_ptr<Tensor>> TensorMap;
Workspace() {}
......
......@@ -22,4 +22,4 @@ def if_android_arm64(a):
return select({
"//mace:android_arm64": a,
"//conditions:default": [],
})
})
\ No newline at end of file
......@@ -9,7 +9,7 @@ package(
licenses(["notice"]) # Apache 2.0
cc_library(
name = "op",
name = "ops",
srcs = ["relu.cc"],
hdrs = glob(["*.h"]),
deps = [
......@@ -19,10 +19,11 @@ cc_library(
)
cc_test(
name = "op_test",
name = "relu_test",
srcs = ["relu_test.cc",],
deps = [
"@gtest//:gtest",
":op",
":ops",
],
)
\ No newline at end of file
)
......@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPERATORS_RELU_H_
#define MACE_OPERATORS_RELU_H_
#ifndef MACE_OPS_RELU_H_
#define MACE_OPS_RELU_H_
#include "mace/core/operator.h"
......@@ -19,4 +19,4 @@ class ReluOp : public Operator<D, T> {
} // namespace mace
#endif // MACE_OPERATORS_RELU_H_
#endif // MACE_OPS_RELU_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册