提交 578b382a 编写于 作者: 吴承辉

Merge branch 'style' into 'master'

Fix Google Style

See merge request !43
...@@ -7,13 +7,9 @@ ...@@ -7,13 +7,9 @@
namespace mace { namespace mace {
static std::unique_ptr<CPUAllocator> g_cpu_allocator(new CPUAllocator()); static std::unique_ptr<CPUAllocator> g_cpu_allocator(new CPUAllocator());
CPUAllocator* cpu_allocator() { CPUAllocator* cpu_allocator() { return g_cpu_allocator.get(); }
return g_cpu_allocator.get();
}
void SetCPUAllocator(CPUAllocator* alloc) { void SetCPUAllocator(CPUAllocator* alloc) { g_cpu_allocator.reset(alloc); }
g_cpu_allocator.reset(alloc);
}
Allocator* GetDeviceAllocator(DeviceType type) { Allocator* GetDeviceAllocator(DeviceType type) {
switch (type) { switch (type) {
...@@ -26,4 +22,4 @@ Allocator* GetDeviceAllocator(DeviceType type) { ...@@ -26,4 +22,4 @@ Allocator* GetDeviceAllocator(DeviceType type) {
return nullptr; return nullptr;
} }
} // namespace mace } // namespace mace
...@@ -39,7 +39,7 @@ class Allocator { ...@@ -39,7 +39,7 @@ class Allocator {
} }
}; };
class CPUAllocator: public Allocator { class CPUAllocator : public Allocator {
public: public:
~CPUAllocator() override {} ~CPUAllocator() override {}
void* New(size_t nbytes) override { void* New(size_t nbytes) override {
...@@ -55,9 +55,7 @@ class CPUAllocator: public Allocator { ...@@ -55,9 +55,7 @@ class CPUAllocator: public Allocator {
return data; return data;
} }
void Delete(void* data) override { void Delete(void* data) override { free(data); }
free(data);
}
void CopyBytes(void* dst, const void* src, size_t size) override { void CopyBytes(void* dst, const void* src, size_t size) override {
memcpy(dst, src, size); memcpy(dst, src, size);
...@@ -85,6 +83,6 @@ struct DeviceContext<DeviceType::NEON> { ...@@ -85,6 +83,6 @@ struct DeviceContext<DeviceType::NEON> {
Allocator* GetDeviceAllocator(DeviceType type); Allocator* GetDeviceAllocator(DeviceType type);
} // namespace mace } // namespace mace
#endif // MACE_CORE_ALLOCATOR_H_ #endif // MACE_CORE_ALLOCATOR_H_
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
#ifndef MACE_CORE_COMMON_H_ #ifndef MACE_CORE_COMMON_H_
#define MACE_CORE_COMMON_H_ #define MACE_CORE_COMMON_H_
#include <set> #include <algorithm>
#include <map> #include <map>
#include <string>
#include <memory> #include <memory>
#include <set>
#include <string>
#include <vector> #include <vector>
#include <algorithm>
#include "mace/core/logging.h" #include "mace/core/logging.h"
...@@ -24,9 +24,9 @@ typedef int64_t index_t; ...@@ -24,9 +24,9 @@ typedef int64_t index_t;
// Disable the copy and assignment operator for a class. // Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN #ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \ #define DISABLE_COPY_AND_ASSIGN(classname) \
private: \ private: \
classname(const classname&) = delete; \ classname(const classname&) = delete; \
classname& operator=(const classname&) = delete classname& operator=(const classname&) = delete
#endif #endif
...@@ -35,4 +35,4 @@ private: \ ...@@ -35,4 +35,4 @@ private: \
// TODO: need to fine tune this // TODO: need to fine tune this
#define kCostPerGroup 1024000000 #define kCostPerGroup 1024000000
#endif // MACE_CORE_COMMON_H_ #endif // MACE_CORE_COMMON_H_
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/core/logging.h" #include "mace/core/logging.h"
#include <stdlib.h> #include <stdlib.h>
...@@ -62,11 +61,11 @@ void LogMessage::GenerateLogMessage() { ...@@ -62,11 +61,11 @@ void LogMessage::GenerateLogMessage() {
#else #else
void LogMessage::GenerateLogMessage() { void LogMessage::GenerateLogMessage() {
fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], fname_, line_, str().c_str()); fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], fname_, line_,
str().c_str());
} }
#endif #endif
namespace { namespace {
// Parse log level (int64_t) from environment variable (char*) // Parse log level (int64_t) from environment variable (char*)
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#ifndef MACE_CORE_LOGGING_H_ #ifndef MACE_CORE_LOGGING_H_
#define MACE_CORE_LOGGING_H_ #define MACE_CORE_LOGGING_H_
#include <sstream>
#include <limits> #include <limits>
#include <sstream>
#include <string> #include <string>
#undef ERROR #undef ERROR
...@@ -30,8 +30,8 @@ inline void MakeStringInternal(std::stringstream& ss, const T& t) { ...@@ -30,8 +30,8 @@ inline void MakeStringInternal(std::stringstream& ss, const T& t) {
} }
template <typename T, typename... Args> template <typename T, typename... Args>
inline void inline void MakeStringInternal(std::stringstream& ss, const T& t,
MakeStringInternal(std::stringstream& ss, const T& t, const Args&... args) { const Args&... args) {
MakeStringInternal(ss, t); MakeStringInternal(ss, t);
MakeStringInternal(ss, args...); MakeStringInternal(ss, args...);
} }
...@@ -48,9 +48,7 @@ template <> ...@@ -48,9 +48,7 @@ template <>
inline string MakeString(const string& str) { inline string MakeString(const string& str) {
return str; return str;
} }
inline string MakeString(const char* c_str) { inline string MakeString(const char* c_str) { return string(c_str); }
return string(c_str);
}
class LogMessage : public std::basic_ostringstream<char> { class LogMessage : public std::basic_ostringstream<char> {
public: public:
...@@ -85,8 +83,7 @@ class LogMessageFatal : public LogMessage { ...@@ -85,8 +83,7 @@ class LogMessageFatal : public LogMessage {
::mace::internal::LogMessage(__FILE__, __LINE__, mace::WARNING) ::mace::internal::LogMessage(__FILE__, __LINE__, mace::WARNING)
#define _MACE_LOG_ERROR \ #define _MACE_LOG_ERROR \
::mace::internal::LogMessage(__FILE__, __LINE__, mace::ERROR) ::mace::internal::LogMessage(__FILE__, __LINE__, mace::ERROR)
#define _MACE_LOG_FATAL \ #define _MACE_LOG_FATAL ::mace::internal::LogMessageFatal(__FILE__, __LINE__)
::mace::internal::LogMessageFatal(__FILE__, __LINE__)
#define _MACE_LOG_QFATAL _MACE_LOG_FATAL #define _MACE_LOG_QFATAL _MACE_LOG_FATAL
...@@ -96,10 +93,10 @@ class LogMessageFatal : public LogMessage { ...@@ -96,10 +93,10 @@ class LogMessageFatal : public LogMessage {
// Turn VLOG off when under mobile devices for considerations of binary size. // Turn VLOG off when under mobile devices for considerations of binary size.
#define VLOG_IS_ON(lvl) ((lvl) <= 0) #define VLOG_IS_ON(lvl) ((lvl) <= 0)
#else #else
// Otherwise, Set MACE_CPP_MIN_VLOG_LEVEL environment to update minimum log level // Otherwise, Set MACE_CPP_MIN_VLOG_LEVEL environment to update minimum log
// level
// of VLOG // of VLOG
#define VLOG_IS_ON(lvl) \ #define VLOG_IS_ON(lvl) ((lvl) <= ::mace::internal::LogMessage::MinVLogLevel())
((lvl) <= ::mace::internal::LogMessage::MinVLogLevel())
#endif #endif
#define VLOG(lvl) \ #define VLOG(lvl) \
...@@ -113,16 +110,16 @@ class LogMessageFatal : public LogMessage { ...@@ -113,16 +110,16 @@ class LogMessageFatal : public LogMessage {
// MACE_CHECK(fp->Write(x) == 4) // MACE_CHECK(fp->Write(x) == 4)
// MACE_CHECK(fp->Write(x) == 4, "Write failed") // MACE_CHECK(fp->Write(x) == 4, "Write failed")
// which are not correct for MACE_ASSERT. // which are not correct for MACE_ASSERT.
#define MACE_CHECK(condition, ...) \ #define MACE_CHECK(condition, ...) \
if (!(condition)) \ if (!(condition)) \
LOG(FATAL) << "Check failed: " #condition " " \ LOG(FATAL) << "Check failed: " #condition " " \
<< ::mace::internal::MakeString(__VA_ARGS__) << ::mace::internal::MakeString(__VA_ARGS__)
#ifndef NDEBUG #ifndef NDEBUG
#define MACE_ASSERT(condition, ...) \ #define MACE_ASSERT(condition, ...) \
if (!(condition)) \ if (!(condition)) \
LOG(FATAL) << "Assert failed: " #condition " " \ LOG(FATAL) << "Assert failed: " #condition " " \
<< ::mace::internal::MakeString(__VA_ARGS__) << ::mace::internal::MakeString(__VA_ARGS__)
#else #else
#define MACE_ASSERT(condition, ...) ((void)0) #define MACE_ASSERT(condition, ...) ((void)0)
#endif #endif
...@@ -135,9 +132,9 @@ T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) { ...@@ -135,9 +132,9 @@ T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) {
return std::forward<T>(t); return std::forward<T>(t);
} }
#define MACE_CHECK_NOTNULL(val) \ #define MACE_CHECK_NOTNULL(val) \
::mace::internal::CheckNotNull(__FILE__, __LINE__, \ ::mace::internal::CheckNotNull(__FILE__, __LINE__, \
"'" #val "' Must be non NULL", (val)) "'" #val "' Must be non NULL", (val))
} // namespace internal } // namespace internal
} // namespace mace } // namespace mace
......
...@@ -17,5 +17,4 @@ ...@@ -17,5 +17,4 @@
#define MACE_PREDICT_TRUE(x) (x) #define MACE_PREDICT_TRUE(x) (x)
#endif #endif
#endif // MACE_CORE_MACROS_H_
#endif //MACE_CORE_MACROS_H_
...@@ -6,22 +6,19 @@ ...@@ -6,22 +6,19 @@
namespace mace { namespace mace {
NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def, NetBase::NetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws,
Workspace *ws,
DeviceType type) DeviceType type)
: name_(net_def->name()) { : name_(net_def->name()) {}
}
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def, SimpleNet::SimpleNet(const std::shared_ptr<const NetDef>& net_def,
Workspace *ws, Workspace* ws, DeviceType type)
DeviceType type) : NetBase(net_def, ws, type) { : NetBase(net_def, ws, type) {
VLOG(1) << "Constructing SimpleNet " << net_def->name(); VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto& operator_def = net_def->op(idx); const auto& operator_def = net_def->op(idx);
VLOG(1) << "Creating operator " << operator_def.name() << ":" VLOG(1) << "Creating operator " << operator_def.name() << ":"
<< operator_def.type(); << operator_def.type();
std::unique_ptr<OperatorBase> op {nullptr}; std::unique_ptr<OperatorBase> op{nullptr};
OperatorDef temp_def(operator_def); OperatorDef temp_def(operator_def);
op = CreateOperator(temp_def, ws, type); op = CreateOperator(temp_def, ws, type);
operators_.emplace_back(std::move(op)); operators_.emplace_back(std::move(op));
...@@ -40,20 +37,16 @@ bool SimpleNet::Run() { ...@@ -40,20 +37,16 @@ bool SimpleNet::Run() {
return true; return true;
} }
unique_ptr<NetBase> CreateNet(const NetDef& net_def, unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws,
Workspace* ws,
DeviceType type) { DeviceType type) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def)); std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws, type); return CreateNet(tmp_net_def, ws, type);
} }
unique_ptr<NetBase> CreateNet( unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef>& net_def,
const std::shared_ptr<const NetDef>& net_def, Workspace* ws, DeviceType type) {
Workspace* ws,
DeviceType type) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type)); unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type));
return net; return net;
} }
} // namespace mace
} // namespace mace
...@@ -6,35 +6,31 @@ ...@@ -6,35 +6,31 @@
#define MACE_CORE_NET_H_ #define MACE_CORE_NET_H_
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/core/workspace.h" #include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
class NetBase { class NetBase {
public: public:
NetBase(const std::shared_ptr<const NetDef> &net_def, NetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws,
Workspace* ws,
DeviceType type); DeviceType type);
virtual ~NetBase() noexcept {} virtual ~NetBase() noexcept {}
virtual bool Run() = 0; virtual bool Run() = 0;
const string &Name() const { const string& Name() const { return name_; }
return name_;
}
protected: protected:
string name_; string name_;
DISABLE_COPY_AND_ASSIGN(NetBase); DISABLE_COPY_AND_ASSIGN(NetBase);
}; };
class SimpleNet : public NetBase { class SimpleNet : public NetBase {
public: public:
SimpleNet(const std::shared_ptr<const NetDef>& net_def, SimpleNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws,
Workspace* ws,
DeviceType type); DeviceType type);
bool Run() override; bool Run() override;
...@@ -42,17 +38,14 @@ class SimpleNet : public NetBase { ...@@ -42,17 +38,14 @@ class SimpleNet : public NetBase {
protected: protected:
vector<unique_ptr<OperatorBase> > operators_; vector<unique_ptr<OperatorBase> > operators_;
DISABLE_COPY_AND_ASSIGN(SimpleNet); DISABLE_COPY_AND_ASSIGN(SimpleNet);
}; };
unique_ptr<NetBase> CreateNet(const NetDef& net_def, unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws,
Workspace* ws,
DeviceType type); DeviceType type);
unique_ptr<NetBase> CreateNet( unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef>& net_def,
const std::shared_ptr<const NetDef>& net_def, Workspace* ws, DeviceType type);
Workspace* ws,
DeviceType type);
} // namespace mace } // namespace mace
#endif // MACE_CORE_NET_H_ #endif // MACE_CORE_NET_H_
...@@ -11,33 +11,22 @@ std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() { ...@@ -11,33 +11,22 @@ std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
return &g_device_type_registry; return &g_device_type_registry;
} }
MACE_DEFINE_REGISTRY( MACE_DEFINE_REGISTRY(CPUOperatorRegistry, OperatorBase, const OperatorDef&,
CPUOperatorRegistry, Workspace*);
OperatorBase,
const OperatorDef&,
Workspace*);
MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry); MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry);
MACE_DEFINE_REGISTRY( MACE_DEFINE_REGISTRY(NEONOperatorRegistry, OperatorBase, const OperatorDef&,
NEONOperatorRegistry, Workspace*);
OperatorBase,
const OperatorDef&,
Workspace*);
MACE_REGISTER_DEVICE_TYPE(DeviceType::NEON, NEONOperatorRegistry); MACE_REGISTER_DEVICE_TYPE(DeviceType::NEON, NEONOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator( unique_ptr<OperatorBase> CreateOperator(const OperatorDef& operator_def,
const OperatorDef& operator_def, Workspace* ws, DeviceType type) {
Workspace* ws,
DeviceType type) {
OperatorRegistry* registry = gDeviceTypeRegistry()->at(type); OperatorRegistry* registry = gDeviceTypeRegistry()->at(type);
return registry->Create(operator_def.type(), operator_def, ws); return registry->Create(operator_def.type(), operator_def, ws);
} }
OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws)
: operator_ws_(ws), : operator_ws_(ws),
operator_def_(std::make_shared<OperatorDef>(operator_def)) { operator_def_(std::make_shared<OperatorDef>(operator_def)) {}
}
} // namespace mace } // namespace mace
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
#ifndef MACE_CORE_OPERATOR_H #ifndef MACE_CORE_OPERATOR_H
#define MACE_CORE_OPERATOR_H #define MACE_CORE_OPERATOR_H
#include "mace/core/proto_utils.h"
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/proto/mace.pb.h" #include "mace/core/proto_utils.h"
#include "mace/core/tensor.h"
#include "mace/core/registry.h" #include "mace/core/registry.h"
#include "mace/core/tensor.h"
#include "mace/core/workspace.h" #include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
...@@ -23,22 +23,21 @@ class OperatorBase { ...@@ -23,22 +23,21 @@ class OperatorBase {
MACE_CHECK(operator_def_, "operator_def was null!"); MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name); return ArgumentHelper::HasArgument(*operator_def_, name);
} }
template<typename T> template <typename T>
inline T GetSingleArgument(const string &name, const T &default_value) const { inline T GetSingleArgument(const string &name, const T &default_value) const {
MACE_CHECK(operator_def_, "operator_def was null!"); MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>( return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
*operator_def_, name, default_value); *operator_def_, name, default_value);
} }
template<typename T> template <typename T>
inline bool HasSingleArgumentOfType(const string &name) const { inline bool HasSingleArgumentOfType(const string &name) const {
MACE_CHECK(operator_def_, "operator_def was null!"); MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>( return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
*operator_def_, name); *operator_def_, name);
} }
template<typename T> template <typename T>
inline vector<T> GetRepeatedArgument( inline vector<T> GetRepeatedArgument(
const string &name, const string &name, const vector<T> &default_value = {}) const {
const vector<T> &default_value = {}) const {
MACE_CHECK(operator_def_, "operator_def was null!"); MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>( return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
*operator_def_, name, default_value); *operator_def_, name, default_value);
...@@ -49,9 +48,7 @@ class OperatorBase { ...@@ -49,9 +48,7 @@ class OperatorBase {
return inputs_[idx]; return inputs_[idx];
} }
inline Tensor *Output(int idx) { inline Tensor *Output(int idx) { return outputs_[idx]; }
return outputs_[idx];
}
inline int InputSize() { return inputs_.size(); } inline int InputSize() { return inputs_.size(); }
inline int OutputSize() { return outputs_.size(); } inline int OutputSize() { return outputs_.size(); }
...@@ -70,9 +67,7 @@ class OperatorBase { ...@@ -70,9 +67,7 @@ class OperatorBase {
operator_def_ = operator_def; operator_def_ = operator_def;
} }
inline bool has_debug_def() const { inline bool has_debug_def() const { return operator_def_ != nullptr; }
return operator_def_ != nullptr;
}
protected: protected:
Workspace *operator_ws_; Workspace *operator_ws_;
...@@ -80,7 +75,7 @@ class OperatorBase { ...@@ -80,7 +75,7 @@ class OperatorBase {
vector<const Tensor *> inputs_; vector<const Tensor *> inputs_;
vector<Tensor *> outputs_; vector<Tensor *> outputs_;
DISABLE_COPY_AND_ASSIGN(OperatorBase); DISABLE_COPY_AND_ASSIGN(OperatorBase);
}; };
template <DeviceType D, class T> template <DeviceType D, class T>
...@@ -90,26 +85,22 @@ class Operator : public OperatorBase { ...@@ -90,26 +85,22 @@ class Operator : public OperatorBase {
: OperatorBase(operator_def, ws) { : OperatorBase(operator_def, ws) {
for (const string &input_str : operator_def.input()) { for (const string &input_str : operator_def.input()) {
const Tensor *tensor = ws->GetTensor(input_str); const Tensor *tensor = ws->GetTensor(input_str);
MACE_CHECK( MACE_CHECK(tensor != nullptr, "op ", operator_def.type(),
tensor != nullptr, ": Encountered a non-existing input tensor: ", input_str);
"op ",
operator_def.type(),
": Encountered a non-existing input tensor: ",
input_str);
inputs_.push_back(tensor); inputs_.push_back(tensor);
} }
for (const string &output_str : operator_def.output()) { for (const string &output_str : operator_def.output()) {
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(output_str, outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
DeviceContext<D>::allocator(), output_str, DeviceContext<D>::allocator(), DataTypeToEnum<T>::v())));
DataTypeToEnum<T>::v())));
} }
} }
virtual bool Run() override = 0; virtual bool Run() override = 0;
~Operator() noexcept override {} ~Operator() noexcept override {}
}; };
// OP_INPUT_TAGS and OP_OUTPUT_TAGS are optional features to name the indices of the // OP_INPUT_TAGS and OP_OUTPUT_TAGS are optional features to name the indices of
// the
// operator's inputs and outputs, in order to avoid confusion. For example, for // operator's inputs and outputs, in order to avoid confusion. For example, for
// a fully convolution layer that has input, weight and bias, you can define its // a fully convolution layer that has input, weight and bias, you can define its
// input tags as: // input tags as:
...@@ -119,9 +110,9 @@ class Operator : public OperatorBase { ...@@ -119,9 +110,9 @@ class Operator : public OperatorBase {
// you can now do // you can now do
// auto& weight = Input(WEIGHT); // auto& weight = Input(WEIGHT);
// to make it more clear. // to make it more clear.
#define OP_INPUT_TAGS(first_input, ...) \ #define OP_INPUT_TAGS(first_input, ...) \
enum _InputTags { first_input = 0, __VA_ARGS__ } enum _InputTags { first_input = 0, __VA_ARGS__ }
#define OP_OUTPUT_TAGS(first_input, ...) \ #define OP_OUTPUT_TAGS(first_input, ...) \
enum _OutputTags { first_input = 0, __VA_ARGS__ } enum _OutputTags { first_input = 0, __VA_ARGS__ }
typedef Registry<std::string, OperatorBase, const OperatorDef &, Workspace *> typedef Registry<std::string, OperatorBase, const OperatorDef &, Workspace *>
...@@ -135,7 +126,7 @@ struct DeviceTypeRegisterer { ...@@ -135,7 +126,7 @@ struct DeviceTypeRegisterer {
if (gDeviceTypeRegistry()->count(type)) { if (gDeviceTypeRegistry()->count(type)) {
LOG(ERROR) << "Device type " << type LOG(ERROR) << "Device type " << type
<< "registered twice. This should not happen. Did you have " << "registered twice. This should not happen. Did you have "
"duplicated numbers assigned to different devices?"; "duplicated numbers assigned to different devices?";
std::exit(1); std::exit(1);
} }
// Calling the registry function to get the actual registry pointer. // Calling the registry function to get the actual registry pointer.
...@@ -143,39 +134,31 @@ struct DeviceTypeRegisterer { ...@@ -143,39 +134,31 @@ struct DeviceTypeRegisterer {
} }
}; };
#define MACE_REGISTER_DEVICE_TYPE(type, registry_function) \ #define MACE_REGISTER_DEVICE_TYPE(type, registry_function) \
namespace { \ namespace { \
static DeviceTypeRegisterer MACE_ANONYMOUS_VARIABLE( \ static DeviceTypeRegisterer MACE_ANONYMOUS_VARIABLE(DeviceType)( \
DeviceType)(type, &registry_function); \ type, &registry_function); \
} }
MACE_DECLARE_REGISTRY( MACE_DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase, const OperatorDef &,
CPUOperatorRegistry, Workspace *);
OperatorBase,
const OperatorDef&,
Workspace*);
#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
MACE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) MACE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_CPU_OPERATOR(name, ...) \ #define REGISTER_CPU_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
MACE_DECLARE_REGISTRY( MACE_DECLARE_REGISTRY(NEONOperatorRegistry, OperatorBase, const OperatorDef &,
NEONOperatorRegistry, Workspace *);
OperatorBase,
const OperatorDef&,
Workspace*);
#define REGISTER_NEON_OPERATOR_CREATOR(key, ...) \ #define REGISTER_NEON_OPERATOR_CREATOR(key, ...) \
MACE_REGISTER_CREATOR(NEONOperatorRegistry, key, __VA_ARGS__) MACE_REGISTER_CREATOR(NEONOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_NEON_OPERATOR(name, ...) \ #define REGISTER_NEON_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(NEONOperatorRegistry, name, __VA_ARGS__) MACE_REGISTER_CLASS(NEONOperatorRegistry, name, __VA_ARGS__)
unique_ptr<OperatorBase> CreateOperator( unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
const OperatorDef &operator_def, Workspace *ws, DeviceType type);
Workspace *ws,
DeviceType type);
} // namespace mace } // namespace mace
#endif //MACE_CORE_OPERATOR_H #endif // MACE_CORE_OPERATOR_H
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include "mace/core/proto_utils.h" #include "mace/core/proto_utils.h"
#include <fcntl.h> #include <fcntl.h>
#include <unistd.h>
#include <cerrno> #include <cerrno>
#include <fstream> #include <fstream>
#include <unistd.h>
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
...@@ -82,13 +82,12 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { ...@@ -82,13 +82,12 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
return proto->ParseFromCodedStream(&coded_stream); return proto->ParseFromCodedStream(&coded_stream);
} }
void WriteProtoToBinaryFile( void WriteProtoToBinaryFile(const MessageLite& /*proto*/,
const MessageLite& /*proto*/, const char* /*filename*/) {
const char* /*filename*/) {
LOG(FATAL) << "Not implemented yet."; LOG(FATAL) << "Not implemented yet.";
} }
#else // MACE_USE_LITE_PROTO #else // MACE_USE_LITE_PROTO
// Full protocol buffer. // Full protocol buffer.
...@@ -118,7 +117,7 @@ void WriteProtoToTextFile(const Message& proto, const char* filename) { ...@@ -118,7 +117,7 @@ void WriteProtoToTextFile(const Message& proto, const char* filename) {
} }
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
#if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified #if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int fd = open(filename, O_RDONLY | O_BINARY); int fd = open(filename, O_RDONLY | O_BINARY);
#else #else
int fd = open(filename, O_RDONLY); int fd = open(filename, O_RDONLY);
...@@ -138,8 +137,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { ...@@ -138,8 +137,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) { void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
MACE_CHECK( MACE_CHECK(fd != -1, "File cannot be created: ", filename, " error number: ",
fd != -1, "File cannot be created: ", filename, " error number: ", errno); errno);
std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd)); std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
std::unique_ptr<CodedOutputStream> coded_output( std::unique_ptr<CodedOutputStream> coded_output(
new CodedOutputStream(raw_output.get())); new CodedOutputStream(raw_output.get()));
...@@ -151,18 +150,17 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) { ...@@ -151,18 +150,17 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
#endif // MACE_USE_LITE_PROTO #endif // MACE_USE_LITE_PROTO
ArgumentHelper::ArgumentHelper(const OperatorDef &def) { ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
for (auto &arg : def.arg()) { for (auto& arg : def.arg()) {
if (arg_map_.find(arg.name()) != arg_map_.end()) { if (arg_map_.find(arg.name()) != arg_map_.end()) {
MACE_CHECK( MACE_CHECK(
arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(), arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(),
"Found argument of the same name '", "Found argument of the same name '", arg.name(),
arg.name(), "' but with different contents: ", ProtoDebugString(def));
"' but with different contents: ",
ProtoDebugString(def));
LOG(WARNING) << "Duplicated argument name found in operator def: " LOG(WARNING) << "Duplicated argument name found in operator def: "
<< ProtoDebugString(def) << ", arg: " << ProtoDebugString(arg); << ProtoDebugString(def)
<< ", arg: " << ProtoDebugString(arg);
} }
arg_map_[arg.name()] = arg; arg_map_[arg.name()] = arg;
...@@ -171,10 +169,9 @@ ArgumentHelper::ArgumentHelper(const OperatorDef &def) { ...@@ -171,10 +169,9 @@ ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
ArgumentHelper::ArgumentHelper(const NetDef& netdef) { ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
for (auto& arg : netdef.arg()) { for (auto& arg : netdef.arg()) {
MACE_CHECK( MACE_CHECK(arg_map_.count(arg.name()) == 0,
arg_map_.count(arg.name()) == 0, "Duplicated argument name found in net def: ",
"Duplicated argument name found in net def: ", ProtoDebugString(netdef));
ProtoDebugString(netdef));
arg_map_[arg.name()] = arg; arg_map_[arg.name()] = arg;
} }
} }
...@@ -192,32 +189,24 @@ bool SupportsLosslessConversion(const InputType& value) { ...@@ -192,32 +189,24 @@ bool SupportsLosslessConversion(const InputType& value) {
} }
} }
#define INSTANTIATE_GET_SINGLE_ARGUMENT( \ #define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \
T, fieldname, enforce_lossless_conversion) \ enforce_lossless_conversion) \
template <> \ template <> \
T ArgumentHelper::GetSingleArgument<T>( \ T ArgumentHelper::GetSingleArgument<T>(const string& name, \
const string& name, const T& default_value) const { \ const T& default_value) const { \
if (arg_map_.count(name) == 0) { \ if (arg_map_.count(name) == 0) { \
VLOG(1) << "Using default parameter value " << default_value \ VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \ << " for parameter " << name; \
return default_value; \ return default_value; \
} \ } \
MACE_CHECK( \ MACE_CHECK(arg_map_.at(name).has_##fieldname(), "Argument ", name, \
arg_map_.at(name).has_##fieldname(), \ " does not have the right field: expected field " #fieldname); \
"Argument ", \
name, \
" does not have the right field: expected field " #fieldname); \
auto value = arg_map_.at(name).fieldname(); \ auto value = arg_map_.at(name).fieldname(); \
if (enforce_lossless_conversion) { \ if (enforce_lossless_conversion) { \
auto supportsConversion = \ auto supportsConversion = \
SupportsLosslessConversion<decltype(value), T>(value); \ SupportsLosslessConversion<decltype(value), T>(value); \
MACE_CHECK( \ MACE_CHECK(supportsConversion, "Value", value, " of argument ", name, \
supportsConversion, \ "cannot be represented correctly in a target type"); \
"Value", \
value, \
" of argument ", \
name, \
"cannot be represented correctly in a target type"); \
} \ } \
return value; \ return value; \
} \ } \
...@@ -242,30 +231,25 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true) ...@@ -242,30 +231,25 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false) INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
#undef INSTANTIATE_GET_SINGLE_ARGUMENT #undef INSTANTIATE_GET_SINGLE_ARGUMENT
#define INSTANTIATE_GET_REPEATED_ARGUMENT( \ #define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname, \
T, fieldname, enforce_lossless_conversion) \ enforce_lossless_conversion) \
template <> \ template <> \
vector<T> ArgumentHelper::GetRepeatedArgument<T>( \ vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
const string& name, const std::vector<T>& default_value) const { \ const string& name, const std::vector<T>& default_value) const { \
if (arg_map_.count(name) == 0) { \ if (arg_map_.count(name) == 0) { \
return default_value; \ return default_value; \
} \ } \
vector<T> values; \ vector<T> values; \
for (const auto& v : arg_map_.at(name).fieldname()) { \ for (const auto& v : arg_map_.at(name).fieldname()) { \
if (enforce_lossless_conversion) { \ if (enforce_lossless_conversion) { \
auto supportsConversion = \ auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \ SupportsLosslessConversion<decltype(v), T>(v); \
MACE_CHECK( \ MACE_CHECK(supportsConversion, "Value", v, " of argument ", name, \
supportsConversion, \ "cannot be represented correctly in a target type"); \
"Value", \ } \
v, \ values.push_back(v); \
" of argument ", \ } \
name, \ return values; \
"cannot be represented correctly in a target type"); \
} \
values.push_back(v); \
} \
return values; \
} }
INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false) INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
...@@ -281,14 +265,14 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true) ...@@ -281,14 +265,14 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT #undef INSTANTIATE_GET_REPEATED_ARGUMENT
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ #define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <> \ template <> \
Argument MakeArgument(const string& name, const T& value) { \ Argument MakeArgument(const string& name, const T& value) { \
Argument arg; \ Argument arg; \
arg.set_name(name); \ arg.set_name(name); \
arg.set_##fieldname(value); \ arg.set_##fieldname(value); \
return arg; \ return arg; \
} }
MACE_MAKE_SINGULAR_ARGUMENT(bool, i) MACE_MAKE_SINGULAR_ARGUMENT(bool, i)
MACE_MAKE_SINGULAR_ARGUMENT(float, f) MACE_MAKE_SINGULAR_ARGUMENT(float, f)
...@@ -305,16 +289,16 @@ Argument MakeArgument(const string& name, const MessageLite& value) { ...@@ -305,16 +289,16 @@ Argument MakeArgument(const string& name, const MessageLite& value) {
return arg; return arg;
} }
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \ #define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <> \ template <> \
Argument MakeArgument(const string& name, const vector<T>& value) { \ Argument MakeArgument(const string& name, const vector<T>& value) { \
Argument arg; \ Argument arg; \
arg.set_name(name); \ arg.set_name(name); \
for (const auto& v : value) { \ for (const auto& v : value) { \
arg.add_##fieldname(v); \ arg.add_##fieldname(v); \
} \ } \
return arg; \ return arg; \
} }
MACE_MAKE_REPEATED_ARGUMENT(float, floats) MACE_MAKE_REPEATED_ARGUMENT(float, floats)
MACE_MAKE_REPEATED_ARGUMENT(int, ints) MACE_MAKE_REPEATED_ARGUMENT(int, ints)
...@@ -328,31 +312,24 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) { ...@@ -328,31 +312,24 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) {
return arg; return arg;
} }
} }
MACE_CHECK(false, MACE_CHECK(false, "Argument named ", name, "does not exist in operator ",
"Argument named ", ProtoDebugString(def));
name,
"does not exist in operator ",
ProtoDebugString(def));
} }
bool GetFlagArgument( bool GetFlagArgument(const OperatorDef& def, const string& name,
const OperatorDef& def, bool def_value) {
const string& name,
bool def_value) {
for (const Argument& arg : def.arg()) { for (const Argument& arg : def.arg()) {
if (arg.name() == name) { if (arg.name() == name) {
MACE_CHECK( MACE_CHECK(arg.has_i(), "Can't parse argument as bool: ",
arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg)); ProtoDebugString(arg));
return arg.i(); return arg.i();
} }
} }
return def_value; return def_value;
} }
Argument* GetMutableArgument( Argument* GetMutableArgument(const string& name, const bool create_if_missing,
const string& name, OperatorDef* def) {
const bool create_if_missing,
OperatorDef* def) {
for (int i = 0; i < def->arg_size(); ++i) { for (int i = 0; i < def->arg_size(); ++i) {
if (def->arg(i).name() == name) { if (def->arg(i).name() == name) {
return def->mutable_arg(i); return def->mutable_arg(i);
......
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#endif // !MACE_USE_LITE_PROTO #endif // !MACE_USE_LITE_PROTO
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
using std::string; using std::string;
using ::google::protobuf::MessageLite; using ::google::protobuf::MessageLite;
// Common interfaces that reads file contents into a string. // Common interfaces that reads file contents into a string.
bool ReadStringFromFile(const char* filename, string* str); bool ReadStringFromFile(const char* filename, string* str);
bool WriteStringToFile(const string& str, const char* filename); bool WriteStringToFile(const string& str, const char* filename);
...@@ -46,22 +45,20 @@ inline string ProtoDebugString(const MessageLite& proto) { ...@@ -46,22 +45,20 @@ inline string ProtoDebugString(const MessageLite& proto) {
// Text format MessageLite wrappers: these functions do nothing but just // Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using // allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support. // MessageLite but still want text support.
inline bool ReadProtoFromTextFile( inline bool ReadProtoFromTextFile(const char* /*filename*/,
const char* /*filename*/, MessageLite* /*proto*/) {
MessageLite* /*proto*/) {
LOG(FATAL) << "If you are running lite version, you should not be " LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers."; << "calling any text-format protobuffers.";
return false; // Just to suppress compiler warning. return false; // Just to suppress compiler warning.
} }
inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
return ReadProtoFromTextFile(filename.c_str(), proto); return ReadProtoFromTextFile(filename.c_str(), proto);
} }
inline void WriteProtoToTextFile( inline void WriteProtoToTextFile(const MessageLite& /*proto*/,
const MessageLite& /*proto*/, const char* /*filename*/) {
const char* /*filename*/) {
LOG(FATAL) << "If you are running lite version, you should not be " LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers."; << "calling any text-format protobuffers.";
} }
inline void WriteProtoToTextFile(const MessageLite& proto, inline void WriteProtoToTextFile(const MessageLite& proto,
const string& filename) { const string& filename) {
...@@ -107,16 +104,13 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) { ...@@ -107,16 +104,13 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) {
#endif // MACE_USE_LITE_PROTO #endif // MACE_USE_LITE_PROTO
template < template <class IterableInputs = std::initializer_list<string>,
class IterableInputs = std::initializer_list<string>, class IterableOutputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>, class IterableArgs = std::initializer_list<Argument>>
class IterableArgs = std::initializer_list<Argument>> OperatorDef CreateOperatorDef(const string& type, const string& name,
OperatorDef CreateOperatorDef( const IterableInputs& inputs,
const string& type, const IterableOutputs& outputs,
const string& name, const IterableArgs& args) {
const IterableInputs& inputs,
const IterableOutputs& outputs,
const IterableArgs& args) {
OperatorDef def; OperatorDef def;
def.set_type(type); def.set_type(type);
def.set_name(name); def.set_name(name);
...@@ -134,20 +128,13 @@ OperatorDef CreateOperatorDef( ...@@ -134,20 +128,13 @@ OperatorDef CreateOperatorDef(
// A simplified version compared to the full CreateOperator, if you do not need // A simplified version compared to the full CreateOperator, if you do not need
// to specify args. // to specify args.
template < template <class IterableInputs = std::initializer_list<string>,
class IterableInputs = std::initializer_list<string>, class IterableOutputs = std::initializer_list<string>>
class IterableOutputs = std::initializer_list<string>> inline OperatorDef CreateOperatorDef(const string& type, const string& name,
inline OperatorDef CreateOperatorDef( const IterableInputs& inputs,
const string& type, const IterableOutputs& outputs) {
const string& name, return CreateOperatorDef(type, name, inputs, outputs,
const IterableInputs& inputs, std::vector<Argument>());
const IterableOutputs& outputs) {
return CreateOperatorDef(
type,
name,
inputs,
outputs,
std::vector<Argument>());
} }
/** /**
...@@ -166,10 +153,8 @@ class ArgumentHelper { ...@@ -166,10 +153,8 @@ class ArgumentHelper {
} }
template <typename Def, typename T> template <typename Def, typename T>
static T GetSingleArgument( static T GetSingleArgument(const Def& def, const string& name,
const Def& def, const T& default_value) {
const string& name,
const T& default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value); return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
} }
...@@ -180,8 +165,7 @@ class ArgumentHelper { ...@@ -180,8 +165,7 @@ class ArgumentHelper {
template <typename Def, typename T> template <typename Def, typename T>
static vector<T> GetRepeatedArgument( static vector<T> GetRepeatedArgument(
const Def& def, const Def& def, const string& name,
const string& name,
const std::vector<T>& default_value = std::vector<T>()) { const std::vector<T>& default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value); return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
} }
...@@ -192,9 +176,8 @@ class ArgumentHelper { ...@@ -192,9 +176,8 @@ class ArgumentHelper {
} }
template <typename Def, typename MessageType> template <typename Def, typename MessageType>
static vector<MessageType> GetRepeatedMessageArgument( static vector<MessageType> GetRepeatedMessageArgument(const Def& def,
const Def& def, const string& name) {
const string& name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name); return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
} }
...@@ -216,9 +199,8 @@ class ArgumentHelper { ...@@ -216,9 +199,8 @@ class ArgumentHelper {
MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name);
MessageType message; MessageType message;
if (arg_map_.at(name).has_s()) { if (arg_map_.at(name).has_s()) {
MACE_CHECK( MACE_CHECK(message.ParseFromString(arg_map_.at(name).s()),
message.ParseFromString(arg_map_.at(name).s()), "Faild to parse content from the string");
"Faild to parse content from the string");
} else { } else {
VLOG(1) << "Return empty message for parameter " << name; VLOG(1) << "Return empty message for parameter " << name;
} }
...@@ -230,9 +212,8 @@ class ArgumentHelper { ...@@ -230,9 +212,8 @@ class ArgumentHelper {
MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name);
vector<MessageType> messages(arg_map_.at(name).strings_size()); vector<MessageType> messages(arg_map_.at(name).strings_size());
for (int i = 0; i < messages.size(); ++i) { for (int i = 0; i < messages.size(); ++i) {
MACE_CHECK( MACE_CHECK(messages[i].ParseFromString(arg_map_.at(name).strings(i)),
messages[i].ParseFromString(arg_map_.at(name).strings(i)), "Faild to parse content from the string");
"Faild to parse content from the string");
} }
return messages; return messages;
} }
...@@ -242,15 +223,11 @@ class ArgumentHelper { ...@@ -242,15 +223,11 @@ class ArgumentHelper {
}; };
const Argument& GetArgument(const OperatorDef& def, const string& name); const Argument& GetArgument(const OperatorDef& def, const string& name);
bool GetFlagArgument( bool GetFlagArgument(const OperatorDef& def, const string& name,
const OperatorDef& def, bool def_value = false);
const string& name,
bool def_value = false); Argument* GetMutableArgument(const string& name, const bool create_if_missing,
OperatorDef* def);
Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
OperatorDef* def);
template <typename T> template <typename T>
Argument MakeArgument(const string& name, const T& value); Argument MakeArgument(const string& name, const T& value);
......
...@@ -12,7 +12,7 @@ namespace mace { ...@@ -12,7 +12,7 @@ namespace mace {
template <class SrcType, class ObjectType, class... Args> template <class SrcType, class ObjectType, class... Args>
class Registry { class Registry {
public: public:
typedef std::function<std::unique_ptr<ObjectType> (Args ...)> Creator; typedef std::function<std::unique_ptr<ObjectType>(Args...)> Creator;
Registry() : registry_() {} Registry() : registry_() {}
...@@ -24,7 +24,7 @@ class Registry { ...@@ -24,7 +24,7 @@ class Registry {
inline bool Has(const SrcType& key) { return registry_.count(key) != 0; } inline bool Has(const SrcType& key) { return registry_.count(key) != 0; }
unique_ptr<ObjectType> Create(const SrcType& key, Args ... args) { unique_ptr<ObjectType> Create(const SrcType& key, Args... args) {
if (registry_.count(key) == 0) { if (registry_.count(key) == 0) {
VLOG(2) << "Key not registered: " << key; VLOG(2) << "Key not registered: " << key;
return nullptr; return nullptr;
...@@ -60,7 +60,7 @@ class Registerer { ...@@ -60,7 +60,7 @@ class Registerer {
} }
template <class DerivedType> template <class DerivedType>
static unique_ptr<ObjectType> DefaultCreator(Args ... args) { static unique_ptr<ObjectType> DefaultCreator(Args... args) {
return std::unique_ptr<ObjectType>(new DerivedType(args...)); return std::unique_ptr<ObjectType>(new DerivedType(args...));
} }
}; };
...@@ -74,36 +74,35 @@ class Registerer { ...@@ -74,36 +74,35 @@ class Registerer {
#endif #endif
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \ #define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName(); \ Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName(); \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \ typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName; Registerer##RegistryName;
#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \ #define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName() { \ Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType, ObjectType, ##__VA_ARGS__>* registry = \ static Registry<SrcType, ObjectType, ##__VA_ARGS__>* registry = \
new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \ new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \
return registry; \ return registry; \
} }
#define MACE_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ #define MACE_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
MACE_DECLARE_TYPED_REGISTRY( \ MACE_DECLARE_TYPED_REGISTRY(RegistryName, std::string, ObjectType, \
RegistryName, std::string, ObjectType, ##__VA_ARGS__) ##__VA_ARGS__)
#define MACE_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ #define MACE_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
MACE_DEFINE_TYPED_REGISTRY( \ MACE_DEFINE_TYPED_REGISTRY(RegistryName, std::string, ObjectType, \
RegistryName, std::string, ObjectType, ##__VA_ARGS__) ##__VA_ARGS__)
#define MACE_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ #define MACE_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
namespace { \ namespace { \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \ static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, RegistryName(), __VA_ARGS__); key, RegistryName(), __VA_ARGS__);
#define MACE_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ #define MACE_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
namespace { \ namespace { \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \ static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \ key, RegistryName(), \
RegistryName(), \ Registerer##RegistryName::DefaultCreator<__VA_ARGS__>); \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>); \
} }
#define MACE_REGISTER_CREATOR(RegistryName, key, ...) \ #define MACE_REGISTER_CREATOR(RegistryName, key, ...) \
...@@ -112,6 +111,6 @@ class Registerer { ...@@ -112,6 +111,6 @@ class Registerer {
#define MACE_REGISTER_CLASS(RegistryName, key, ...) \ #define MACE_REGISTER_CLASS(RegistryName, key, ...) \
MACE_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) MACE_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
} // namespace mace } // namespace mace
#endif // MACE_CORE_REGISTRY_H_ #endif // MACE_CORE_REGISTRY_H_
...@@ -4,19 +4,18 @@ ...@@ -4,19 +4,18 @@
#include "mace/core/serializer.h" #include "mace/core/serializer.h"
namespace mace { namespace mace {
unique_ptr<TensorProto> Serializer::Serialize(const Tensor &tensor, unique_ptr<TensorProto> Serializer::Serialize(const Tensor &tensor,
const string &name) { const string &name) {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
return nullptr; return nullptr;
} }
unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto, unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
DeviceType type) { DeviceType type) {
unique_ptr<Tensor> tensor(new Tensor(GetDeviceAllocator(type), unique_ptr<Tensor> tensor(
proto.data_type())); new Tensor(GetDeviceAllocator(type), proto.data_type()));
vector<index_t> dims; vector<index_t> dims;
for (const index_t d : proto.dims()) { for (const index_t d : proto.dims()) {
dims.push_back(d); dims.push_back(d);
...@@ -25,8 +24,7 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto, ...@@ -25,8 +24,7 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
switch (proto.data_type()) { switch (proto.data_type()) {
case DT_FLOAT: case DT_FLOAT:
tensor->Copy<float>(proto.float_data().data(), tensor->Copy<float>(proto.float_data().data(), proto.float_data().size());
proto.float_data().size());
break; break;
case DT_DOUBLE: case DT_DOUBLE:
tensor->Copy<double>(proto.double_data().data(), tensor->Copy<double>(proto.double_data().data(),
...@@ -34,39 +32,38 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto, ...@@ -34,39 +32,38 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
break; break;
case DT_INT32: case DT_INT32:
tensor->template Copy<int32_t>(proto.int32_data().data(), tensor->template Copy<int32_t>(proto.int32_data().data(),
proto.int32_data().size()); proto.int32_data().size());
break; break;
case DT_UINT8: case DT_UINT8:
tensor->CopyWithCast<int32_t, uint8_t>(proto.int32_data().data(), tensor->CopyWithCast<int32_t, uint8_t>(proto.int32_data().data(),
proto.int32_data().size()); proto.int32_data().size());
break; break;
case DT_INT16: case DT_INT16:
tensor->CopyWithCast<int32_t, int16_t>(proto.int32_data().data(), tensor->CopyWithCast<int32_t, int16_t>(proto.int32_data().data(),
proto.int32_data().size()); proto.int32_data().size());
break; break;
case DT_INT8: case DT_INT8:
tensor->CopyWithCast<int32_t, int8_t>(proto.int32_data().data(), tensor->CopyWithCast<int32_t, int8_t>(proto.int32_data().data(),
proto.int32_data().size()); proto.int32_data().size());
break; break;
case DT_INT64: case DT_INT64:
tensor->Copy<int64_t>(proto.int64_data().data(), tensor->Copy<int64_t>(proto.int64_data().data(),
proto.int64_data().size()); proto.int64_data().size());
break; break;
case DT_UINT16: case DT_UINT16:
tensor->CopyWithCast<int32_t, uint16_t>(proto.int32_data().data(), tensor->CopyWithCast<int32_t, uint16_t>(proto.int32_data().data(),
proto.int32_data().size()); proto.int32_data().size());
break; break;
case DT_BOOL: case DT_BOOL:
tensor->CopyWithCast<int32_t, bool>(proto.int32_data().data(), tensor->CopyWithCast<int32_t, bool>(proto.int32_data().data(),
proto.int32_data().size()); proto.int32_data().size());
break; break;
case DT_STRING: { case DT_STRING: {
string *content = tensor->mutable_data<string>(); string *content = tensor->mutable_data<string>();
for (int i = 0; i < proto.string_data().size(); ++i) { for (int i = 0; i < proto.string_data().size(); ++i) {
content[i] = proto.string_data(i); content[i] = proto.string_data(i);
} }
} } break;
break;
default: default:
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
break; break;
...@@ -75,4 +72,4 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto, ...@@ -75,4 +72,4 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
return tensor; return tensor;
} }
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#ifndef MACE_CORE_SERIALIZER_H_ #ifndef MACE_CORE_SERIALIZER_H_
#define MACE_CORE_SERIALIZER_H_ #define MACE_CORE_SERIALIZER_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
...@@ -20,9 +20,9 @@ class Serializer { ...@@ -20,9 +20,9 @@ class Serializer {
unique_ptr<Tensor> Deserialize(const TensorProto& proto, DeviceType type); unique_ptr<Tensor> Deserialize(const TensorProto& proto, DeviceType type);
DISABLE_COPY_AND_ASSIGN(Serializer); DISABLE_COPY_AND_ASSIGN(Serializer);
}; };
} // namespace mace } // namespace mace
#endif // MACE_CORE_SERIALIZER_H_ #endif // MACE_CORE_SERIALIZER_H_
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
#ifndef MACE_CORE_TENSOR_H_ #ifndef MACE_CORE_TENSOR_H_
#define MACE_CORE_TENSOR_H_ #define MACE_CORE_TENSOR_H_
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/allocator.h" #include "mace/core/allocator.h"
#include "mace/core/types.h" #include "mace/core/common.h"
#include "mace/core/logging.h" #include "mace/core/logging.h"
#include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
...@@ -25,13 +25,13 @@ namespace mace { ...@@ -25,13 +25,13 @@ namespace mace {
switch (TYPE_ENUM) { \ switch (TYPE_ENUM) { \
CASE(float, SINGLE_ARG(STMTS)) \ CASE(float, SINGLE_ARG(STMTS)) \
CASE(double, SINGLE_ARG(STMTS)) \ CASE(double, SINGLE_ARG(STMTS)) \
CASE(int32_t, SINGLE_ARG(STMTS)) \ CASE(int32_t, SINGLE_ARG(STMTS)) \
CASE(uint8_t, SINGLE_ARG(STMTS)) \ CASE(uint8_t, SINGLE_ARG(STMTS)) \
CASE(uint16_t, SINGLE_ARG(STMTS)) \ CASE(uint16_t, SINGLE_ARG(STMTS)) \
CASE(int16_t, SINGLE_ARG(STMTS)) \ CASE(int16_t, SINGLE_ARG(STMTS)) \
CASE(int8_t, SINGLE_ARG(STMTS)) \ CASE(int8_t, SINGLE_ARG(STMTS)) \
CASE(string, SINGLE_ARG(STMTS)) \ CASE(string, SINGLE_ARG(STMTS)) \
CASE(int64_t, SINGLE_ARG(STMTS)) \ CASE(int64_t, SINGLE_ARG(STMTS)) \
CASE(bool, SINGLE_ARG(STMTS)) \ CASE(bool, SINGLE_ARG(STMTS)) \
case DT_INVALID: \ case DT_INVALID: \
INVALID; \ INVALID; \
...@@ -41,20 +41,17 @@ namespace mace { ...@@ -41,20 +41,17 @@ namespace mace {
break; \ break; \
} }
#define CASES(TYPE_ENUM, STMTS) \ #define CASES(TYPE_ENUM, STMTS) \
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \ CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;) , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
class Tensor { class Tensor {
public: public:
Tensor() Tensor()
: alloc_(cpu_allocator()), : alloc_(cpu_allocator()), size_(0), dtype_(DT_FLOAT), data_(nullptr){};
size_(0), dtype_(DT_FLOAT), data_(nullptr) {};
Tensor(Allocator* a, DataType type) Tensor(Allocator* a, DataType type)
: alloc_(a), size_(0), dtype_(type), data_(nullptr) {}; : alloc_(a), size_(0), dtype_(type), data_(nullptr){};
~Tensor() { ~Tensor() {
if (alloc_ && data_.get()) { if (alloc_ && data_.get()) {
...@@ -92,9 +89,8 @@ class Tensor { ...@@ -92,9 +89,8 @@ class Tensor {
if (data_.get() || size_ == 0) { if (data_.get() || size_ == 0) {
return data_.get(); return data_.get();
} else { } else {
CASES(dtype_, data_.reset(alloc_->New(size_ * sizeof(T)), [this](void* ptr) { CASES(dtype_, data_.reset(alloc_->New(size_ * sizeof(T)),
alloc_->Delete(ptr); [this](void* ptr) { alloc_->Delete(ptr); }));
}));
return data_.get(); return data_.get();
} }
} }
...@@ -116,13 +112,9 @@ class Tensor { ...@@ -116,13 +112,9 @@ class Tensor {
} }
} }
inline void ResizeLike(const Tensor& other) { inline void ResizeLike(const Tensor& other) { Resize(other.shape()); }
Resize(other.shape());
}
inline void ResizeLike(const Tensor* other) { inline void ResizeLike(const Tensor* other) { Resize(other->shape()); }
Resize(other->shape());
}
template <typename T> template <typename T>
inline void Copy(const T* src, index_t size) { inline void Copy(const T* src, index_t size) {
...@@ -132,7 +124,8 @@ class Tensor { ...@@ -132,7 +124,8 @@ class Tensor {
template <typename SrcType, typename DstType> template <typename SrcType, typename DstType>
inline void CopyWithCast(const SrcType* src, size_t size) { inline void CopyWithCast(const SrcType* src, size_t size) {
MACE_CHECK(static_cast<index_t>(size) == size_, "copy src and dst with different size."); MACE_CHECK(static_cast<index_t>(size) == size_,
"copy src and dst with different size.");
unique_ptr<DstType[]> buffer(new DstType[size]); unique_ptr<DstType[]> buffer(new DstType[size]);
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
buffer[i] = static_cast<DstType>(src[i]); buffer[i] = static_cast<DstType>(src[i]);
...@@ -146,10 +139,11 @@ class Tensor { ...@@ -146,10 +139,11 @@ class Tensor {
inline void DebugPrint() { inline void DebugPrint() {
std::stringstream os; std::stringstream os;
for (int i: shape_) { for (int i : shape_) {
os << i << ", "; os << i << ", ";
} }
LOG(INFO) << "Tensor shape: " << os.str() << " type: " << DataType_Name(dtype_); LOG(INFO) << "Tensor shape: " << os.str()
<< " type: " << DataType_Name(dtype_);
os.str(""); os.str("");
os.clear(); os.clear();
...@@ -175,7 +169,8 @@ class Tensor { ...@@ -175,7 +169,8 @@ class Tensor {
private: private:
inline int64_t NumElements() const { inline int64_t NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64_t>()); return std::accumulate(shape_.begin(), shape_.end(), 1,
std::multiplies<int64_t>());
} }
Allocator* alloc_; Allocator* alloc_;
...@@ -184,9 +179,9 @@ class Tensor { ...@@ -184,9 +179,9 @@ class Tensor {
std::shared_ptr<void> data_; std::shared_ptr<void> data_;
vector<index_t> shape_; vector<index_t> shape_;
DISABLE_COPY_AND_ASSIGN(Tensor); DISABLE_COPY_AND_ASSIGN(Tensor);
}; };
} // namespace tensor } // namespace tensor
#endif //MACE_CORE_TENSOR_H_ #endif // MACE_CORE_TENSOR_H_
...@@ -51,11 +51,8 @@ Benchmark* Benchmark::ArgPair(int x, int y) { ...@@ -51,11 +51,8 @@ Benchmark* Benchmark::ArgPair(int x, int y) {
return this; return this;
} }
// Run all benchmarks // Run all benchmarks
void Benchmark::Run() { void Benchmark::Run() { Run("all"); }
Run("all");
}
void Benchmark::Run(const char* pattern) { void Benchmark::Run(const char* pattern) {
if (!all_benchmarks) return; if (!all_benchmarks) return;
...@@ -113,8 +110,8 @@ void Benchmark::Run(const char* pattern) { ...@@ -113,8 +110,8 @@ void Benchmark::Run(const char* pattern) {
(items_processed * 1e-6) / seconds); (items_processed * 1e-6) / seconds);
full_label += buf; full_label += buf;
} }
printf("%-*s %10.0f %10d\t%s\n", width, name, printf("%-*s %10.0f %10d\t%s\n", width, name, seconds * 1e9 / iters,
seconds * 1e9 / iters, iters, full_label.c_str()); iters, full_label.c_str());
} }
} }
} }
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
#include "mace/core/types.h" #include "mace/core/types.h"
#define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c #define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c
#define BENCHMARK(n) \ #define BENCHMARK(n) \
static ::mace::testing::Benchmark* MACE_BENCHMARK_CONCAT(__benchmark_, n, __LINE__) = \ static ::mace::testing::Benchmark* MACE_BENCHMARK_CONCAT( \
(new ::mace::testing::Benchmark(#n, (n))) __benchmark_, n, __LINE__) = (new ::mace::testing::Benchmark(#n, (n)))
namespace mace { namespace mace {
namespace testing { namespace testing {
......
...@@ -17,4 +17,3 @@ int main(int argc, char** argv) { ...@@ -17,4 +17,3 @@ int main(int argc, char** argv) {
} }
return 0; return 0;
} }
...@@ -18,26 +18,25 @@ struct DataTypeToEnum { ...@@ -18,26 +18,25 @@ struct DataTypeToEnum {
static_assert(IsValidDataType<T>::value, "Specified Data Type not supported"); static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
}; };
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g. // EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
// EnumToDataType<DT_FLOAT>::Type is float. // EnumToDataType<DT_FLOAT>::Type is float.
template <DataType VALUE> template <DataType VALUE>
struct EnumToDataType {}; // Specializations below struct EnumToDataType {}; // Specializations below
// Template specialization for both DataTypeToEnum and EnumToDataType. // Template specialization for both DataTypeToEnum and EnumToDataType.
#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ #define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \
template <> \ template <> \
struct DataTypeToEnum<TYPE> { \ struct DataTypeToEnum<TYPE> { \
static DataType v() { return ENUM; } \ static DataType v() { return ENUM; } \
static constexpr DataType value = ENUM; \ static constexpr DataType value = ENUM; \
}; \ }; \
template <> \ template <> \
struct IsValidDataType<TYPE> { \ struct IsValidDataType<TYPE> { \
static constexpr bool value = true; \ static constexpr bool value = true; \
}; \ }; \
template <> \ template <> \
struct EnumToDataType<ENUM> { \ struct EnumToDataType<ENUM> { \
typedef TYPE Type; \ typedef TYPE Type; \
} }
MATCH_TYPE_AND_ENUM(float, DT_FLOAT); MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
...@@ -53,6 +52,6 @@ MATCH_TYPE_AND_ENUM(bool, DT_BOOL); ...@@ -53,6 +52,6 @@ MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
static const int32_t kint32_tmax = ((int32_t)0x7FFFFFFF); static const int32_t kint32_tmax = ((int32_t)0x7FFFFFFF);
} // namespace mace } // namespace mace
#endif // MACE_CORE_TYPES_H_ #endif // MACE_CORE_TYPES_H_
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/core/common.h"
#include "mace/core/workspace.h" #include "mace/core/workspace.h"
#include "mace/core/common.h"
#include "mace/core/serializer.h" #include "mace/core/serializer.h"
namespace mace { namespace mace {
...@@ -16,8 +16,7 @@ vector<string> Workspace::Tensors() const { ...@@ -16,8 +16,7 @@ vector<string> Workspace::Tensors() const {
return names; return names;
} }
Tensor* Workspace::CreateTensor(const string& name, Tensor* Workspace::CreateTensor(const string& name, Allocator* alloc,
Allocator* alloc,
DataType type) { DataType type) {
if (HasTensor(name)) { if (HasTensor(name)) {
VLOG(1) << "Tensor " << name << " already exists. Skipping."; VLOG(1) << "Tensor " << name << " already exists. Skipping.";
...@@ -46,14 +45,16 @@ const Tensor* Workspace::GetTensor(const string& name) const { ...@@ -46,14 +45,16 @@ const Tensor* Workspace::GetTensor(const string& name) const {
} }
Tensor* Workspace::GetTensor(const string& name) { Tensor* Workspace::GetTensor(const string& name) {
return const_cast<Tensor*>(static_cast<const Workspace*>(this)->GetTensor(name)); return const_cast<Tensor*>(
static_cast<const Workspace*>(this)->GetTensor(name));
} }
void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) { void Workspace::LoadModelTensor(const NetDef& net_def, DeviceType type) {
Serializer serializer; Serializer serializer;
for (auto& tensor_proto: net_def.tensors()) { for (auto& tensor_proto : net_def.tensors()) {
tensor_map_[tensor_proto.name()] = serializer.Deserialize(tensor_proto, type); tensor_map_[tensor_proto.name()] =
serializer.Deserialize(tensor_proto, type);
} }
} }
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#ifndef MACE_CORE_WORKSPACE_H_ #ifndef MACE_CORE_WORKSPACE_H_
#define MACE_CORE_WORKSPACE_H_ #define MACE_CORE_WORKSPACE_H_
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h" #include "mace/proto/mace.pb.h"
...@@ -37,10 +36,9 @@ class Workspace { ...@@ -37,10 +36,9 @@ class Workspace {
private: private:
TensorMap tensor_map_; TensorMap tensor_map_;
DISABLE_COPY_AND_ASSIGN(Workspace); DISABLE_COPY_AND_ASSIGN(Workspace);
}; };
} // namespace mace } // namespace mace
#endif // MACE_CORE_WORKSPACE_H_ #endif // MACE_CORE_WORKSPACE_H_
...@@ -14,7 +14,7 @@ static void foo(int iters) { ...@@ -14,7 +14,7 @@ static void foo(int iters) {
float* out = new float[N]; float* out = new float[N];
while (iters--) { while (iters--) {
for (int i=0; i < N; i++) { for (int i = 0; i < N; i++) {
out[i] = inp[i] * 2.0; out[i] = inp[i] * 2.0;
} }
} }
...@@ -24,7 +24,6 @@ static void foo(int iters) { ...@@ -24,7 +24,6 @@ static void foo(int iters) {
BENCHMARK(foo); BENCHMARK(foo);
static void bar(int iters, int n) { static void bar(int iters, int n) {
const int64_t tot = static_cast<int64_t>(iters) * n; const int64_t tot = static_cast<int64_t>(iters) * n;
mace::testing::ItemsProcessed(tot); mace::testing::ItemsProcessed(tot);
...@@ -34,7 +33,7 @@ static void bar(int iters, int n) { ...@@ -34,7 +33,7 @@ static void bar(int iters, int n) {
float* out = new float[n]; float* out = new float[n];
while (iters--) { while (iters--) {
for (int i=0; i < n; i++) { for (int i = 0; i < n; i++) {
out[i] = inp[i] * 2.0; out[i] = inp[i] * 2.0;
} }
} }
......
...@@ -10,10 +10,9 @@ ...@@ -10,10 +10,9 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
struct AddNFunctor { struct AddNFunctor {
void operator()(const vector<const T*>& inputs, void operator()(const vector<const T*>& inputs, T* output, index_t size) {
T *output, index_t size) {
memset(output, 0, size * sizeof(T)); memset(output, 0, size * sizeof(T));
int n = inputs.size(); int n = inputs.size();
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -25,11 +24,10 @@ struct AddNFunctor { ...@@ -25,11 +24,10 @@ struct AddNFunctor {
}; };
template <> template <>
void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>& inputs, void AddNFunctor<DeviceType::NEON, float>::operator()(
float *output, const vector<const float*>& inputs, float* output, index_t size);
index_t size);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_ADDN_H_ #endif // MACE_KERNELS_ADDN_H_
\ No newline at end of file \ No newline at end of file
...@@ -11,26 +11,21 @@ ...@@ -11,26 +11,21 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
struct BatchNormFunctor { struct BatchNormFunctor {
float variance_epsilon_; float variance_epsilon_;
BatchNormFunctor(const float variance_epsilon) BatchNormFunctor(const float variance_epsilon)
: variance_epsilon_(variance_epsilon){} : variance_epsilon_(variance_epsilon) {}
void operator()(const T* input, void operator()(const T* input, const T* scale, const T* offset,
const T* scale, const T* mean, const T* var, const index_t n,
const T* offset, const index_t channel, const index_t sample_size, T* output) {
const T* mean,
const T* var,
const index_t n,
const index_t channel,
const index_t sample_size,
T* output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is // The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} } // ( \offset - \frac { \scale * mean } {
// \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val; // new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
...@@ -53,18 +48,12 @@ struct BatchNormFunctor { ...@@ -53,18 +48,12 @@ struct BatchNormFunctor {
}; };
template <> template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input, void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* scale, const float* input, const float* scale, const float* offset,
const float* offset, const float* mean, const float* var, const index_t n, const index_t channel,
const float* mean, const index_t sample_size, float* output);
const float* var,
const index_t n,
const index_t channel,
const index_t sample_size,
float* output);
} // namepsace kernels
} // namespace mace
} // namepsace kernels #endif // MACE_KERNELS_BATCH_NORM_H_
} // namespace mace
#endif // MACE_KERNELS_BATCH_NORM_H_
...@@ -10,114 +10,103 @@ ...@@ -10,114 +10,103 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class Conv2dFunctor { class Conv2dFunctor {
public: public:
Conv2dFunctor(const int* strides, Conv2dFunctor(const int* strides, const int* paddings, const int* dilations)
const int* paddings, : strides_(strides), paddings_(paddings), dilations_(dilations) {}
const int* dilations) :
strides_(strides), void operator()(const T* input, // NCHW
paddings_(paddings), const index_t* input_shape,
dilations_(dilations) {} const T* filter, // c_out, c_in, kernel_h, kernel_w
const index_t* filter_shape,
void operator()(const T* input, // NCHW const T* bias, // c_out
const index_t* input_shape, T* output, // NCHW
const T* filter, // c_out, c_in, kernel_h, kernel_w const index_t* output_shape) {
const index_t* filter_shape, MACE_CHECK_NOTNULL(output);
const T* bias, // c_out
T* output, // NCHW index_t batch = output_shape[0];
const index_t* output_shape) { index_t channels = output_shape[1];
MACE_CHECK_NOTNULL(output); index_t height = output_shape[2];
index_t width = output_shape[3];
index_t batch = output_shape[0];
index_t channels = output_shape[1]; index_t input_batch = input_shape[0];
index_t height = output_shape[2]; index_t input_channels = input_shape[1];
index_t width = output_shape[3]; index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t input_batch = input_shape[0];
index_t input_channels = input_shape[1]; index_t kernel_h = filter_shape[2];
index_t input_height = input_shape[2]; index_t kernel_w = filter_shape[3];
index_t input_width = input_shape[3];
int stride_h = strides_[0];
index_t kernel_h = filter_shape[2]; int stride_w = strides_[1];
index_t kernel_w = filter_shape[3];
int dilation_h = dilations_[0];
int stride_h = strides_[0]; int dilation_w = dilations_[1];
int stride_w = strides_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
int dilation_h = dilations_[0];
int dilation_w = dilations_[1]; // The left-upper most offset of the padded input
int padded_h_start = 0 - paddings_[0] / 2;
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); int padded_w_start = 0 - paddings_[1] / 2;
index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
// The left-upper most offset of the padded input index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2;
int padded_h_start = 0 - paddings_[0] / 2;
int padded_w_start = 0 - paddings_[1] / 2; index_t kernel_size = input_channels * kernel_h * kernel_w;
index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2;
index_t kernel_size = input_channels * kernel_h * kernel_w;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) { for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) { for (int w = 0; w < width; ++w) {
index_t offset = n * channels * height * width + index_t offset = n * channels * height * width +
c * height * width + c * height * width + h * width + w;
h * width + w; T sum = 0;
T sum = 0; const T* filter_ptr = filter + c * kernel_size;
const T* filter_ptr = filter + c * kernel_size; for (int inc = 0; inc < input_channels; ++inc) {
for (int inc = 0; inc < input_channels; ++inc) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kh = 0; kh < kernel_h; ++kh) { for (int kw = 0; kw < kernel_w; ++kw) {
for (int kw = 0; kw < kernel_w; ++kw) { int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
int inh = padded_h_start + h * stride_h + dilation_h * kh; if (inh < 0 || inh >= input_height || inw < 0 ||
int inw = padded_w_start + w * stride_w + dilation_w * kw; inw >= input_width) {
if (inh < 0 || inh >= input_height || MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw < 0 || inw >= input_width) { inw >= padded_w_start && inw < padded_w_stop,
MACE_CHECK(inh >= padded_h_start && "Out of range read from input: ", inh, ", ",
inh < padded_h_stop && inw);
inw >= padded_w_start && // else padding with 0:
inw < padded_w_stop, // sum += 0;
"Out of range read from input: ", } else {
inh, ", ", inw); index_t input_offset =
// else padding with 0:
// sum += 0;
} else {
index_t input_offset =
n * input_channels * input_height * input_width + n * input_channels * input_height * input_width +
inc * input_height * input_width + inc * input_height * input_width + inh * input_width +
inh * input_width + inw; inw;
sum += input[input_offset] * *filter_ptr; sum += input[input_offset] * *filter_ptr;
}
++filter_ptr;
} }
++filter_ptr;
} }
output[offset] = sum + bias[c];
} }
output[offset] = sum + bias[c];
} }
} }
} }
} }
} }
}
private: private:
const int* strides_; // [stride_h, stride_w] const int* strides_; // [stride_h, stride_w]
const int* paddings_; // [padding_h, padding_w] const int* paddings_; // [padding_h, padding_w]
const int* dilations_; // [dilation_h, dilation_w] const int* dilations_; // [dilation_h, dilation_w]
}; };
template <> template <>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, void Conv2dFunctor<DeviceType::NEON, float>::operator()(
const index_t* input_shape, const float* input, const index_t* input_shape, const float* filter,
const float* filter, const index_t* filter_shape, const float* bias, float* output,
const index_t* filter_shape, const index_t* output_shape);
const float* bias,
float* output, } // namespace kernels
const index_t* output_shape); } // namespace mace
} // namespace kernels #endif // MACE_KERNELS_CONV_2D_H_
} // namespace mace
#endif // MACE_KERNELS_CONV_2D_H_
...@@ -7,12 +7,10 @@ ...@@ -7,12 +7,10 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int *dilations, const int *dilations, const int *strides,
const int *strides, Padding padding, index_t *output_shape,
Padding padding,
index_t *output_shape,
int *padding_size) { int *padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1"); "Invalid dilations, must >= 1");
...@@ -43,14 +41,16 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW ...@@ -43,14 +41,16 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1; output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1; output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
break; break;
case SAME:output_height = (input_shape[2] - 1) / strides[0] + 1; case SAME:
output_height = (input_shape[2] - 1) / strides[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1; output_width = (input_shape[3] - 1) / strides[1] + 1;
break; break;
case FULL: case FULL:
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break; break;
default:MACE_CHECK(false, "Unsupported padding type: ", padding); default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
} }
// Note: TensorFlow may padded one more on the right/bottom side // Note: TensorFlow may padded one more on the right/bottom side
...@@ -58,10 +58,10 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW ...@@ -58,10 +58,10 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
// utilize the more centered features. We need to benchmark // utilize the more centered features. We need to benchmark
// based on the model accuracy. // based on the model accuracy.
padding_size[0] = (output_height - 1) * strides[0] + padding_size[0] =
k_extent_height - input_shape[2]; (output_height - 1) * strides[0] + k_extent_height - input_shape[2];
padding_size[1] = (output_width - 1) * strides[1] + padding_size[1] =
k_extent_width - input_shape[3]; (output_width - 1) * strides[1] + k_extent_width - input_shape[3];
output_shape[0] = input_shape[0]; output_shape[0] = input_shape[0];
output_shape[1] = output_channels; output_shape[1] = output_channels;
...@@ -69,19 +69,15 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW ...@@ -69,19 +69,15 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
output_shape[3] = output_width; output_shape[3] = output_width;
} }
void ConstructInputWithPadding(const float *input, void ConstructInputWithPadding(const float *input, const index_t *input_shape,
const index_t *input_shape, const int *paddings, Tensor *output_tensor) {
const int *paddings,
Tensor *output_tensor) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
index_t channels = input_shape[1]; index_t channels = input_shape[1];
index_t height = input_shape[2]; index_t height = input_shape[2];
index_t width = input_shape[3]; index_t width = input_shape[3];
std::vector<index_t> output_shape({batch, std::vector<index_t> output_shape(
channels, {batch, channels, paddings[0] + height, paddings[1] + width});
paddings[0] + height,
paddings[1] + width});
const index_t output_width = output_shape[3]; const index_t output_width = output_shape[3];
const int padded_top = paddings[0] / 2; const int padded_top = paddings[0] / 2;
...@@ -105,5 +101,5 @@ void ConstructInputWithPadding(const float *input, ...@@ -105,5 +101,5 @@ void ConstructInputWithPadding(const float *input,
} }
} }
} }
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -10,26 +10,22 @@ ...@@ -10,26 +10,22 @@
namespace mace { namespace mace {
enum Padding { enum Padding {
VALID = 0, // No padding VALID = 0, // No padding
SAME = 1, // Pads with half the filter size (rounded down) on both sides SAME = 1, // Pads with half the filter size (rounded down) on both sides
FULL = 2, // Pads with one less than the filter size on both sides FULL = 2, // Pads with one less than the filter size on both sides
}; };
namespace kernels { namespace kernels {
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int *dilations, const int *dilations, const int *strides,
const int *strides, Padding padding, index_t *output_shape,
Padding padding,
index_t *output_shape,
int *padding_size); int *padding_size);
void ConstructInputWithPadding(const float *input, void ConstructInputWithPadding(const float *input, const index_t *input_shape,
const index_t *input_shape, const int *paddings, Tensor *output_tensor);
const int *paddings, } // namespace kernels
Tensor *output_tensor); } // namespace mace
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_CONV_POOL_2D_UTIL_H_ #endif // MACE_KERNELS_CONV_POOL_2D_UTIL_H_
...@@ -2,16 +2,15 @@ ...@@ -2,16 +2,15 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include <arm_neon.h>
#include "mace/kernels/addn.h" #include "mace/kernels/addn.h"
#include <arm_neon.h>
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <> template <>
void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>& inputs, void AddNFunctor<DeviceType::NEON, float>::operator()(
float *output, const vector<const float *> &inputs, float *output, index_t size) {
index_t size) {
// TODO: neon mem copy // TODO: neon mem copy
memset(output, 0, size * sizeof(float)); memset(output, 0, size * sizeof(float));
int n = inputs.size(); int n = inputs.size();
...@@ -22,7 +21,7 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*> ...@@ -22,7 +21,7 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>
} }
int64_t element_per_group = size / groups; int64_t element_per_group = size / groups;
#pragma omp parallel for num_threads(1) // no significant performance improve #pragma omp parallel for num_threads(1) // no significant performance improve
for (int64_t i = 0; i < size; i += element_per_group) { for (int64_t i = 0; i < size; i += element_per_group) {
int64_t count = std::min(element_per_group, size - i); int64_t count = std::min(element_per_group, size - i);
int nn = count >> 2; int nn = count >> 2;
...@@ -48,5 +47,5 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*> ...@@ -48,5 +47,5 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>
} }
}; };
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -2,29 +2,25 @@ ...@@ -2,29 +2,25 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include <arm_neon.h>
#include "mace/kernels/batch_norm.h" #include "mace/kernels/batch_norm.h"
#include <arm_neon.h>
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <> template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input, void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* scale, const float* input, const float* scale, const float* offset,
const float* offset, const float* mean, const float* var, const index_t n, const index_t channel,
const float* mean, const index_t sample_size, float* output) {
const float* var, // Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
const index_t n, // The calculation formula for inference is
const index_t channel, // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
const index_t sample_size, // ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon}
float* output) { // }
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// The calculation formula for inference is // new_offset = \offset - mean * common_val;
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + // Y = new_scale * X + new_offset;
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
float new_scale, new_offset; float new_scale, new_offset;
index_t count = sample_size >> 2; index_t count = sample_size >> 2;
index_t remain_count = sample_size - (count << 2); index_t remain_count = sample_size - (count << 2);
...@@ -36,8 +32,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input, ...@@ -36,8 +32,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset); float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (index_t i = 0; i < n; ++i) { for (index_t i = 0; i < n; ++i) {
const float *input_sample_ptr = input + pos; const float* input_sample_ptr = input + pos;
float *output_sample_ptr = output + pos; float* output_sample_ptr = output + pos;
for (index_t j = 0; j < count; ++j) { for (index_t j = 0; j < count; ++j) {
float32x4_t input_f = vld1q_f32(input_sample_ptr); float32x4_t input_f = vld1q_f32(input_sample_ptr);
...@@ -57,5 +53,5 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input, ...@@ -57,5 +53,5 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
} }
}; };
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -20,62 +20,39 @@ extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape, ...@@ -20,62 +20,39 @@ extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape,
const float *filter, const float *bias, const float *filter, const float *bias,
float *output, const index_t *output_shape); float *output, const index_t *output_shape);
template<> template <>
void Conv2dFunctor<DeviceType::NEON, void Conv2dFunctor<DeviceType::NEON,
float>::operator()(const float *input, // NCHW float>::
const index_t *input_shape, operator()(const float *input, // NCHW
const float *filter, // c_out, c_in, kernel_h, kernel_w const index_t *input_shape,
const index_t *filter_shape, const float *filter, // c_out, c_in, kernel_h, kernel_w
const float *bias, // c_out const index_t *filter_shape,
float *output, // NCHW const float *bias, // c_out
const index_t *output_shape) { float *output, // NCHW
const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)(const float *input, // NCHW typedef void (*Conv2dNeonFunction)(
const index_t *input_shape, const float *input, // NCHW
const float *filter, // c_out, c_in, kernel_h, kernel_w const index_t *input_shape,
const float *bias, // c_out const float *filter, // c_out, c_in, kernel_h, kernel_w
float *output, // NCHW const float *bias, // c_out
const index_t *output_shape); float *output, // NCHW
const index_t *output_shape);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = { static const Conv2dNeonFunction selector[5][2] = {
{ {Conv2dNeonK1x1S1, nullptr},
Conv2dNeonK1x1S1, {nullptr, nullptr},
nullptr {Conv2dNeonK3x3S1, nullptr},
}, {nullptr, nullptr},
{ {Conv2dNeonK5x5S1, nullptr}};
nullptr,
nullptr
},
{
Conv2dNeonK3x3S1,
nullptr
},
{
nullptr,
nullptr
},
{
Conv2dNeonK5x5S1,
nullptr
}
};
// not implement yet // not implement yet
index_t kernel_h = filter_shape[2]; index_t kernel_h = filter_shape[2];
index_t kernel_w = filter_shape[3]; index_t kernel_w = filter_shape[3];
if (kernel_h != kernel_w || kernel_h > 5 || if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] != strides_[1] || strides_[0] > 2 || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) { selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion"; LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion";
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)( Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, input, input_shape, filter, filter_shape, bias, output, output_shape);
input_shape,
filter,
filter_shape,
bias,
output,
output_shape
);
return; return;
} }
...@@ -87,13 +64,8 @@ void Conv2dFunctor<DeviceType::NEON, ...@@ -87,13 +64,8 @@ void Conv2dFunctor<DeviceType::NEON,
input_shape = padded_input.shape().data(); input_shape = padded_input.shape().data();
} }
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input, conv2d_neon_func(input, input_shape, filter, bias, output, output_shape);
input_shape,
filter,
bias,
output,
output_shape);
} }
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -8,25 +8,24 @@ ...@@ -8,25 +8,24 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
void Conv2dNeonK1x1S1(const float* input, // NCHW void Conv2dNeonK1x1S1(const float* input, // NCHW
const index_t* input_shape, const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out const float* bias, // c_out
float* output, // NCHW float* output, // NCHW
const index_t* output_shape) { const index_t* output_shape) {
const index_t batch = output_shape[0]; const index_t batch = output_shape[0];
const index_t channels = output_shape[1]; const index_t channels = output_shape[1];
const index_t height = output_shape[2]; const index_t height = output_shape[2];
const index_t width = output_shape[3]; const index_t width = output_shape[3];
const index_t input_batch = input_shape[0]; const index_t input_batch = input_shape[0];
const index_t input_channels = input_shape[1]; const index_t input_channels = input_shape[1];
const index_t input_height = input_shape[2]; const index_t input_height = input_shape[2];
const index_t input_width = input_shape[3]; const index_t input_width = input_shape[3];
MACE_CHECK(input_batch == batch && MACE_CHECK(input_batch == batch && input_height == height &&
input_height == height && input_width == width);
input_width == width);
const index_t total_pixels = height * width; const index_t total_pixels = height * width;
// Process 4 * 2 = 8 pixels for each innermost loop // Process 4 * 2 = 8 pixels for each innermost loop
...@@ -37,17 +36,18 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW ...@@ -37,17 +36,18 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
// benchmark omp collapsed(2) // benchmark omp collapsed(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
const float* filter_ptr = filter; const float* filter_ptr = filter;
#pragma omp parallel for #pragma omp parallel for
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
// TODO Will GCC opt these out? // TODO Will GCC opt these out?
float* channel_output_start = float* channel_output_start =
output + n * channels * height * width + c * height * width; output + n * channels * height * width + c * height * width;
const float* input_ptr = input + n * input_channels * input_height * input_width; const float* input_ptr =
input + n * input_channels * input_height * input_width;
// Fill with bias // Fill with bias
float* output_ptr = channel_output_start; float* output_ptr = channel_output_start;
for (index_t ptr = 0; ptr < total_pixels; ++ptr) { for (index_t ptr = 0; ptr < total_pixels; ++ptr) {
output_ptr[ptr] = bias[c]; // TODO can we avoid this? output_ptr[ptr] = bias[c]; // TODO can we avoid this?
} }
index_t inc = 0; index_t inc = 0;
...@@ -55,15 +55,14 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW ...@@ -55,15 +55,14 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
for (; inc + 3 < input_channels; inc += 4) { for (; inc + 3 < input_channels; inc += 4) {
float* output_ptr = channel_output_start; float* output_ptr = channel_output_start;
// The begining of each input feature map channel // The begining of each input feature map channel
MACE_ASSERT(input_ptr == input + n * input_channels * MACE_ASSERT(input_ptr ==
input_height * input_width + input + n * input_channels * input_height * input_width +
inc * input_height * input_width); inc * input_height * input_width);
const float* input_ptr1 = input_ptr + total_pixels; const float* input_ptr1 = input_ptr + total_pixels;
const float* input_ptr2 = input_ptr1 + total_pixels; const float* input_ptr2 = input_ptr1 + total_pixels;
const float* input_ptr3 = input_ptr2 + total_pixels; const float* input_ptr3 = input_ptr2 + total_pixels;
// filter is in c_out, c_in, 1, 1 order // filter is in c_out, c_in, 1, 1 order
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc); MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
const float k0 = filter_ptr[0]; const float k0 = filter_ptr[0];
...@@ -113,7 +112,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW ...@@ -113,7 +112,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
vst1q_f32(output_ptr + 4, out4); vst1q_f32(output_ptr + 4, out4);
output_ptr += 8; output_ptr += 8;
input_ptr += 8; input_ptr += 8;
input_ptr1 += 8; input_ptr1 += 8;
input_ptr2 += 8; input_ptr2 += 8;
input_ptr3 += 8; input_ptr3 += 8;
...@@ -121,7 +120,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW ...@@ -121,7 +120,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
// Process the remaining pixels // Process the remaining pixels
index_t remaining_pixels = loop_remaining; index_t remaining_pixels = loop_remaining;
for (; remaining_pixels > 0; --remaining_pixels) { for (; remaining_pixels > 0; --remaining_pixels) {
const float mul = *input_ptr * k0; const float mul = *input_ptr * k0;
const float mul1 = *input_ptr1 * k1; const float mul1 = *input_ptr1 * k1;
const float mul2 = *input_ptr2 * k2; const float mul2 = *input_ptr2 * k2;
const float mul3 = *input_ptr3 * k3; const float mul3 = *input_ptr3 * k3;
...@@ -141,9 +140,9 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW ...@@ -141,9 +140,9 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
// Process the remaining channels // Process the remaining channels
for (; inc < input_channels; ++inc) { for (; inc < input_channels; ++inc) {
float* output_ptr = channel_output_start; float* output_ptr = channel_output_start;
MACE_ASSERT(input_ptr == input + n * input_channels * MACE_ASSERT(input_ptr ==
input_height * input_width + input + n * input_channels * input_height * input_width +
inc * input_height * input_width); inc * input_height * input_width);
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc); MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
const float k0 = filter_ptr[0]; const float k0 = filter_ptr[0];
...@@ -166,13 +165,13 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW ...@@ -166,13 +165,13 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
vst1q_f32(output_ptr + 4, out4); vst1q_f32(output_ptr + 4, out4);
output_ptr += 8; output_ptr += 8;
input_ptr += 8; input_ptr += 8;
} }
// Process the remaining pixels // Process the remaining pixels
index_t remaining_pixels = loop_remaining; index_t remaining_pixels = loop_remaining;
for (; remaining_pixels > 0; --remaining_pixels) { for (; remaining_pixels > 0; --remaining_pixels) {
const float mul = *input_ptr * k0; const float mul = *input_ptr * k0;
*output_ptr += mul; *output_ptr += mul;
++output_ptr; ++output_ptr;
...@@ -183,5 +182,5 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW ...@@ -183,5 +182,5 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
} }
}; };
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -10,78 +10,81 @@ namespace kernels { ...@@ -10,78 +10,81 @@ namespace kernels {
static const int kRegisterSize = 4; static const int kRegisterSize = 4;
void Conv2dNeonK3x3S1(const float* input, // NCHW void Conv2dNeonK3x3S1(const float* input, // NCHW
const index_t* input_shape, const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out const float* bias, // c_out
float* output, // NCHW float* output, // NCHW
const index_t* output_shape) { const index_t* output_shape) {
int batch = output_shape[0];
int batch = output_shape[0];
int channels = output_shape[1]; int channels = output_shape[1];
int height = output_shape[2]; int height = output_shape[2];
int width = output_shape[3]; int width = output_shape[3];
int input_batch = input_shape[0]; int input_batch = input_shape[0];
int input_channels = input_shape[1]; int input_channels = input_shape[1];
int input_height = input_shape[2]; int input_height = input_shape[2];
int input_width = input_shape[3]; int input_width = input_shape[3];
int kernel_h = 3; int kernel_h = 3;
int kernel_w = 3; int kernel_w = 3;
int height_count = (height >> 1) << 1; int height_count = (height >> 1) << 1;
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
float* output_ptr_base = output + b * channels * height * width; float* output_ptr_base = output + b * channels * height * width;
for (int oc = 0; oc < channels; ++oc) { for (int oc = 0; oc < channels; ++oc) {
const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; const float* filter_ptr =
const float* input_ptr = input + b * input_channels * input_height * input_width; filter + oc * input_channels * kernel_h * kernel_w;
const float* input_ptr =
input + b * input_channels * input_height * input_width;
float* output_ptr = output_ptr_base + oc * height * width; float* output_ptr = output_ptr_base + oc * height * width;
std::fill(output_ptr, output_ptr + height * width, bias[oc]); std::fill(output_ptr, output_ptr + height * width, bias[oc]);
for (int ic = 0; ic < input_channels; ++ic) { for (int ic = 0; ic < input_channels; ++ic) {
float32x4_t filter0 = vld1q_f32(filter_ptr); float32x4_t filter0 = vld1q_f32(filter_ptr);
float32x4_t filter3 = vld1q_f32(filter_ptr+3); float32x4_t filter3 = vld1q_f32(filter_ptr + 3);
float32x4_t filter6 = vld1q_f32(filter_ptr+6); float32x4_t filter6 = vld1q_f32(filter_ptr + 6);
const float* row[kRegisterSize] = { const float* row[kRegisterSize] = {input_ptr, input_ptr + input_width,
input_ptr, input_ptr + input_width, input_ptr + 2 * input_width,
input_ptr + 2 * input_width, input_ptr + 3 * input_width input_ptr + 3 * input_width};
};
float* output_ptr1 = output_ptr; float* output_ptr1 = output_ptr;
float* output_ptr2 = output_ptr + width; float* output_ptr2 = output_ptr + width;
for (int h = 0; h < height_count; h += 2) { for (int h = 0; h < height_count; h += 2) {
int count = width >> 2; int count = width >> 2;
int remain_count = width & 3; int remain_count = width & 3;
for (; count > 0; --count) { for (; count > 0; --count) {
float32x4_t sum0 = vdupq_n_f32(.0f); float32x4_t sum0 = vdupq_n_f32(.0f);
float32x4_t sum1 = vdupq_n_f32(.0f); float32x4_t sum1 = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123 float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123
float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); //4567 float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); // 4567
float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 float32x4_t row0_ext_1 =
float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 vextq_f32(row0_ext_0, row0_latter, 1); // 1234
float32x4_t row0_ext_2 =
vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123 float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123
float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); //4567 float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); // 4567
float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 float32x4_t row1_ext_1 =
float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 vextq_f32(row1_ext_0, row1_latter, 1); // 1234
float32x4_t row1_ext_2 =
vextq_f32(row1_ext_0, row1_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2);
row0_ext_0 = vld1q_f32(row[2]); //0123 row0_ext_0 = vld1q_f32(row[2]); // 0123
row0_latter = vld1q_f32(row[2] + kRegisterSize); //4567 row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0); sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1); sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1);
...@@ -96,10 +99,10 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ...@@ -96,10 +99,10 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
sum1 = vfmaq_laneq_f32(sum1, row0_ext_1, filter3, 1); sum1 = vfmaq_laneq_f32(sum1, row0_ext_1, filter3, 1);
sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2); sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2);
row1_ext_0 = vld1q_f32(row[3]); //0123 row1_ext_0 = vld1q_f32(row[3]); // 0123
row1_latter = vld1q_f32(row[3] + kRegisterSize); //4567 row1_latter = vld1q_f32(row[3] + kRegisterSize); // 4567
row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); // 1234
row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); // 2345
sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter6, 0); sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter6, 0);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter6, 1); sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter6, 1);
...@@ -114,15 +117,15 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ...@@ -114,15 +117,15 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
output_ptr1 += kRegisterSize; output_ptr1 += kRegisterSize;
output_ptr2 += kRegisterSize; output_ptr2 += kRegisterSize;
for(int i = 0; i < kRegisterSize; ++i) { for (int i = 0; i < kRegisterSize; ++i) {
row[i] += kRegisterSize; row[i] += kRegisterSize;
} }
} }
for (; remain_count > 0; --remain_count) { for (; remain_count > 0; --remain_count) {
float32x4_t row0 = vld1q_f32(row[0]); //0123 float32x4_t row0 = vld1q_f32(row[0]); // 0123
float32x4_t row1 = vld1q_f32(row[1]); //0123 float32x4_t row1 = vld1q_f32(row[1]); // 0123
float32x4_t row2 = vld1q_f32(row[2]); //0123 float32x4_t row2 = vld1q_f32(row[2]); // 0123
float32x4_t row3 = vld1q_f32(row[3]); //0123 float32x4_t row3 = vld1q_f32(row[3]); // 0123
float32x4_t sum = vmulq_f32(row0, filter0); float32x4_t sum = vmulq_f32(row0, filter0);
sum = vmlaq_f32(sum, row1, filter3); sum = vmlaq_f32(sum, row1, filter3);
...@@ -138,13 +141,13 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ...@@ -138,13 +141,13 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
++output_ptr1; ++output_ptr1;
++output_ptr2; ++output_ptr2;
for(int i = 0; i < kRegisterSize; ++i) { for (int i = 0; i < kRegisterSize; ++i) {
row[i] += 1; row[i] += 1;
} }
} }
output_ptr1 += width; output_ptr1 += width;
output_ptr2 += width; output_ptr2 += width;
for(int i = 0; i < kRegisterSize; ++i) { for (int i = 0; i < kRegisterSize; ++i) {
row[i] += 2 + input_width; row[i] += 2 + input_width;
} }
} }
...@@ -152,30 +155,34 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ...@@ -152,30 +155,34 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
if (height != height_count) { if (height != height_count) {
int count = width >> 2; int count = width >> 2;
int remain_count = width & 3; int remain_count = width & 3;
for(; count > 0; --count) { for (; count > 0; --count) {
float32x4_t sum0 = vdupq_n_f32(.0f); float32x4_t sum0 = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123 float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123
float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); //4567 float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); // 4567
float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 float32x4_t row0_ext_1 =
float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 vextq_f32(row0_ext_0, row0_latter, 1); // 1234
float32x4_t row0_ext_2 =
vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123 float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123
float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); //4567 float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); // 4567
float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 float32x4_t row1_ext_1 =
float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 vextq_f32(row1_ext_0, row1_latter, 1); // 1234
float32x4_t row1_ext_2 =
vextq_f32(row1_ext_0, row1_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2);
row0_ext_0 = vld1q_f32(row[2]); //0123 row0_ext_0 = vld1q_f32(row[2]); // 0123
row0_latter = vld1q_f32(row[2] + kRegisterSize); //4567 row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0); sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1); sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1);
...@@ -185,14 +192,14 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ...@@ -185,14 +192,14 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
output_row0 = vaddq_f32(output_row0, sum0); output_row0 = vaddq_f32(output_row0, sum0);
vst1q_f32(output_ptr1, output_row0); vst1q_f32(output_ptr1, output_row0);
output_ptr1 += kRegisterSize; output_ptr1 += kRegisterSize;
for(int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
row[i] += kRegisterSize; row[i] += kRegisterSize;
} }
} }
for (; remain_count > 0; --remain_count) { for (; remain_count > 0; --remain_count) {
float32x4_t row0 = vld1q_f32(row[0]); //0123 float32x4_t row0 = vld1q_f32(row[0]); // 0123
float32x4_t row1 = vld1q_f32(row[1]); //0123 float32x4_t row1 = vld1q_f32(row[1]); // 0123
float32x4_t row2 = vld1q_f32(row[2]); //0123 float32x4_t row2 = vld1q_f32(row[2]); // 0123
float32x4_t sum = vmulq_f32(row0, filter0); float32x4_t sum = vmulq_f32(row0, filter0);
sum = vmlaq_f32(sum, row1, filter3); sum = vmlaq_f32(sum, row1, filter3);
...@@ -201,7 +208,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ...@@ -201,7 +208,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
*output_ptr1 = vaddvq_f32(sum); *output_ptr1 = vaddvq_f32(sum);
++output_ptr1; ++output_ptr1;
for(int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
row[i] += 1; row[i] += 1;
} }
} }
...@@ -213,5 +220,5 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW ...@@ -213,5 +220,5 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
} }
} }
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
void Conv2dNeonK5x5S1(const float* input, // NCHW void Conv2dNeonK5x5S1(const float* input, // NCHW
const index_t* input_shape, const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out const float* bias, // c_out
float* output, // NCHW float* output, // NCHW
const index_t* output_shape) { const index_t* output_shape) {
const index_t batch = output_shape[0]; const index_t batch = output_shape[0];
const index_t channels = output_shape[1]; const index_t channels = output_shape[1];
...@@ -30,17 +30,17 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW ...@@ -30,17 +30,17 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
const index_t input_total_pixels_per_channel = input_height * input_width; const index_t input_total_pixels_per_channel = input_height * input_width;
const index_t output_total_pixels_per_channel = height * width; const index_t output_total_pixels_per_channel = height * width;
const index_t input_total_pixels_per_batch = input_total_pixels_per_channel const index_t input_total_pixels_per_batch =
* input_channels; input_total_pixels_per_channel * input_channels;
const index_t output_total_pixels_per_batch = output_total_pixels_per_channel const index_t output_total_pixels_per_batch =
* channels; output_total_pixels_per_channel * channels;
const index_t patch_size = input_channels * 25; const index_t patch_size = input_channels * 25;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
float* output_ptr = output + n * output_total_pixels_per_batch float* output_ptr = output + n * output_total_pixels_per_batch +
+ c * output_total_pixels_per_channel; c * output_total_pixels_per_channel;
const float* input_ptr = input + n * input_total_pixels_per_batch; const float* input_ptr = input + n * input_total_pixels_per_batch;
// Fill with bias // Fill with bias
...@@ -53,7 +53,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW ...@@ -53,7 +53,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
float* outptr2 = outptr + width; float* outptr2 = outptr + width;
const float* inptr = input_ptr + inc * input_total_pixels_per_channel; const float* inptr = input_ptr + inc * input_total_pixels_per_channel;
const float* filter_ptr = filter + c * patch_size + inc * 25; const float* filter_ptr = filter + c * patch_size + inc * 25;
const float* r0 = inptr; const float* r0 = inptr;
const float* r1 = inptr + input_width; const float* r1 = inptr + input_width;
...@@ -246,8 +246,8 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW ...@@ -246,8 +246,8 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
sum2 = r5[4] * k4[4]; sum2 = r5[4] * k4[4];
float32x2_t _ss = vadd_f32(vget_low_f32(_sum), vget_high_f32(_sum)); float32x2_t _ss = vadd_f32(vget_low_f32(_sum), vget_high_f32(_sum));
float32x2_t float32x2_t _ss2 =
_ss2 = vadd_f32(vget_low_f32(_sum2), vget_high_f32(_sum2)); vadd_f32(vget_low_f32(_sum2), vget_high_f32(_sum2));
float32x2_t _ss_ss2 = vpadd_f32(_ss, _ss2); float32x2_t _ss_ss2 = vpadd_f32(_ss, _ss2);
sum += vget_lane_f32(_ss_ss2, 0); sum += vget_lane_f32(_ss_ss2, 0);
...@@ -414,7 +414,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW ...@@ -414,7 +414,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
} }
} }
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_NEON_CONV_2D_NEON_5X5_H_ #endif // MACE_KERNELS_NEON_CONV_2D_NEON_5X5_H_
...@@ -2,19 +2,17 @@ ...@@ -2,19 +2,17 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include <arm_neon.h>
#include <float.h> #include <float.h>
#include <limits> #include <limits>
#include <arm_neon.h>
#include "mace/core/common.h" #include "mace/core/common.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
void PoolingMaxNeonK2x2S2x2(const float *input, void PoolingMaxNeonK2x2S2x2(const float *input, const index_t *in_shape,
const index_t *in_shape, float *output, const index_t *out_shape,
float *output,
const index_t *out_shape,
const int *paddings) { const int *paddings) {
index_t batch = in_shape[0]; index_t batch = in_shape[0];
index_t channels = in_shape[1]; index_t channels = in_shape[1];
...@@ -44,7 +42,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input, ...@@ -44,7 +42,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
int w = 0; int w = 0;
int num_vectors = 0; int num_vectors = 0;
if (!((h == 0 && padding_top > 0) || if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) { (h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width; r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width; r1 = r0 + in_width;
if (padding_left > 0) { if (padding_left > 0) {
...@@ -86,8 +84,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input, ...@@ -86,8 +84,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
for (int kw = 0; kw < 2; ++kw) { for (int kw = 0; kw < 2; ++kw) {
int inh = h * 2 - padding_top + kh; int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw; int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height && if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]); max = std::max(max, input[input_offset + inh * in_width + inw]);
} }
} }
...@@ -104,10 +101,8 @@ void PoolingMaxNeonK2x2S2x2(const float *input, ...@@ -104,10 +101,8 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
} }
// assume the input has already been padded // assume the input has already been padded
void PoolingMaxNeonK2x2S2x2Padded(const float *input, void PoolingMaxNeonK2x2S2x2Padded(const float *input, const index_t *in_shape,
const index_t *in_shape, float *output, const index_t *out_shape) {
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0]; index_t batch = in_shape[0];
index_t channels = in_shape[1]; index_t channels = in_shape[1];
index_t in_height = in_shape[2]; index_t in_height = in_shape[2];
......
...@@ -2,19 +2,17 @@ ...@@ -2,19 +2,17 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include <arm_neon.h>
#include <float.h> #include <float.h>
#include <limits> #include <limits>
#include <arm_neon.h>
#include "mace/core/common.h" #include "mace/core/common.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
void PoolingMaxNeonK3x3S2x2(const float *input, void PoolingMaxNeonK3x3S2x2(const float *input, const index_t *in_shape,
const index_t *in_shape, float *output, const index_t *out_shape,
float *output,
const index_t *out_shape,
const int *paddings) { const int *paddings) {
index_t batch = in_shape[0]; index_t batch = in_shape[0];
index_t channels = in_shape[1]; index_t channels = in_shape[1];
...@@ -44,7 +42,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input, ...@@ -44,7 +42,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
int num_vectors = 0; int num_vectors = 0;
const float *r0, *r1, *r2; const float *r0, *r1, *r2;
if (!((h == 0 && padding_top > 0) || if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) { (h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width; r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width; r1 = r0 + in_width;
r2 = r1 + in_width; r2 = r1 + in_width;
...@@ -112,8 +110,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input, ...@@ -112,8 +110,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
for (int kw = 0; kw < 3; ++kw) { for (int kw = 0; kw < 3; ++kw) {
int inh = h * 2 - padding_top + kh; int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw; int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height && if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]); max = std::max(max, input[input_offset + inh * in_width + inw]);
} }
} }
...@@ -130,10 +127,8 @@ void PoolingMaxNeonK3x3S2x2(const float *input, ...@@ -130,10 +127,8 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
} }
// assume the input has already been padded // assume the input has already been padded
void PoolingMaxNeonK3x3S2x2Padded(const float *input, void PoolingMaxNeonK3x3S2x2Padded(const float *input, const index_t *in_shape,
const index_t *in_shape, float *output, const index_t *out_shape) {
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0]; index_t batch = in_shape[0];
index_t channels = in_shape[1]; index_t channels = in_shape[1];
index_t in_height = in_shape[2]; index_t in_height = in_shape[2];
...@@ -218,5 +213,5 @@ void PoolingMaxNeonK3x3S2x2Padded(const float *input, ...@@ -218,5 +213,5 @@ void PoolingMaxNeonK3x3S2x2Padded(const float *input,
} }
} }
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -2,45 +2,36 @@ ...@@ -2,45 +2,36 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include <arm_neon.h>
#include "mace/kernels/pooling.h" #include "mace/kernels/pooling.h"
#include <arm_neon.h>
#include "mace/kernels/conv_pool_2d_util.h" #include "mace/kernels/conv_pool_2d_util.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
extern void PoolingMaxNeonK2x2S2x2(const float *input, extern void PoolingMaxNeonK2x2S2x2(const float *input, const index_t *in_shape,
const index_t *in_shape, float *output, const index_t *out_shape,
float *output,
const index_t *out_shape,
const int *paddings); const int *paddings);
extern void PoolingMaxNeonK3x3S2x2(const float *input, extern void PoolingMaxNeonK3x3S2x2(const float *input, const index_t *in_shape,
const index_t *in_shape, float *output, const index_t *out_shape,
float *output,
const index_t *out_shape,
const int *paddings); const int *paddings);
#ifdef __COPY_MAKE_PADDING #ifdef __COPY_MAKE_PADDING
extern void PoolingMaxNeonK2x2S2x2Padded(const float* input, extern void PoolingMaxNeonK2x2S2x2Padded(const float *input,
const index_t* in_shape, const index_t *in_shape, float *output,
float* output, const index_t *out_shape);
const index_t* out_shape); extern void PoolingMaxNeonK3x3S2x2Padded(const float *input,
extern void PoolingMaxNeonK3x3S2x2Padded(const float* input, const index_t *in_shape, float *output,
const index_t* in_shape, const index_t *out_shape);
float* output,
const index_t* out_shape);
#endif #endif
template<> template <>
void PoolingFunctor<DeviceType::NEON, float>::operator()( void PoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input, const float *input, const index_t *input_shape, float *output,
const index_t *input_shape,
float *output,
const index_t *output_shape) { const index_t *output_shape) {
if (kernels_[0] == 2 && kernels_[1] == 2 && if (kernels_[0] == 2 && kernels_[1] == 2 && strides_[0] == 2 &&
strides_[0] == 2 && strides_[1] == 2 && strides_[1] == 2 && pooling_type_ == MAX) {
pooling_type_ == MAX) {
#ifdef __COPY_MAKE_PADDING #ifdef __COPY_MAKE_PADDING
Tensor padded_input; Tensor padded_input;
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
...@@ -50,9 +41,8 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()( ...@@ -50,9 +41,8 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
#else #else
PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape, paddings_); PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape, paddings_);
#endif #endif
} else if (kernels_[0] == 3 && kernels_[1] == 3 && } else if (kernels_[0] == 3 && kernels_[1] == 3 && strides_[0] == 2 &&
strides_[0] == 2 && strides_[1] == 2 && strides_[1] == 2 && pooling_type_ == MAX) {
pooling_type_ == MAX) {
#ifdef __COPY_MAKE_PADDING #ifdef __COPY_MAKE_PADDING
Tensor padded_input; Tensor padded_input;
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
...@@ -65,13 +55,9 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()( ...@@ -65,13 +55,9 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
} else { // not implement yet } else { // not implement yet
PoolingFunctor<DeviceType::CPU, float>(pooling_type_, kernels_, strides_, PoolingFunctor<DeviceType::CPU, float>(pooling_type_, kernels_, strides_,
paddings_, dilations_)( paddings_, dilations_)(
input, input, input_shape, output, output_shape);
input_shape,
output,
output_shape
);
} }
} }
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include <arm_neon.h>
#include "mace/kernels/relu.h" #include "mace/kernels/relu.h"
#include <arm_neon.h>
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <> template <>
void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input, void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
float *output, float *output,
index_t size) { index_t size) {
#pragma omp parallel for num_threads(1) // no significant performance improve #pragma omp parallel for num_threads(1) // no significant performance improve
for (int64_t i = 0; i < size; i += kCostPerGroup) { for (int64_t i = 0; i < size; i += kCostPerGroup) {
int64_t count = std::min(static_cast<int64_t>(kCostPerGroup), size - i); int64_t count = std::min(static_cast<int64_t>(kCostPerGroup), size - i);
int nn = count >> 2; int nn = count >> 2;
...@@ -36,6 +36,5 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input, ...@@ -36,6 +36,5 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
} }
}; };
} // namespace kernels
} // namespace kernels } // namespace mace
} // namespace mace \ No newline at end of file
\ No newline at end of file
...@@ -11,29 +11,24 @@ ...@@ -11,29 +11,24 @@
namespace mace { namespace mace {
enum PoolingType { enum PoolingType {
AVG = 1, // avg_pool AVG = 1, // avg_pool
MAX = 2, // max_pool MAX = 2, // max_pool
}; };
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class PoolingFunctor { class PoolingFunctor {
public: public:
PoolingFunctor(const PoolingType pooling_type, PoolingFunctor(const PoolingType pooling_type, const int *kernels,
const int *kernels, const int *strides, const int *paddings, const int *dilations)
const int *strides,
const int *paddings,
const int *dilations)
: pooling_type_(pooling_type), : pooling_type_(pooling_type),
kernels_(kernels), kernels_(kernels),
strides_(strides), strides_(strides),
paddings_(paddings), paddings_(paddings),
dilations_(dilations) {} dilations_(dilations) {}
void operator()(const T *input, void operator()(const T *input, const index_t *input_shape, T *output,
const index_t *input_shape,
T *output,
const index_t *output_shape) { const index_t *output_shape) {
index_t batch = output_shape[0]; index_t batch = output_shape[0];
index_t channels = output_shape[1]; index_t channels = output_shape[1];
...@@ -60,32 +55,31 @@ class PoolingFunctor { ...@@ -60,32 +55,31 @@ class PoolingFunctor {
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) { for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
index_t out_offset = n * channels * height * width + index_t out_offset = n * channels * height * width + c * height * width;
c * height * width;
index_t in_offset = n * input_channels * input_height * input_width + index_t in_offset = n * input_channels * input_height * input_width +
c * input_height * input_width; c * input_height * input_width;
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) { for (int w = 0; w < width; ++w) {
T sum_or_max = 0; T sum_or_max = 0;
switch (pooling_type_) { switch (pooling_type_) {
case AVG:break; case AVG:
case MAX:sum_or_max = std::numeric_limits<T>::lowest(); break;
case MAX:
sum_or_max = std::numeric_limits<T>::lowest();
break; break;
default: default:
MACE_CHECK(false, MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_);
"Unsupported pooling type: ",
pooling_type_);
} }
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh; int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw; int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh >= 0 && inh < input_height && if (inh >= 0 && inh < input_height && inw >= 0 &&
inw >= 0 && inw < input_width) { inw < input_width) {
index_t input_offset = in_offset + index_t input_offset = in_offset + inh * input_width + inw;
inh * input_width + inw;
switch (pooling_type_) { switch (pooling_type_) {
case AVG:sum_or_max += input[input_offset]; case AVG:
sum_or_max += input[input_offset];
break; break;
case MAX: case MAX:
sum_or_max = std::max(sum_or_max, input[input_offset]); sum_or_max = std::max(sum_or_max, input[input_offset]);
...@@ -98,14 +92,14 @@ class PoolingFunctor { ...@@ -98,14 +92,14 @@ class PoolingFunctor {
} }
} }
switch (pooling_type_) { switch (pooling_type_) {
case AVG:output[out_offset] = sum_or_max / (kernel_h * kernel_w); case AVG:
output[out_offset] = sum_or_max / (kernel_h * kernel_w);
break; break;
case MAX:output[out_offset] = sum_or_max; case MAX:
output[out_offset] = sum_or_max;
break; break;
default: default:
MACE_CHECK(false, MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_);
"Unsupported pooling type: ",
pooling_type_);
} }
out_offset += 1; out_offset += 1;
} }
...@@ -122,14 +116,12 @@ class PoolingFunctor { ...@@ -122,14 +116,12 @@ class PoolingFunctor {
const int *dilations_; const int *dilations_;
}; };
template<> template <>
void PoolingFunctor<DeviceType::NEON, float>::operator()( void PoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input, const float *input, const index_t *input_shape, float *output,
const index_t *input_shape,
float *output,
const index_t *output_shape); const index_t *output_shape);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif //MACE_KERNELS_POOLING_H #endif // MACE_KERNELS_POOLING_H
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
struct ReluFunctor { struct ReluFunctor {
void operator()(const T *input, T *output, index_t size) { void operator()(const T *input, T *output, index_t size) {
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
...@@ -24,7 +24,7 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input, ...@@ -24,7 +24,7 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
float *output, float *output,
index_t size); index_t size);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_RELU_H_ #endif // MACE_KERNELS_RELU_H_
\ No newline at end of file \ No newline at end of file
...@@ -22,8 +22,8 @@ struct CachedInterpolation { ...@@ -22,8 +22,8 @@ struct CachedInterpolation {
inline float CalculateResizeScale(index_t in_size, index_t out_size, inline float CalculateResizeScale(index_t in_size, index_t out_size,
bool align_corners) { bool align_corners) {
return (align_corners && out_size > 1) return (align_corners && out_size > 1)
? (in_size - 1) / static_cast<float>(out_size - 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size); : in_size / static_cast<float>(out_size);
} }
inline void ComputeInterpolationWeights(const index_t out_size, inline void ComputeInterpolationWeights(const index_t out_size,
...@@ -41,21 +41,20 @@ inline void ComputeInterpolationWeights(const index_t out_size, ...@@ -41,21 +41,20 @@ inline void ComputeInterpolationWeights(const index_t out_size,
} }
inline float ComputeLerp(const float top_left, const float top_right, inline float ComputeLerp(const float top_left, const float top_right,
const float bottom_left, const float bottom_right, const float bottom_left, const float bottom_right,
const float x_lerp, const float y_lerp) { const float x_lerp, const float y_lerp) {
const float top = top_left + (top_right - top_left) * x_lerp; const float top = top_left + (top_right - top_left) * x_lerp;
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
return top + (bottom - top) * y_lerp; return top + (bottom - top) * y_lerp;
} }
template<typename T> template <typename T>
void ResizeImage(const T *images, void ResizeImage(const T *images, const index_t batch_size,
const index_t batch_size, const index_t in_height, const index_t in_height, const index_t in_width,
const index_t in_width, const index_t out_height, const index_t out_height, const index_t out_width,
const index_t out_width, const index_t channels, const index_t channels,
const std::vector<CachedInterpolation> &xs_vec, const std::vector<CachedInterpolation> &xs_vec,
const std::vector<CachedInterpolation> &ys, const std::vector<CachedInterpolation> &ys, float *output) {
float *output) {
const index_t in_channel_size = in_height * in_width; const index_t in_channel_size = in_height * in_width;
const index_t in_batch_num_values = channels * in_channel_size; const index_t in_batch_num_values = channels * in_channel_size;
const index_t out_channel_size = out_height * out_width; const index_t out_channel_size = out_height * out_width;
...@@ -65,10 +64,10 @@ void ResizeImage(const T *images, ...@@ -65,10 +64,10 @@ void ResizeImage(const T *images,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch_size; ++b) { for (index_t b = 0; b < batch_size; ++b) {
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
const T* input_ptr = images + in_batch_num_values * b const T *input_ptr =
+ in_channel_size * c; images + in_batch_num_values * b + in_channel_size * c;
float *output_ptr = output + out_batch_num_values * b float *output_ptr =
+ out_channel_size * c; output + out_batch_num_values * b + out_channel_size * c;
for (index_t y = 0; y < out_height; ++y) { for (index_t y = 0; y < out_height; ++y) {
const T *ys_input_lower_ptr = input_ptr + ys[y].lower * in_width; const T *ys_input_lower_ptr = input_ptr + ys[y].lower * in_width;
const T *ys_input_upper_ptr = input_ptr + ys[y].upper * in_width; const T *ys_input_upper_ptr = input_ptr + ys[y].upper * in_width;
...@@ -83,9 +82,8 @@ void ResizeImage(const T *images, ...@@ -83,9 +82,8 @@ void ResizeImage(const T *images,
const float bottom_left = ys_input_upper_ptr[xs_lower]; const float bottom_left = ys_input_upper_ptr[xs_lower];
const float bottom_right = ys_input_upper_ptr[xs_upper]; const float bottom_right = ys_input_upper_ptr[xs_upper];
output_ptr[x] = output_ptr[x] = ComputeLerp(top_left, top_right, bottom_left,
ComputeLerp(top_left, top_right, bottom_left, bottom_right, bottom_right, xs_lerp, ys_lerp);
xs_lerp, ys_lerp);
} }
output_ptr += out_width; output_ptr += out_width;
} }
...@@ -94,16 +92,15 @@ void ResizeImage(const T *images, ...@@ -94,16 +92,15 @@ void ResizeImage(const T *images,
} }
} }
template<DeviceType D, typename T> template <DeviceType D, typename T>
struct ResizeBilinearFunctor { struct ResizeBilinearFunctor {
bool align_corners_; bool align_corners_;
ResizeBilinearFunctor(bool align_corners) ResizeBilinearFunctor(bool align_corners) : align_corners_(align_corners) {}
: align_corners_(align_corners) {}
void operator()(const T *input, T *output, void operator()(const T *input, T *output, index_t n, index_t channels,
index_t n, index_t channels, index_t in_height, index_t in_height, index_t in_width, index_t out_height,
index_t in_width, index_t out_height, index_t out_width) { index_t out_width) {
if (out_height == in_height && out_width == in_width) { if (out_height == in_height && out_width == in_width) {
std::copy(input, input + channels * in_height * in_width, output); std::copy(input, input + channels * in_height * in_width, output);
return; return;
...@@ -111,8 +108,8 @@ struct ResizeBilinearFunctor { ...@@ -111,8 +108,8 @@ struct ResizeBilinearFunctor {
float height_scale = float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_); CalculateResizeScale(in_height, out_height, align_corners_);
float float width_scale =
width_scale = CalculateResizeScale(in_width, out_width, align_corners_); CalculateResizeScale(in_width, out_width, align_corners_);
std::vector<CachedInterpolation> ys(out_height + 1); std::vector<CachedInterpolation> ys(out_height + 1);
std::vector<CachedInterpolation> xs(out_width + 1); std::vector<CachedInterpolation> xs(out_width + 1);
...@@ -121,12 +118,12 @@ struct ResizeBilinearFunctor { ...@@ -121,12 +118,12 @@ struct ResizeBilinearFunctor {
ComputeInterpolationWeights(out_height, in_height, height_scale, ys.data()); ComputeInterpolationWeights(out_height, in_height, height_scale, ys.data());
ComputeInterpolationWeights(out_width, in_width, width_scale, xs.data()); ComputeInterpolationWeights(out_width, in_width, width_scale, xs.data());
ResizeImage(input, n, in_height, in_width, out_height, out_width, ResizeImage(input, n, in_height, in_width, out_height, out_width, channels,
channels, xs, ys, output); xs, ys, output);
} }
}; };
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_RESIZE_BILINEAR_H_ #endif // MACE_KERNELS_RESIZE_BILINEAR_H_
...@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(AddN, AddNOp<DeviceType::CPU, float>); ...@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(AddN, AddNOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(AddN, AddNOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(AddN, AddNOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -10,10 +10,10 @@ ...@@ -10,10 +10,10 @@
namespace mace { namespace mace {
template<DeviceType D, class T> template <DeviceType D, class T>
class AddNOp : public Operator<D, T> { class AddNOp : public Operator<D, T> {
public: public:
AddNOp(const OperatorDef &operator_def, Workspace *ws) AddNOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<D, T>(operator_def, ws) {} : Operator<D, T>(operator_def, ws) {}
bool Run() override { bool Run() override {
...@@ -36,6 +36,6 @@ class AddNOp : public Operator<D, T> { ...@@ -36,6 +36,6 @@ class AddNOp : public Operator<D, T> {
kernels::AddNFunctor<D, T> functor_; kernels::AddNFunctor<D, T> functor_;
}; };
} // namespace mace } // namespace mace
#endif // MACE_OPS_ADDN_H_ #endif // MACE_OPS_ADDN_H_
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
namespace mace { namespace mace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void AddNBenchmark(int iters, int n, int size) { static void AddNBenchmark(int iters, int n, int size) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
...@@ -18,8 +17,7 @@ static void AddNBenchmark(int iters, int n, int size) { ...@@ -18,8 +17,7 @@ static void AddNBenchmark(int iters, int n, int size) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
op_def_builder.Input(internal::MakeString("Input", i).c_str()); op_def_builder.Input(internal::MakeString("Input", i).c_str());
} }
op_def_builder.Output("Output") op_def_builder.Output("Output").Finalize(net.operator_def());
.Finalize(net.operator_def());
// Add input data // Add input data
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -32,27 +30,26 @@ static void AddNBenchmark(int iters, int n, int size) { ...@@ -32,27 +30,26 @@ static void AddNBenchmark(int iters, int n, int size) {
} }
mace::testing::StartTiming(); mace::testing::StartTiming();
while(iters--) { while (iters--) {
net.RunOp(D); net.RunOp(D);
} }
} }
#define BM_ADDN_MACRO(N, SIZE, TYPE, DEVICE) \ #define BM_ADDN_MACRO(N, SIZE, TYPE, DEVICE) \
static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE( \ static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE(int iters) { \
int iters) { \ const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \
const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \ mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \ mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \
AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \ } \
} \
BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE) BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE)
#define BM_ADDN(N, SIZE, TYPE) \ #define BM_ADDN(N, SIZE, TYPE) \
BM_ADDN_MACRO(N, SIZE, TYPE, CPU); \ BM_ADDN_MACRO(N, SIZE, TYPE, CPU); \
BM_ADDN_MACRO(N, SIZE, TYPE, NEON); BM_ADDN_MACRO(N, SIZE, TYPE, NEON);
BM_ADDN(10, 1000, float); BM_ADDN(10, 1000, float);
BM_ADDN(10, 10000, float); BM_ADDN(10, 10000, float);
BM_ADDN(100, 1000, float); BM_ADDN(100, 1000, float);
BM_ADDN(100, 10000, float); BM_ADDN(100, 10000, float);
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -36,4 +36,4 @@ TEST_F(AddnOpTest, AddnOp) { ...@@ -36,4 +36,4 @@ TEST_F(AddnOpTest, AddnOp) {
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01); ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01);
} }
} // namespace mace } // namespace mace
...@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp<DeviceType::CPU, float>); ...@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(BatchNorm, BatchNormOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(BatchNorm, BatchNormOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -10,50 +10,55 @@ ...@@ -10,50 +10,55 @@
namespace mace { namespace mace {
template<DeviceType D, class T> template <DeviceType D, class T>
class BatchNormOp : public Operator<D, T> { class BatchNormOp : public Operator<D, T> {
public: public:
BatchNormOp(const OperatorDef &operator_def, Workspace *ws) BatchNormOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<float>("variance_epsilon", 1e-4)){} functor_(
OperatorBase::GetSingleArgument<float>("variance_epsilon", 1e-4)) {}
bool Run() override {
const Tensor* input = this->Input(0);
const Tensor* scale = this->Input(1);
const Tensor* offset = this->Input(2);
const Tensor* mean = this->Input(3);
const Tensor* var = this->Input(4);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size());
MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ", scale->dim_size());
MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ", offset->dim_size());
MACE_CHECK(mean->dim_size() == 1, "mean must be 1-dimensional. ", mean->dim_size());
MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ", var->dim_size());
Tensor* output = this->Output(0);
output->ResizeLike(input);
const index_t n = input->dim(0);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
const T* input_ptr = input->data<T>();
const T* scale_ptr = scale->data<T>();
const T* offset_ptr = offset->data<T>();
const T* mean_ptr = mean->data<T>();
const T* var_ptr = var->data<T>();
T* output_ptr = output->mutable_data<T>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr,
n, channel, sample_size,
output_ptr);
return true;
}
private:
kernels::BatchNormFunctor<D, T> functor_;
bool Run() override {
const Tensor* input = this->Input(0);
const Tensor* scale = this->Input(1);
const Tensor* offset = this->Input(2);
const Tensor* mean = this->Input(3);
const Tensor* var = this->Input(4);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ",
scale->dim_size());
MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ",
offset->dim_size());
MACE_CHECK(mean->dim_size() == 1, "mean must be 1-dimensional. ",
mean->dim_size());
MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ",
var->dim_size());
Tensor* output = this->Output(0);
output->ResizeLike(input);
const index_t n = input->dim(0);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
const T* input_ptr = input->data<T>();
const T* scale_ptr = scale->data<T>();
const T* offset_ptr = offset->data<T>();
const T* mean_ptr = mean->data<T>();
const T* var_ptr = var->data<T>();
T* output_ptr = output->mutable_data<T>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, n, channel,
sample_size, output_ptr);
return true;
}
private:
kernels::BatchNormFunctor<D, T> functor_;
}; };
} // namespace mace } // namespace mace
#endif // MACE_BATCH_NORM_H_ #endif // MACE_BATCH_NORM_H_
...@@ -8,19 +8,19 @@ ...@@ -8,19 +8,19 @@
namespace mace { namespace mace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void BatchNorm(int iters, int batch, int channels, int height, int width) { static void BatchNorm(int iters, int batch, int channels, int height,
int width) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormBM") OpDefBuilder("BatchNorm", "BatchNormBM")
.Input("Input") .Input("Input")
.Input("Scale") .Input("Scale")
.Input("Offset") .Input("Offset")
.Input("Mean") .Input("Mean")
.Input("Var") .Input("Var")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add input data // Add input data
net.AddRandomInput<T>("Input", {batch, channels, height, width}); net.AddRandomInput<T>("Input", {batch, channels, height, width});
...@@ -35,23 +35,23 @@ static void BatchNorm(int iters, int batch, int channels, int height, int width) ...@@ -35,23 +35,23 @@ static void BatchNorm(int iters, int batch, int channels, int height, int width)
} }
mace::testing::StartTiming(); mace::testing::StartTiming();
while(iters--) { while (iters--) {
net.RunOp(D); net.RunOp(D);
} }
} }
#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \ #define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \ BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \
} \ } \
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BATCH_NORM(N, C, H, W, TYPE) \ #define BM_BATCH_NORM(N, C, H, W, TYPE) \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \ BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON);
BM_BATCH_NORM(1, 1, 512, 512, float); BM_BATCH_NORM(1, 1, 512, 512, float);
...@@ -65,4 +65,4 @@ BM_BATCH_NORM(1, 128, 256, 256, float); ...@@ -65,4 +65,4 @@ BM_BATCH_NORM(1, 128, 256, 256, float);
BM_BATCH_NORM(1, 128, 512, 512, float); BM_BATCH_NORM(1, 128, 512, 512, float);
BM_BATCH_NORM(32, 1, 256, 256, float); BM_BATCH_NORM(32, 1, 256, 256, float);
BM_BATCH_NORM(32, 3, 256, 256, float); BM_BATCH_NORM(32, 3, 256, 256, float);
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -13,17 +13,17 @@ TEST_F(BatchNormOpTest, SimpleCPU) { ...@@ -13,17 +13,17 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input") .Input("Input")
.Input("Scale") .Input("Scale")
.Input("Offset") .Input("Offset")
.Input("Mean") .Input("Mean")
.Input("Var") .Input("Var")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 1, 6, 2}, net.AddInputFromArray<float>("Input", {1, 1, 6, 2},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
net.AddInputFromArray<float>("Scale", {1}, {4.0f}); net.AddInputFromArray<float>("Scale", {1}, {4.0f});
net.AddInputFromArray<float>("Offset", {1}, {2.0}); net.AddInputFromArray<float>("Offset", {1}, {2.0});
net.AddInputFromArray<float>("Mean", {1}, {10}); net.AddInputFromArray<float>("Mean", {1}, {10});
...@@ -33,8 +33,8 @@ TEST_F(BatchNormOpTest, SimpleCPU) { ...@@ -33,8 +33,8 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 1, 6, 2}, auto expected =
{-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, CreateTensor<float>({1, 1, 6, 2}, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83,
3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); 3.17, 3.17, 5.51, 5.51, 7.86, 7.86});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.01); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.01);
...@@ -51,13 +51,13 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -51,13 +51,13 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input") .Input("Input")
.Input("Scale") .Input("Scale")
.Input("Offset") .Input("Offset")
.Input("Mean") .Input("Mean")
.Input("Var") .Input("Var")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add input data // Add input data
net.AddRandomInput<float>("Input", {batch, channels, height, width}); net.AddRandomInput<float>("Input", {batch, channels, height, width});
...@@ -77,5 +77,4 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -77,5 +77,4 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} }
} }
...@@ -11,6 +11,6 @@ REGISTER_CPU_OPERATOR(Conv2d, Conv2dOp<DeviceType::CPU, float>); ...@@ -11,6 +11,6 @@ REGISTER_CPU_OPERATOR(Conv2d, Conv2dOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(Conv2d, Conv2dOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(Conv2d, Conv2dOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
namespace mace { namespace mace {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class Conv2dOp : public ConvPool2dOpBase<D, T> { class Conv2dOp : public ConvPool2dOpBase<D, T> {
public: public:
Conv2dOp(const OperatorDef& op_def, Workspace* ws) Conv2dOp(const OperatorDef& op_def, Workspace* ws)
: ConvPool2dOpBase<D, T>(op_def, ws) {}; : ConvPool2dOpBase<D, T>(op_def, ws){};
bool Run() override { bool Run() override {
const Tensor* input = this->Input(INPUT); const Tensor* input = this->Input(INPUT);
...@@ -27,21 +27,16 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -27,21 +27,16 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize(input->shape().data(), kernels::CalcPaddingAndOutputSize(
filter->shape().data(), input->shape().data(), filter->shape().data(), this->dilations_.data(),
this->dilations_.data(), this->strides_.data(), this->padding_, output_shape.data(),
this->strides_.data(), paddings.data());
this->padding_,
output_shape.data(),
paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
auto conv2d = kernels::Conv2dFunctor<D, T>(this->strides_.data(), auto conv2d = kernels::Conv2dFunctor<D, T>(
paddings.data(), this->strides_.data(), paddings.data(), this->dilations_.data());
this->dilations_.data()); conv2d(input->data<T>(), input->shape().data(), filter->data<T>(),
conv2d(input->data<T>(), input->shape().data(), filter->shape().data(), bias->data<T>(), output->mutable_data<T>(),
filter->data<T>(), filter->shape().data(),
bias->data<T>(), output->mutable_data<T>(),
output->shape().data()); output->shape().data());
return true; return true;
...@@ -52,6 +47,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -52,6 +47,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
OP_OUTPUT_TAGS(OUTPUT); OP_OUTPUT_TAGS(OUTPUT);
}; };
} // namespace mace } // namespace mace
#endif // MACE_OPS_CONV_2D_H_ #endif // MACE_OPS_CONV_2D_H_
...@@ -13,17 +13,17 @@ namespace mace { ...@@ -13,17 +13,17 @@ namespace mace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void Conv2d(int iters, int batch, int channels, int height, int width, static void Conv2d(int iters, int batch, int channels, int height, int width,
int kernel_h, int kernel_w, int stride, int kernel_h, int kernel_w, int stride, Padding padding,
Padding padding, int output_channels) { int output_channels) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("strides", {stride, stride}); net.AddIntsArg("strides", {stride, stride});
...@@ -32,7 +32,8 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, ...@@ -32,7 +32,8 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
// Add input data // Add input data
net.AddRandomInput<float>("Input", {batch, channels, height, width}); net.AddRandomInput<float>("Input", {batch, channels, height, width});
net.AddRandomInput<float>("Filter", {output_channels, channels, kernel_h, kernel_w}); net.AddRandomInput<float>("Filter",
{output_channels, channels, kernel_h, kernel_w});
net.AddRandomInput<float>("Bias", {output_channels}); net.AddRandomInput<float>("Bias", {output_channels});
// Warm-up // Warm-up
...@@ -41,27 +42,30 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, ...@@ -41,27 +42,30 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
} }
mace::testing::StartTiming(); mace::testing::StartTiming();
while(iters--) { while (iters--) {
net.RunOp(D); net.RunOp(D);
} }
} }
#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \ #define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
static void BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \ static void \
int iters) { \ BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ int iters) { \
mace::testing::ItemsProcessed(tot); \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ mace::testing::ItemsProcessed(tot); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \ mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
} \ Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, \
BENCHMARK(BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE) OC); \
} \
BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE)
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ #define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \ BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);
BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
...@@ -71,4 +75,4 @@ BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float); ...@@ -71,4 +75,4 @@ BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float); BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128, float); BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128, float);
} // namespace mace } // namespace mace
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/core/operator.h"
#include "mace/ops/conv_2d.h" #include "mace/ops/conv_2d.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
using namespace mace; using namespace mace;
...@@ -14,11 +14,11 @@ TEST_F(Conv2dOpTest, Simple_VALID) { ...@@ -14,11 +14,11 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("strides", {1, 1}); net.AddIntsArg("strides", {1, 1});
...@@ -26,17 +26,13 @@ TEST_F(Conv2dOpTest, Simple_VALID) { ...@@ -26,17 +26,13 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 2, 3, 3}, net.AddInputFromArray<float>(
{1, 1, 1, "Input", {1, 2, 3, 3},
1, 1, 1, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
1, 1, 1, net.AddInputFromArray<float>(
1, 1, 1, "Filter", {1, 2, 3, 3},
1, 1, 1, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1, 1, 1}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<float>("Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<float>("Bias", {1}, {0.1f}); net.AddInputFromArray<float>("Bias", {1}, {0.1f});
// Run // Run
...@@ -52,11 +48,11 @@ TEST_F(Conv2dOpTest, Simple_SAME) { ...@@ -52,11 +48,11 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("strides", {1, 1}); net.AddIntsArg("strides", {1, 1});
...@@ -64,27 +60,22 @@ TEST_F(Conv2dOpTest, Simple_SAME) { ...@@ -64,27 +60,22 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 2, 3, 3}, net.AddInputFromArray<float>(
{1, 1, 1, "Input", {1, 2, 3, 3},
1, 1, 1, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
1, 1, 1, net.AddInputFromArray<float>(
1, 1, 1, "Filter", {1, 2, 3, 3},
1, 1, 1, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1, 1, 1}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<float>("Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<float>("Bias", {1}, {0.1f}); net.AddInputFromArray<float>("Bias", {1}, {0.1f});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 1, 3, 3}, auto expected = CreateTensor<float>(
{ 8.1f, 12.1f, 8.1f, {1, 1, 3, 3},
12.1f, 18.1f, 12.1f, {8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f});
8.1f, 12.1f, 8.1f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -93,11 +84,11 @@ TEST_F(Conv2dOpTest, Combined) { ...@@ -93,11 +84,11 @@ TEST_F(Conv2dOpTest, Combined) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("strides", {2, 2}); net.AddIntsArg("strides", {2, 2});
...@@ -105,36 +96,24 @@ TEST_F(Conv2dOpTest, Combined) { ...@@ -105,36 +96,24 @@ TEST_F(Conv2dOpTest, Combined) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 2, 5, 5}, net.AddInputFromArray<float>(
{1, 1, 1, 1, 1, "Input", {1, 2, 5, 5}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
1, 1, 1, 1, 1, net.AddInputFromArray<float>(
1, 1, 1, 1, 1, "Filter", {2, 2, 3, 3},
1, 1, 1, 1, 1, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1, 1, 1, 1, 1, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
1, 1, 1, 1, 1, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
1, 1, 1, 1, 1,
1, 1, 1, 1, 1});
net.AddInputFromArray<float>("Filter", {2, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
net.AddInputFromArray<float>("Bias", {2}, {0.1f, 0.2f}); net.AddInputFromArray<float>("Bias", {2}, {0.1f, 0.2f});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 2, 3, 3}, auto expected = CreateTensor<float>(
{ 8.1f, 12.1f, 8.1f, {1, 2, 3, 3}, {8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f,
12.1f, 18.1f, 12.1f, 4.2f, 6.2f, 4.2f, 6.2f, 9.2f, 6.2f, 4.2f, 6.2f, 4.2f});
8.1f, 12.1f, 8.1f,
4.2f, 6.2f, 4.2f,
6.2f, 9.2f, 6.2f,
4.2f, 6.2f, 4.2f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -143,11 +122,11 @@ TEST_F(Conv2dOpTest, Conv1x1) { ...@@ -143,11 +122,11 @@ TEST_F(Conv2dOpTest, Conv1x1) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("strides", {1, 1}); net.AddIntsArg("strides", {1, 1});
...@@ -155,38 +134,32 @@ TEST_F(Conv2dOpTest, Conv1x1) { ...@@ -155,38 +134,32 @@ TEST_F(Conv2dOpTest, Conv1x1) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 5, 3, 10}, net.AddInputFromArray<float>(
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "Input", {1, 5, 3, 10},
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, net.AddInputFromArray<float>(
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "Filter", {2, 5, 1, 1},
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f});
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<float>("Filter", {2, 5, 1, 1},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
2.0f, 2.0f, 2.0f, 2.0f, 2.0f});
net.AddInputFromArray<float>("Bias", {2}, {0.1f, 0.2f}); net.AddInputFromArray<float>("Bias", {2}, {0.1f, 0.2f});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 2, 3, 10}, auto expected = CreateTensor<float>(
{5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, {1, 2, 3, 10},
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, {5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f}); 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -194,8 +167,7 @@ TEST_F(Conv2dOpTest, Conv1x1) { ...@@ -194,8 +167,7 @@ TEST_F(Conv2dOpTest, Conv1x1) {
// TODO we need more tests // TODO we need more tests
TEST_F(Conv2dOpTest, ConvNxNS12) { TEST_F(Conv2dOpTest, ConvNxNS12) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
int stride_h, int stride_w,
Padding type) { Padding type) {
srand(time(NULL)); srand(time(NULL));
...@@ -206,7 +178,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { ...@@ -206,7 +178,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
index_t width = 7 + rand() % 100; index_t width = 7 + rand() % 100;
index_t output_channels = 1 + rand() % 50; index_t output_channels = 1 + rand() % 50;
// Construct graph // Construct graph
auto &net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -221,8 +193,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { ...@@ -221,8 +193,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// Add input data // Add input data
net.AddRandomInput<float>("Input", {batch, input_channels, height, width}); net.AddRandomInput<float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<float>("Filter", {output_channels, input_channels, net.AddRandomInput<float>(
kernel_h, kernel_w}); "Filter", {output_channels, input_channels, kernel_h, kernel_w});
net.AddRandomInput<float>("Bias", {output_channels}); net.AddRandomInput<float>("Bias", {output_channels});
// run cpu // run cpu
net.RunOp(); net.RunOp();
......
...@@ -10,16 +10,15 @@ ...@@ -10,16 +10,15 @@
namespace mace { namespace mace {
template<DeviceType D, class T> template <DeviceType D, class T>
class ConvPool2dOpBase : public Operator<D, T> { class ConvPool2dOpBase : public Operator<D, T> {
public: public:
ConvPool2dOpBase(const OperatorDef& op_def, Workspace* ws) ConvPool2dOpBase(const OperatorDef& op_def, Workspace* ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
strides_(OperatorBase::GetRepeatedArgument<int>("strides")), strides_(OperatorBase::GetRepeatedArgument<int>("strides")),
padding_(static_cast<Padding>( padding_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>(
OperatorBase::GetSingleArgument<int>("padding", "padding", static_cast<int>(SAME)))),
static_cast<int>(SAME)))), dilations_(OperatorBase::GetRepeatedArgument<int>("dilations")) {}
dilations_(OperatorBase::GetRepeatedArgument<int>("dilations")) {}
protected: protected:
std::vector<int> strides_; std::vector<int> strides_;
...@@ -27,6 +26,6 @@ class ConvPool2dOpBase : public Operator<D, T> { ...@@ -27,6 +26,6 @@ class ConvPool2dOpBase : public Operator<D, T> {
std::vector<int> dilations_; std::vector<int> dilations_;
}; };
} // namespace mace } // namespace mace
#endif // MACE_OPS_CONV_POOL_2D_BASE_H_ #endif // MACE_OPS_CONV_POOL_2D_BASE_H_
...@@ -43,31 +43,33 @@ class OpsTestNet { ...@@ -43,31 +43,33 @@ class OpsTestNet {
public: public:
OpsTestNet() {} OpsTestNet() {}
template<typename T> template <typename T>
void AddInputFromArray(const char *name, void AddInputFromArray(const char *name, const std::vector<index_t> &shape,
const std::vector<index_t> &shape,
const std::vector<T> &data) { const std::vector<T> &data) {
Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v()); Tensor *input =
ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape); input->Resize(shape);
T *input_data = input->mutable_data<T>(); T *input_data = input->mutable_data<T>();
MACE_CHECK(input->size() == data.size()); MACE_CHECK(input->size() == data.size());
memcpy(input_data, data.data(), data.size() * sizeof(T)); memcpy(input_data, data.data(), data.size() * sizeof(T));
} }
template<typename T> template <typename T>
void AddRepeatedInput(const char *name, void AddRepeatedInput(const char *name, const std::vector<index_t> &shape,
const std::vector<index_t> &shape, const T data) {
const T data) { Tensor *input =
Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v()); ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape); input->Resize(shape);
T *input_data = input->mutable_data<T>(); T *input_data = input->mutable_data<T>();
MACE_CHECK(input->size() == data.size()); MACE_CHECK(input->size() == data.size());
std::fill(input_data, input_data + input->size(), data); std::fill(input_data, input_data + input->size(), data);
} }
template<typename T> template <typename T>
void AddRandomInput(const char *name, const std::vector<index_t> &shape, bool positive = false) { void AddRandomInput(const char *name, const std::vector<index_t> &shape,
Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v()); bool positive = false) {
Tensor *input =
ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape); input->Resize(shape);
float *input_data = input->mutable_data<T>(); float *input_data = input->mutable_data<T>();
...@@ -76,12 +78,16 @@ class OpsTestNet { ...@@ -76,12 +78,16 @@ class OpsTestNet {
std::normal_distribution<T> nd(0, 1); std::normal_distribution<T> nd(0, 1);
std::generate(input_data, input_data + input->size(), std::generate(input_data, input_data + input->size(),
[&gen, &nd, positive] { return positive ? std::abs(nd(gen)) : nd(gen); }); [&gen, &nd, positive] {
return positive ? std::abs(nd(gen)) : nd(gen);
});
} }
template<typename T> template <typename T>
void AddFixedInput(const char *name, const std::vector<index_t> &shape, T value) { void AddFixedInput(const char *name, const std::vector<index_t> &shape,
Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v()); T value) {
Tensor *input =
ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape); input->Resize(shape);
float *input_data = input->mutable_data<T>(); float *input_data = input->mutable_data<T>();
...@@ -122,7 +128,8 @@ class OpsTestNet { ...@@ -122,7 +128,8 @@ class OpsTestNet {
} }
} }
void AddStringsArg(const char *name, const std::vector<const char *> &values) { void AddStringsArg(const char *name,
const std::vector<const char *> &values) {
auto arg = op_def_.add_arg(); auto arg = op_def_.add_arg();
arg->set_name(name); arg->set_name(name);
for (auto value : values) { for (auto value : values) {
...@@ -145,9 +152,7 @@ class OpsTestNet { ...@@ -145,9 +152,7 @@ class OpsTestNet {
return net_->Run(); return net_->Run();
} }
bool RunOp() { bool RunOp() { return RunOp(DeviceType::CPU); }
return RunOp(DeviceType::CPU);
}
Tensor *GetOutput(const char *output_name) { Tensor *GetOutput(const char *output_name) {
return ws_.GetTensor(output_name); return ws_.GetTensor(output_name);
...@@ -177,8 +182,9 @@ class OpsTestBase : public ::testing::Test { ...@@ -177,8 +182,9 @@ class OpsTestBase : public ::testing::Test {
OpsTestNet test_net_; OpsTestNet test_net_;
}; };
template<typename T> template <typename T>
unique_ptr<Tensor> CreateTensor(const std::vector<index_t> &shape, const std::vector<T> &data) { unique_ptr<Tensor> CreateTensor(const std::vector<index_t> &shape,
const std::vector<T> &data) {
unique_ptr<Tensor> res(new Tensor(cpu_allocator(), DataTypeToEnum<T>::v())); unique_ptr<Tensor> res(new Tensor(cpu_allocator(), DataTypeToEnum<T>::v()));
res->Resize(shape); res->Resize(shape);
T *input_data = res->mutable_data<T>(); T *input_data = res->mutable_data<T>();
...@@ -209,40 +215,38 @@ inline std::string ShapeToString(const Tensor &x) { ...@@ -209,40 +215,38 @@ inline std::string ShapeToString(const Tensor &x) {
return std::string(stream.str()); return std::string(stream.str());
} }
template <typename T>
template<typename T>
struct is_floating_point_type { struct is_floating_point_type {
static const bool value = std::is_same<T, float>::value || static const bool value =
std::is_same<T, double>::value; std::is_same<T, float>::value || std::is_same<T, double>::value;
}; };
template<typename T> template <typename T>
inline void ExpectEqual(const T &a, const T &b) { inline void ExpectEqual(const T &a, const T &b) {
EXPECT_EQ(a, b); EXPECT_EQ(a, b);
} }
template<> template <>
inline void ExpectEqual<float>(const float &a, const float &b) { inline void ExpectEqual<float>(const float &a, const float &b) {
EXPECT_FLOAT_EQ(a, b); EXPECT_FLOAT_EQ(a, b);
} }
template<> template <>
inline void ExpectEqual<double>(const double &a, const double &b) { inline void ExpectEqual<double>(const double &a, const double &b) {
EXPECT_DOUBLE_EQ(a, b); EXPECT_DOUBLE_EQ(a, b);
} }
inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) { inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), y.dtype()); ASSERT_EQ(x.dtype(), y.dtype());
ASSERT_TRUE(IsSameSize(x, y)) ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs "
<< "x.shape [" << ShapeToString(x) << "] vs " << "y.shape [ " << ShapeToString(y) << "]";
<< "y.shape [ " << ShapeToString(y) << "]";
} }
template<typename T, bool is_fp = is_floating_point_type<T>::value> template <typename T, bool is_fp = is_floating_point_type<T>::value>
struct Expector; struct Expector;
// Partial specialization for float and double. // Partial specialization for float and double.
template<typename T> template <typename T>
struct Expector<T, true> { struct Expector<T, true> {
static void Equal(const T &a, const T &b) { ExpectEqual(a, b); } static void Equal(const T &a, const T &b) { ExpectEqual(a, b); }
...@@ -262,18 +266,19 @@ struct Expector<T, true> { ...@@ -262,18 +266,19 @@ struct Expector<T, true> {
auto a = x.data<T>(); auto a = x.data<T>();
auto b = y.data<T>(); auto b = y.data<T>();
for (int i = 0; i < x.size(); ++i) { for (int i = 0; i < x.size(); ++i) {
EXPECT_NEAR(a[i], b[i], abs_err) EXPECT_NEAR(a[i], b[i], abs_err) << "a = " << a << " b = " << b
<< "a = " << a << " b = " << b << " index = " << i; << " index = " << i;
} }
} }
}; };
template<typename T> template <typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) { void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) {
static_assert(is_floating_point_type<T>::value, "T is not a floating point type"); static_assert(is_floating_point_type<T>::value,
"T is not a floating point type");
Expector<T>::Near(x, y, abs_err); Expector<T>::Near(x, y, abs_err);
} }
} // namespace mace } // namespace mace
#endif // MACE_OPS_TEST_UTIL_H_ #endif // MACE_OPS_TEST_UTIL_H_
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/ops/pooling.h" #include "mace/ops/pooling.h"
namespace mace { namespace mace {
...@@ -11,6 +10,6 @@ REGISTER_CPU_OPERATOR(Pooling, PoolingOp<DeviceType::CPU, float>); ...@@ -11,6 +10,6 @@ REGISTER_CPU_OPERATOR(Pooling, PoolingOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(Pooling, PoolingOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(Pooling, PoolingOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -11,17 +11,17 @@ ...@@ -11,17 +11,17 @@
namespace mace { namespace mace {
template<DeviceType D, class T> template <DeviceType D, class T>
class PoolingOp : public ConvPool2dOpBase<D, T> { class PoolingOp : public ConvPool2dOpBase<D, T> {
public: public:
PoolingOp(const OperatorDef& op_def, Workspace* ws) PoolingOp(const OperatorDef& op_def, Workspace* ws)
: ConvPool2dOpBase<D, T>(op_def, ws), : ConvPool2dOpBase<D, T>(op_def, ws),
kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")), kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")),
pooling_type_(static_cast<PoolingType>( pooling_type_(
OperatorBase::GetSingleArgument<int>( static_cast<PoolingType>(OperatorBase::GetSingleArgument<int>(
"pooling_type", static_cast<int>(AVG)))) {}; "pooling_type", static_cast<int>(AVG)))){};
bool Run() override{ bool Run() override {
const Tensor* input = this->Input(INPUT); const Tensor* input = this->Input(INPUT);
Tensor* output = this->Output(OUTPUT); Tensor* output = this->Output(OUTPUT);
std::vector<index_t> in_shape = input->shape(); std::vector<index_t> in_shape = input->shape();
...@@ -33,28 +33,21 @@ public: ...@@ -33,28 +33,21 @@ public:
filter_shape[1] = in_shape[0]; filter_shape[1] = in_shape[0];
filter_shape[2] = kernels_[0]; filter_shape[2] = kernels_[0];
filter_shape[3] = kernels_[1]; filter_shape[3] = kernels_[1];
kernels::CalcPaddingAndOutputSize(in_shape.data(), kernels::CalcPaddingAndOutputSize(in_shape.data(), filter_shape.data(),
filter_shape.data(),
this->dilations_.data(), this->dilations_.data(),
this->strides_.data(), this->strides_.data(), this->padding_,
this->padding_, output_shape.data(), paddings.data());
output_shape.data(),
paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
auto pooling_func = kernels::PoolingFunctor<D, T>(pooling_type_, auto pooling_func = kernels::PoolingFunctor<D, T>(
kernels_.data(), pooling_type_, kernels_.data(), this->strides_.data(), paddings.data(),
this->strides_.data(), this->dilations_.data());
paddings.data(), pooling_func(input->data<float>(), in_shape.data(),
this->dilations_.data()); output->mutable_data<float>(), output->shape().data());
pooling_func(input->data<float>(),
in_shape.data(),
output->mutable_data<float>(),
output->shape().data());
return true; return true;
}; };
protected: protected:
std::vector<int> kernels_; std::vector<int> kernels_;
PoolingType pooling_type_; PoolingType pooling_type_;
...@@ -62,6 +55,6 @@ protected: ...@@ -62,6 +55,6 @@ protected:
OP_OUTPUT_TAGS(OUTPUT); OP_OUTPUT_TAGS(OUTPUT);
}; };
} // namespace mace } // namespace mace
#endif //MACE_OPS_POOLING_H_ #endif // MACE_OPS_POOLING_H_
...@@ -2,20 +2,19 @@ ...@@ -2,20 +2,19 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/core/testing/test_benchmark.h"
#include "mace/core/operator.h"
#include "mace/kernels/pooling.h" #include "mace/kernels/pooling.h"
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/conv_pool_2d_util.h" #include "mace/kernels/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
using namespace mace; using namespace mace;
using namespace mace::kernels; using namespace mace::kernels;
template<DeviceType D> template <DeviceType D>
static void Pooling(int iters, int batch, int channels, int height, static void Pooling(int iters, int batch, int channels, int height, int width,
int width, int kernel, int stride, Padding padding, int kernel, int stride, Padding padding,
PoolingType pooling_type) { PoolingType pooling_type) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
...@@ -45,18 +44,21 @@ static void Pooling(int iters, int batch, int channels, int height, ...@@ -45,18 +44,21 @@ static void Pooling(int iters, int batch, int channels, int height,
} }
} }
#define BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \ #define BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \
static void BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE( \ static void \
int iters) { \ BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE( \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ int iters) { \
mace::testing::ItemsProcessed(tot); \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot * (sizeof(float)));\ mace::testing::ItemsProcessed(tot); \
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, PoolingType::PO); \ mace::testing::BytesProcessed(tot*(sizeof(float))); \
} \ Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \
BENCHMARK(BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE) PoolingType::PO); \
} \
BENCHMARK( \
BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE)
#define BM_POOLING(N, C, H, W, K, S, PA, PO) \ #define BM_POOLING(N, C, H, W, K, S, PA, PO) \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \ BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON); BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON);
BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX); BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX);
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/kernels/pooling.h" #include "mace/kernels/pooling.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/ops_test_util.h"
using namespace mace; using namespace mace;
...@@ -17,9 +17,9 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -17,9 +17,9 @@ TEST_F(PoolingOpTest, MAX_VALID) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("kernels", {2, 2}); net.AddIntsArg("kernels", {2, 2});
...@@ -29,34 +29,28 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -29,34 +29,28 @@ TEST_F(PoolingOpTest, MAX_VALID) {
net.AddIntArg("pooling_type", PoolingType::MAX); net.AddIntArg("pooling_type", PoolingType::MAX);
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 2, 4, 4}, net.AddInputFromArray<float>(
{0, 1, 2, 3, "Input", {1, 2, 4, 4},
4, 5, 6, 7, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
8, 9, 10, 11, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23,
24, 25, 26, 27,
28, 29, 30, 31});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 2, 2, 2}, auto expected =
{5, 7, 13, 15, 21, 23, 29, 31}); CreateTensor<float>({1, 2, 2, 2}, {5, 7, 13, 15, 21, 23, 29, 31});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
TEST_F(PoolingOpTest, AVG_VALID) { TEST_F(PoolingOpTest, AVG_VALID) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("kernels", {2, 2}); net.AddIntsArg("kernels", {2, 2});
...@@ -66,22 +60,17 @@ TEST_F(PoolingOpTest, AVG_VALID) { ...@@ -66,22 +60,17 @@ TEST_F(PoolingOpTest, AVG_VALID) {
net.AddIntArg("pooling_type", PoolingType::AVG); net.AddIntArg("pooling_type", PoolingType::AVG);
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 2, 4, 4}, net.AddInputFromArray<float>(
{0, 1, 2, 3, "Input", {1, 2, 4, 4},
4, 5, 6, 7, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
8, 9, 10, 11, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23,
24, 25, 26, 27,
28, 29, 30, 31});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 2, 2, 2}, auto expected = CreateTensor<float>(
{2.5, 4.5, 10.5, 12.5, 18.5, 20.5, 26.5, 28.5}); {1, 2, 2, 2}, {2.5, 4.5, 10.5, 12.5, 18.5, 20.5, 26.5, 28.5});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -90,9 +79,9 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -90,9 +79,9 @@ TEST_F(PoolingOpTest, MAX_SAME) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("kernels", {2, 2}); net.AddIntsArg("kernels", {2, 2});
...@@ -103,16 +92,13 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -103,16 +92,13 @@ TEST_F(PoolingOpTest, MAX_SAME) {
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 1, 3, 3}, net.AddInputFromArray<float>("Input", {1, 1, 3, 3},
{0, 1, 2, {0, 1, 2, 3, 4, 5, 6, 7, 8});
3, 4, 5,
6, 7, 8});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 1, 2, 2}, auto expected = CreateTensor<float>({1, 1, 2, 2}, {4, 5, 7, 8});
{4, 5, 7, 8});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -121,9 +107,9 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -121,9 +107,9 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("kernels", {2, 2}); net.AddIntsArg("kernels", {2, 2});
...@@ -133,18 +119,15 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -133,18 +119,15 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
net.AddIntArg("pooling_type", PoolingType::MAX); net.AddIntArg("pooling_type", PoolingType::MAX);
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 1, 4, 4}, net.AddInputFromArray<float>(
{0, 1, 2, 3, "Input", {1, 1, 4, 4},
4, 5, 6, 7, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
8, 9, 10, 11,
12, 13, 14, 15});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 1, 2, 2}, auto expected = CreateTensor<float>({1, 1, 2, 2}, {10, 11, 14, 15});
{10, 11, 14, 15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -153,9 +136,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -153,9 +136,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntArg("pooling_type", PoolingType::MAX); net.AddIntArg("pooling_type", PoolingType::MAX);
...@@ -165,18 +148,14 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -165,18 +148,14 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 1, 4, 5}, net.AddInputFromArray<float>(
{0, 1, 2, 3, 4, "Input", {1, 1, 4, 5},
5, 6, 7, 8, 9, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19});
10, 11, 12, 13, 14,
15, 16, 17, 18, 19});
// Run // Run
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3}, Tensor expected = CreateTensor<float>({1, 1, 2, 3}, {6, 8, 9, 16, 18, 19});
{6, 8, 9,
16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -185,9 +164,9 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { ...@@ -185,9 +164,9 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntArg("pooling_type", PoolingType::MAX); net.AddIntArg("pooling_type", PoolingType::MAX);
...@@ -197,18 +176,14 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { ...@@ -197,18 +176,14 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 1, 4, 5}, net.AddInputFromArray<float>(
{0, 1, 2, 3, 4, "Input", {1, 1, 4, 5},
5, 6, 7, 8, 9, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19});
10, 11, 12, 13, 14,
15, 16, 17, 18, 19});
// Run // Run
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3}, Tensor expected = CreateTensor<float>({1, 1, 2, 3}, {11, 13, 14, 16, 18, 19});
{11, 13, 14,
16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>); ...@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(Relu, ReluOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(Relu, ReluOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -10,10 +10,10 @@ ...@@ -10,10 +10,10 @@
namespace mace { namespace mace {
template<DeviceType D, class T> template <DeviceType D, class T>
class ReluOp : public Operator<D, T> { class ReluOp : public Operator<D, T> {
public: public:
ReluOp(const OperatorDef &operator_def, Workspace *ws) ReluOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<D, T>(operator_def, ws) {} : Operator<D, T>(operator_def, ws) {}
bool Run() override { bool Run() override {
const Tensor* input_tensor = this->inputs_[0]; const Tensor* input_tensor = this->inputs_[0];
...@@ -31,6 +31,6 @@ class ReluOp : public Operator<D, T> { ...@@ -31,6 +31,6 @@ class ReluOp : public Operator<D, T> {
kernels::ReluFunctor<D, T> functor_; kernels::ReluFunctor<D, T> functor_;
}; };
} // namespace mace } // namespace mace
#endif // MACE_OPS_RELU_H_ #endif // MACE_OPS_RELU_H_
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
namespace mace { namespace mace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void ReluBenchmark(int iters, int size) { static void ReluBenchmark(int iters, int size) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
...@@ -28,26 +27,25 @@ static void ReluBenchmark(int iters, int size) { ...@@ -28,26 +27,25 @@ static void ReluBenchmark(int iters, int size) {
} }
mace::testing::StartTiming(); mace::testing::StartTiming();
while(iters--) { while (iters--) {
net.RunOp(D); net.RunOp(D);
} }
} }
#define BM_RELU_MACRO(SIZE, TYPE, DEVICE) \ #define BM_RELU_MACRO(SIZE, TYPE, DEVICE) \
static void BM_RELU_##SIZE##_##TYPE##_##DEVICE( \ static void BM_RELU_##SIZE##_##TYPE##_##DEVICE(int iters) { \
int iters) { \ const int64_t tot = static_cast<int64_t>(iters) * SIZE; \
const int64_t tot = static_cast<int64_t>(iters) * SIZE; \ mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \ mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \
ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \ } \
} \
BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE) BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE)
#define BM_RELU(SIZE, TYPE) \ #define BM_RELU(SIZE, TYPE) \
BM_RELU_MACRO(SIZE, TYPE, CPU); \ BM_RELU_MACRO(SIZE, TYPE, CPU); \
BM_RELU_MACRO(SIZE, TYPE, NEON); BM_RELU_MACRO(SIZE, TYPE, NEON);
BM_RELU(1000, float); BM_RELU(1000, float);
BM_RELU(100000, float); BM_RELU(100000, float);
BM_RELU(10000000, float); BM_RELU(10000000, float);
} // namespace mace } // namespace mace
\ No newline at end of file \ No newline at end of file
...@@ -32,4 +32,4 @@ TEST_F(ReluOpTest, ReluOp) { ...@@ -32,4 +32,4 @@ TEST_F(ReluOpTest, ReluOp) {
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01); ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01);
} }
} // namespace mace } // namespace mace
...@@ -9,7 +9,8 @@ namespace mace { ...@@ -9,7 +9,8 @@ namespace mace {
REGISTER_CPU_OPERATOR(ResizeBilinear, ResizeBilinearOp<DeviceType::CPU, float>); REGISTER_CPU_OPERATOR(ResizeBilinear, ResizeBilinearOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(ResizeBilinear, ResizeBilinearOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(ResizeBilinear,
#endif // __ARM_NEON ResizeBilinearOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -5,18 +5,18 @@ ...@@ -5,18 +5,18 @@
#ifndef MACE_RESIZE_BILINEAR_H #ifndef MACE_RESIZE_BILINEAR_H
#define MACE_RESIZE_BILINEAR_H #define MACE_RESIZE_BILINEAR_H
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/resize_bilinear.h" #include "mace/kernels/resize_bilinear.h"
namespace mace { namespace mace {
template<DeviceType D, class T> template <DeviceType D, class T>
class ResizeBilinearOp : public Operator<D, T> { class ResizeBilinearOp : public Operator<D, T> {
public: public:
ResizeBilinearOp(const OperatorDef &operator_def, Workspace *ws) ResizeBilinearOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<bool>("align_corners", false)) {} functor_(
OperatorBase::GetSingleArgument<bool>("align_corners", false)) {}
bool Run() override { bool Run() override {
const Tensor* input = this->Input(0); const Tensor* input = this->Input(0);
...@@ -24,8 +24,8 @@ class ResizeBilinearOp : public Operator<D, T> { ...@@ -24,8 +24,8 @@ class ResizeBilinearOp : public Operator<D, T> {
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.", MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.",
input->dim_size()); input->dim_size());
MACE_CHECK(resize_dims->dim_size() == 1, "resize dim must be 2-dimensional.", MACE_CHECK(resize_dims->dim_size() == 1,
resize_dims->dim_size()); "resize dim must be 2-dimensional.", resize_dims->dim_size());
Tensor* output = this->Output(0); Tensor* output = this->Output(0);
...@@ -35,7 +35,7 @@ class ResizeBilinearOp : public Operator<D, T> { ...@@ -35,7 +35,7 @@ class ResizeBilinearOp : public Operator<D, T> {
index_t in_width = input->dim(3); index_t in_width = input->dim(3);
index_t out_height = resize_dims->data<index_t>()[0]; index_t out_height = resize_dims->data<index_t>()[0];
index_t out_width = resize_dims->data<index_t>()[1]; index_t out_width = resize_dims->data<index_t>()[1];
vector<index_t> out_shape {n, channels, out_height, out_width}; vector<index_t> out_shape{n, channels, out_height, out_width};
output->Resize(out_shape); output->Resize(out_shape);
const T* input_ptr = input->data<T>(); const T* input_ptr = input->data<T>();
...@@ -45,10 +45,11 @@ class ResizeBilinearOp : public Operator<D, T> { ...@@ -45,10 +45,11 @@ class ResizeBilinearOp : public Operator<D, T> {
out_height, out_width); out_height, out_width);
return true; return true;
} }
private: private:
kernels::ResizeBilinearFunctor<D, T> functor_; kernels::ResizeBilinearFunctor<D, T> functor_;
}; };
} // namespace mace } // namespace mace
#endif // MACE_RESIZE_BILINEAR_H #endif // MACE_RESIZE_BILINEAR_H
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/ops/resize_bilinear.h"
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/ops/resize_bilinear.h"
using namespace mace; using namespace mace;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册