提交 37e5bd43 编写于 作者: 李寅

Add Ops, Net, and Dummy Relu

上级 c418edb2
workspace(name = "mace")
# proto_library rules implicitly depend on @com_google_protobuf//:protoc,
# which is the proto-compiler.
# This statement defines the @com_google_protobuf repo.
http_archive(
name = "com_google_protobuf",
urls = ["https://github.com/google/protobuf/archive/b4b0e304be5a68de3d0ee1af9b286f958750f5e4.zip"],
strip_prefix = "protobuf-b4b0e304be5a68de3d0ee1af9b286f958750f5e4",
sha256 = "ff771a662fb6bd4d3cc209bcccedef3e93980a49f71df1e987f6afa3bcdcba3a",
)
# cc_proto_library rules implicitly depend on @com_google_protobuf_cc//:cc_toolchain,
# which is the C++ proto runtime (base classes and common utilities).
http_archive(
name = "com_google_protobuf_cc",
urls = ["https://github.com/google/protobuf/archive/b4b0e304be5a68de3d0ee1af9b286f958750f5e4.zip"],
strip_prefix = "protobuf-b4b0e304be5a68de3d0ee1af9b286f958750f5e4",
sha256 = "ff771a662fb6bd4d3cc209bcccedef3e93980a49f71df1e987f6afa3bcdcba3a",
)
new_http_archive(
name = "gtest",
url = "https://github.com/google/googletest/archive/release-1.8.0.zip",
sha256 = "f3ed3b58511efd272eb074a3a6d6fb79d7c2e6a0e374323d1e6bcbcc1ef141bf",
build_file = "mace/third_party/gtest.BUILD",
strip_prefix = "googletest-release-1.8.0/googletest",
)
# Set up Android NDK
android_ndk_repository(
name = "androidndk",
......
package(default_visibility = ["//visibility:public"])
# Description:
# Mace core.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "lib_core",
hdrs = [
"logging.h"
],
srcs = [
],
)
name = "core",
srcs = glob(["*.cc"]),
hdrs = glob(["*.h"]),
deps = [
"//mace/proto:cc_proto",
],
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/allocator.h"
namespace mace {
static std::unique_ptr<CPUAllocator> g_cpu_allocator(new CPUAllocator());
CPUAllocator* cpu_allocator() {
return g_cpu_allocator.get();
}
void SetCPUAllocator(CPUAllocator* alloc) {
g_cpu_allocator.reset(alloc);
}
} // namespace mace
//
// Created by liyin on 8/28/17.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_ALLOCATOR_H_
#define MACE_CORE_ALLOCATOR_H_
#include <unordered_map>
#include <functional>
#include <malloc.h>
#include <cstring>
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
namespace mace {
constexpr size_t kMaceAlignment = 16;
using MemoryDeleter = std::function<void(void* ptr)>;
class Allocator {
public:
Allocator() {}
virtual ~Allocator() noexcept {}
virtual void* New(size_t nbytes) = 0;
virtual void Delete(void* data) = 0;
template <typename T>
T* New(size_t num_elements) {
if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) {
return NULL;
}
void* p = New(sizeof(T) * num_elements);
T* typed_p = reinterpret_cast<T*>(p);
return typed_p;
}
};
class CPUAllocator: public Allocator {
public:
~CPUAllocator() override {}
void* New(size_t nbytes) override {
void* data = nullptr;
#ifdef __ANDROID__
data = memalign(gMaceAlignment, nbytes);
#elif defined(_MSC_VER)
data = _aligned_malloc(nbytes, gMaceAlignment);
#else
CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0);
#endif
CHECK_NOTNULL(data);
memset(data, 0, nbytes);
return data;
}
#ifdef _MSC_VER
void Delete(void* data) {
_aligned_free(data);
}
#else
void Delete(void* data) {
free(data);
}
#endif
};
// Get the CPU Alloctor.
CPUAllocator* cpu_allocator();
// Sets the CPU allocator to the given allocator: the caller gives away the
// ownership of the pointer.
void SetCPUAllocator(CPUAllocator* alloc);
template <DeviceType DT>
struct DeviceContext {};
template <>
struct DeviceContext<DeviceType::CPU> {
static Allocator* alloctor() { return cpu_allocator(); }
};
} // namespace mace
#endif // MACE_CORE_ALLOCATOR_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_COMMON_H_
#define MACE_CORE_COMMON_H_
#include <set>
#include <string>
#include <memory>
#include <vector>
#include <algorithm>
#include "mace/core/integral_types.h"
#include "mace/core/logging.h"
using std::set;
using std::string;
using std::unique_ptr;
using std::vector;
typedef int64 TIndex;
// Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname&) = delete; \
classname& operator=(const classname&) = delete
#endif
#define MACE_NOT_IMPLEMENTED REQUIRE(false, "not implemented")
#endif // MACE_CORE_COMMON_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_INTEGRAL_TYPES_H_
#define MACE_CORE_INTEGRAL_TYPES_H_
typedef signed char int8;
typedef short int16;
typedef int int32;
typedef long long int64;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long long uint64;
#endif // MACE_CORE_INTEGRAL_TYPES_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/logging.h"
#include <stdlib.h>
#if defined(PLATFORM_POSIX_ANDROID)
#include <android/log.h>
#include <iostream>
#include <sstream>
#endif
namespace mace {
namespace internal {
LogMessage::LogMessage(const char* fname, int line, int severity)
: fname_(fname), line_(line), severity_(severity) {}
#if defined(PLATFORM_POSIX_ANDROID)
void LogMessage::GenerateLogMessage() {
int android_log_level;
switch (severity_) {
case INFO:
android_log_level = ANDROID_LOG_INFO;
break;
case WARNING:
android_log_level = ANDROID_LOG_WARN;
break;
case ERROR:
android_log_level = ANDROID_LOG_ERROR;
break;
case FATAL:
android_log_level = ANDROID_LOG_FATAL;
break;
default:
if (severity_ < INFO) {
android_log_level = ANDROID_LOG_VERBOSE;
} else {
android_log_level = ANDROID_LOG_ERROR;
}
break;
}
std::stringstream ss;
const char* const partial_name = strrchr(fname_, '/');
ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_
<< " " << str();
__android_log_write(android_log_level, "native", ss.str().c_str());
// Also log to stderr (for standalone Android apps).
std::cerr << "native : " << ss.str() << std::endl;
// Android logging at level FATAL does not terminate execution, so abort()
// is still required to stop the program.
if (severity_ == FATAL) {
abort();
}
}
#else
void LogMessage::GenerateLogMessage() {
fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], fname_, line_, str().c_str());
}
#endif
namespace {
// Parse log level (int64) from environment variable (char*)
int64 LogLevelStrToInt(const char* tf_env_var_val) {
if (tf_env_var_val == nullptr) {
return 0;
}
// Ideally we would use env_var / safe_strto64, but it is
// hard to use here without pulling in a lot of dependencies,
// so we use std:istringstream instead
string min_log_level(tf_env_var_val);
std::istringstream ss(min_log_level);
int64 level;
if (!(ss >> level)) {
// Invalid vlog level setting, set level to default (0)
level = 0;
}
return level;
}
int64 MinLogLevelFromEnv() {
const char* tf_env_var_val = getenv("MACE_CPP_MIN_LOG_LEVEL");
return LogLevelStrToInt(tf_env_var_val);
}
int64 MinVLogLevelFromEnv() {
const char* tf_env_var_val = getenv("MACE_CPP_MIN_VLOG_LEVEL");
return LogLevelStrToInt(tf_env_var_val);
}
} // namespace
LogMessage::~LogMessage() {
// Read the min log level once during the first call to logging.
static int64 min_log_level = MinLogLevelFromEnv();
if (severity_ >= min_log_level) GenerateLogMessage();
}
int64 LogMessage::MinVLogLevel() {
static int64 min_vlog_level = MinVLogLevelFromEnv();
return min_vlog_level;
}
LogMessageFatal::LogMessageFatal(const char* file, int line)
: LogMessage(file, line, FATAL) {}
LogMessageFatal::~LogMessageFatal() {
// abort() ensures we don't return (we promised we would not via
// ATTRIBUTE_NORETURN).
GenerateLogMessage();
abort();
}
} // namespace internal
} // namespace mace
#ifndef MACE_COMMON_LOGGING_H_
#define MACE_COMMON_LOGGING_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifdef __ANDROID__
#include <android/log.h>
#else
#include <cstdio>
#endif
#ifndef MACE_CORE_LOGGING_H_
#define MACE_CORE_LOGGING_H_
namespace mace {
#include <sstream>
#include <limits>
#include <string>
#include "mace/core/integral_types.h"
const int FATAL = 0;
const int ERROR = 1;
const int WARN = 2;
const int INFO = 3;
const int DEBUG = 4;
const int VERBOSE = 5;
#undef ERROR
namespace mace {
const int INFO = 0; // base_logging::INFO;
const int WARNING = 1; // base_logging::WARNING;
const int ERROR = 2; // base_logging::ERROR;
const int FATAL = 3; // base_logging::FATAL;
const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES;
namespace internal {
const char *kTag = "MACE";
using std::string;
inline void MakeStringInternal(std::stringstream& /*ss*/) {}
template <typename T>
inline void MakeStringInternal(std::stringstream& ss, const T& t) {
ss << t;
}
template <typename T, typename... Args>
inline void
MakeStringInternal(std::stringstream& ss, const T& t, const Args&... args) {
MakeStringInternal(ss, t);
MakeStringInternal(ss, args...);
}
template <typename... Args>
string MakeString(const Args&... args) {
std::stringstream ss;
MakeStringInternal(ss, args...);
return string(ss.str());
}
// Specializations for already-a-string types.
template <>
inline string MakeString(const string& str) {
return str;
}
inline string MakeString(const char* c_str) {
return string(c_str);
}
class LogMessage : public std::basic_ostringstream<char> {
public:
LogMessage(const char* fname, int line, int severity);
~LogMessage();
// Returns the minimum log level for VLOG statements.
// E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output,
// but VLOG(3) will not. Defaults to 0.
static int64 MinVLogLevel();
protected:
void GenerateLogMessage();
#ifdef __ANDROID__
private:
const char* fname_;
int line_;
int severity_;
};
// LogMessageFatal ensures the process will exit in failure after
// logging this message.
class LogMessageFatal : public LogMessage {
public:
LogMessageFatal(const char* file, int line);
~LogMessageFatal();
};
#define _MACE_LOG_INFO \
::mace::internal::LogMessage(__FILE__, __LINE__, mace::INFO)
#define _MACE_LOG_WARNING \
::mace::internal::LogMessage(__FILE__, __LINE__, mace::WARNING)
#define _MACE_LOG_ERROR \
::mace::internal::LogMessage(__FILE__, __LINE__, mace::ERROR)
#define _MACE_LOG_FATAL \
do { \
__android_log_print(ANDROID_LOG_FATAL, mace::internal::kTag, __VA_ARGS__); \
abort(); \
} while (0)
#define _MACE_LOG_ERROR(...) \
__android_log_print(ANDROID_LOG_ERROR, mace::internal::kTag, __VA_ARGS__)
#define _MACE_LOG_WARN(...) \
__android_log_print(ANDROID_LOG_WARN, mace::internal::kTag, __VA_ARGS__)
#define _MACE_LOG_INFO(...) \
__android_log_print(ANDROID_LOG_INFO, mace::internal::kTag, __VA_ARGS__)
#define _MACE_LOG_DEBUG(...) \
__android_log_print(ANDROID_LOG_DEBUG, mace::internal::kTag, __VA_ARGS__)
#define _MACE_LOG_VERBOSE(...) \
__android_log_print(ANDROID_LOG_VERBOSE, mace::internal::kTag, __VA_ARGS__)
#define LOG(severity, ...) _MACE_LOG_##severity(__VA_ARGS__)
#else // Non Android, just for tests
// TODO(heliangliang): Fix newline
#define LOG(severity, ...) \
do { \
printf(__VA_ARGS__); \
printf("\n"); \
} while (0)
#endif // __ANDROID__
} // namespace internal
} // namespace mace
#endif // MACE_COMMON_LOGGING_H_
::mace::internal::LogMessageFatal(__FILE__, __LINE__)
#define _MACE_LOG_QFATAL _MACE_LOG_FATAL
#define LOG(severity) _MACE_LOG_##severity
#ifdef IS_MOBILE_PLAMACEORM
// Turn VLOG off when under mobile devices for considerations of binary size.
#define VLOG_IS_ON(lvl) ((lvl) <= 0)
#else
// Otherwise, Set MACE_CPP_MIN_VLOG_LEVEL environment to update minimum log level
// of VLOG
#define VLOG_IS_ON(lvl) \
((lvl) <= ::mace::internal::LogMessage::MinVLogLevel())
#endif
#define VLOG(lvl) \
if (VLOG_IS_ON(lvl)) \
::mace::internal::LogMessage(__FILE__, __LINE__, mace::INFO)
// CHECK dies with a fatal error if condition is not true. It is *not*
// controlled by NDEBUG, so the check will be executed regardless of
// compilation mode. Therefore, it is safe to do things like:
// CHECK(fp->Write(x) == 4)
#define CHECK(condition) \
if (!(condition)) \
LOG(FATAL) << "Check failed: " #condition " "
#define REQUIRE(condition, ...) \
if (!(condition)) \
LOG(FATAL) << "Check failed: " #condition " " << ::mace::internal::MakeString(__VA_ARGS__)
template <typename T>
T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) {
if (t == nullptr) {
LogMessageFatal(file, line) << string(exprtext);
}
return std::forward<T>(t);
}
#define CHECK_NOTNULL(val) \
::mace::internal::CheckNotNull(__FILE__, __LINE__, \
"'" #val "' Must be non NULL", (val))
} // namespace internal
} // namespace mace
#endif // MACE_CORE_LOGGING_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/net.h"
namespace mace {
NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type) {
}
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type) : NetBase(net_def, ws, type) {
VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto& operator_def = net_def->op(idx);
VLOG(1) << "Creating operator " << operator_def.name() << ":"
<< operator_def.type();
std::unique_ptr<OperatorBase> op {nullptr};
OperatorDef temp_def(operator_def);
op = CreateOperator(temp_def, ws, type);
operators_.emplace_back(std::move(op));
}
}
bool SimpleNet::Run() {
VLOG(1) << "Running net " << name_;
for (auto& op : operators_) {
VLOG(1) << "Running operator " << op->debug_def().name() << "("
<< op->debug_def().type() << ").";
if (!op->Run()) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false;
}
}
}
unique_ptr<NetBase> CreateNet(const NetDef& net_def,
Workspace* ws,
DeviceType type) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws, type);
}
unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const NetDef>& net_def,
Workspace* ws,
DeviceType type) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type));
return net;
}
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_NET_H_
#define MACE_CORE_NET_H_
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/operator.h"
#include "mace/core/workspace.h"
namespace mace {
class NetBase {
public:
NetBase(const std::shared_ptr<const NetDef> &net_def, Workspace* ws, DeviceType type);
virtual ~NetBase() noexcept {}
virtual bool Run() = 0;
const string &Name() const {
return name_;
}
protected:
string name_;
DISABLE_COPY_AND_ASSIGN(NetBase);
};
class SimpleNet : public NetBase {
public:
SimpleNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws, DeviceType type);
virtual bool Run() override;
protected:
vector<unique_ptr<OperatorBase> > operators_;
DISABLE_COPY_AND_ASSIGN(SimpleNet);
};
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws, DeviceType type);
unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const NetDef>& net_def,
Workspace* ws,
DeviceType type);
} // namespace mace
#endif // MACE_CORE_NET_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
namespace mace {
std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
static std::map<int32_t, OperatorRegistry*> g_device_type_registry;
return &g_device_type_registry;
}
MACE_DEFINE_REGISTRY(
CPUOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator(
const OperatorDef& operator_def,
Workspace* ws,
DeviceType type) {
OperatorRegistry* registry = gDeviceTypeRegistry()->at(type);
return registry->Create(operator_def.type(), operator_def, ws);
}
OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws)
: operator_ws_(ws),
operator_def_(std::make_shared<OperatorDef>(operator_def)) {
}
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_OPERATOR_H
#define MACE_CORE_OPERATOR_H
#include "mace/core/proto_utils.h"
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/tensor.h"
#include "mace/core/registry.h"
#include "mace/core/workspace.h"
namespace mace {
class OperatorBase {
public:
explicit OperatorBase(const OperatorDef &operator_def, Workspace *ws);
virtual ~OperatorBase() noexcept {}
inline bool HasArgument(const string &name) const {
REQUIRE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
}
template<typename T>
inline T GetSingleArgument(const string &name, const T &default_value) const {
REQUIRE(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
*operator_def_, name, default_value);
}
template<typename T>
inline bool HasSingleArgumentOfType(const string &name) const {
REQUIRE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
*operator_def_, name);
}
template<typename T>
inline vector<T> GetRepeatedArgument(
const string &name,
const vector<T> &default_value = {}) const {
REQUIRE(operator_def_, "operator_def was null!");
return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
*operator_def_, name, default_value);
}
inline const Tensor *Input(int idx) {
CHECK(idx < inputs_.size());
return inputs_[idx];
}
inline Tensor *Output(int idx) {
return outputs_[idx];
}
inline int InputSize() { return inputs_.size(); }
inline int OutputSize() { return outputs_.size(); }
inline const vector<const Tensor *> &Inputs() const { return inputs_; }
inline const vector<Tensor *> &Outputs() { return outputs_; }
virtual bool Run() {
MACE_NOT_IMPLEMENTED;
return false;
}
inline const OperatorDef &debug_def() const {
REQUIRE(has_debug_def(), "operator_def was null!");
return *operator_def_;
}
inline void set_debug_def(
const std::shared_ptr<const OperatorDef> &operator_def) {
operator_def_ = operator_def;
}
inline bool has_debug_def() const {
return operator_def_ != nullptr;
}
protected:
Workspace *operator_ws_;
std::shared_ptr<const OperatorDef> operator_def_;
vector<const Tensor *> inputs_;
vector<Tensor *> outputs_;
DISABLE_COPY_AND_ASSIGN(OperatorBase);
};
template <DeviceType D, class T>
class Operator : public OperatorBase {
public:
explicit Operator(const OperatorDef &operator_def, Workspace *ws)
: OperatorBase(operator_def, ws) {
for (const string &input_str : operator_def.input()) {
const Tensor *tensor = ws->GetTensor(input_str);
REQUIRE(
tensor != nullptr,
"op ",
operator_def.type(),
": Encountered a non-existing input tensor: ",
input_str);
inputs_.push_back(tensor);
}
for (const string &output_str : operator_def.output()) {
outputs_.push_back(CHECK_NOTNULL(ws->CreateTensor(output_str,
DeviceContext<D>::alloctor(),
DataTypeToEnum<T>::v())));
}
}
virtual bool Run() {
MACE_NOT_IMPLEMENTED;
return false;
}
~Operator() noexcept override {}
};
typedef Registry<std::string, OperatorBase, const OperatorDef &, Workspace *>
OperatorRegistry;
typedef Registry<std::string, OperatorBase, const OperatorDef &, Workspace *> *(
*RegistryFunction)();
std::map<int32_t, OperatorRegistry *> *gDeviceTypeRegistry();
struct DeviceTypeRegisterer {
explicit DeviceTypeRegisterer(int32_t type, RegistryFunction func) {
if (gDeviceTypeRegistry()->count(type)) {
LOG(ERROR) << "Device type " << type
<< "registered twice. This should not happen. Did you have "
"duplicated numbers assigned to different devices?";
std::exit(1);
}
// Calling the registry function to get the actual registry pointer.
gDeviceTypeRegistry()->emplace(type, func());
}
};
#define MACE_REGISTER_DEVICE_TYPE(type, registry_function) \
namespace { \
static DeviceTypeRegisterer MACE_ANONYMOUS_VARIABLE( \
DeviceType)(type, &registry_function); \
}
MACE_DECLARE_REGISTRY(
CPUOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
MACE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_CPU_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
unique_ptr<OperatorBase> CreateOperator(
const OperatorDef &operator_def,
Workspace *ws,
DeviceType type);
} // namespace mace
#endif //MACE_CORE_OPERATOR_H
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/proto_utils.h"
#include <fcntl.h>
#include <cerrno>
#include <fstream>
#include <unistd.h>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#ifndef MACE_USE_LITE_PROTO
#include "google/protobuf/text_format.h"
#endif // !MACE_USE_LITE_PROTO
namespace mace {
bool ReadStringFromFile(const char* filename, string* str) {
std::ifstream ifs(filename, std::ios::in);
if (!ifs) {
VLOG(1) << "File cannot be opened: " << filename
<< " error: " << ifs.rdstate();
return false;
}
ifs.seekg(0, std::ios::end);
size_t n = ifs.tellg();
str->resize(n);
ifs.seekg(0);
ifs.read(&(*str)[0], n);
return true;
}
bool WriteStringToFile(const string& str, const char* filename) {
std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
if (!ofs.is_open()) {
VLOG(1) << "File cannot be created: " << filename
<< " error: " << ofs.rdstate();
return false;
}
ofs << str;
return true;
}
// IO-specific proto functions: we will deal with the protocol buffer lite and
// full versions differently.
#ifdef MACE_USE_LITE_PROTO
// Lite runtime.
namespace {
class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
public:
explicit IfstreamInputStream(const string& filename)
: ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
~IfstreamInputStream() { ifs_.close(); }
int Read(void* buffer, int size) {
if (!ifs_) {
return -1;
}
ifs_.read(static_cast<char*>(buffer), size);
return ifs_.gcount();
}
private:
std::ifstream ifs_;
};
} // namespace
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
::google::protobuf::io::CopyingInputStreamAdaptor stream(
new IfstreamInputStream(filename));
stream.SetOwnsCopyingStream(true);
// Total bytes hard limit / warning limit are set to 1GB and 512MB
// respectively.
::google::protobuf::io::CodedInputStream coded_stream(&stream);
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
void WriteProtoToBinaryFile(
const MessageLite& /*proto*/,
const char* /*filename*/) {
LOG(FATAL) << "Not implemented yet.";
}
#else // MACE_USE_LITE_PROTO
// Full protocol buffer.
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::io::FileOutputStream;
using ::google::protobuf::io::ZeroCopyInputStream;
using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::ZeroCopyOutputStream;
using ::google::protobuf::io::CodedOutputStream;
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
int fd = open(filename, O_RDONLY);
REQUIRE(fd != -1, "File not found: ", filename);
FileInputStream* input = new FileInputStream(fd);
bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
close(fd);
return success;
}
void WriteProtoToTextFile(const Message& proto, const char* filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
FileOutputStream* output = new FileOutputStream(fd);
CHECK(google::protobuf::TextFormat::Print(proto, output));
delete output;
close(fd);
}
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
#if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified
int fd = open(filename, O_RDONLY | O_BINARY);
#else
int fd = open(filename, O_RDONLY);
#endif
REQUIRE(fd != -1, "File not found: ", filename);
std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
std::unique_ptr<CodedInputStream> coded_input(
new CodedInputStream(raw_input.get()));
// A hack to manually allow using very large protocol buffers.
coded_input->SetTotalBytesLimit(1073741824, 536870912);
bool success = proto->ParseFromCodedStream(coded_input.get());
coded_input.reset();
raw_input.reset();
close(fd);
return success;
}
void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
REQUIRE(
fd != -1, "File cannot be created: ", filename, " error number: ", errno);
std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
std::unique_ptr<CodedOutputStream> coded_output(
new CodedOutputStream(raw_output.get()));
CHECK(proto.SerializeToCodedStream(coded_output.get()));
coded_output.reset();
raw_output.reset();
close(fd);
}
#endif // MACE_USE_LITE_PROTO
ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
for (auto &arg : def.arg()) {
if (arg_map_.count(arg.name())) {
REQUIRE(
arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(),
"Found argument of the same name ",
arg.name(),
"but with different contents.",
ProtoDebugString(def));
} else {
LOG(WARNING) << "Duplicated argument name found in operator def: "
<< ProtoDebugString(def);
}
arg_map_[arg.name()] = arg;
}
}
ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
for (auto& arg : netdef.arg()) {
REQUIRE(
arg_map_.count(arg.name()) == 0,
"Duplicated argument name found in net def: ",
ProtoDebugString(netdef));
arg_map_[arg.name()] = arg;
}
}
bool ArgumentHelper::HasArgument(const string& name) const {
return arg_map_.count(name);
}
namespace {
// Helper function to verify that conversion between types won't loose any
// significant bit.
template <typename InputType, typename TargetType>
bool SupportsLosslessConversion(const InputType& value) {
return static_cast<InputType>(static_cast<TargetType>(value)) == value;
}
}
#define INSTANTIATE_GET_SINGLE_ARGUMENT( \
T, fieldname, enforce_lossless_conversion) \
template <> \
T ArgumentHelper::GetSingleArgument<T>( \
const string& name, const T& default_value) const { \
if (arg_map_.count(name) == 0) { \
VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \
return default_value; \
} \
REQUIRE( \
arg_map_.at(name).has_##fieldname(), \
"Argument ", \
name, \
" does not have the right field: expected field " #fieldname); \
auto value = arg_map_.at(name).fieldname(); \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(value), T>(value); \
REQUIRE( \
supportsConversion, \
"Value", \
value, \
" of argument ", \
name, \
"cannot be represented correctly in a target type"); \
} \
return value; \
} \
template <> \
bool ArgumentHelper::HasSingleArgumentOfType<T>(const string& name) const { \
if (arg_map_.count(name) == 0) { \
return false; \
} \
return arg_map_.at(name).has_##fieldname(); \
}
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
#undef INSTANTIATE_GET_SINGLE_ARGUMENT
#define INSTANTIATE_GET_REPEATED_ARGUMENT( \
T, fieldname, enforce_lossless_conversion) \
template <> \
vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
const string& name, const std::vector<T>& default_value) const { \
if (arg_map_.count(name) == 0) { \
return default_value; \
} \
vector<T> values; \
for (const auto& v : arg_map_.at(name).fieldname()) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \
REQUIRE( \
supportsConversion, \
"Value", \
v, \
" of argument ", \
name, \
"cannot be represented correctly in a target type"); \
} \
values.push_back(v); \
} \
return values; \
}
INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <> \
Argument MakeArgument(const string& name, const T& value) { \
Argument arg; \
arg.set_name(name); \
arg.set_##fieldname(value); \
return arg; \
}
MACE_MAKE_SINGULAR_ARGUMENT(bool, i)
MACE_MAKE_SINGULAR_ARGUMENT(float, f)
MACE_MAKE_SINGULAR_ARGUMENT(int, i)
MACE_MAKE_SINGULAR_ARGUMENT(int64_t, i)
MACE_MAKE_SINGULAR_ARGUMENT(string, s)
#undef MACE_MAKE_SINGULAR_ARGUMENT
template <>
Argument MakeArgument(const string& name, const MessageLite& value) {
Argument arg;
arg.set_name(name);
arg.set_s(value.SerializeAsString());
return arg;
}
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <> \
Argument MakeArgument(const string& name, const vector<T>& value) { \
Argument arg; \
arg.set_name(name); \
for (const auto& v : value) { \
arg.add_##fieldname(v); \
} \
return arg; \
}
MACE_MAKE_REPEATED_ARGUMENT(float, floats)
MACE_MAKE_REPEATED_ARGUMENT(int, ints)
MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints)
MACE_MAKE_REPEATED_ARGUMENT(string, strings)
#undef MACE_MAKE_REPEATED_ARGUMENT
const Argument& GetArgument(const OperatorDef& def, const string& name) {
for (const Argument& arg : def.arg()) {
if (arg.name() == name) {
return arg;
}
}
REQUIRE(false,
"Argument named ",
name,
"does not exist in operator ",
ProtoDebugString(def));
}
bool GetFlagArgument(
const OperatorDef& def,
const string& name,
bool def_value) {
for (const Argument& arg : def.arg()) {
if (arg.name() == name) {
REQUIRE(
arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg));
return arg.i();
}
}
return def_value;
}
Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
OperatorDef* def) {
for (int i = 0; i < def->arg_size(); ++i) {
if (def->arg(i).name() == name) {
return def->mutable_arg(i);
}
}
// If no argument of the right name is found...
if (create_if_missing) {
Argument* arg = def->add_arg();
arg->set_name(name);
return arg;
} else {
return nullptr;
}
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_PROTO_UTILS_H_
#define MACE_CORE_PROTO_UTILS_H_
#include <map>
#include "google/protobuf/message_lite.h"
#ifndef MACE_USE_LITE_PROTO
#include "google/protobuf/message.h"
#endif // !MACE_USE_LITE_PROTO
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
namespace mace {
using std::string;
using ::google::protobuf::MessageLite;
// Common interfaces that reads file contents into a string.
bool ReadStringFromFile(const char* filename, string* str);
bool WriteStringToFile(const string& str, const char* filename);
// Common interfaces that are supported by both lite and full protobuf.
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
return ReadProtoFromBinaryFile(filename.c_str(), proto);
}
void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
inline void WriteProtoToBinaryFile(const MessageLite& proto,
const string& filename) {
return WriteProtoToBinaryFile(proto, filename.c_str());
}
#ifdef MACE_USE_LITE_PROTO
inline string ProtoDebugString(const MessageLite& proto) {
return proto.SerializeAsString();
}
// Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support.
inline bool ReadProtoFromTextFile(
const char* /*filename*/,
MessageLite* /*proto*/) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
return false; // Just to suppress compiler warning.
}
inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
inline void WriteProtoToTextFile(
const MessageLite& /*proto*/,
const char* /*filename*/) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
}
inline void WriteProtoToTextFile(const MessageLite& proto,
const string& filename) {
return WriteProtoToTextFile(proto, filename.c_str());
}
inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#else // MACE_USE_LITE_PROTO
using ::google::protobuf::Message;
inline string ProtoDebugString(const Message& proto) {
return proto.ShortDebugString();
}
bool ReadProtoFromTextFile(const char* filename, Message* proto);
inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
void WriteProtoToTextFile(const Message& proto, const char* filename);
inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
return WriteProtoToTextFile(proto, filename.c_str());
}
// Read Proto from a file, letting the code figure out if it is text or binary.
inline bool ReadProtoFromFile(const char* filename, Message* proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string& filename, Message* proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#endif // MACE_USE_LITE_PROTO
template <
class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>,
class IterableArgs = std::initializer_list<Argument>>
OperatorDef CreateOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs,
const IterableArgs& args) {
OperatorDef def;
def.set_type(type);
def.set_name(name);
for (const string& in : inputs) {
def.add_input(in);
}
for (const string& out : outputs) {
def.add_output(out);
}
for (const Argument& arg : args) {
def.add_arg()->CopyFrom(arg);
}
return def;
}
// A simplified version compared to the full CreateOperator, if you do not need
// to specify args.
template <
class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>>
inline OperatorDef CreateOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs) {
return CreateOperatorDef(
type,
name,
inputs,
outputs,
std::vector<Argument>());
}
/**
* @brief A helper class to index into arguments.
*
* This helper helps us to more easily index into a set of arguments
* that are present in the operator. To save memory, the argument helper
* does not copy the operator def, so one would need to make sure that the
* lifetime of the OperatorDef object outlives that of the ArgumentHelper.
*/
class ArgumentHelper {
public:
template <typename Def>
static bool HasArgument(const Def& def, const string& name) {
return ArgumentHelper(def).HasArgument(name);
}
template <typename Def, typename T>
static T GetSingleArgument(
const Def& def,
const string& name,
const T& default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
}
template <typename Def, typename T>
static bool HasSingleArgumentOfType(const Def& def, const string& name) {
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
}
template <typename Def, typename T>
static vector<T> GetRepeatedArgument(
const Def& def,
const string& name,
const std::vector<T>& default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
}
template <typename Def, typename MessageType>
static MessageType GetMessageArgument(const Def& def, const string& name) {
return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
}
template <typename Def, typename MessageType>
static vector<MessageType> GetRepeatedMessageArgument(
const Def& def,
const string& name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
}
explicit ArgumentHelper(const OperatorDef& def);
explicit ArgumentHelper(const NetDef& netdef);
bool HasArgument(const string& name) const;
template <typename T>
T GetSingleArgument(const string& name, const T& default_value) const;
template <typename T>
bool HasSingleArgumentOfType(const string& name) const;
template <typename T>
vector<T> GetRepeatedArgument(
const string& name,
const std::vector<T>& default_value = std::vector<T>()) const;
template <typename MessageType>
MessageType GetMessageArgument(const string& name) const {
REQUIRE(arg_map_.count(name), "Cannot find parameter named " + name);
MessageType message;
if (arg_map_.at(name).has_s()) {
REQUIRE(
message.ParseFromString(arg_map_.at(name).s()),
"Faild to parse content from the string");
} else {
VLOG(1) << "Return empty message for parameter " << name;
}
return message;
}
template <typename MessageType>
vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
REQUIRE(arg_map_.count(name), "Cannot find parameter named " + name);
vector<MessageType> messages(arg_map_.at(name).strings_size());
for (int i = 0; i < messages.size(); ++i) {
REQUIRE(
messages[i].ParseFromString(arg_map_.at(name).strings(i)),
"Faild to parse content from the string");
}
return messages;
}
private:
std::map<string, Argument> arg_map_;
};
const Argument& GetArgument(const OperatorDef& def, const string& name);
bool GetFlagArgument(
const OperatorDef& def,
const string& name,
bool def_value = false);
Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
OperatorDef* def);
template <typename T>
Argument MakeArgument(const string& name, const T& value);
template <typename T>
inline void AddArgument(const string& name, const T& value, OperatorDef* def) {
GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
}
} // namespace mace
#endif // MACE_CORE_PROTO_UTILS_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_REGISTRY_H_
#define MACE_CORE_REGISTRY_H_
#include <memory>
#include <mutex>
#include <string>
#include <map>
#include "mace/core/common.h"
namespace mace {
template <class SrcType, class ObjectType, class... Args>
class Registry {
public:
typedef std::function<std::unique_ptr<ObjectType> (Args ...)> Creator;
Registry() : registry_() {}
void Register(const SrcType& key, Creator creator) {
std::lock_guard<std::mutex> lock(register_mutex_);
REQUIRE(registry_.count(key) == 0, "Key already registered.");
registry_[key] = creator;
}
inline bool Has(const SrcType& key) { return registry_.count(key) != 0; }
unique_ptr<ObjectType> Create(const SrcType& key, Args ... args) {
if (registry_.count(key) == 0) {
return nullptr;
}
return registry_[key](args...);
}
/**
* Returns the keys currently registered as a vector.
*/
vector<SrcType> Keys() {
vector<SrcType> keys;
for (const auto& it : registry_) {
keys.push_back(it.first);
}
return keys;
}
private:
std::map<SrcType, Creator> registry_;
std::mutex register_mutex_;
DISABLE_COPY_AND_ASSIGN(Registry);
};
template <class SrcType, class ObjectType, class... Args>
class Registerer {
public:
Registerer(const SrcType& key,
Registry<SrcType, ObjectType, Args...>* registry,
typename Registry<SrcType, ObjectType, Args...>::Creator creator) {
registry->Register(key, creator);
}
template <class DerivedType>
static unique_ptr<ObjectType> DefaultCreator(Args ... args) {
return std::unique_ptr<ObjectType>(new DerivedType(args...));
}
};
#define MACE_CONCATENATE_IMPL(s1, s2) s1##s2
#define MACE_CONCATENATE(s1, s2) MACE_CONCATENATE_IMPL(s1, s2)
#ifdef __COUNTER__
#define MACE_ANONYMOUS_VARIABLE(str) MACE_CONCATENATE(str, __COUNTER__)
#else
#define MACE_ANONYMOUS_VARIABLE(str) MACE_CONCATENATE(str, __LINE__)
#endif
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName(); \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName;
#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType, ObjectType, ##__VA_ARGS__>* registry = \
new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \
return registry; \
}
#define MACE_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
MACE_DECLARE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, ##__VA_ARGS__)
#define MACE_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
MACE_DEFINE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, ##__VA_ARGS__)
#define MACE_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
namespace { \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, RegistryName(), __VA_ARGS__);
#define MACE_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
namespace { \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \
RegistryName(), \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>); \
}
#define MACE_REGISTER_CREATOR(RegistryName, key, ...) \
MACE_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
#define MACE_REGISTER_CLASS(RegistryName, key, ...) \
MACE_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
} // namespace mace
#endif // MACE_CORE_REGISTRY_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_TENSOR_H_
#define MACE_CORE_TENSOR_H_
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/allocator.h"
#include "mace/core/types.h"
namespace mace {
#define SINGLE_ARG(...) __VA_ARGS__
#define CASE(TYPE, STMTS) \
case DataTypeToEnum<TYPE>::value: { \
typedef TYPE T; \
STMTS; \
break; \
}
#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
switch (TYPE_ENUM) { \
CASE(float, SINGLE_ARG(STMTS)) \
CASE(double, SINGLE_ARG(STMTS)) \
CASE(int32, SINGLE_ARG(STMTS)) \
CASE(uint8, SINGLE_ARG(STMTS)) \
CASE(uint16, SINGLE_ARG(STMTS)) \
CASE(int16, SINGLE_ARG(STMTS)) \
CASE(int8, SINGLE_ARG(STMTS)) \
CASE(string, SINGLE_ARG(STMTS)) \
CASE(int64, SINGLE_ARG(STMTS)) \
CASE(bool, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
INVALID; \
break; \
default: \
DEFAULT; \
break; \
}
#define CASES(TYPE_ENUM, STMTS) \
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
class Tensor {
public:
Tensor()
: alloc_(cpu_allocator()),
size_(0), dtype_(DT_FLOAT), data_(nullptr) {};
Tensor(Allocator* a, DataType type)
: alloc_(a), size_(0), dtype_(DT_FLOAT), data_(nullptr) {};
~Tensor() {
if (alloc_ && data_.get()) {
data_.reset();
}
}
inline DataType dtype() const { return dtype_; }
inline const vector<TIndex>& shape() const { return shape_; }
inline int64 NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64>());
}
inline TIndex dim_size() { return shape_.size(); }
inline TIndex size() const { return size_; }
inline const void* raw_data() const {
CHECK(data_.get() || size_ == 0);
return data_.get();
}
template <typename T>
inline const T* data() const {
REQUIRE(
data_.get() || size_ == 0,
"The tensor is of non-zero shape, but its data is not allocated yet. ");
return static_cast<T*>(data_.get());
}
void Deleter(void* data) {
alloc_->Delete(data);
}
inline void* raw_mutable_data() {
if (data_.get() || size_ == 0) {
return data_.get();
} else {
CASES(dtype_, data_.reset(alloc_->New(size_ * sizeof(T)), [this](void* ptr) {
alloc_->Delete(ptr);
}));
return data_.get();
}
}
template <typename T>
inline T* mutable_data() {
if (size_ == 0 || data_.get()) {
return static_cast<T*>(data_.get());
}
return static_cast<T*>(raw_mutable_data());
}
inline void Resize(const vector<TIndex>& shape) {
shape_ = shape;
TIndex size = NumElements();
if (size_ != size) {
size_ = NumElements();
data_.reset();
}
}
inline void ResizeLike(const Tensor& other) {
Resize(other.shape());
}
inline void ResizeLike(const Tensor* other) {
Resize(other->shape());
}
private:
Allocator* alloc_;
TIndex size_;
DataType dtype_;
std::shared_ptr<void> data_;
vector<TIndex> shape_;
};
} // namespace tensor
#endif //MACE_CORE_TENSOR_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_TYPES_H_
#define MACE_CORE_TYPES_H_
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
namespace mace {
template <class T>
struct IsValidDataType;
template <class T>
struct DataTypeToEnum {
static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
}; // Specializations below
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
// EnumToDataType<DT_FLOAT>::Type is float.
template <DataType VALUE>
struct EnumToDataType {}; // Specializations below
// Template specialization for both DataTypeToEnum and EnumToDataType.
#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \
template <> \
struct DataTypeToEnum<TYPE> { \
static DataType v() { return ENUM; } \
static constexpr DataType value = ENUM; \
}; \
template <> \
struct IsValidDataType<TYPE> { \
static constexpr bool value = true; \
}; \
template <> \
struct EnumToDataType<ENUM> { \
typedef TYPE Type; \
}
MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
MATCH_TYPE_AND_ENUM(int32, DT_INT32);
MATCH_TYPE_AND_ENUM(uint16, DT_UINT16);
MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
MATCH_TYPE_AND_ENUM(int16, DT_INT16);
MATCH_TYPE_AND_ENUM(int8, DT_INT8);
MATCH_TYPE_AND_ENUM(string, DT_STRING);
MATCH_TYPE_AND_ENUM(int64, DT_INT64);
MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
} // namespace mace
#endif // MACE_CORE_TYPES_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/common.h"
#include "mace/core/workspace.h"
namespace mace {
vector<string> Workspace::Tensors() const {
vector<string> names;
for (auto& entry : tensor_map_) {
names.push_back(entry.first);
}
return names;
}
Tensor* Workspace::CreateTensor(const string& name, Allocator* alloc, DataType type) {
if (HasTensor(name)) {
VLOG(1) << "Tensor " << name << " already exists. Skipping.";
} else {
VLOG(1) << "Creating Tensor " << name;
tensor_map_[name] = unique_ptr<Tensor>(new Tensor(alloc, type));
}
return GetTensor(name);
}
bool Workspace::RemoveTensor(const string& name) {
auto it = tensor_map_.find(name);
if (it != tensor_map_.end()) {
VLOG(1) << "Removing blob " << name << " from this workspace.";
tensor_map_.erase(it);
return true;
}
return false;
}
const Tensor* Workspace::GetTensor(const string& name) const {
if (tensor_map_.count(name)) {
return tensor_map_.at(name).get();
}
return nullptr;
}
Tensor* Workspace::GetTensor(const string& name) {
return const_cast<Tensor*>(static_cast<const Workspace*>(this)->GetTensor(name));
}
bool RunNet();
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_WORKSPACE_H_
#define MACE_CORE_WORKSPACE_H_
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
namespace mace {
class Workspace {
public:
typedef std::map<string, unique_ptr<Tensor>> TensorMap;
Workspace() {}
vector<string> Tensors() const;
Tensor* CreateTensor(const string& name, Allocator* alloc, DataType type);
bool RemoveTensor(const string& name);
inline bool HasTensor(const string& name) const {
return tensor_map_.count(name);
}
const Tensor* GetTensor(const string& name) const;
Tensor* GetTensor(const string& name);
private:
TensorMap tensor_map_;
DISABLE_COPY_AND_ASSIGN(Workspace);
};
} // namespace mace
#endif // MACE_CORE_WORKSPACE_H_
# Description:
# Mace operators.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "op",
srcs = ["relu.cc"],
hdrs = glob(["*.h"]),
deps = [
"//mace/proto:cc_proto",
"//mace/core:core",
],
)
cc_test(
name = "op_test",
srcs = ["relu_test.cc",],
deps = [
"@gtest//:gtest",
":op",
],
)
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/relu.h"
#include "mace/proto/mace.pb.h"
namespace mace {
template <>
bool ReluOp<DeviceType::CPU, float>::Run() {
const Tensor* X = Input(0);
Tensor* Y = Output(0);
Y->ResizeLike(X);
const float* Xdata = X-> data<float>();
float* Ydata = Y->mutable_data<float>();
for (int i = 0; i < X->size(); ++i) {
Ydata[i] = std::max(Xdata[i], 0.f);
VLOG(0) << i << ": " << Xdata[i] << " " << Ydata[i];
}
return true;
}
REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>);
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPERATORS_RELU_H_
#define MACE_OPERATORS_RELU_H_
#include "mace/core/operator.h"
namespace mace {
template<DeviceType D, class T>
class ReluOp : public Operator<D, T> {
public:
ReluOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
bool Run() override;
};
} // namespace mace
#endif // MACE_OPERATORS_RELU_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "gtest/gtest.h"
#include "mace/core/operator.h"
#include "mace/core/net.h"
using namespace mace;
TEST(ReluTest, Relu) {
OperatorRegistry* registry = gDeviceTypeRegistry()->at(DeviceType::CPU);
vector<string> registry_keys = registry->Keys();
for (auto& key: registry_keys) {
VLOG(0) << "registry_op: " << key;
}
// Construct graph
OperatorDef op_def;
op_def.add_input("Input0");
op_def.add_output("Output0");
op_def.set_name("ReluTest");
op_def.set_type("Relu");
auto arg = op_def.add_arg();
arg->set_name("arg0");
arg->set_f(1.5);
NetDef net_def;
net_def.set_name("NetTest");
net_def.add_op()->CopyFrom(op_def);
VLOG(0) << net_def.DebugString();
// Create workspace and input tensor
Workspace ws;
Tensor* input = ws.CreateTensor("Input0", cpu_allocator(), DataType::DT_FLOAT);
input->Resize({2,3});
float* input_data = input->mutable_data<float>();
for (int i = 0; i < 6; ++i) {
input_data[i] = i-3;
}
// Create Net & run
auto net = CreateNet(net_def, &ws, DeviceType::CPU);
net->Run();
// Create Op & run
auto op = CreateOperator(op_def, &ws, DeviceType::CPU);
ASSERT_FLOAT_EQ(1.5f, op->GetSingleArgument<float>("arg0", 1.0f));
}
\ No newline at end of file
# Description:
# mace proto.
#
package(
default_visibility = ["//visibility:public"],
)
proto_library(
name = "proto",
srcs = ["mace.proto"],
)
cc_proto_library(
name = "cc_proto",
deps = [":proto"],
)
\ No newline at end of file
syntax = "proto2";
package mace;
enum DeviceType {
CPU = 0; // In default, we will use CPU.
GPU = 1;
}
enum DataType {
DT_INVALID = 0;
// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_INT64 = 8;
DT_UINT16 = 9;
DT_BOOL = 10;
}
message TensorProto {
// The dimensions in the tensor.
repeated int64 dims = 1;
optional DataType data_type = 2 [default = DT_FLOAT];
// For float
repeated float float_data = 3 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16
// Note about float16: in storage we will basically convert float16 byte-wise
// to unsigned short and then store them in the int32_data field.
repeated int32 int32_data = 4 [packed = true];
// For bytes
optional bytes byte_data = 5;
// For strings
repeated bytes string_data = 6;
// For double
repeated double double_data = 9 [packed = true];
// For int64
repeated int64 int64_data = 10 [packed = true];
// Optionally, a name for the tensor.
optional string name = 7;
}
message Argument {
optional string name = 1;
optional float f = 2;
optional int64 i = 3;
optional bytes s = 4;
repeated float floats = 5;
repeated int64 ints = 6;
repeated bytes strings = 7;
}
message OperatorDef {
repeated string input = 1;
repeated string output = 2;
optional string name = 3;
optional string type = 4;
repeated Argument arg = 5;
}
message NetDef {
optional string name = 1;
repeated OperatorDef op = 2;
optional string version = 3;
repeated Argument arg = 4;
repeated TensorProto tensors = 5;
}
\ No newline at end of file
# Description:
# Google test
licenses(["notice"])
cc_library(
name = "gtest",
srcs = glob(
["src/*.cc"],
exclude = ["src/gtest-all.cc"]
),
hdrs = glob([
"include/**/*.h",
"src/*.h"
]),
copts = ["-Iexternal/gtest/include"],
linkopts = ["-pthread"],
visibility = ["//visibility:public"],
)
\ No newline at end of file
# Description:
# Mace utils.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "utils",
srcs = glob(["*.cc"]),
hdrs = glob(["*.h"]),
deps = [
"//mace/proto:cc_proto",
],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册