diff --git a/WORKSPACE b/WORKSPACE index cbd99e3434737c84faae457e53e3379d3e8b2fc6..7cae3f70830582e27a434f0a151261937bc8996b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,3 +1,32 @@ +workspace(name = "mace") + +# proto_library rules implicitly depend on @com_google_protobuf//:protoc, +# which is the proto-compiler. +# This statement defines the @com_google_protobuf repo. +http_archive( + name = "com_google_protobuf", + urls = ["https://github.com/google/protobuf/archive/b4b0e304be5a68de3d0ee1af9b286f958750f5e4.zip"], + strip_prefix = "protobuf-b4b0e304be5a68de3d0ee1af9b286f958750f5e4", + sha256 = "ff771a662fb6bd4d3cc209bcccedef3e93980a49f71df1e987f6afa3bcdcba3a", +) + +# cc_proto_library rules implicitly depend on @com_google_protobuf_cc//:cc_toolchain, +# which is the C++ proto runtime (base classes and common utilities). +http_archive( + name = "com_google_protobuf_cc", + urls = ["https://github.com/google/protobuf/archive/b4b0e304be5a68de3d0ee1af9b286f958750f5e4.zip"], + strip_prefix = "protobuf-b4b0e304be5a68de3d0ee1af9b286f958750f5e4", + sha256 = "ff771a662fb6bd4d3cc209bcccedef3e93980a49f71df1e987f6afa3bcdcba3a", +) + +new_http_archive( + name = "gtest", + url = "https://github.com/google/googletest/archive/release-1.8.0.zip", + sha256 = "f3ed3b58511efd272eb074a3a6d6fb79d7c2e6a0e374323d1e6bcbcc1ef141bf", + build_file = "mace/third_party/gtest.BUILD", + strip_prefix = "googletest-release-1.8.0/googletest", +) + # Set up Android NDK android_ndk_repository( name = "androidndk", diff --git a/mace/core/BUILD b/mace/core/BUILD index b6b2cdbd285f452f4e76a4e0edb945979f7ea15b..8b63a4b1ea8a55539134c8421a02f462c5f08ff9 100644 --- a/mace/core/BUILD +++ b/mace/core/BUILD @@ -1,10 +1,19 @@ -package(default_visibility = ["//visibility:public"]) +# Description: +# Mace core. +# +package( + default_visibility = ["//visibility:public"], +) + + +licenses(["notice"]) # Apache 2.0 cc_library( - name = "lib_core", - hdrs = [ - "logging.h" - ], - srcs = [ - ], - ) + name = "core", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + deps = [ + "//mace/proto:cc_proto", + ], +) + diff --git a/mace/core/allocator.cc b/mace/core/allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..61e28e9d4a78fb3d8b40275c62366c104f6f2847 --- /dev/null +++ b/mace/core/allocator.cc @@ -0,0 +1,18 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/allocator.h" + +namespace mace { + +static std::unique_ptr g_cpu_allocator(new CPUAllocator()); +CPUAllocator* cpu_allocator() { + return g_cpu_allocator.get(); +} + +void SetCPUAllocator(CPUAllocator* alloc) { + g_cpu_allocator.reset(alloc); +} + +} // namespace mace diff --git a/mace/core/allocator.h b/mace/core/allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..e4482f6bc0a33521ef95a37c0c7d4e6f63ee27cb --- /dev/null +++ b/mace/core/allocator.h @@ -0,0 +1,86 @@ +// +// Created by liyin on 8/28/17. +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_ALLOCATOR_H_ +#define MACE_CORE_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "mace/core/common.h" +#include "mace/proto/mace.pb.h" + +namespace mace { + +constexpr size_t kMaceAlignment = 16; + +using MemoryDeleter = std::function; + +class Allocator { + public: + Allocator() {} + virtual ~Allocator() noexcept {} + virtual void* New(size_t nbytes) = 0; + virtual void Delete(void* data) = 0; + + template + T* New(size_t num_elements) { + if (num_elements > (std::numeric_limits::max() / sizeof(T))) { + return NULL; + } + void* p = New(sizeof(T) * num_elements); + T* typed_p = reinterpret_cast(p); + return typed_p; + } +}; + +class CPUAllocator: public Allocator { + public: + ~CPUAllocator() override {} + void* New(size_t nbytes) override { + void* data = nullptr; +#ifdef __ANDROID__ + data = memalign(gMaceAlignment, nbytes); +#elif defined(_MSC_VER) + data = _aligned_malloc(nbytes, gMaceAlignment); +#else + CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0); +#endif + CHECK_NOTNULL(data); + memset(data, 0, nbytes); + return data; + } + +#ifdef _MSC_VER + void Delete(void* data) { + _aligned_free(data); + } +#else + void Delete(void* data) { + free(data); + } +#endif +}; + +// Get the CPU Alloctor. +CPUAllocator* cpu_allocator(); +// Sets the CPU allocator to the given allocator: the caller gives away the +// ownership of the pointer. +void SetCPUAllocator(CPUAllocator* alloc); + +template +struct DeviceContext {}; + +template <> +struct DeviceContext { + static Allocator* alloctor() { return cpu_allocator(); } +}; + + +} // namespace mace + +#endif // MACE_CORE_ALLOCATOR_H_ diff --git a/mace/core/common.h b/mace/core/common.h new file mode 100644 index 0000000000000000000000000000000000000000..5c24503e549a6503f87165d0a6b9c4fe849ca06f --- /dev/null +++ b/mace/core/common.h @@ -0,0 +1,34 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_COMMON_H_ +#define MACE_CORE_COMMON_H_ + +#include +#include +#include +#include +#include + +#include "mace/core/integral_types.h" +#include "mace/core/logging.h" + +using std::set; +using std::string; +using std::unique_ptr; +using std::vector; + +typedef int64 TIndex; + +// Disable the copy and assignment operator for a class. +#ifndef DISABLE_COPY_AND_ASSIGN +#define DISABLE_COPY_AND_ASSIGN(classname) \ +private: \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete +#endif + +#define MACE_NOT_IMPLEMENTED REQUIRE(false, "not implemented") + +#endif // MACE_CORE_COMMON_H_ diff --git a/mace/core/integral_types.h b/mace/core/integral_types.h new file mode 100644 index 0000000000000000000000000000000000000000..10a330539b5ab54a9cda03b947192beb4efcb0f3 --- /dev/null +++ b/mace/core/integral_types.h @@ -0,0 +1,19 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + + +#ifndef MACE_CORE_INTEGRAL_TYPES_H_ +#define MACE_CORE_INTEGRAL_TYPES_H_ + +typedef signed char int8; +typedef short int16; +typedef int int32; +typedef long long int64; + +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long long uint64; + +#endif // MACE_CORE_INTEGRAL_TYPES_H_ diff --git a/mace/core/logging.cc b/mace/core/logging.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e0982d58e5d38fa1117b9d35ba2bec8a55dc092 --- /dev/null +++ b/mace/core/logging.cc @@ -0,0 +1,125 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + + +#include "mace/core/logging.h" + +#include +#if defined(PLATFORM_POSIX_ANDROID) +#include +#include +#include +#endif + +namespace mace { +namespace internal { + +LogMessage::LogMessage(const char* fname, int line, int severity) + : fname_(fname), line_(line), severity_(severity) {} + +#if defined(PLATFORM_POSIX_ANDROID) +void LogMessage::GenerateLogMessage() { + int android_log_level; + switch (severity_) { + case INFO: + android_log_level = ANDROID_LOG_INFO; + break; + case WARNING: + android_log_level = ANDROID_LOG_WARN; + break; + case ERROR: + android_log_level = ANDROID_LOG_ERROR; + break; + case FATAL: + android_log_level = ANDROID_LOG_FATAL; + break; + default: + if (severity_ < INFO) { + android_log_level = ANDROID_LOG_VERBOSE; + } else { + android_log_level = ANDROID_LOG_ERROR; + } + break; + } + + std::stringstream ss; + const char* const partial_name = strrchr(fname_, '/'); + ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_ + << " " << str(); + __android_log_write(android_log_level, "native", ss.str().c_str()); + + // Also log to stderr (for standalone Android apps). + std::cerr << "native : " << ss.str() << std::endl; + + // Android logging at level FATAL does not terminate execution, so abort() + // is still required to stop the program. + if (severity_ == FATAL) { + abort(); + } +} + +#else + +void LogMessage::GenerateLogMessage() { + fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], fname_, line_, str().c_str()); +} +#endif + + +namespace { + +// Parse log level (int64) from environment variable (char*) +int64 LogLevelStrToInt(const char* tf_env_var_val) { + if (tf_env_var_val == nullptr) { + return 0; + } + + // Ideally we would use env_var / safe_strto64, but it is + // hard to use here without pulling in a lot of dependencies, + // so we use std:istringstream instead + string min_log_level(tf_env_var_val); + std::istringstream ss(min_log_level); + int64 level; + if (!(ss >> level)) { + // Invalid vlog level setting, set level to default (0) + level = 0; + } + + return level; +} + +int64 MinLogLevelFromEnv() { + const char* tf_env_var_val = getenv("MACE_CPP_MIN_LOG_LEVEL"); + return LogLevelStrToInt(tf_env_var_val); +} + +int64 MinVLogLevelFromEnv() { + const char* tf_env_var_val = getenv("MACE_CPP_MIN_VLOG_LEVEL"); + return LogLevelStrToInt(tf_env_var_val); +} + +} // namespace + +LogMessage::~LogMessage() { + // Read the min log level once during the first call to logging. + static int64 min_log_level = MinLogLevelFromEnv(); + if (severity_ >= min_log_level) GenerateLogMessage(); +} + +int64 LogMessage::MinVLogLevel() { + static int64 min_vlog_level = MinVLogLevelFromEnv(); + return min_vlog_level; +} + +LogMessageFatal::LogMessageFatal(const char* file, int line) + : LogMessage(file, line, FATAL) {} +LogMessageFatal::~LogMessageFatal() { + // abort() ensures we don't return (we promised we would not via + // ATTRIBUTE_NORETURN). + GenerateLogMessage(); + abort(); +} + +} // namespace internal +} // namespace mace diff --git a/mace/core/logging.h b/mace/core/logging.h index ba5e4c2a753c3580dad8dc98438854ca53c36c6d..8a8715d2fa3a4da10720f1ce8bb3b4cf5da592b9 100644 --- a/mace/core/logging.h +++ b/mace/core/logging.h @@ -1,60 +1,138 @@ -#ifndef MACE_COMMON_LOGGING_H_ -#define MACE_COMMON_LOGGING_H_ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// -#ifdef __ANDROID__ -#include -#else -#include -#endif +#ifndef MACE_CORE_LOGGING_H_ +#define MACE_CORE_LOGGING_H_ -namespace mace { +#include +#include +#include + +#include "mace/core/integral_types.h" -const int FATAL = 0; -const int ERROR = 1; -const int WARN = 2; -const int INFO = 3; -const int DEBUG = 4; -const int VERBOSE = 5; +#undef ERROR + +namespace mace { +const int INFO = 0; // base_logging::INFO; +const int WARNING = 1; // base_logging::WARNING; +const int ERROR = 2; // base_logging::ERROR; +const int FATAL = 3; // base_logging::FATAL; +const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES; namespace internal { -const char *kTag = "MACE"; +using std::string; + +inline void MakeStringInternal(std::stringstream& /*ss*/) {} + +template +inline void MakeStringInternal(std::stringstream& ss, const T& t) { + ss << t; +} + +template +inline void +MakeStringInternal(std::stringstream& ss, const T& t, const Args&... args) { + MakeStringInternal(ss, t); + MakeStringInternal(ss, args...); +} + +template +string MakeString(const Args&... args) { + std::stringstream ss; + MakeStringInternal(ss, args...); + return string(ss.str()); +} + +// Specializations for already-a-string types. +template <> +inline string MakeString(const string& str) { + return str; +} +inline string MakeString(const char* c_str) { + return string(c_str); +} + +class LogMessage : public std::basic_ostringstream { + public: + LogMessage(const char* fname, int line, int severity); + ~LogMessage(); + + // Returns the minimum log level for VLOG statements. + // E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output, + // but VLOG(3) will not. Defaults to 0. + static int64 MinVLogLevel(); + protected: + void GenerateLogMessage(); -#ifdef __ANDROID__ + private: + const char* fname_; + int line_; + int severity_; +}; +// LogMessageFatal ensures the process will exit in failure after +// logging this message. +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line); + ~LogMessageFatal(); +}; + +#define _MACE_LOG_INFO \ + ::mace::internal::LogMessage(__FILE__, __LINE__, mace::INFO) +#define _MACE_LOG_WARNING \ + ::mace::internal::LogMessage(__FILE__, __LINE__, mace::WARNING) +#define _MACE_LOG_ERROR \ + ::mace::internal::LogMessage(__FILE__, __LINE__, mace::ERROR) #define _MACE_LOG_FATAL \ - do { \ - __android_log_print(ANDROID_LOG_FATAL, mace::internal::kTag, __VA_ARGS__); \ - abort(); \ - } while (0) - -#define _MACE_LOG_ERROR(...) \ - __android_log_print(ANDROID_LOG_ERROR, mace::internal::kTag, __VA_ARGS__) -#define _MACE_LOG_WARN(...) \ - __android_log_print(ANDROID_LOG_WARN, mace::internal::kTag, __VA_ARGS__) -#define _MACE_LOG_INFO(...) \ - __android_log_print(ANDROID_LOG_INFO, mace::internal::kTag, __VA_ARGS__) -#define _MACE_LOG_DEBUG(...) \ - __android_log_print(ANDROID_LOG_DEBUG, mace::internal::kTag, __VA_ARGS__) -#define _MACE_LOG_VERBOSE(...) \ - __android_log_print(ANDROID_LOG_VERBOSE, mace::internal::kTag, __VA_ARGS__) - - -#define LOG(severity, ...) _MACE_LOG_##severity(__VA_ARGS__) - -#else // Non Android, just for tests - -// TODO(heliangliang): Fix newline -#define LOG(severity, ...) \ - do { \ - printf(__VA_ARGS__); \ - printf("\n"); \ - } while (0) - -#endif // __ANDROID__ - -} // namespace internal -} // namespace mace - -#endif // MACE_COMMON_LOGGING_H_ + ::mace::internal::LogMessageFatal(__FILE__, __LINE__) + +#define _MACE_LOG_QFATAL _MACE_LOG_FATAL + +#define LOG(severity) _MACE_LOG_##severity + +#ifdef IS_MOBILE_PLAMACEORM +// Turn VLOG off when under mobile devices for considerations of binary size. +#define VLOG_IS_ON(lvl) ((lvl) <= 0) +#else +// Otherwise, Set MACE_CPP_MIN_VLOG_LEVEL environment to update minimum log level +// of VLOG +#define VLOG_IS_ON(lvl) \ + ((lvl) <= ::mace::internal::LogMessage::MinVLogLevel()) +#endif + +#define VLOG(lvl) \ + if (VLOG_IS_ON(lvl)) \ + ::mace::internal::LogMessage(__FILE__, __LINE__, mace::INFO) + +// CHECK dies with a fatal error if condition is not true. It is *not* +// controlled by NDEBUG, so the check will be executed regardless of +// compilation mode. Therefore, it is safe to do things like: +// CHECK(fp->Write(x) == 4) +#define CHECK(condition) \ + if (!(condition)) \ + LOG(FATAL) << "Check failed: " #condition " " + +#define REQUIRE(condition, ...) \ + if (!(condition)) \ + LOG(FATAL) << "Check failed: " #condition " " << ::mace::internal::MakeString(__VA_ARGS__) + +template +T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) { + if (t == nullptr) { + LogMessageFatal(file, line) << string(exprtext); + } + return std::forward(t); +} + +#define CHECK_NOTNULL(val) \ + ::mace::internal::CheckNotNull(__FILE__, __LINE__, \ + "'" #val "' Must be non NULL", (val)) + +} // namespace internal +} // namespace mace + +#endif // MACE_CORE_LOGGING_H_ diff --git a/mace/core/net.cc b/mace/core/net.cc new file mode 100644 index 0000000000000000000000000000000000000000..96ee4656f0e574e6bdfbb2eeac68b6ef066a54f4 --- /dev/null +++ b/mace/core/net.cc @@ -0,0 +1,58 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/net.h" + +namespace mace { + +NetBase::NetBase(const std::shared_ptr &net_def, + Workspace *ws, + DeviceType type) { + +} + + +SimpleNet::SimpleNet(const std::shared_ptr &net_def, + Workspace *ws, + DeviceType type) : NetBase(net_def, ws, type) { + VLOG(1) << "Constructing SimpleNet " << net_def->name(); + for (int idx = 0; idx < net_def->op_size(); ++idx) { + const auto& operator_def = net_def->op(idx); + VLOG(1) << "Creating operator " << operator_def.name() << ":" + << operator_def.type(); + std::unique_ptr op {nullptr}; + OperatorDef temp_def(operator_def); + op = CreateOperator(temp_def, ws, type); + operators_.emplace_back(std::move(op)); + } +} +bool SimpleNet::Run() { + VLOG(1) << "Running net " << name_; + for (auto& op : operators_) { + VLOG(1) << "Running operator " << op->debug_def().name() << "(" + << op->debug_def().type() << ")."; + if (!op->Run()) { + LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); + return false; + } + } +} + +unique_ptr CreateNet(const NetDef& net_def, + Workspace* ws, + DeviceType type) { + std::shared_ptr tmp_net_def(new NetDef(net_def)); + return CreateNet(tmp_net_def, ws, type); +} + +unique_ptr CreateNet( + const std::shared_ptr& net_def, + Workspace* ws, + DeviceType type) { + unique_ptr net(new SimpleNet(net_def, ws, type)); + return net; +} + + +} // namespace mace \ No newline at end of file diff --git a/mace/core/net.h b/mace/core/net.h new file mode 100644 index 0000000000000000000000000000000000000000..37df69de1649d3df2f6c1ed9ea0b6622c5c9eb8a --- /dev/null +++ b/mace/core/net.h @@ -0,0 +1,52 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_NET_H_ +#define MACE_CORE_NET_H_ + +#include "mace/core/common.h" +#include "mace/proto/mace.pb.h" +#include "mace/core/operator.h" +#include "mace/core/workspace.h" + +namespace mace { + +class NetBase { + public: + NetBase(const std::shared_ptr &net_def, Workspace* ws, DeviceType type); + virtual ~NetBase() noexcept {} + + virtual bool Run() = 0; + + const string &Name() const { + return name_; + } + + protected: + string name_; + + DISABLE_COPY_AND_ASSIGN(NetBase); +}; + +class SimpleNet : public NetBase { + public: + SimpleNet(const std::shared_ptr& net_def, Workspace* ws, DeviceType type); + + virtual bool Run() override; + + protected: + vector > operators_; + + DISABLE_COPY_AND_ASSIGN(SimpleNet); +}; + +unique_ptr CreateNet(const NetDef& net_def, Workspace* ws, DeviceType type); +unique_ptr CreateNet( + const std::shared_ptr& net_def, + Workspace* ws, + DeviceType type); + +} // namespace mace + +#endif // MACE_CORE_NET_H_ diff --git a/mace/core/operator.cc b/mace/core/operator.cc new file mode 100644 index 0000000000000000000000000000000000000000..a233b273859d8fa90a493e14c804281418d6ee2c --- /dev/null +++ b/mace/core/operator.cc @@ -0,0 +1,36 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" + +namespace mace { + +std::map* gDeviceTypeRegistry() { + static std::map g_device_type_registry; + return &g_device_type_registry; +} + +MACE_DEFINE_REGISTRY( + CPUOperatorRegistry, + OperatorBase, + const OperatorDef&, + Workspace*); +MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry); + +unique_ptr CreateOperator( + const OperatorDef& operator_def, + Workspace* ws, + DeviceType type) { + OperatorRegistry* registry = gDeviceTypeRegistry()->at(type); + return registry->Create(operator_def.type(), operator_def, ws); +} + + +OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws) + : operator_ws_(ws), + operator_def_(std::make_shared(operator_def)) { +} + + +} // namespace mace \ No newline at end of file diff --git a/mace/core/operator.h b/mace/core/operator.h new file mode 100644 index 0000000000000000000000000000000000000000..b079bdac4f6fb88fcf702af640dc6d4e9110a142 --- /dev/null +++ b/mace/core/operator.h @@ -0,0 +1,161 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_OPERATOR_H +#define MACE_CORE_OPERATOR_H + +#include "mace/core/proto_utils.h" +#include "mace/core/common.h" +#include "mace/proto/mace.pb.h" +#include "mace/core/tensor.h" +#include "mace/core/registry.h" +#include "mace/core/workspace.h" + +namespace mace { + +class OperatorBase { + public: + explicit OperatorBase(const OperatorDef &operator_def, Workspace *ws); + virtual ~OperatorBase() noexcept {} + + inline bool HasArgument(const string &name) const { + REQUIRE(operator_def_, "operator_def was null!"); + return ArgumentHelper::HasArgument(*operator_def_, name); + } + template + inline T GetSingleArgument(const string &name, const T &default_value) const { + REQUIRE(operator_def_, "operator_def was null!"); + return ArgumentHelper::GetSingleArgument( + *operator_def_, name, default_value); + } + template + inline bool HasSingleArgumentOfType(const string &name) const { + REQUIRE(operator_def_, "operator_def was null!"); + return ArgumentHelper::HasSingleArgumentOfType( + *operator_def_, name); + } + template + inline vector GetRepeatedArgument( + const string &name, + const vector &default_value = {}) const { + REQUIRE(operator_def_, "operator_def was null!"); + return ArgumentHelper::GetRepeatedArgument( + *operator_def_, name, default_value); + } + + inline const Tensor *Input(int idx) { + CHECK(idx < inputs_.size()); + return inputs_[idx]; + } + + inline Tensor *Output(int idx) { + return outputs_[idx]; + } + + inline int InputSize() { return inputs_.size(); } + inline int OutputSize() { return outputs_.size(); } + inline const vector &Inputs() const { return inputs_; } + inline const vector &Outputs() { return outputs_; } + + virtual bool Run() { + MACE_NOT_IMPLEMENTED; + return false; + } + + inline const OperatorDef &debug_def() const { + REQUIRE(has_debug_def(), "operator_def was null!"); + return *operator_def_; + } + + inline void set_debug_def( + const std::shared_ptr &operator_def) { + operator_def_ = operator_def; + } + + inline bool has_debug_def() const { + return operator_def_ != nullptr; + } + + protected: + Workspace *operator_ws_; + std::shared_ptr operator_def_; + vector inputs_; + vector outputs_; + + DISABLE_COPY_AND_ASSIGN(OperatorBase); +}; + +template +class Operator : public OperatorBase { + public: + explicit Operator(const OperatorDef &operator_def, Workspace *ws) + : OperatorBase(operator_def, ws) { + for (const string &input_str : operator_def.input()) { + const Tensor *tensor = ws->GetTensor(input_str); + REQUIRE( + tensor != nullptr, + "op ", + operator_def.type(), + ": Encountered a non-existing input tensor: ", + input_str); + inputs_.push_back(tensor); + } + + for (const string &output_str : operator_def.output()) { + outputs_.push_back(CHECK_NOTNULL(ws->CreateTensor(output_str, + DeviceContext::alloctor(), + DataTypeToEnum::v()))); + } + } + virtual bool Run() { + MACE_NOT_IMPLEMENTED; + return false; + } + ~Operator() noexcept override {} +}; + +typedef Registry + OperatorRegistry; +typedef Registry *( + *RegistryFunction)(); +std::map *gDeviceTypeRegistry(); + +struct DeviceTypeRegisterer { + explicit DeviceTypeRegisterer(int32_t type, RegistryFunction func) { + if (gDeviceTypeRegistry()->count(type)) { + LOG(ERROR) << "Device type " << type + << "registered twice. This should not happen. Did you have " + "duplicated numbers assigned to different devices?"; + std::exit(1); + } + // Calling the registry function to get the actual registry pointer. + gDeviceTypeRegistry()->emplace(type, func()); + } +}; + +#define MACE_REGISTER_DEVICE_TYPE(type, registry_function) \ + namespace { \ + static DeviceTypeRegisterer MACE_ANONYMOUS_VARIABLE( \ + DeviceType)(type, ®istry_function); \ + } + +MACE_DECLARE_REGISTRY( + CPUOperatorRegistry, + OperatorBase, + const OperatorDef&, + Workspace*); + +#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ + MACE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) +#define REGISTER_CPU_OPERATOR(name, ...) \ + MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) + +unique_ptr CreateOperator( + const OperatorDef &operator_def, + Workspace *ws, + DeviceType type); + +} // namespace mace + +#endif //MACE_CORE_OPERATOR_H diff --git a/mace/core/proto_utils.cc b/mace/core/proto_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..db9af80a80fba32bfac466235e61b2afbd8cbdb4 --- /dev/null +++ b/mace/core/proto_utils.cc @@ -0,0 +1,371 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/proto_utils.h" + +#include +#include +#include +#include + +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +#ifndef MACE_USE_LITE_PROTO +#include "google/protobuf/text_format.h" +#endif // !MACE_USE_LITE_PROTO + +namespace mace { + +bool ReadStringFromFile(const char* filename, string* str) { + std::ifstream ifs(filename, std::ios::in); + if (!ifs) { + VLOG(1) << "File cannot be opened: " << filename + << " error: " << ifs.rdstate(); + return false; + } + ifs.seekg(0, std::ios::end); + size_t n = ifs.tellg(); + str->resize(n); + ifs.seekg(0); + ifs.read(&(*str)[0], n); + return true; +} + +bool WriteStringToFile(const string& str, const char* filename) { + std::ofstream ofs(filename, std::ios::out | std::ios::trunc); + if (!ofs.is_open()) { + VLOG(1) << "File cannot be created: " << filename + << " error: " << ofs.rdstate(); + return false; + } + ofs << str; + return true; +} + +// IO-specific proto functions: we will deal with the protocol buffer lite and +// full versions differently. + +#ifdef MACE_USE_LITE_PROTO + +// Lite runtime. + +namespace { +class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { + public: + explicit IfstreamInputStream(const string& filename) + : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {} + ~IfstreamInputStream() { ifs_.close(); } + + int Read(void* buffer, int size) { + if (!ifs_) { + return -1; + } + ifs_.read(static_cast(buffer), size); + return ifs_.gcount(); + } + + private: + std::ifstream ifs_; +}; +} // namespace + +bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { + ::google::protobuf::io::CopyingInputStreamAdaptor stream( + new IfstreamInputStream(filename)); + stream.SetOwnsCopyingStream(true); + // Total bytes hard limit / warning limit are set to 1GB and 512MB + // respectively. + ::google::protobuf::io::CodedInputStream coded_stream(&stream); + coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); + return proto->ParseFromCodedStream(&coded_stream); +} + +void WriteProtoToBinaryFile( + const MessageLite& /*proto*/, + const char* /*filename*/) { + LOG(FATAL) << "Not implemented yet."; +} + +#else // MACE_USE_LITE_PROTO + +// Full protocol buffer. + +using ::google::protobuf::io::FileInputStream; +using ::google::protobuf::io::FileOutputStream; +using ::google::protobuf::io::ZeroCopyInputStream; +using ::google::protobuf::io::CodedInputStream; +using ::google::protobuf::io::ZeroCopyOutputStream; +using ::google::protobuf::io::CodedOutputStream; + +bool ReadProtoFromTextFile(const char* filename, Message* proto) { + int fd = open(filename, O_RDONLY); + REQUIRE(fd != -1, "File not found: ", filename); + FileInputStream* input = new FileInputStream(fd); + bool success = google::protobuf::TextFormat::Parse(input, proto); + delete input; + close(fd); + return success; +} + +void WriteProtoToTextFile(const Message& proto, const char* filename) { + int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); + FileOutputStream* output = new FileOutputStream(fd); + CHECK(google::protobuf::TextFormat::Print(proto, output)); + delete output; + close(fd); +} + +bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { +#if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified + int fd = open(filename, O_RDONLY | O_BINARY); +#else + int fd = open(filename, O_RDONLY); +#endif + REQUIRE(fd != -1, "File not found: ", filename); + std::unique_ptr raw_input(new FileInputStream(fd)); + std::unique_ptr coded_input( + new CodedInputStream(raw_input.get())); + // A hack to manually allow using very large protocol buffers. + coded_input->SetTotalBytesLimit(1073741824, 536870912); + bool success = proto->ParseFromCodedStream(coded_input.get()); + coded_input.reset(); + raw_input.reset(); + close(fd); + return success; +} + +void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) { + int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); + REQUIRE( + fd != -1, "File cannot be created: ", filename, " error number: ", errno); + std::unique_ptr raw_output(new FileOutputStream(fd)); + std::unique_ptr coded_output( + new CodedOutputStream(raw_output.get())); + CHECK(proto.SerializeToCodedStream(coded_output.get())); + coded_output.reset(); + raw_output.reset(); + close(fd); +} + +#endif // MACE_USE_LITE_PROTO + +ArgumentHelper::ArgumentHelper(const OperatorDef &def) { + for (auto &arg : def.arg()) { + if (arg_map_.count(arg.name())) { + REQUIRE( + arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(), + "Found argument of the same name ", + arg.name(), + "but with different contents.", + ProtoDebugString(def)); + } else { + LOG(WARNING) << "Duplicated argument name found in operator def: " + << ProtoDebugString(def); + } + + arg_map_[arg.name()] = arg; + } +} + +ArgumentHelper::ArgumentHelper(const NetDef& netdef) { + for (auto& arg : netdef.arg()) { + REQUIRE( + arg_map_.count(arg.name()) == 0, + "Duplicated argument name found in net def: ", + ProtoDebugString(netdef)); + arg_map_[arg.name()] = arg; + } +} + +bool ArgumentHelper::HasArgument(const string& name) const { + return arg_map_.count(name); +} + +namespace { +// Helper function to verify that conversion between types won't loose any +// significant bit. +template +bool SupportsLosslessConversion(const InputType& value) { + return static_cast(static_cast(value)) == value; +} +} + +#define INSTANTIATE_GET_SINGLE_ARGUMENT( \ + T, fieldname, enforce_lossless_conversion) \ + template <> \ + T ArgumentHelper::GetSingleArgument( \ + const string& name, const T& default_value) const { \ + if (arg_map_.count(name) == 0) { \ + VLOG(1) << "Using default parameter value " << default_value \ + << " for parameter " << name; \ + return default_value; \ + } \ + REQUIRE( \ + arg_map_.at(name).has_##fieldname(), \ + "Argument ", \ + name, \ + " does not have the right field: expected field " #fieldname); \ + auto value = arg_map_.at(name).fieldname(); \ + if (enforce_lossless_conversion) { \ + auto supportsConversion = \ + SupportsLosslessConversion(value); \ + REQUIRE( \ + supportsConversion, \ + "Value", \ + value, \ + " of argument ", \ + name, \ + "cannot be represented correctly in a target type"); \ + } \ + return value; \ + } \ + template <> \ + bool ArgumentHelper::HasSingleArgumentOfType(const string& name) const { \ + if (arg_map_.count(name) == 0) { \ + return false; \ + } \ + return arg_map_.at(name).has_##fieldname(); \ + } + +INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false) +INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false) +INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false) +INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true) +INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false) +#undef INSTANTIATE_GET_SINGLE_ARGUMENT + +#define INSTANTIATE_GET_REPEATED_ARGUMENT( \ + T, fieldname, enforce_lossless_conversion) \ + template <> \ + vector ArgumentHelper::GetRepeatedArgument( \ + const string& name, const std::vector& default_value) const { \ + if (arg_map_.count(name) == 0) { \ + return default_value; \ + } \ + vector values; \ + for (const auto& v : arg_map_.at(name).fieldname()) { \ + if (enforce_lossless_conversion) { \ + auto supportsConversion = \ + SupportsLosslessConversion(v); \ + REQUIRE( \ + supportsConversion, \ + "Value", \ + v, \ + " of argument ", \ + name, \ + "cannot be represented correctly in a target type"); \ + } \ + values.push_back(v); \ + } \ + return values; \ + } + +INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false) +INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false) +INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false) +INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true) +INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) +#undef INSTANTIATE_GET_REPEATED_ARGUMENT + +#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ +template <> \ +Argument MakeArgument(const string& name, const T& value) { \ + Argument arg; \ + arg.set_name(name); \ + arg.set_##fieldname(value); \ + return arg; \ +} + +MACE_MAKE_SINGULAR_ARGUMENT(bool, i) +MACE_MAKE_SINGULAR_ARGUMENT(float, f) +MACE_MAKE_SINGULAR_ARGUMENT(int, i) +MACE_MAKE_SINGULAR_ARGUMENT(int64_t, i) +MACE_MAKE_SINGULAR_ARGUMENT(string, s) +#undef MACE_MAKE_SINGULAR_ARGUMENT + +template <> +Argument MakeArgument(const string& name, const MessageLite& value) { + Argument arg; + arg.set_name(name); + arg.set_s(value.SerializeAsString()); + return arg; +} + +#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \ +template <> \ +Argument MakeArgument(const string& name, const vector& value) { \ + Argument arg; \ + arg.set_name(name); \ + for (const auto& v : value) { \ + arg.add_##fieldname(v); \ + } \ + return arg; \ +} + +MACE_MAKE_REPEATED_ARGUMENT(float, floats) +MACE_MAKE_REPEATED_ARGUMENT(int, ints) +MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints) +MACE_MAKE_REPEATED_ARGUMENT(string, strings) +#undef MACE_MAKE_REPEATED_ARGUMENT + +const Argument& GetArgument(const OperatorDef& def, const string& name) { + for (const Argument& arg : def.arg()) { + if (arg.name() == name) { + return arg; + } + } + REQUIRE(false, + "Argument named ", + name, + "does not exist in operator ", + ProtoDebugString(def)); +} + +bool GetFlagArgument( + const OperatorDef& def, + const string& name, + bool def_value) { + for (const Argument& arg : def.arg()) { + if (arg.name() == name) { + REQUIRE( + arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg)); + return arg.i(); + } + } + return def_value; +} + +Argument* GetMutableArgument( + const string& name, + const bool create_if_missing, + OperatorDef* def) { + for (int i = 0; i < def->arg_size(); ++i) { + if (def->arg(i).name() == name) { + return def->mutable_arg(i); + } + } + // If no argument of the right name is found... + if (create_if_missing) { + Argument* arg = def->add_arg(); + arg->set_name(name); + return arg; + } else { + return nullptr; + } +} + +} // namespace mace diff --git a/mace/core/proto_utils.h b/mace/core/proto_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..263b01b21d35da976aae17eca61637b7f7329bf3 --- /dev/null +++ b/mace/core/proto_utils.h @@ -0,0 +1,265 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_PROTO_UTILS_H_ +#define MACE_CORE_PROTO_UTILS_H_ + +#include + +#include "google/protobuf/message_lite.h" +#ifndef MACE_USE_LITE_PROTO +#include "google/protobuf/message.h" +#endif // !MACE_USE_LITE_PROTO + +#include "mace/proto/mace.pb.h" +#include "mace/core/common.h" + +namespace mace { + +using std::string; +using ::google::protobuf::MessageLite; + + +// Common interfaces that reads file contents into a string. +bool ReadStringFromFile(const char* filename, string* str); +bool WriteStringToFile(const string& str, const char* filename); + +// Common interfaces that are supported by both lite and full protobuf. +bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto); +inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) { + return ReadProtoFromBinaryFile(filename.c_str(), proto); +} + +void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename); +inline void WriteProtoToBinaryFile(const MessageLite& proto, + const string& filename) { + return WriteProtoToBinaryFile(proto, filename.c_str()); +} + +#ifdef MACE_USE_LITE_PROTO + +inline string ProtoDebugString(const MessageLite& proto) { + return proto.SerializeAsString(); +} + +// Text format MessageLite wrappers: these functions do nothing but just +// allowing things to compile. It will produce a runtime error if you are using +// MessageLite but still want text support. +inline bool ReadProtoFromTextFile( + const char* /*filename*/, + MessageLite* /*proto*/) { + LOG(FATAL) << "If you are running lite version, you should not be " + << "calling any text-format protobuffers."; + return false; // Just to suppress compiler warning. +} +inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { + return ReadProtoFromTextFile(filename.c_str(), proto); +} + +inline void WriteProtoToTextFile( + const MessageLite& /*proto*/, + const char* /*filename*/) { + LOG(FATAL) << "If you are running lite version, you should not be " + << "calling any text-format protobuffers."; +} +inline void WriteProtoToTextFile(const MessageLite& proto, + const string& filename) { + return WriteProtoToTextFile(proto, filename.c_str()); +} + +inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) { + return (ReadProtoFromBinaryFile(filename, proto) || + ReadProtoFromTextFile(filename, proto)); +} + +inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) { + return ReadProtoFromFile(filename.c_str(), proto); +} + +#else // MACE_USE_LITE_PROTO + +using ::google::protobuf::Message; + +inline string ProtoDebugString(const Message& proto) { + return proto.ShortDebugString(); +} + +bool ReadProtoFromTextFile(const char* filename, Message* proto); +inline bool ReadProtoFromTextFile(const string filename, Message* proto) { + return ReadProtoFromTextFile(filename.c_str(), proto); +} + +void WriteProtoToTextFile(const Message& proto, const char* filename); +inline void WriteProtoToTextFile(const Message& proto, const string& filename) { + return WriteProtoToTextFile(proto, filename.c_str()); +} + +// Read Proto from a file, letting the code figure out if it is text or binary. +inline bool ReadProtoFromFile(const char* filename, Message* proto) { + return (ReadProtoFromBinaryFile(filename, proto) || + ReadProtoFromTextFile(filename, proto)); +} + +inline bool ReadProtoFromFile(const string& filename, Message* proto) { + return ReadProtoFromFile(filename.c_str(), proto); +} + +#endif // MACE_USE_LITE_PROTO + +template < + class IterableInputs = std::initializer_list, + class IterableOutputs = std::initializer_list, + class IterableArgs = std::initializer_list> +OperatorDef CreateOperatorDef( + const string& type, + const string& name, + const IterableInputs& inputs, + const IterableOutputs& outputs, + const IterableArgs& args) { + OperatorDef def; + def.set_type(type); + def.set_name(name); + for (const string& in : inputs) { + def.add_input(in); + } + for (const string& out : outputs) { + def.add_output(out); + } + for (const Argument& arg : args) { + def.add_arg()->CopyFrom(arg); + } + return def; +} + +// A simplified version compared to the full CreateOperator, if you do not need +// to specify args. +template < + class IterableInputs = std::initializer_list, + class IterableOutputs = std::initializer_list> +inline OperatorDef CreateOperatorDef( + const string& type, + const string& name, + const IterableInputs& inputs, + const IterableOutputs& outputs) { + return CreateOperatorDef( + type, + name, + inputs, + outputs, + std::vector()); +} + +/** + * @brief A helper class to index into arguments. + * + * This helper helps us to more easily index into a set of arguments + * that are present in the operator. To save memory, the argument helper + * does not copy the operator def, so one would need to make sure that the + * lifetime of the OperatorDef object outlives that of the ArgumentHelper. + */ +class ArgumentHelper { + public: + template + static bool HasArgument(const Def& def, const string& name) { + return ArgumentHelper(def).HasArgument(name); + } + + template + static T GetSingleArgument( + const Def& def, + const string& name, + const T& default_value) { + return ArgumentHelper(def).GetSingleArgument(name, default_value); + } + + template + static bool HasSingleArgumentOfType(const Def& def, const string& name) { + return ArgumentHelper(def).HasSingleArgumentOfType(name); + } + + template + static vector GetRepeatedArgument( + const Def& def, + const string& name, + const std::vector& default_value = std::vector()) { + return ArgumentHelper(def).GetRepeatedArgument(name, default_value); + } + + template + static MessageType GetMessageArgument(const Def& def, const string& name) { + return ArgumentHelper(def).GetMessageArgument(name); + } + + template + static vector GetRepeatedMessageArgument( + const Def& def, + const string& name) { + return ArgumentHelper(def).GetRepeatedMessageArgument(name); + } + + explicit ArgumentHelper(const OperatorDef& def); + explicit ArgumentHelper(const NetDef& netdef); + bool HasArgument(const string& name) const; + + template + T GetSingleArgument(const string& name, const T& default_value) const; + template + bool HasSingleArgumentOfType(const string& name) const; + template + vector GetRepeatedArgument( + const string& name, + const std::vector& default_value = std::vector()) const; + + template + MessageType GetMessageArgument(const string& name) const { + REQUIRE(arg_map_.count(name), "Cannot find parameter named " + name); + MessageType message; + if (arg_map_.at(name).has_s()) { + REQUIRE( + message.ParseFromString(arg_map_.at(name).s()), + "Faild to parse content from the string"); + } else { + VLOG(1) << "Return empty message for parameter " << name; + } + return message; + } + + template + vector GetRepeatedMessageArgument(const string& name) const { + REQUIRE(arg_map_.count(name), "Cannot find parameter named " + name); + vector messages(arg_map_.at(name).strings_size()); + for (int i = 0; i < messages.size(); ++i) { + REQUIRE( + messages[i].ParseFromString(arg_map_.at(name).strings(i)), + "Faild to parse content from the string"); + } + return messages; + } + + private: + std::map arg_map_; +}; + +const Argument& GetArgument(const OperatorDef& def, const string& name); +bool GetFlagArgument( + const OperatorDef& def, + const string& name, + bool def_value = false); + +Argument* GetMutableArgument( + const string& name, + const bool create_if_missing, + OperatorDef* def); + +template +Argument MakeArgument(const string& name, const T& value); + +template +inline void AddArgument(const string& name, const T& value, OperatorDef* def) { + GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value)); +} + +} // namespace mace + +#endif // MACE_CORE_PROTO_UTILS_H_ diff --git a/mace/core/registry.h b/mace/core/registry.h new file mode 100644 index 0000000000000000000000000000000000000000..4064e7d501bb49b0f755adf848fbdf14b6eb7af7 --- /dev/null +++ b/mace/core/registry.h @@ -0,0 +1,120 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_REGISTRY_H_ +#define MACE_CORE_REGISTRY_H_ + +#include +#include +#include +#include +#include "mace/core/common.h" + +namespace mace { + +template +class Registry { + public: + typedef std::function (Args ...)> Creator; + + Registry() : registry_() {} + + void Register(const SrcType& key, Creator creator) { + std::lock_guard lock(register_mutex_); + REQUIRE(registry_.count(key) == 0, "Key already registered."); + registry_[key] = creator; + } + + inline bool Has(const SrcType& key) { return registry_.count(key) != 0; } + + unique_ptr Create(const SrcType& key, Args ... args) { + if (registry_.count(key) == 0) { + return nullptr; + } + return registry_[key](args...); + } + + /** + * Returns the keys currently registered as a vector. + */ + vector Keys() { + vector keys; + for (const auto& it : registry_) { + keys.push_back(it.first); + } + return keys; + } + + private: + std::map registry_; + std::mutex register_mutex_; + + DISABLE_COPY_AND_ASSIGN(Registry); +}; + +template +class Registerer { + public: + Registerer(const SrcType& key, + Registry* registry, + typename Registry::Creator creator) { + registry->Register(key, creator); + } + + template + static unique_ptr DefaultCreator(Args ... args) { + return std::unique_ptr(new DerivedType(args...)); + } +}; + +#define MACE_CONCATENATE_IMPL(s1, s2) s1##s2 +#define MACE_CONCATENATE(s1, s2) MACE_CONCATENATE_IMPL(s1, s2) +#ifdef __COUNTER__ +#define MACE_ANONYMOUS_VARIABLE(str) MACE_CONCATENATE(str, __COUNTER__) +#else +#define MACE_ANONYMOUS_VARIABLE(str) MACE_CONCATENATE(str, __LINE__) +#endif + +#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \ + Registry* RegistryName(); \ + typedef Registerer \ + Registerer##RegistryName; + +#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \ + Registry* RegistryName() { \ + static Registry* registry = \ + new Registry(); \ + return registry; \ + } + +#define MACE_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + MACE_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, ##__VA_ARGS__) + +#define MACE_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ + MACE_DEFINE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, ##__VA_ARGS__) + +#define MACE_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ + namespace { \ + static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, RegistryName(), __VA_ARGS__); + +#define MACE_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ + namespace { \ + static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, \ + RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>); \ + } + +#define MACE_REGISTER_CREATOR(RegistryName, key, ...) \ + MACE_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__) + +#define MACE_REGISTER_CLASS(RegistryName, key, ...) \ + MACE_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) + +} // namespace mace + +#endif // MACE_CORE_REGISTRY_H_ diff --git a/mace/core/tensor.h b/mace/core/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..d1059971a249dc9194a899665ac5d5f25ff94a92 --- /dev/null +++ b/mace/core/tensor.h @@ -0,0 +1,139 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_TENSOR_H_ +#define MACE_CORE_TENSOR_H_ + +#include "mace/core/common.h" +#include "mace/proto/mace.pb.h" +#include "mace/core/allocator.h" +#include "mace/core/types.h" + +namespace mace { + +#define SINGLE_ARG(...) __VA_ARGS__ +#define CASE(TYPE, STMTS) \ + case DataTypeToEnum::value: { \ + typedef TYPE T; \ + STMTS; \ + break; \ + } + +#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \ + switch (TYPE_ENUM) { \ + CASE(float, SINGLE_ARG(STMTS)) \ + CASE(double, SINGLE_ARG(STMTS)) \ + CASE(int32, SINGLE_ARG(STMTS)) \ + CASE(uint8, SINGLE_ARG(STMTS)) \ + CASE(uint16, SINGLE_ARG(STMTS)) \ + CASE(int16, SINGLE_ARG(STMTS)) \ + CASE(int8, SINGLE_ARG(STMTS)) \ + CASE(string, SINGLE_ARG(STMTS)) \ + CASE(int64, SINGLE_ARG(STMTS)) \ + CASE(bool, SINGLE_ARG(STMTS)) \ + case DT_INVALID: \ + INVALID; \ + break; \ + default: \ + DEFAULT; \ + break; \ + } + + +#define CASES(TYPE_ENUM, STMTS) \ + CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \ + , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;) + + +class Tensor { + public: + Tensor() + : alloc_(cpu_allocator()), + size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; + + Tensor(Allocator* a, DataType type) + : alloc_(a), size_(0), dtype_(DT_FLOAT), data_(nullptr) {}; + + ~Tensor() { + if (alloc_ && data_.get()) { + data_.reset(); + } + } + + inline DataType dtype() const { return dtype_; } + + inline const vector& shape() const { return shape_; } + + inline int64 NumElements() const { + return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); + } + + inline TIndex dim_size() { return shape_.size(); } + + inline TIndex size() const { return size_; } + + inline const void* raw_data() const { + CHECK(data_.get() || size_ == 0); + return data_.get(); + } + + template + inline const T* data() const { + REQUIRE( + data_.get() || size_ == 0, + "The tensor is of non-zero shape, but its data is not allocated yet. "); + return static_cast(data_.get()); + } + + void Deleter(void* data) { + alloc_->Delete(data); + } + + inline void* raw_mutable_data() { + if (data_.get() || size_ == 0) { + return data_.get(); + } else { + CASES(dtype_, data_.reset(alloc_->New(size_ * sizeof(T)), [this](void* ptr) { + alloc_->Delete(ptr); + })); + return data_.get(); + } + } + + template + inline T* mutable_data() { + if (size_ == 0 || data_.get()) { + return static_cast(data_.get()); + } + return static_cast(raw_mutable_data()); + } + + inline void Resize(const vector& shape) { + shape_ = shape; + TIndex size = NumElements(); + if (size_ != size) { + size_ = NumElements(); + data_.reset(); + } + } + + inline void ResizeLike(const Tensor& other) { + Resize(other.shape()); + } + + inline void ResizeLike(const Tensor* other) { + Resize(other->shape()); + } + + private: + Allocator* alloc_; + TIndex size_; + DataType dtype_; + std::shared_ptr data_; + vector shape_; +}; + +} // namespace tensor + +#endif //MACE_CORE_TENSOR_H_ diff --git a/mace/core/types.h b/mace/core/types.h new file mode 100644 index 0000000000000000000000000000000000000000..476fa54dafb3885d565b753ccd7bd520bbeccb9e --- /dev/null +++ b/mace/core/types.h @@ -0,0 +1,56 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_TYPES_H_ +#define MACE_CORE_TYPES_H_ + +#include "mace/core/common.h" +#include "mace/proto/mace.pb.h" + +namespace mace { + +template +struct IsValidDataType; + +template +struct DataTypeToEnum { + static_assert(IsValidDataType::value, "Specified Data Type not supported"); +}; // Specializations below + + +// EnumToDataType::Type is the type for DataType constant VALUE, e.g. +// EnumToDataType::Type is float. +template +struct EnumToDataType {}; // Specializations below + +// Template specialization for both DataTypeToEnum and EnumToDataType. +#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ + template <> \ + struct DataTypeToEnum { \ + static DataType v() { return ENUM; } \ + static constexpr DataType value = ENUM; \ + }; \ + template <> \ + struct IsValidDataType { \ + static constexpr bool value = true; \ + }; \ + template <> \ + struct EnumToDataType { \ + typedef TYPE Type; \ + } + +MATCH_TYPE_AND_ENUM(float, DT_FLOAT); +MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); +MATCH_TYPE_AND_ENUM(int32, DT_INT32); +MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); +MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); +MATCH_TYPE_AND_ENUM(int16, DT_INT16); +MATCH_TYPE_AND_ENUM(int8, DT_INT8); +MATCH_TYPE_AND_ENUM(string, DT_STRING); +MATCH_TYPE_AND_ENUM(int64, DT_INT64); +MATCH_TYPE_AND_ENUM(bool, DT_BOOL); + +} // namespace mace + +#endif // MACE_CORE_TYPES_H_ diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc new file mode 100644 index 0000000000000000000000000000000000000000..20ff5727d937c44566874f9d7b8b3df718153fa4 --- /dev/null +++ b/mace/core/workspace.cc @@ -0,0 +1,51 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/common.h" +#include "mace/core/workspace.h" + +namespace mace { + +vector Workspace::Tensors() const { + vector names; + for (auto& entry : tensor_map_) { + names.push_back(entry.first); + } + return names; +} + +Tensor* Workspace::CreateTensor(const string& name, Allocator* alloc, DataType type) { + if (HasTensor(name)) { + VLOG(1) << "Tensor " << name << " already exists. Skipping."; + } else { + VLOG(1) << "Creating Tensor " << name; + tensor_map_[name] = unique_ptr(new Tensor(alloc, type)); + } + return GetTensor(name); +} + +bool Workspace::RemoveTensor(const string& name) { + auto it = tensor_map_.find(name); + if (it != tensor_map_.end()) { + VLOG(1) << "Removing blob " << name << " from this workspace."; + tensor_map_.erase(it); + return true; + } + return false; +} + +const Tensor* Workspace::GetTensor(const string& name) const { + if (tensor_map_.count(name)) { + return tensor_map_.at(name).get(); + } + return nullptr; +} + +Tensor* Workspace::GetTensor(const string& name) { + return const_cast(static_cast(this)->GetTensor(name)); +} + +bool RunNet(); + +} // namespace mace \ No newline at end of file diff --git a/mace/core/workspace.h b/mace/core/workspace.h new file mode 100644 index 0000000000000000000000000000000000000000..3f16077fbd8a7940030d17a6f0670f9030e56004 --- /dev/null +++ b/mace/core/workspace.h @@ -0,0 +1,44 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_WORKSPACE_H_ +#define MACE_CORE_WORKSPACE_H_ + + +#include "mace/core/common.h" +#include "mace/core/tensor.h" +#include "mace/proto/mace.pb.h" + +namespace mace { + +class Workspace { + public: + typedef std::map> TensorMap; + + Workspace() {} + + vector Tensors() const; + + Tensor* CreateTensor(const string& name, Allocator* alloc, DataType type); + + bool RemoveTensor(const string& name); + + inline bool HasTensor(const string& name) const { + return tensor_map_.count(name); + } + + const Tensor* GetTensor(const string& name) const; + + Tensor* GetTensor(const string& name); + + private: + TensorMap tensor_map_; + + DISABLE_COPY_AND_ASSIGN(Workspace); +}; + +} // namespace mace + + +#endif // MACE_CORE_WORKSPACE_H_ diff --git a/mace/ops/BUILD b/mace/ops/BUILD index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..1f06a345ee83fb4a196eb6c1cd987a6ba7b38069 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -0,0 +1,28 @@ +# Description: +# Mace operators. +# +package( + default_visibility = ["//visibility:public"], +) + + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "op", + srcs = ["relu.cc"], + hdrs = glob(["*.h"]), + deps = [ + "//mace/proto:cc_proto", + "//mace/core:core", + ], +) + +cc_test( + name = "op_test", + srcs = ["relu_test.cc",], + deps = [ + "@gtest//:gtest", + ":op", + ], +) \ No newline at end of file diff --git a/mace/ops/relu.cc b/mace/ops/relu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a31c25ddf96b8fbc5ffd1233690576f3b4d5a57d --- /dev/null +++ b/mace/ops/relu.cc @@ -0,0 +1,28 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/relu.h" +#include "mace/proto/mace.pb.h" + +namespace mace { + +template <> +bool ReluOp::Run() { + const Tensor* X = Input(0); + Tensor* Y = Output(0); + Y->ResizeLike(X); + + const float* Xdata = X-> data(); + float* Ydata = Y->mutable_data(); + for (int i = 0; i < X->size(); ++i) { + Ydata[i] = std::max(Xdata[i], 0.f); + VLOG(0) << i << ": " << Xdata[i] << " " << Ydata[i]; + } + + return true; +} + +REGISTER_CPU_OPERATOR(Relu, ReluOp); + +} // namespace mace \ No newline at end of file diff --git a/mace/ops/relu.h b/mace/ops/relu.h new file mode 100644 index 0000000000000000000000000000000000000000..b658050652f4a04d743ba80a078f14ebb2219f9c --- /dev/null +++ b/mace/ops/relu.h @@ -0,0 +1,22 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPERATORS_RELU_H_ +#define MACE_OPERATORS_RELU_H_ + +#include "mace/core/operator.h" + +namespace mace { + +template +class ReluOp : public Operator { + public: + ReluOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) {} + bool Run() override; +}; + +} // namespace mace + +#endif // MACE_OPERATORS_RELU_H_ diff --git a/mace/ops/relu_test.cc b/mace/ops/relu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..209fe83a890081163ccabfd7091e1537b3a230e6 --- /dev/null +++ b/mace/ops/relu_test.cc @@ -0,0 +1,52 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "gtest/gtest.h" + +#include "mace/core/operator.h" +#include "mace/core/net.h" + +using namespace mace; + +TEST(ReluTest, Relu) { + OperatorRegistry* registry = gDeviceTypeRegistry()->at(DeviceType::CPU); + vector registry_keys = registry->Keys(); + for (auto& key: registry_keys) { + VLOG(0) << "registry_op: " << key; + } + + // Construct graph + OperatorDef op_def; + op_def.add_input("Input0"); + op_def.add_output("Output0"); + op_def.set_name("ReluTest"); + op_def.set_type("Relu"); + auto arg = op_def.add_arg(); + arg->set_name("arg0"); + arg->set_f(1.5); + + NetDef net_def; + net_def.set_name("NetTest"); + net_def.add_op()->CopyFrom(op_def); + + VLOG(0) << net_def.DebugString(); + + // Create workspace and input tensor + Workspace ws; + Tensor* input = ws.CreateTensor("Input0", cpu_allocator(), DataType::DT_FLOAT); + input->Resize({2,3}); + float* input_data = input->mutable_data(); + for (int i = 0; i < 6; ++i) { + input_data[i] = i-3; + } + + // Create Net & run + auto net = CreateNet(net_def, &ws, DeviceType::CPU); + net->Run(); + + // Create Op & run + auto op = CreateOperator(op_def, &ws, DeviceType::CPU); + ASSERT_FLOAT_EQ(1.5f, op->GetSingleArgument("arg0", 1.0f)); + +} \ No newline at end of file diff --git a/mace/proto/BUILD b/mace/proto/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..bb2512bcf34f95a703760010e7592bc8f1ec4e39 --- /dev/null +++ b/mace/proto/BUILD @@ -0,0 +1,16 @@ +# Description: +# mace proto. +# +package( + default_visibility = ["//visibility:public"], +) + +proto_library( + name = "proto", + srcs = ["mace.proto"], +) + +cc_proto_library( + name = "cc_proto", + deps = [":proto"], +) \ No newline at end of file diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto new file mode 100644 index 0000000000000000000000000000000000000000..10c37f12267b996a30265a40cbffbf89ef01bb2a --- /dev/null +++ b/mace/proto/mace.proto @@ -0,0 +1,73 @@ +syntax = "proto2"; + +package mace; + +enum DeviceType { + CPU = 0; // In default, we will use CPU. + GPU = 1; +} + +enum DataType { + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_INT64 = 8; + DT_UINT16 = 9; + DT_BOOL = 10; +} + +message TensorProto { + // The dimensions in the tensor. + repeated int64 dims = 1; + optional DataType data_type = 2 [default = DT_FLOAT]; + // For float + repeated float float_data = 3 [packed = true]; + // For int32, uint8, int8, uint16, int16, bool, and float16 + // Note about float16: in storage we will basically convert float16 byte-wise + // to unsigned short and then store them in the int32_data field. + repeated int32 int32_data = 4 [packed = true]; + // For bytes + optional bytes byte_data = 5; + // For strings + repeated bytes string_data = 6; + // For double + repeated double double_data = 9 [packed = true]; + // For int64 + repeated int64 int64_data = 10 [packed = true]; + // Optionally, a name for the tensor. + optional string name = 7; +} + +message Argument { + optional string name = 1; + optional float f = 2; + optional int64 i = 3; + optional bytes s = 4; + repeated float floats = 5; + repeated int64 ints = 6; + repeated bytes strings = 7; +} + +message OperatorDef { + repeated string input = 1; + repeated string output = 2; + optional string name = 3; + optional string type = 4; + repeated Argument arg = 5; +} + +message NetDef { + optional string name = 1; + repeated OperatorDef op = 2; + optional string version = 3; + repeated Argument arg = 4; + repeated TensorProto tensors = 5; +} \ No newline at end of file diff --git a/mace/third_party/gtest.BUILD b/mace/third_party/gtest.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..be8010ba780f6b12f88cdbecd3263793c006b5f6 --- /dev/null +++ b/mace/third_party/gtest.BUILD @@ -0,0 +1,19 @@ +# Description: +# Google test + +licenses(["notice"]) + +cc_library( + name = "gtest", + srcs = glob( + ["src/*.cc"], + exclude = ["src/gtest-all.cc"] + ), + hdrs = glob([ + "include/**/*.h", + "src/*.h" + ]), + copts = ["-Iexternal/gtest/include"], + linkopts = ["-pthread"], + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/mace/utils/BUILD b/mace/utils/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..29615ad093fd8cb50cc0f3c88427175bab4e3634 --- /dev/null +++ b/mace/utils/BUILD @@ -0,0 +1,18 @@ +# Description: +# Mace utils. +# +package( + default_visibility = ["//visibility:public"], +) + + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "utils", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + deps = [ + "//mace/proto:cc_proto", + ], +)