提交 ebff986d 编写于 作者: L Liangliang He

Reformat with google coding style

上级 a23d053d
...@@ -26,8 +26,8 @@ typedef int64_t index_t; ...@@ -26,8 +26,8 @@ typedef int64_t index_t;
#ifndef DISABLE_COPY_AND_ASSIGN #ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \ #define DISABLE_COPY_AND_ASSIGN(classname) \
private: \ private: \
classname(const classname&) = delete; \ classname(const classname &) = delete; \
classname& operator=(const classname&) = delete classname &operator=(const classname &) = delete
#endif #endif
#define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented") #define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented")
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace mace { namespace mace {
namespace internal { namespace internal {
LogMessage::LogMessage(const char* fname, int line, int severity) LogMessage::LogMessage(const char *fname, int line, int severity)
: fname_(fname), line_(line), severity_(severity) {} : fname_(fname), line_(line), severity_(severity) {}
#if defined(PLATFORM_POSIX_ANDROID) #if defined(PLATFORM_POSIX_ANDROID)
...@@ -43,7 +43,7 @@ void LogMessage::GenerateLogMessage() { ...@@ -43,7 +43,7 @@ void LogMessage::GenerateLogMessage() {
} }
std::stringstream ss; std::stringstream ss;
const char* const partial_name = strrchr(fname_, '/'); const char *const partial_name = strrchr(fname_, '/');
ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_ ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_
<< " " << str(); << " " << str();
__android_log_write(android_log_level, "native", ss.str().c_str()); __android_log_write(android_log_level, "native", ss.str().c_str());
...@@ -69,7 +69,7 @@ void LogMessage::GenerateLogMessage() { ...@@ -69,7 +69,7 @@ void LogMessage::GenerateLogMessage() {
namespace { namespace {
// Parse log level (int64_t) from environment variable (char*) // Parse log level (int64_t) from environment variable (char*)
int64_t LogLevelStrToInt(const char* mace_env_var_val) { int64_t LogLevelStrToInt(const char *mace_env_var_val) {
if (mace_env_var_val == nullptr) { if (mace_env_var_val == nullptr) {
return 0; return 0;
} }
...@@ -89,12 +89,12 @@ int64_t LogLevelStrToInt(const char* mace_env_var_val) { ...@@ -89,12 +89,12 @@ int64_t LogLevelStrToInt(const char* mace_env_var_val) {
} }
int64_t MinLogLevelFromEnv() { int64_t MinLogLevelFromEnv() {
const char* mace_env_var_val = getenv("MACE_CPP_MIN_LOG_LEVEL"); const char *mace_env_var_val = getenv("MACE_CPP_MIN_LOG_LEVEL");
return LogLevelStrToInt(mace_env_var_val); return LogLevelStrToInt(mace_env_var_val);
} }
int64_t MinVLogLevelFromEnv() { int64_t MinVLogLevelFromEnv() {
const char* mace_env_var_val = getenv("MACE_CPP_MIN_VLOG_LEVEL"); const char *mace_env_var_val = getenv("MACE_CPP_MIN_VLOG_LEVEL");
return LogLevelStrToInt(mace_env_var_val); return LogLevelStrToInt(mace_env_var_val);
} }
...@@ -111,7 +111,7 @@ int64_t LogMessage::MinVLogLevel() { ...@@ -111,7 +111,7 @@ int64_t LogMessage::MinVLogLevel() {
return min_vlog_level; return min_vlog_level;
} }
LogMessageFatal::LogMessageFatal(const char* file, int line) LogMessageFatal::LogMessageFatal(const char *file, int line)
: LogMessage(file, line, FATAL) {} : LogMessage(file, line, FATAL) {}
LogMessageFatal::~LogMessageFatal() { LogMessageFatal::~LogMessageFatal() {
// abort() ensures we don't return (we promised we would not via // abort() ensures we don't return (we promised we would not via
......
...@@ -23,23 +23,23 @@ namespace internal { ...@@ -23,23 +23,23 @@ namespace internal {
using std::string; using std::string;
inline void MakeStringInternal(std::stringstream& /*ss*/) {} inline void MakeStringInternal(std::stringstream & /*ss*/) {}
template <typename T> template <typename T>
inline void MakeStringInternal(std::stringstream& ss, const T& t) { inline void MakeStringInternal(std::stringstream &ss, const T &t) {
ss << t; ss << t;
} }
template <typename T, typename... Args> template <typename T, typename... Args>
inline void MakeStringInternal(std::stringstream& ss, inline void MakeStringInternal(std::stringstream &ss,
const T& t, const T &t,
const Args&... args) { const Args &... args) {
MakeStringInternal(ss, t); MakeStringInternal(ss, t);
MakeStringInternal(ss, args...); MakeStringInternal(ss, args...);
} }
template <typename... Args> template <typename... Args>
string MakeString(const Args&... args) { string MakeString(const Args &... args) {
std::stringstream ss; std::stringstream ss;
MakeStringInternal(ss, args...); MakeStringInternal(ss, args...);
return ss.str(); return ss.str();
...@@ -48,7 +48,7 @@ string MakeString(const Args&... args) { ...@@ -48,7 +48,7 @@ string MakeString(const Args&... args) {
template <typename T> template <typename T>
string MakeString(const std::vector<T> &args) { string MakeString(const std::vector<T> &args) {
std::stringstream ss; std::stringstream ss;
for (const T& arg: args) { for (const T &arg : args) {
ss << arg << ", "; ss << arg << ", ";
} }
return ss.str(); return ss.str();
...@@ -56,14 +56,14 @@ string MakeString(const std::vector<T> &args) { ...@@ -56,14 +56,14 @@ string MakeString(const std::vector<T> &args) {
// Specializations for already-a-string types. // Specializations for already-a-string types.
template <> template <>
inline string MakeString(const string& str) { inline string MakeString(const string &str) {
return str; return str;
} }
inline string MakeString(const char* c_str) { return string(c_str); } inline string MakeString(const char *c_str) { return string(c_str); }
class LogMessage : public std::basic_ostringstream<char> { class LogMessage : public std::basic_ostringstream<char> {
public: public:
LogMessage(const char* fname, int line, int severity); LogMessage(const char *fname, int line, int severity);
~LogMessage(); ~LogMessage();
// Returns the minimum log level for VLOG statements. // Returns the minimum log level for VLOG statements.
...@@ -75,7 +75,7 @@ class LogMessage : public std::basic_ostringstream<char> { ...@@ -75,7 +75,7 @@ class LogMessage : public std::basic_ostringstream<char> {
void GenerateLogMessage(); void GenerateLogMessage();
private: private:
const char* fname_; const char *fname_;
int line_; int line_;
int severity_; int severity_;
}; };
...@@ -84,7 +84,7 @@ class LogMessage : public std::basic_ostringstream<char> { ...@@ -84,7 +84,7 @@ class LogMessage : public std::basic_ostringstream<char> {
// logging this message. // logging this message.
class LogMessageFatal : public LogMessage { class LogMessageFatal : public LogMessage {
public: public:
LogMessageFatal(const char* file, int line); LogMessageFatal(const char *file, int line);
~LogMessageFatal(); ~LogMessageFatal();
}; };
...@@ -136,7 +136,7 @@ class LogMessageFatal : public LogMessage { ...@@ -136,7 +136,7 @@ class LogMessageFatal : public LogMessage {
#endif #endif
template <typename T> template <typename T>
T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) { T &&CheckNotNull(const char *file, int line, const char *exprtext, T &&t) {
if (t == nullptr) { if (t == nullptr) {
LogMessageFatal(file, line) << string(exprtext); LogMessageFatal(file, line) << string(exprtext);
} }
......
...@@ -7,18 +7,18 @@ ...@@ -7,18 +7,18 @@
namespace mace { namespace mace {
NetBase::NetBase(const std::shared_ptr<const NetDef>& net_def, NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws, Workspace *ws,
DeviceType type) DeviceType type)
: name_(net_def->name()) {} : name_(net_def->name()) {}
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef>& net_def, SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws, Workspace *ws,
DeviceType type) DeviceType type)
: NetBase(net_def, ws, type) { : NetBase(net_def, ws, type) {
VLOG(1) << "Constructing SimpleNet " << net_def->name(); VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto& operator_def = net_def->op(idx); const auto &operator_def = net_def->op(idx);
VLOG(1) << "Creating operator " << operator_def.name() << ":" VLOG(1) << "Creating operator " << operator_def.name() << ":"
<< operator_def.type(); << operator_def.type();
std::unique_ptr<OperatorBase> op{nullptr}; std::unique_ptr<OperatorBase> op{nullptr};
...@@ -29,26 +29,29 @@ SimpleNet::SimpleNet(const std::shared_ptr<const NetDef>& net_def, ...@@ -29,26 +29,29 @@ SimpleNet::SimpleNet(const std::shared_ptr<const NetDef>& net_def,
} }
} }
} }
bool SimpleNet::Run(RunMetadata* run_metadata) { bool SimpleNet::Run(RunMetadata *run_metadata) {
VLOG(1) << "Running net " << name_; VLOG(1) << "Running net " << name_;
for (auto& op : operators_) { for (auto &op : operators_) {
VLOG(1) << "Running operator " << op->debug_def().name() << "(" VLOG(1) << "Running operator " << op->debug_def().name() << "("
<< op->debug_def().type() << ")."; << op->debug_def().type() << ").";
OperatorStats* op_stats = nullptr; OperatorStats *op_stats = nullptr;
if (run_metadata) { if (run_metadata) {
op_stats = run_metadata->add_op_stats(); op_stats = run_metadata->add_op_stats();
op_stats->set_operator_name(op->debug_def().name()); op_stats->set_operator_name(op->debug_def().name());
op_stats->set_type(op->debug_def().type()); op_stats->set_type(op->debug_def().type());
op_stats->set_all_start_micros(NowInMicroSec()); op_stats->set_all_start_micros(NowInMicroSec());
op_stats->set_op_start_rel_micros(NowInMicroSec() - op_stats->all_start_micros()); op_stats->set_op_start_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
} }
if (!op->Run()) { if (!op->Run()) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false; return false;
} }
if (op_stats) { if (op_stats) {
op_stats->set_op_end_rel_micros(NowInMicroSec() - op_stats->all_start_micros()); op_stats->set_op_end_rel_micros(NowInMicroSec() -
op_stats->set_all_end_rel_micros(NowInMicroSec() - op_stats->all_start_micros()); op_stats->all_start_micros());
op_stats->set_all_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
} }
VLOG(1) << "Op " << op->debug_def().name() VLOG(1) << "Op " << op->debug_def().name()
<< " has shape: " << internal::MakeString(op->Output(0)->shape()); << " has shape: " << internal::MakeString(op->Output(0)->shape());
...@@ -56,15 +59,15 @@ bool SimpleNet::Run(RunMetadata* run_metadata) { ...@@ -56,15 +59,15 @@ bool SimpleNet::Run(RunMetadata* run_metadata) {
return true; return true;
} }
unique_ptr<NetBase> CreateNet(const NetDef& net_def, unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace* ws, Workspace *ws,
DeviceType type) { DeviceType type) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def)); std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws, type); return CreateNet(tmp_net_def, ws, type);
} }
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef>& net_def, unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws, Workspace *ws,
DeviceType type) { DeviceType type) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type)); unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type));
return net; return net;
......
...@@ -15,14 +15,14 @@ namespace mace { ...@@ -15,14 +15,14 @@ namespace mace {
class NetBase { class NetBase {
public: public:
NetBase(const std::shared_ptr<const NetDef>& net_def, NetBase(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws, Workspace *ws,
DeviceType type); DeviceType type);
virtual ~NetBase() noexcept {} virtual ~NetBase() noexcept {}
virtual bool Run(RunMetadata* run_metadata = nullptr) = 0; virtual bool Run(RunMetadata *run_metadata = nullptr) = 0;
const string& Name() const { return name_; } const string &Name() const { return name_; }
protected: protected:
string name_; string name_;
...@@ -32,11 +32,11 @@ class NetBase { ...@@ -32,11 +32,11 @@ class NetBase {
class SimpleNet : public NetBase { class SimpleNet : public NetBase {
public: public:
SimpleNet(const std::shared_ptr<const NetDef>& net_def, SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws, Workspace *ws,
DeviceType type); DeviceType type);
bool Run(RunMetadata* run_metadata = nullptr) override; bool Run(RunMetadata *run_metadata = nullptr) override;
protected: protected:
vector<unique_ptr<OperatorBase> > operators_; vector<unique_ptr<OperatorBase> > operators_;
...@@ -44,11 +44,11 @@ class SimpleNet : public NetBase { ...@@ -44,11 +44,11 @@ class SimpleNet : public NetBase {
DISABLE_COPY_AND_ASSIGN(SimpleNet); DISABLE_COPY_AND_ASSIGN(SimpleNet);
}; };
unique_ptr<NetBase> CreateNet(const NetDef& net_def, unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace* ws, Workspace *ws,
DeviceType type); DeviceType type);
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef>& net_def, unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace* ws, Workspace *ws,
DeviceType type); DeviceType type);
} // namespace mace } // namespace mace
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace mace { namespace mace {
bool ReadStringFromFile(const char* filename, string* str) { bool ReadStringFromFile(const char *filename, string *str) {
std::ifstream ifs(filename, std::ios::in); std::ifstream ifs(filename, std::ios::in);
if (!ifs) { if (!ifs) {
VLOG(1) << "File cannot be opened: " << filename VLOG(1) << "File cannot be opened: " << filename
...@@ -33,7 +33,7 @@ bool ReadStringFromFile(const char* filename, string* str) { ...@@ -33,7 +33,7 @@ bool ReadStringFromFile(const char* filename, string* str) {
return true; return true;
} }
bool WriteStringToFile(const string& str, const char* filename) { bool WriteStringToFile(const string &str, const char *filename) {
std::ofstream ofs(filename, std::ios::out | std::ios::trunc); std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
if (!ofs.is_open()) { if (!ofs.is_open()) {
VLOG(1) << "File cannot be created: " << filename VLOG(1) << "File cannot be created: " << filename
...@@ -54,15 +54,15 @@ bool WriteStringToFile(const string& str, const char* filename) { ...@@ -54,15 +54,15 @@ bool WriteStringToFile(const string& str, const char* filename) {
namespace { namespace {
class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
public: public:
explicit IfstreamInputStream(const string& filename) explicit IfstreamInputStream(const string &filename)
: ifs_(filename.c_str(), std::ios::in | std::ios::binary) {} : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
~IfstreamInputStream() { ifs_.close(); } ~IfstreamInputStream() { ifs_.close(); }
int Read(void* buffer, int size) { int Read(void *buffer, int size) {
if (!ifs_) { if (!ifs_) {
return -1; return -1;
} }
ifs_.read(static_cast<char*>(buffer), size); ifs_.read(static_cast<char *>(buffer), size);
return ifs_.gcount(); return ifs_.gcount();
} }
...@@ -71,7 +71,7 @@ class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { ...@@ -71,7 +71,7 @@ class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
}; };
} // namespace } // namespace
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) {
::google::protobuf::io::CopyingInputStreamAdaptor stream( ::google::protobuf::io::CopyingInputStreamAdaptor stream(
new IfstreamInputStream(filename)); new IfstreamInputStream(filename));
stream.SetOwnsCopyingStream(true); stream.SetOwnsCopyingStream(true);
...@@ -82,8 +82,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { ...@@ -82,8 +82,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
return proto->ParseFromCodedStream(&coded_stream); return proto->ParseFromCodedStream(&coded_stream);
} }
void WriteProtoToBinaryFile(const MessageLite& /*proto*/, void WriteProtoToBinaryFile(const MessageLite & /*proto*/,
const char* /*filename*/) { const char * /*filename*/) {
LOG(FATAL) << "Not implemented yet."; LOG(FATAL) << "Not implemented yet.";
} }
...@@ -98,25 +98,25 @@ using ::google::protobuf::io::CodedInputStream; ...@@ -98,25 +98,25 @@ using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::ZeroCopyOutputStream; using ::google::protobuf::io::ZeroCopyOutputStream;
using ::google::protobuf::io::CodedOutputStream; using ::google::protobuf::io::CodedOutputStream;
bool ReadProtoFromTextFile(const char* filename, Message* proto) { bool ReadProtoFromTextFile(const char *filename, Message *proto) {
int fd = open(filename, O_RDONLY); int fd = open(filename, O_RDONLY);
MACE_CHECK(fd != -1, "File not found: ", filename); MACE_CHECK(fd != -1, "File not found: ", filename);
FileInputStream* input = new FileInputStream(fd); FileInputStream *input = new FileInputStream(fd);
bool success = google::protobuf::TextFormat::Parse(input, proto); bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input; delete input;
close(fd); close(fd);
return success; return success;
} }
void WriteProtoToTextFile(const Message& proto, const char* filename) { void WriteProtoToTextFile(const Message &proto, const char *filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
FileOutputStream* output = new FileOutputStream(fd); FileOutputStream *output = new FileOutputStream(fd);
MACE_CHECK(google::protobuf::TextFormat::Print(proto, output)); MACE_CHECK(google::protobuf::TextFormat::Print(proto, output));
delete output; delete output;
close(fd); close(fd);
} }
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) {
#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified #if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int fd = open(filename, O_RDONLY | O_BINARY); int fd = open(filename, O_RDONLY | O_BINARY);
#else #else
...@@ -135,7 +135,7 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { ...@@ -135,7 +135,7 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
return success; return success;
} }
void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) { void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
MACE_CHECK(fd != -1, "File cannot be created: ", filename, " error number: ", MACE_CHECK(fd != -1, "File cannot be created: ", filename, " error number: ",
errno); errno);
...@@ -150,8 +150,8 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) { ...@@ -150,8 +150,8 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
#endif // MACE_USE_LITE_PROTO #endif // MACE_USE_LITE_PROTO
ArgumentHelper::ArgumentHelper(const OperatorDef& def) { ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
for (auto& arg : def.arg()) { for (auto &arg : def.arg()) {
if (arg_map_.find(arg.name()) != arg_map_.end()) { if (arg_map_.find(arg.name()) != arg_map_.end()) {
MACE_CHECK( MACE_CHECK(
arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(), arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(),
...@@ -167,8 +167,8 @@ ArgumentHelper::ArgumentHelper(const OperatorDef& def) { ...@@ -167,8 +167,8 @@ ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
} }
} }
ArgumentHelper::ArgumentHelper(const NetDef& netdef) { ArgumentHelper::ArgumentHelper(const NetDef &netdef) {
for (auto& arg : netdef.arg()) { for (auto &arg : netdef.arg()) {
MACE_CHECK(arg_map_.count(arg.name()) == 0, MACE_CHECK(arg_map_.count(arg.name()) == 0,
"Duplicated argument name found in net def: ", "Duplicated argument name found in net def: ",
ProtoDebugString(netdef)); ProtoDebugString(netdef));
...@@ -176,7 +176,7 @@ ArgumentHelper::ArgumentHelper(const NetDef& netdef) { ...@@ -176,7 +176,7 @@ ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
} }
} }
bool ArgumentHelper::HasArgument(const string& name) const { bool ArgumentHelper::HasArgument(const string &name) const {
return arg_map_.count(name); return arg_map_.count(name);
} }
...@@ -184,7 +184,7 @@ namespace { ...@@ -184,7 +184,7 @@ namespace {
// Helper function to verify that conversion between types won't loose any // Helper function to verify that conversion between types won't loose any
// significant bit. // significant bit.
template <typename InputType, typename TargetType> template <typename InputType, typename TargetType>
bool SupportsLosslessConversion(const InputType& value) { bool SupportsLosslessConversion(const InputType &value) {
return static_cast<InputType>(static_cast<TargetType>(value)) == value; return static_cast<InputType>(static_cast<TargetType>(value)) == value;
} }
} }
...@@ -192,8 +192,8 @@ bool SupportsLosslessConversion(const InputType& value) { ...@@ -192,8 +192,8 @@ bool SupportsLosslessConversion(const InputType& value) {
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \ #define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \
enforce_lossless_conversion) \ enforce_lossless_conversion) \
template <> \ template <> \
T ArgumentHelper::GetSingleArgument<T>(const string& name, \ T ArgumentHelper::GetSingleArgument<T>(const string &name, \
const T& default_value) const { \ const T &default_value) const { \
if (arg_map_.count(name) == 0) { \ if (arg_map_.count(name) == 0) { \
VLOG(1) << "Using default parameter value " << default_value \ VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \ << " for parameter " << name; \
...@@ -211,7 +211,7 @@ bool SupportsLosslessConversion(const InputType& value) { ...@@ -211,7 +211,7 @@ bool SupportsLosslessConversion(const InputType& value) {
return value; \ return value; \
} \ } \
template <> \ template <> \
bool ArgumentHelper::HasSingleArgumentOfType<T>(const string& name) const { \ bool ArgumentHelper::HasSingleArgumentOfType<T>(const string &name) const { \
if (arg_map_.count(name) == 0) { \ if (arg_map_.count(name) == 0) { \
return false; \ return false; \
} \ } \
...@@ -235,12 +235,12 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false) ...@@ -235,12 +235,12 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
enforce_lossless_conversion) \ enforce_lossless_conversion) \
template <> \ template <> \
vector<T> ArgumentHelper::GetRepeatedArgument<T>( \ vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
const string& name, const std::vector<T>& default_value) const { \ const string &name, const std::vector<T> &default_value) const { \
if (arg_map_.count(name) == 0) { \ if (arg_map_.count(name) == 0) { \
return default_value; \ return default_value; \
} \ } \
vector<T> values; \ vector<T> values; \
for (const auto& v : arg_map_.at(name).fieldname()) { \ for (const auto &v : arg_map_.at(name).fieldname()) { \
if (enforce_lossless_conversion) { \ if (enforce_lossless_conversion) { \
auto supportsConversion = \ auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \ SupportsLosslessConversion<decltype(v), T>(v); \
...@@ -267,7 +267,7 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) ...@@ -267,7 +267,7 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ #define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <> \ template <> \
Argument MakeArgument(const string& name, const T& value) { \ Argument MakeArgument(const string &name, const T &value) { \
Argument arg; \ Argument arg; \
arg.set_name(name); \ arg.set_name(name); \
arg.set_##fieldname(value); \ arg.set_##fieldname(value); \
...@@ -282,7 +282,7 @@ MACE_MAKE_SINGULAR_ARGUMENT(string, s) ...@@ -282,7 +282,7 @@ MACE_MAKE_SINGULAR_ARGUMENT(string, s)
#undef MACE_MAKE_SINGULAR_ARGUMENT #undef MACE_MAKE_SINGULAR_ARGUMENT
template <> template <>
Argument MakeArgument(const string& name, const MessageLite& value) { Argument MakeArgument(const string &name, const MessageLite &value) {
Argument arg; Argument arg;
arg.set_name(name); arg.set_name(name);
arg.set_s(value.SerializeAsString()); arg.set_s(value.SerializeAsString());
...@@ -291,10 +291,10 @@ Argument MakeArgument(const string& name, const MessageLite& value) { ...@@ -291,10 +291,10 @@ Argument MakeArgument(const string& name, const MessageLite& value) {
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \ #define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <> \ template <> \
Argument MakeArgument(const string& name, const vector<T>& value) { \ Argument MakeArgument(const string &name, const vector<T> &value) { \
Argument arg; \ Argument arg; \
arg.set_name(name); \ arg.set_name(name); \
for (const auto& v : value) { \ for (const auto &v : value) { \
arg.add_##fieldname(v); \ arg.add_##fieldname(v); \
} \ } \
return arg; \ return arg; \
...@@ -306,8 +306,8 @@ MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints) ...@@ -306,8 +306,8 @@ MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints)
MACE_MAKE_REPEATED_ARGUMENT(string, strings) MACE_MAKE_REPEATED_ARGUMENT(string, strings)
#undef MACE_MAKE_REPEATED_ARGUMENT #undef MACE_MAKE_REPEATED_ARGUMENT
const Argument& GetArgument(const OperatorDef& def, const string& name) { const Argument &GetArgument(const OperatorDef &def, const string &name) {
for (const Argument& arg : def.arg()) { for (const Argument &arg : def.arg()) {
if (arg.name() == name) { if (arg.name() == name) {
return arg; return arg;
} }
...@@ -318,10 +318,10 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) { ...@@ -318,10 +318,10 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) {
return std::move(Argument()); return std::move(Argument());
} }
bool GetFlagArgument(const OperatorDef& def, bool GetFlagArgument(const OperatorDef &def,
const string& name, const string &name,
bool def_value) { bool def_value) {
for (const Argument& arg : def.arg()) { for (const Argument &arg : def.arg()) {
if (arg.name() == name) { if (arg.name() == name) {
MACE_CHECK(arg.has_i(), "Can't parse argument as bool: ", MACE_CHECK(arg.has_i(), "Can't parse argument as bool: ",
ProtoDebugString(arg)); ProtoDebugString(arg));
...@@ -331,9 +331,9 @@ bool GetFlagArgument(const OperatorDef& def, ...@@ -331,9 +331,9 @@ bool GetFlagArgument(const OperatorDef& def,
return def_value; return def_value;
} }
Argument* GetMutableArgument(const string& name, Argument *GetMutableArgument(const string &name,
const bool create_if_missing, const bool create_if_missing,
OperatorDef* def) { OperatorDef *def) {
for (int i = 0; i < def->arg_size(); ++i) { for (int i = 0; i < def->arg_size(); ++i) {
if (def->arg(i).name() == name) { if (def->arg(i).name() == name) {
return def->mutable_arg(i); return def->mutable_arg(i);
...@@ -341,7 +341,7 @@ Argument* GetMutableArgument(const string& name, ...@@ -341,7 +341,7 @@ Argument* GetMutableArgument(const string& name,
} }
// If no argument of the right name is found... // If no argument of the right name is found...
if (create_if_missing) { if (create_if_missing) {
Argument* arg = def->add_arg(); Argument *arg = def->add_arg();
arg->set_name(name); arg->set_name(name);
return arg; return arg;
} else { } else {
......
...@@ -21,56 +21,56 @@ using std::string; ...@@ -21,56 +21,56 @@ using std::string;
using ::google::protobuf::MessageLite; using ::google::protobuf::MessageLite;
// Common interfaces that reads file contents into a string. // Common interfaces that reads file contents into a string.
bool ReadStringFromFile(const char* filename, string* str); bool ReadStringFromFile(const char *filename, string *str);
bool WriteStringToFile(const string& str, const char* filename); bool WriteStringToFile(const string &str, const char *filename);
// Common interfaces that are supported by both lite and full protobuf. // Common interfaces that are supported by both lite and full protobuf.
bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto); bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto);
inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) { inline bool ReadProtoFromBinaryFile(const string filename, MessageLite *proto) {
return ReadProtoFromBinaryFile(filename.c_str(), proto); return ReadProtoFromBinaryFile(filename.c_str(), proto);
} }
void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename); void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename);
inline void WriteProtoToBinaryFile(const MessageLite& proto, inline void WriteProtoToBinaryFile(const MessageLite &proto,
const string& filename) { const string &filename) {
return WriteProtoToBinaryFile(proto, filename.c_str()); return WriteProtoToBinaryFile(proto, filename.c_str());
} }
#ifdef MACE_USE_LITE_PROTO #ifdef MACE_USE_LITE_PROTO
inline string ProtoDebugString(const MessageLite& proto) { inline string ProtoDebugString(const MessageLite &proto) {
return proto.SerializeAsString(); return proto.SerializeAsString();
} }
// Text format MessageLite wrappers: these functions do nothing but just // Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using // allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support. // MessageLite but still want text support.
inline bool ReadProtoFromTextFile(const char* /*filename*/, inline bool ReadProtoFromTextFile(const char * /*filename*/,
MessageLite* /*proto*/) { MessageLite * /*proto*/) {
LOG(FATAL) << "If you are running lite version, you should not be " LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers."; << "calling any text-format protobuffers.";
return false; // Just to suppress compiler warning. return false; // Just to suppress compiler warning.
} }
inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { inline bool ReadProtoFromTextFile(const string filename, MessageLite *proto) {
return ReadProtoFromTextFile(filename.c_str(), proto); return ReadProtoFromTextFile(filename.c_str(), proto);
} }
inline void WriteProtoToTextFile(const MessageLite& /*proto*/, inline void WriteProtoToTextFile(const MessageLite & /*proto*/,
const char* /*filename*/) { const char * /*filename*/) {
LOG(FATAL) << "If you are running lite version, you should not be " LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers."; << "calling any text-format protobuffers.";
} }
inline void WriteProtoToTextFile(const MessageLite& proto, inline void WriteProtoToTextFile(const MessageLite &proto,
const string& filename) { const string &filename) {
return WriteProtoToTextFile(proto, filename.c_str()); return WriteProtoToTextFile(proto, filename.c_str());
} }
inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) { inline bool ReadProtoFromFile(const char *filename, MessageLite *proto) {
return (ReadProtoFromBinaryFile(filename, proto) || return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto)); ReadProtoFromTextFile(filename, proto));
} }
inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) { inline bool ReadProtoFromFile(const string &filename, MessageLite *proto) {
return ReadProtoFromFile(filename.c_str(), proto); return ReadProtoFromFile(filename.c_str(), proto);
} }
...@@ -78,27 +78,27 @@ inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) { ...@@ -78,27 +78,27 @@ inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
using ::google::protobuf::Message; using ::google::protobuf::Message;
inline string ProtoDebugString(const Message& proto) { inline string ProtoDebugString(const Message &proto) {
return proto.ShortDebugString(); return proto.ShortDebugString();
} }
bool ReadProtoFromTextFile(const char* filename, Message* proto); bool ReadProtoFromTextFile(const char *filename, Message *proto);
inline bool ReadProtoFromTextFile(const string filename, Message* proto) { inline bool ReadProtoFromTextFile(const string filename, Message *proto) {
return ReadProtoFromTextFile(filename.c_str(), proto); return ReadProtoFromTextFile(filename.c_str(), proto);
} }
void WriteProtoToTextFile(const Message& proto, const char* filename); void WriteProtoToTextFile(const Message &proto, const char *filename);
inline void WriteProtoToTextFile(const Message& proto, const string& filename) { inline void WriteProtoToTextFile(const Message &proto, const string &filename) {
return WriteProtoToTextFile(proto, filename.c_str()); return WriteProtoToTextFile(proto, filename.c_str());
} }
// Read Proto from a file, letting the code figure out if it is text or binary. // Read Proto from a file, letting the code figure out if it is text or binary.
inline bool ReadProtoFromFile(const char* filename, Message* proto) { inline bool ReadProtoFromFile(const char *filename, Message *proto) {
return (ReadProtoFromBinaryFile(filename, proto) || return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto)); ReadProtoFromTextFile(filename, proto));
} }
inline bool ReadProtoFromFile(const string& filename, Message* proto) { inline bool ReadProtoFromFile(const string &filename, Message *proto) {
return ReadProtoFromFile(filename.c_str(), proto); return ReadProtoFromFile(filename.c_str(), proto);
} }
...@@ -107,21 +107,21 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) { ...@@ -107,21 +107,21 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) {
template <class IterableInputs = std::initializer_list<string>, template <class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>, class IterableOutputs = std::initializer_list<string>,
class IterableArgs = std::initializer_list<Argument>> class IterableArgs = std::initializer_list<Argument>>
OperatorDef CreateOperatorDef(const string& type, OperatorDef CreateOperatorDef(const string &type,
const string& name, const string &name,
const IterableInputs& inputs, const IterableInputs &inputs,
const IterableOutputs& outputs, const IterableOutputs &outputs,
const IterableArgs& args) { const IterableArgs &args) {
OperatorDef def; OperatorDef def;
def.set_type(type); def.set_type(type);
def.set_name(name); def.set_name(name);
for (const string& in : inputs) { for (const string &in : inputs) {
def.add_input(in); def.add_input(in);
} }
for (const string& out : outputs) { for (const string &out : outputs) {
def.add_output(out); def.add_output(out);
} }
for (const Argument& arg : args) { for (const Argument &arg : args) {
def.add_arg()->CopyFrom(arg); def.add_arg()->CopyFrom(arg);
} }
return def; return def;
...@@ -131,10 +131,10 @@ OperatorDef CreateOperatorDef(const string& type, ...@@ -131,10 +131,10 @@ OperatorDef CreateOperatorDef(const string& type,
// to specify args. // to specify args.
template <class IterableInputs = std::initializer_list<string>, template <class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>> class IterableOutputs = std::initializer_list<string>>
inline OperatorDef CreateOperatorDef(const string& type, inline OperatorDef CreateOperatorDef(const string &type,
const string& name, const string &name,
const IterableInputs& inputs, const IterableInputs &inputs,
const IterableOutputs& outputs) { const IterableOutputs &outputs) {
return CreateOperatorDef(type, name, inputs, outputs, return CreateOperatorDef(type, name, inputs, outputs,
std::vector<Argument>()); std::vector<Argument>());
} }
...@@ -150,56 +150,56 @@ inline OperatorDef CreateOperatorDef(const string& type, ...@@ -150,56 +150,56 @@ inline OperatorDef CreateOperatorDef(const string& type,
class ArgumentHelper { class ArgumentHelper {
public: public:
template <typename Def> template <typename Def>
static bool HasArgument(const Def& def, const string& name) { static bool HasArgument(const Def &def, const string &name) {
return ArgumentHelper(def).HasArgument(name); return ArgumentHelper(def).HasArgument(name);
} }
template <typename Def, typename T> template <typename Def, typename T>
static T GetSingleArgument(const Def& def, static T GetSingleArgument(const Def &def,
const string& name, const string &name,
const T& default_value) { const T &default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value); return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
} }
template <typename Def, typename T> template <typename Def, typename T>
static bool HasSingleArgumentOfType(const Def& def, const string& name) { static bool HasSingleArgumentOfType(const Def &def, const string &name) {
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name); return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
} }
template <typename Def, typename T> template <typename Def, typename T>
static vector<T> GetRepeatedArgument( static vector<T> GetRepeatedArgument(
const Def& def, const Def &def,
const string& name, const string &name,
const std::vector<T>& default_value = std::vector<T>()) { const std::vector<T> &default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value); return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
} }
template <typename Def, typename MessageType> template <typename Def, typename MessageType>
static MessageType GetMessageArgument(const Def& def, const string& name) { static MessageType GetMessageArgument(const Def &def, const string &name) {
return ArgumentHelper(def).GetMessageArgument<MessageType>(name); return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
} }
template <typename Def, typename MessageType> template <typename Def, typename MessageType>
static vector<MessageType> GetRepeatedMessageArgument(const Def& def, static vector<MessageType> GetRepeatedMessageArgument(const Def &def,
const string& name) { const string &name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name); return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
} }
explicit ArgumentHelper(const OperatorDef& def); explicit ArgumentHelper(const OperatorDef &def);
explicit ArgumentHelper(const NetDef& netdef); explicit ArgumentHelper(const NetDef &netdef);
bool HasArgument(const string& name) const; bool HasArgument(const string &name) const;
template <typename T> template <typename T>
T GetSingleArgument(const string& name, const T& default_value) const; T GetSingleArgument(const string &name, const T &default_value) const;
template <typename T> template <typename T>
bool HasSingleArgumentOfType(const string& name) const; bool HasSingleArgumentOfType(const string &name) const;
template <typename T> template <typename T>
vector<T> GetRepeatedArgument( vector<T> GetRepeatedArgument(
const string& name, const string &name,
const std::vector<T>& default_value = std::vector<T>()) const; const std::vector<T> &default_value = std::vector<T>()) const;
template <typename MessageType> template <typename MessageType>
MessageType GetMessageArgument(const string& name) const { MessageType GetMessageArgument(const string &name) const {
MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name);
MessageType message; MessageType message;
if (arg_map_.at(name).has_s()) { if (arg_map_.at(name).has_s()) {
...@@ -212,7 +212,7 @@ class ArgumentHelper { ...@@ -212,7 +212,7 @@ class ArgumentHelper {
} }
template <typename MessageType> template <typename MessageType>
vector<MessageType> GetRepeatedMessageArgument(const string& name) const { vector<MessageType> GetRepeatedMessageArgument(const string &name) const {
MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name); MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name);
vector<MessageType> messages(arg_map_.at(name).strings_size()); vector<MessageType> messages(arg_map_.at(name).strings_size());
for (int i = 0; i < messages.size(); ++i) { for (int i = 0; i < messages.size(); ++i) {
...@@ -226,20 +226,20 @@ class ArgumentHelper { ...@@ -226,20 +226,20 @@ class ArgumentHelper {
std::map<string, Argument> arg_map_; std::map<string, Argument> arg_map_;
}; };
const Argument& GetArgument(const OperatorDef& def, const string& name); const Argument &GetArgument(const OperatorDef &def, const string &name);
bool GetFlagArgument(const OperatorDef& def, bool GetFlagArgument(const OperatorDef &def,
const string& name, const string &name,
bool def_value = false); bool def_value = false);
Argument* GetMutableArgument(const string& name, Argument *GetMutableArgument(const string &name,
const bool create_if_missing, const bool create_if_missing,
OperatorDef* def); OperatorDef *def);
template <typename T> template <typename T>
Argument MakeArgument(const string& name, const T& value); Argument MakeArgument(const string &name, const T &value);
template <typename T> template <typename T>
inline void AddArgument(const string& name, const T& value, OperatorDef* def) { inline void AddArgument(const string &name, const T &value, OperatorDef *def) {
GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value)); GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
} }
......
...@@ -16,15 +16,15 @@ class Registry { ...@@ -16,15 +16,15 @@ class Registry {
Registry() : registry_() {} Registry() : registry_() {}
void Register(const SrcType& key, Creator creator) { void Register(const SrcType &key, Creator creator) {
std::lock_guard<std::mutex> lock(register_mutex_); std::lock_guard<std::mutex> lock(register_mutex_);
MACE_CHECK(registry_.count(key) == 0, "Key already registered."); MACE_CHECK(registry_.count(key) == 0, "Key already registered.");
registry_[key] = creator; registry_[key] = creator;
} }
inline bool Has(const SrcType& key) { return registry_.count(key) != 0; } inline bool Has(const SrcType &key) { return registry_.count(key) != 0; }
unique_ptr<ObjectType> Create(const SrcType& key, Args... args) { unique_ptr<ObjectType> Create(const SrcType &key, Args... args) {
if (registry_.count(key) == 0) { if (registry_.count(key) == 0) {
LOG(FATAL) << "Key not registered: " << key; LOG(FATAL) << "Key not registered: " << key;
} }
...@@ -36,7 +36,7 @@ class Registry { ...@@ -36,7 +36,7 @@ class Registry {
*/ */
vector<SrcType> Keys() { vector<SrcType> Keys() {
vector<SrcType> keys; vector<SrcType> keys;
for (const auto& it : registry_) { for (const auto &it : registry_) {
keys.push_back(it.first); keys.push_back(it.first);
} }
return keys; return keys;
...@@ -52,8 +52,8 @@ class Registry { ...@@ -52,8 +52,8 @@ class Registry {
template <class SrcType, class ObjectType, class... Args> template <class SrcType, class ObjectType, class... Args>
class Registerer { class Registerer {
public: public:
Registerer(const SrcType& key, Registerer(const SrcType &key,
Registry<SrcType, ObjectType, Args...>* registry, Registry<SrcType, ObjectType, Args...> *registry,
typename Registry<SrcType, ObjectType, Args...>::Creator creator) { typename Registry<SrcType, ObjectType, Args...>::Creator creator) {
registry->Register(key, creator); registry->Register(key, creator);
} }
...@@ -73,13 +73,13 @@ class Registerer { ...@@ -73,13 +73,13 @@ class Registerer {
#endif #endif
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \ #define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName(); \ Registry<SrcType, ObjectType, ##__VA_ARGS__> *RegistryName(); \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \ typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName; Registerer##RegistryName;
#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \ #define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName() { \ Registry<SrcType, ObjectType, ##__VA_ARGS__> *RegistryName() { \
static Registry<SrcType, ObjectType, ##__VA_ARGS__>* registry = \ static Registry<SrcType, ObjectType, ##__VA_ARGS__> *registry = \
new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \ new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \
return registry; \ return registry; \
} }
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
// //
#include "mace/core/runtime/opencl/opencl_allocator.h" #include "mace/core/runtime/opencl/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/cl2.hpp" #include "mace/core/runtime/opencl/cl2.hpp"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
......
...@@ -30,7 +30,9 @@ bool ReadSourceFile(const char *filename, std::string *content) { ...@@ -30,7 +30,9 @@ bool ReadSourceFile(const char *filename, std::string *content) {
return true; return true;
} }
bool BuildProgram(OpenCLRuntime *runtime, const char *filename, cl::Program *program) { bool BuildProgram(OpenCLRuntime *runtime,
const char *filename,
cl::Program *program) {
MACE_CHECK_NOTNULL(filename); MACE_CHECK_NOTNULL(filename);
MACE_CHECK_NOTNULL(program); MACE_CHECK_NOTNULL(program);
......
...@@ -16,9 +16,9 @@ class Serializer { ...@@ -16,9 +16,9 @@ class Serializer {
Serializer() {} Serializer() {}
~Serializer() {} ~Serializer() {}
unique_ptr<TensorProto> Serialize(const Tensor& tensor, const string& name); unique_ptr<TensorProto> Serialize(const Tensor &tensor, const string &name);
unique_ptr<Tensor> Deserialize(const TensorProto& proto, DeviceType type); unique_ptr<Tensor> Deserialize(const TensorProto &proto, DeviceType type);
DISABLE_COPY_AND_ASSIGN(Serializer); DISABLE_COPY_AND_ASSIGN(Serializer);
}; };
......
...@@ -16,36 +16,36 @@ ...@@ -16,36 +16,36 @@
namespace mace { namespace mace {
namespace testing { namespace testing {
static std::vector<Benchmark*>* all_benchmarks = nullptr; static std::vector<Benchmark *> *all_benchmarks = nullptr;
static std::string label; static std::string label;
static int64_t bytes_processed; static int64_t bytes_processed;
static int64_t items_processed; static int64_t items_processed;
static int64_t accum_time = 0; static int64_t accum_time = 0;
static int64_t start_time = 0; static int64_t start_time = 0;
Benchmark::Benchmark(const char* name, void (*fn)(int)) Benchmark::Benchmark(const char *name, void (*fn)(int))
: name_(name), num_args_(0), fn0_(fn) { : name_(name), num_args_(0), fn0_(fn) {
args_.push_back(std::make_pair(-1, -1)); args_.push_back(std::make_pair(-1, -1));
Register(); Register();
} }
Benchmark::Benchmark(const char* name, void (*fn)(int, int)) Benchmark::Benchmark(const char *name, void (*fn)(int, int))
: name_(name), num_args_(1), fn1_(fn) { : name_(name), num_args_(1), fn1_(fn) {
Register(); Register();
} }
Benchmark::Benchmark(const char* name, void (*fn)(int, int, int)) Benchmark::Benchmark(const char *name, void (*fn)(int, int, int))
: name_(name), num_args_(2), fn2_(fn) { : name_(name), num_args_(2), fn2_(fn) {
Register(); Register();
} }
Benchmark* Benchmark::Arg(int x) { Benchmark *Benchmark::Arg(int x) {
MACE_CHECK(num_args_ == 1); MACE_CHECK(num_args_ == 1);
args_.push_back(std::make_pair(x, -1)); args_.push_back(std::make_pair(x, -1));
return this; return this;
} }
Benchmark* Benchmark::ArgPair(int x, int y) { Benchmark *Benchmark::ArgPair(int x, int y) {
MACE_CHECK(num_args_ == 2); MACE_CHECK(num_args_ == 2);
args_.push_back(std::make_pair(x, y)); args_.push_back(std::make_pair(x, y));
return this; return this;
...@@ -54,7 +54,7 @@ Benchmark* Benchmark::ArgPair(int x, int y) { ...@@ -54,7 +54,7 @@ Benchmark* Benchmark::ArgPair(int x, int y) {
// Run all benchmarks // Run all benchmarks
void Benchmark::Run() { Run("all"); } void Benchmark::Run() { Run("all"); }
void Benchmark::Run(const char* pattern) { void Benchmark::Run(const char *pattern) {
if (!all_benchmarks) return; if (!all_benchmarks) return;
if (std::string(pattern) == "all") { if (std::string(pattern) == "all") {
...@@ -117,11 +117,11 @@ void Benchmark::Run(const char* pattern) { ...@@ -117,11 +117,11 @@ void Benchmark::Run(const char* pattern) {
} }
void Benchmark::Register() { void Benchmark::Register() {
if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark*>; if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark *>;
all_benchmarks->push_back(this); all_benchmarks->push_back(this);
} }
void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) { void Benchmark::Run(int arg1, int arg2, int *run_count, double *run_seconds) {
static const int64_t kMinIters = 10; static const int64_t kMinIters = 10;
static const int64_t kMaxIters = 1000000000; static const int64_t kMaxIters = 1000000000;
static const double kMinTime = 0.5; static const double kMinTime = 0.5;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c #define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c
#define BENCHMARK(n) \ #define BENCHMARK(n) \
static ::mace::testing::Benchmark* MACE_BENCHMARK_CONCAT( \ static ::mace::testing::Benchmark *MACE_BENCHMARK_CONCAT( \
__benchmark_, n, __LINE__) = (new ::mace::testing::Benchmark(#n, (n))) __benchmark_, n, __LINE__) = (new ::mace::testing::Benchmark(#n, (n)))
namespace mace { namespace mace {
...@@ -21,14 +21,14 @@ namespace testing { ...@@ -21,14 +21,14 @@ namespace testing {
class Benchmark { class Benchmark {
public: public:
Benchmark(const char* name, void (*fn)(int)); Benchmark(const char *name, void (*fn)(int));
Benchmark(const char* name, void (*fn)(int, int)); Benchmark(const char *name, void (*fn)(int, int));
Benchmark(const char* name, void (*fn)(int, int, int)); Benchmark(const char *name, void (*fn)(int, int, int));
Benchmark* Arg(int x); Benchmark *Arg(int x);
Benchmark* ArgPair(int x, int y); Benchmark *ArgPair(int x, int y);
static void Run(); static void Run();
static void Run(const char* pattern); static void Run(const char *pattern);
private: private:
string name_; string name_;
...@@ -39,7 +39,7 @@ class Benchmark { ...@@ -39,7 +39,7 @@ class Benchmark {
void (*fn2_)(int, int, int) = nullptr; void (*fn2_)(int, int, int) = nullptr;
void Register(); void Register();
void Run(int arg1, int arg2, int* run_count, double* run_seconds); void Run(int arg1, int arg2, int *run_count, double *run_seconds);
}; };
void RunBenchmarks(); void RunBenchmarks();
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "mace/core/testing/test_benchmark.h" #include "mace/core/testing/test_benchmark.h"
int main(int argc, char** argv) { int main(int argc, char **argv) {
std::cout << "Running main() from test_main.cc\n"; std::cout << "Running main() from test_main.cc\n";
// TODO Use gflags // TODO Use gflags
......
...@@ -10,14 +10,14 @@ namespace mace { ...@@ -10,14 +10,14 @@ namespace mace {
vector<string> Workspace::Tensors() const { vector<string> Workspace::Tensors() const {
vector<string> names; vector<string> names;
for (auto& entry : tensor_map_) { for (auto &entry : tensor_map_) {
names.push_back(entry.first); names.push_back(entry.first);
} }
return names; return names;
} }
Tensor* Workspace::CreateTensor(const string& name, Tensor *Workspace::CreateTensor(const string &name,
Allocator* alloc, Allocator *alloc,
DataType type) { DataType type) {
if (HasTensor(name)) { if (HasTensor(name)) {
VLOG(1) << "Tensor " << name << " already exists. Skipping."; VLOG(1) << "Tensor " << name << " already exists. Skipping.";
...@@ -28,7 +28,7 @@ Tensor* Workspace::CreateTensor(const string& name, ...@@ -28,7 +28,7 @@ Tensor* Workspace::CreateTensor(const string& name,
return GetTensor(name); return GetTensor(name);
} }
bool Workspace::RemoveTensor(const string& name) { bool Workspace::RemoveTensor(const string &name) {
auto it = tensor_map_.find(name); auto it = tensor_map_.find(name);
if (it != tensor_map_.end()) { if (it != tensor_map_.end()) {
VLOG(1) << "Removing blob " << name << " from this workspace."; VLOG(1) << "Removing blob " << name << " from this workspace.";
...@@ -38,7 +38,7 @@ bool Workspace::RemoveTensor(const string& name) { ...@@ -38,7 +38,7 @@ bool Workspace::RemoveTensor(const string& name) {
return false; return false;
} }
const Tensor* Workspace::GetTensor(const string& name) const { const Tensor *Workspace::GetTensor(const string &name) const {
if (tensor_map_.count(name)) { if (tensor_map_.count(name)) {
return tensor_map_.at(name).get(); return tensor_map_.at(name).get();
} else { } else {
...@@ -47,18 +47,17 @@ const Tensor* Workspace::GetTensor(const string& name) const { ...@@ -47,18 +47,17 @@ const Tensor* Workspace::GetTensor(const string& name) const {
return nullptr; return nullptr;
} }
Tensor* Workspace::GetTensor(const string& name) { Tensor *Workspace::GetTensor(const string &name) {
return const_cast<Tensor*>( return const_cast<Tensor *>(
static_cast<const Workspace*>(this)->GetTensor(name)); static_cast<const Workspace *>(this)->GetTensor(name));
} }
void Workspace::LoadModelTensor(const NetDef& net_def, DeviceType type) { void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
Serializer serializer; Serializer serializer;
for (auto& tensor_proto : net_def.tensors()) { for (auto &tensor_proto : net_def.tensors()) {
VLOG(1) << "Load tensor: " << tensor_proto.name() << " has shape: "
VLOG(1) << "Load tensor: " << tensor_proto.name() << internal::MakeString(vector<index_t>(tensor_proto.dims().begin(),
<< " has shape: " << internal::MakeString(vector<index_t>( tensor_proto.dims().end()));
tensor_proto.dims().begin(), tensor_proto.dims().end()));
tensor_map_[tensor_proto.name()] = tensor_map_[tensor_proto.name()] =
serializer.Deserialize(tensor_proto, type); serializer.Deserialize(tensor_proto, type);
} }
......
...@@ -19,19 +19,19 @@ class Workspace { ...@@ -19,19 +19,19 @@ class Workspace {
vector<string> Tensors() const; vector<string> Tensors() const;
Tensor* CreateTensor(const string& name, Allocator* alloc, DataType type); Tensor *CreateTensor(const string &name, Allocator *alloc, DataType type);
bool RemoveTensor(const string& name); bool RemoveTensor(const string &name);
inline bool HasTensor(const string& name) const { inline bool HasTensor(const string &name) const {
return tensor_map_.count(name); return tensor_map_.count(name);
} }
const Tensor* GetTensor(const string& name) const; const Tensor *GetTensor(const string &name) const;
Tensor* GetTensor(const string& name); Tensor *GetTensor(const string &name);
void LoadModelTensor(const NetDef& net_def, DeviceType type); void LoadModelTensor(const NetDef &net_def, DeviceType type);
private: private:
TensorMap tensor_map_; TensorMap tensor_map_;
......
...@@ -10,8 +10,8 @@ static void foo(int iters) { ...@@ -10,8 +10,8 @@ static void foo(int iters) {
mace::testing::ItemsProcessed(tot); mace::testing::ItemsProcessed(tot);
mace::testing::BytesProcessed(tot * (sizeof(float))); mace::testing::BytesProcessed(tot * (sizeof(float)));
float* inp = new float[N]; float *inp = new float[N];
float* out = new float[N]; float *out = new float[N];
while (iters--) { while (iters--) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
...@@ -29,8 +29,8 @@ static void bar(int iters, int n) { ...@@ -29,8 +29,8 @@ static void bar(int iters, int n) {
mace::testing::ItemsProcessed(tot); mace::testing::ItemsProcessed(tot);
mace::testing::BytesProcessed(tot * (sizeof(float))); mace::testing::BytesProcessed(tot * (sizeof(float)));
float* inp = new float[n]; float *inp = new float[n];
float* out = new float[n]; float *out = new float[n];
while (iters--) { while (iters--) {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
* --output_file=mace.out \ * --output_file=mace.out \
* --device=NEON * --device=NEON
*/ */
#include <fstream>
#include <sys/time.h> #include <sys/time.h>
#include <fstream>
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/utils/command_line_flags.h" #include "mace/utils/command_line_flags.h"
...@@ -83,12 +83,11 @@ int main(int argc, char **argv) { ...@@ -83,12 +83,11 @@ int main(int argc, char **argv) {
Workspace ws; Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU); ws.LoadModelTensor(net_def, DeviceType::CPU);
Tensor *input_tensor = ws.CreateTensor(input_node + ":0", Tensor *input_tensor =
cpu_allocator(), DT_FLOAT); ws.CreateTensor(input_node + ":0", cpu_allocator(), DT_FLOAT);
input_tensor->Resize(shape); input_tensor->Resize(shape);
float *input_data = input_tensor->mutable_data<float>(); float *input_data = input_tensor->mutable_data<float>();
// load input // load input
ifstream in_file(input_file, ios::in | ios::binary); ifstream in_file(input_file, ios::in | ios::binary);
in_file.read(reinterpret_cast<char *>(input_data), in_file.read(reinterpret_cast<char *>(input_data),
...@@ -112,14 +111,17 @@ int main(int argc, char **argv) { ...@@ -112,14 +111,17 @@ int main(int argc, char **argv) {
net->Run(); net->Run();
} }
gettimeofday(&tv2, NULL); gettimeofday(&tv2, NULL);
cout << "avg duration: " << ((tv2.tv_sec - tv1.tv_sec) * 1000 cout << "avg duration: "
+ (tv2.tv_usec - tv1.tv_usec) / 1000) / round << endl; << ((tv2.tv_sec - tv1.tv_sec) * 1000 +
(tv2.tv_usec - tv1.tv_usec) / 1000) /
round
<< endl;
// save output // save output
const Tensor *output = ws.GetTensor(output_node + ":0"); const Tensor *output = ws.GetTensor(output_node + ":0");
ofstream out_file(output_file, ios::binary); ofstream out_file(output_file, ios::binary);
out_file.write((const char *) (output->data<float>()), out_file.write((const char *)(output->data<float>()),
output->size() * sizeof(float)); output->size() * sizeof(float));
out_file.flush(); out_file.flush();
out_file.close(); out_file.close();
......
...@@ -20,7 +20,7 @@ cc_library( ...@@ -20,7 +20,7 @@ cc_library(
linkopts = if_android(["-lm"]), linkopts = if_android(["-lm"]),
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/utils:utils", "//mace/utils",
], ],
) )
......
...@@ -12,7 +12,7 @@ namespace kernels { ...@@ -12,7 +12,7 @@ namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct AddNFunctor { struct AddNFunctor {
void operator()(const vector<const T*>& inputs, T* output, index_t size) { void operator()(const vector<const T *> &inputs, T *output, index_t size) {
memset(output, 0, size * sizeof(T)); memset(output, 0, size * sizeof(T));
int n = inputs.size(); int n = inputs.size();
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -25,7 +25,7 @@ struct AddNFunctor { ...@@ -25,7 +25,7 @@ struct AddNFunctor {
template <> template <>
void AddNFunctor<DeviceType::NEON, float>::operator()( void AddNFunctor<DeviceType::NEON, float>::operator()(
const vector<const float*>& inputs, float* output, index_t size); const vector<const float *> &inputs, float *output, index_t size);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -13,17 +13,16 @@ namespace kernels { ...@@ -13,17 +13,16 @@ namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct BatchNormFunctor { struct BatchNormFunctor {
void operator()(const T *input,
void operator()(const T* input, const T *scale,
const T* scale, const T *offset,
const T* offset, const T *mean,
const T* mean, const T *var,
const T* var,
const float variance_epsilon, const float variance_epsilon,
const index_t n, const index_t n,
const index_t channel, const index_t channel,
const index_t sample_size, const index_t sample_size,
T* output) { T *output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is // The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
...@@ -40,8 +39,8 @@ struct BatchNormFunctor { ...@@ -40,8 +39,8 @@ struct BatchNormFunctor {
index_t pos = c * sample_size; index_t pos = c * sample_size;
for (index_t i = 0; i < n; ++i) { for (index_t i = 0; i < n; ++i) {
const T* input_sample_ptr = input + pos; const T *input_sample_ptr = input + pos;
T* output_sample_ptr = output + pos; T *output_sample_ptr = output + pos;
for (index_t j = 0; j < sample_size; ++j) { for (index_t j = 0; j < sample_size; ++j) {
output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset; output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset;
} }
...@@ -53,16 +52,16 @@ struct BatchNormFunctor { ...@@ -53,16 +52,16 @@ struct BatchNormFunctor {
template <> template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()( void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* input, const float *input,
const float* scale, const float *scale,
const float* offset, const float *offset,
const float* mean, const float *mean,
const float* var, const float *var,
const float variance_epsilon, const float variance_epsilon,
const index_t n, const index_t n,
const index_t channel, const index_t channel,
const index_t sample_size, const index_t sample_size,
float* output); float *output);
} // namepsace kernels } // namepsace kernels
} // namespace mace } // namespace mace
......
...@@ -10,11 +10,10 @@ ...@@ -10,11 +10,10 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class ChannelShuffleFunctor { class ChannelShuffleFunctor {
public: public:
ChannelShuffleFunctor(const int group) ChannelShuffleFunctor(const int group) : group_(group) {}
: group_(group) {}
void operator()(const T *input, const index_t *input_shape, T *output) { void operator()(const T *input, const index_t *input_shape, T *output) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
...@@ -28,8 +27,8 @@ class ChannelShuffleFunctor { ...@@ -28,8 +27,8 @@ class ChannelShuffleFunctor {
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels_of_group; ++c) { for (int c = 0; c < channels_of_group; ++c) {
for (int g = 0; g < group_; ++g) { for (int g = 0; g < group_; ++g) {
index_t input_offset = (b * channels + g * channels_of_group + c) * index_t input_offset =
image_size; (b * channels + g * channels_of_group + c) * image_size;
index_t output_offset = (b * channels + c * group_ + g) * image_size; index_t output_offset = (b * channels + c * group_ + g) * image_size;
memcpy(output + output_offset, input + input_offset, memcpy(output + output_offset, input + input_offset,
image_size * sizeof(T)); image_size * sizeof(T));
......
...@@ -5,13 +5,13 @@ ...@@ -5,13 +5,13 @@
#ifndef MACE_KERNELS_CONCAT_H_ #ifndef MACE_KERNELS_CONCAT_H_
#define MACE_KERNELS_CONCAT_H_ #define MACE_KERNELS_CONCAT_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
struct ConcatFunctor { struct ConcatFunctor {
void operator()(std::vector<const T *> &input_list, void operator()(std::vector<const T *> &input_list,
const index_t inner_dim, const index_t inner_dim,
......
...@@ -11,15 +11,13 @@ ...@@ -11,15 +11,13 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
struct Conv2dFunctor { struct Conv2dFunctor {
Conv2dFunctor() {} Conv2dFunctor() {}
Conv2dFunctor(const int *strides, Conv2dFunctor(const int *strides,
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations) : const int *dilations)
strides_(strides), : strides_(strides), paddings_(paddings), dilations_(dilations) {}
paddings_(paddings),
dilations_(dilations) {}
void operator()(const T *input, // NCHW void operator()(const T *input, // NCHW
const index_t *input_shape, const index_t *input_shape,
...@@ -106,7 +104,7 @@ struct Conv2dFunctor { ...@@ -106,7 +104,7 @@ struct Conv2dFunctor {
const int *dilations_; // [dilation_h, dilation_w] const int *dilations_; // [dilation_h, dilation_w]
}; };
template<> template <>
void Conv2dFunctor<DeviceType::NEON, float>::operator()( void Conv2dFunctor<DeviceType::NEON, float>::operator()(
const float *input, const float *input,
const index_t *input_shape, const index_t *input_shape,
......
...@@ -77,7 +77,6 @@ void CalPaddingSize(const index_t *input_shape, // NCHW ...@@ -77,7 +77,6 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
const int *strides, const int *strides,
Padding padding, Padding padding,
int *padding_size) { int *padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1"); "Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
......
...@@ -5,22 +5,20 @@ ...@@ -5,22 +5,20 @@
#ifndef MACE_KERNELS_DEPTHWISE_CONV_H_ #ifndef MACE_KERNELS_DEPTHWISE_CONV_H_
#define MACE_KERNELS_DEPTHWISE_CONV_H_ #define MACE_KERNELS_DEPTHWISE_CONV_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/kernels/conv_pool_2d_util.h" #include "mace/kernels/conv_pool_2d_util.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<DeviceType D, typename T> template <DeviceType D, typename T>
struct DepthwiseConv2dFunctor { struct DepthwiseConv2dFunctor {
DepthwiseConv2dFunctor() {} DepthwiseConv2dFunctor() {}
DepthwiseConv2dFunctor(const int *strides, DepthwiseConv2dFunctor(const int *strides,
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations) : const int *dilations)
strides_(strides), : strides_(strides), paddings_(paddings), dilations_(dilations) {}
paddings_(paddings),
dilations_(dilations) {}
void operator()(const T *input, // NCHW void operator()(const T *input, // NCHW
const index_t *input_shape, const index_t *input_shape,
...@@ -80,15 +78,14 @@ struct DepthwiseConv2dFunctor { ...@@ -80,15 +78,14 @@ struct DepthwiseConv2dFunctor {
inw >= input_width) { inw >= input_width) {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop, inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", inh, ", ", "Out of range read from input: ", inh, ", ", inw);
inw);
// else padding with 0: // else padding with 0:
// sum += 0; // sum += 0;
} else { } else {
index_t input_offset = index_t input_offset =
n * input_channels * input_height * input_width + n * input_channels * input_height * input_width +
(c / multiplier) * input_height * input_width + inh * input_width + (c / multiplier) * input_height * input_width +
inw; inh * input_width + inw;
sum += input[input_offset] * *filter_ptr; sum += input[input_offset] * *filter_ptr;
} }
++filter_ptr; ++filter_ptr;
...@@ -106,8 +103,9 @@ struct DepthwiseConv2dFunctor { ...@@ -106,8 +103,9 @@ struct DepthwiseConv2dFunctor {
const int *dilations_; // [dilation_h, dilation_w] const int *dilations_; // [dilation_h, dilation_w]
}; };
template<> template <>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *input, void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape, const index_t *input_shape,
const float *filter, const float *filter,
const index_t *filter_shape, const index_t *filter_shape,
......
...@@ -35,9 +35,7 @@ struct GlobalAvgPoolingFunctor { ...@@ -35,9 +35,7 @@ struct GlobalAvgPoolingFunctor {
template <> template <>
void GlobalAvgPoolingFunctor<DeviceType::NEON, float>::operator()( void GlobalAvgPoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input, const float *input, const index_t *input_shape, float *output);
const index_t *input_shape,
float *output);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -33,7 +33,7 @@ void PoolingAvgNeonK3x3S2x2(const float *input, ...@@ -33,7 +33,7 @@ void PoolingAvgNeonK3x3S2x2(const float *input,
int out_image_size = out_height * out_width; int out_image_size = out_height * out_width;
index_t input_offset = 0; index_t input_offset = 0;
index_t output_offset = 0; index_t output_offset = 0;
float avg_factors[4] = {1.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/9.0}; float avg_factors[4] = {1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0};
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
...@@ -147,7 +147,7 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input, ...@@ -147,7 +147,7 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input,
int out_image_size = out_height * out_width; int out_image_size = out_height * out_width;
index_t input_offset = 0; index_t input_offset = 0;
index_t output_offset = 0; index_t output_offset = 0;
float avg_factors[4] = {1.0/9.0, 1.0/9.0, 1.0/9.0, 1.0/9.0}; float avg_factors[4] = {1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0};
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) { for (int b = 0; b < batch; ++b) {
...@@ -200,8 +200,9 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input, ...@@ -200,8 +200,9 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input,
} }
for (; remain > 0; remain--) { for (; remain > 0; remain--) {
*outptr = (r0[0] + r0[1] + r0[2] + r1[0] + r1[1] + r1[2] + *outptr = (r0[0] + r0[1] + r0[2] + r1[0] + r1[1] + r1[2] + r2[0] +
r2[0] + r2[1] + r2[2]) / 9.0; r2[1] + r2[2]) /
9.0;
r0 += 2; r0 += 2;
r1 += 2; r1 += 2;
......
...@@ -10,16 +10,16 @@ namespace kernels { ...@@ -10,16 +10,16 @@ namespace kernels {
template <> template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()( void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float* input, const float *input,
const float* scale, const float *scale,
const float* offset, const float *offset,
const float* mean, const float *mean,
const float* var, const float *var,
const float variance_epsilon, const float variance_epsilon,
const index_t n, const index_t n,
const index_t channel, const index_t channel,
const index_t sample_size, const index_t sample_size,
float* output) { float *output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is // The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
...@@ -40,8 +40,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()( ...@@ -40,8 +40,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset); float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (index_t i = 0; i < n; ++i) { for (index_t i = 0; i < n; ++i) {
const float* input_sample_ptr = input + pos; const float *input_sample_ptr = input + pos;
float* output_sample_ptr = output + pos; float *output_sample_ptr = output + pos;
for (index_t j = 0; j < count; ++j) { for (index_t j = 0; j < count; ++j) {
float32x4_t input_f = vld1q_f32(input_sample_ptr); float32x4_t input_f = vld1q_f32(input_sample_ptr);
......
...@@ -41,7 +41,8 @@ extern void Conv2dNeonK5x5S1(const float *input, ...@@ -41,7 +41,8 @@ extern void Conv2dNeonK5x5S1(const float *input,
const index_t *output_shape); const index_t *output_shape);
template <> template <>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input, void Conv2dFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape, const index_t *input_shape,
const float *filter, const float *filter,
const index_t *filter_shape, const index_t *filter_shape,
...@@ -49,12 +50,8 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input, ...@@ -49,12 +50,8 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
float *output, float *output,
const index_t *output_shape) { const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)( typedef void (*Conv2dNeonFunction)(
const float *input, const float *input, const index_t *input_shape, const float *filter,
const index_t *input_shape, const index_t *filter_shape, const float *bias, float *output,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape); const index_t *output_shape);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = { static const Conv2dNeonFunction selector[5][2] = {
...@@ -81,12 +78,14 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input, ...@@ -81,12 +78,14 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
// Keep this alive during kernel execution // Keep this alive during kernel execution
Tensor padded_input; Tensor padded_input;
if (paddings_[0] > 0 || paddings_[1] > 0) { if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input); ConstructInputWithPadding(input, input_shape, paddings_.data(),
&padded_input);
input = padded_input.data<float>(); input = padded_input.data<float>();
input_shape = padded_input.shape().data(); input_shape = padded_input.shape().data();
} }
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input, input_shape, filter, nullptr, bias, output, output_shape); conv2d_neon_func(input, input_shape, filter, nullptr, bias, output,
output_shape);
} }
} // namespace kernels } // namespace kernels
......
...@@ -10,9 +10,8 @@ namespace mace { ...@@ -10,9 +10,8 @@ namespace mace {
namespace kernels { namespace kernels {
static constexpr index_t kInputChannelBlockSize = 2; static constexpr index_t kInputChannelBlockSize = 2;
static constexpr index_t kOutputChannelBlockSize = 4; static constexpr index_t kOutputChannelBlockSize = 4;
static __attribute__((__aligned__(64))) int32_t mask_array[8] = { static __attribute__((__aligned__(64)))
0, 0, 0, 0, -1, -1, -1, -1 int32_t mask_array[8] = {0, 0, 0, 0, -1, -1, -1, -1};
};
static inline void NeonConv2x4Kernel(index_t input_channels, static inline void NeonConv2x4Kernel(index_t input_channels,
index_t pixel_size, index_t pixel_size,
...@@ -77,15 +76,15 @@ static inline void NeonConv2x4Kernel(index_t input_channels, ...@@ -77,15 +76,15 @@ static inline void NeonConv2x4Kernel(index_t input_channels,
output3 = output3 + pixel_size - 4; output3 = output3 + pixel_size - 4;
float32x4_t voutput3 = vld1q_f32(output3); float32x4_t voutput3 = vld1q_f32(output3);
const float32x4_t vinput0 = vreinterpretq_f32_s32( const float32x4_t vinput0 = vreinterpretq_f32_s32(vandq_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4])))); vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0); voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0); voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0); voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0); voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
const float32x4_t vinput1 = vreinterpretq_f32_s32( const float32x4_t vinput1 = vreinterpretq_f32_s32(vandq_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4])))); vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1); voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1); voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1); voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
...@@ -98,7 +97,8 @@ static inline void NeonConv2x4Kernel(index_t input_channels, ...@@ -98,7 +97,8 @@ static inline void NeonConv2x4Kernel(index_t input_channels,
} }
} }
static inline void NeonConv2x4SubBlockKernel(index_t input_channels_subblock_size, static inline void NeonConv2x4SubBlockKernel(
index_t input_channels_subblock_size,
index_t output_channels_subblock_size, index_t output_channels_subblock_size,
index_t input_channels, index_t input_channels,
index_t pixel_size, index_t pixel_size,
...@@ -204,16 +204,16 @@ static inline void NeonConv2x4SubBlockKernel(index_t input_channels_subblock_siz ...@@ -204,16 +204,16 @@ static inline void NeonConv2x4SubBlockKernel(index_t input_channels_subblock_siz
} }
} }
const float32x4_t vinput0 = vreinterpretq_f32_s32( const float32x4_t vinput0 = vreinterpretq_f32_s32(vandq_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4])))); vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0); voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0); voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0); voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0); voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
if (input_channels_subblock_size > 1) { if (input_channels_subblock_size > 1) {
const float32x4_t vinput1 = vreinterpretq_f32_s32( const float32x4_t vinput1 = vreinterpretq_f32_s32(vandq_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4])))); vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1); voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1); voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1); voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
...@@ -259,22 +259,27 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW ...@@ -259,22 +259,27 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
for (int i = 0; i < channels; ++i) { for (int i = 0; i < channels; ++i) {
float *output_ptr_base = output + n * channels * total_pixels + i * total_pixels; float *output_ptr_base =
std::fill(output_ptr_base, output_ptr_base + total_pixels, bias ? bias[i] : 0); output + n * channels * total_pixels + i * total_pixels;
std::fill(output_ptr_base, output_ptr_base + total_pixels,
bias ? bias[i] : 0);
} }
} }
// benchmark omp collapsed(2) // benchmark omp collapsed(2)
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < round_up_channels; c += kOutputChannelBlockSize) { for (index_t c = 0; c < round_up_channels; c += kOutputChannelBlockSize) {
const float *input_ptr = input + n * input_channels * total_pixels; const float *input_ptr = input + n * input_channels * total_pixels;
const float *filter_ptr = filter + c * input_channels; const float *filter_ptr = filter + c * input_channels;
float *output_ptr = output + n * channels * total_pixels + c * total_pixels; float *output_ptr =
const index_t output_channel_block_size = std::min(channels - c, kOutputChannelBlockSize); output + n * channels * total_pixels + c * total_pixels;
const index_t output_channel_block_size =
std::min(channels - c, kOutputChannelBlockSize);
index_t remain_input_channels = input_channels; index_t remain_input_channels = input_channels;
if (c + kOutputChannelBlockSize <= channels) { if (c + kOutputChannelBlockSize <= channels) {
while (remain_input_channels >= kInputChannelBlockSize) { while (remain_input_channels >= kInputChannelBlockSize) {
NeonConv2x4Kernel(input_channels, total_pixels, input_ptr, filter_ptr, output_ptr); NeonConv2x4Kernel(input_channels, total_pixels, input_ptr, filter_ptr,
output_ptr);
input_ptr += kInputChannelBlockSize * total_pixels; input_ptr += kInputChannelBlockSize * total_pixels;
filter_ptr += kInputChannelBlockSize; filter_ptr += kInputChannelBlockSize;
...@@ -282,19 +287,21 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW ...@@ -282,19 +287,21 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
} }
} }
while (remain_input_channels != 0) { while (remain_input_channels != 0) {
const index_t input_channel_block_size = std::min(remain_input_channels, kInputChannelBlockSize); const index_t input_channel_block_size =
NeonConv2x4SubBlockKernel(input_channel_block_size, output_channel_block_size, std::min(remain_input_channels, kInputChannelBlockSize);
input_channels, total_pixels, input_ptr, filter_ptr, output_ptr); NeonConv2x4SubBlockKernel(
input_channel_block_size, output_channel_block_size, input_channels,
total_pixels, input_ptr, filter_ptr, output_ptr);
input_ptr += kInputChannelBlockSize * total_pixels; input_ptr += kInputChannelBlockSize * total_pixels;
filter_ptr += kInputChannelBlockSize; filter_ptr += kInputChannelBlockSize;
remain_input_channels -= input_channel_block_size; remain_input_channels -= input_channel_block_size;
} }
} }
} }
}; };
void Conv2dNeonPixelK1x1S1(const float *input, // NCHW void Conv2dNeonPixelK1x1S1(
const float *input, // NCHW
const index_t *input_shape, const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape, const index_t *filter_shape,
...@@ -320,7 +327,7 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW ...@@ -320,7 +327,7 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
const index_t total_loops = total_pixels >> 3; const index_t total_loops = total_pixels >> 3;
const index_t loop_remaining = total_pixels & 7; const index_t loop_remaining = total_pixels & 7;
// benchmark omp collapsed(2) // benchmark omp collapsed(2)
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
......
...@@ -18,7 +18,6 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW ...@@ -18,7 +18,6 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
const float *bias, // c_out const float *bias, // c_out
float *output, // NCHW float *output, // NCHW
const index_t *output_shape) { const index_t *output_shape) {
int height_count = (output_shape[2] >> 1) << 1; int height_count = (output_shape[2] >> 1) << 1;
int output_batch = output_shape[0]; int output_batch = output_shape[0];
...@@ -29,26 +28,32 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW ...@@ -29,26 +28,32 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
int input_channels = input_shape[1]; int input_channels = input_shape[1];
int input_height = input_shape[2]; int input_height = input_shape[2];
int input_width = input_shape[3]; int input_width = input_shape[3];
int multiplier = filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels); int multiplier =
int filter_in_channels = filter_shape == nullptr ? input_channels : filter_shape[1]; filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels);
int filter_in_channels =
filter_shape == nullptr ? input_channels : filter_shape[1];
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) { for (int b = 0; b < output_batch; ++b) {
for (int oc = 0; oc < output_channels; ++oc) { for (int oc = 0; oc < output_channels; ++oc) {
float *output_ptr_base = output + b * output_channels * output_height * output_width; float *output_ptr_base =
output + b * output_channels * output_height * output_width;
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize; const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize;
const float *input_ptr = input + b * input_channels * input_height * input_width; const float *input_ptr =
input + b * input_channels * input_height * input_width;
if (filter_shape != nullptr) { if (filter_shape != nullptr) {
input_ptr += (oc / multiplier) * input_height * input_width; input_ptr += (oc / multiplier) * input_height * input_width;
} }
float *output_ptr = output_ptr_base + oc * output_height * output_width; float *output_ptr = output_ptr_base + oc * output_height * output_width;
std::fill(output_ptr, output_ptr + output_height * output_width, bias ? bias[oc] : 0); std::fill(output_ptr, output_ptr + output_height * output_width,
bias ? bias[oc] : 0);
for (int ic = 0; ic < filter_in_channels; ++ic) { for (int ic = 0; ic < filter_in_channels; ++ic) {
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)}; float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr),
vld1q_f32(filter_ptr + 3),
vld1q_f32(filter_ptr + 6)};
const float *row_ptr_v[kRegisterSize] = { const float *row_ptr_v[kRegisterSize] = {
input_ptr, input_ptr + input_width, input_ptr, input_ptr + input_width, input_ptr + 2 * input_width,
input_ptr + 2 * input_width, input_ptr + 3 * input_width input_ptr + 3 * input_width};
};
float *output_ptr_v[] = {output_ptr, output_ptr + output_width}; float *output_ptr_v[] = {output_ptr, output_ptr + output_width};
...@@ -69,8 +74,10 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW ...@@ -69,8 +74,10 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
float32x4_t n_row1_former = vld1q_f32(row_ptr_v[1]); float32x4_t n_row1_former = vld1q_f32(row_ptr_v[1]);
float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize); float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize);
float32x4_t n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1); float32x4_t n_row1_ext0 =
float32x4_t n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2); vextq_f32(n_row1_former, n_row1_latter, 1);
float32x4_t n_row1_ext1 =
vextq_f32(n_row1_former, n_row1_latter, 2);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_former, n_filter_v[1], 0); n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_former, n_filter_v[1], 0);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext0, n_filter_v[1], 1); n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext0, n_filter_v[1], 1);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext1, n_filter_v[1], 2); n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext1, n_filter_v[1], 2);
...@@ -115,11 +122,9 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW ...@@ -115,11 +122,9 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
} }
} }
for (; remain_count > 0; --remain_count) { for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = { float32x4_t n_row_v[] = {vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]), vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2]) vld1q_f32(row_ptr_v[2])};
};
float32x4_t n_sum0 = vmulq_f32(n_row_v[0], n_filter_v[0]); float32x4_t n_sum0 = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum0 = vmlaq_f32(n_sum0, n_row_v[1], n_filter_v[1]); n_sum0 = vmlaq_f32(n_sum0, n_row_v[1], n_filter_v[1]);
n_sum0 = vmlaq_f32(n_sum0, n_row_v[2], n_filter_v[2]); n_sum0 = vmlaq_f32(n_sum0, n_row_v[2], n_filter_v[2]);
...@@ -185,8 +190,7 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW ...@@ -185,8 +190,7 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
} }
for (; remain_count > 0; --remain_count) { for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = { float32x4_t n_row_v[] = {
vld1q_f32(row_ptr_v[0]), vld1q_f32(row_ptr_v[0]), vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2]), vld1q_f32(row_ptr_v[2]),
}; };
...@@ -227,26 +231,32 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW ...@@ -227,26 +231,32 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
int input_channels = input_shape[1]; int input_channels = input_shape[1];
int input_height = input_shape[2]; int input_height = input_shape[2];
int input_width = input_shape[3]; int input_width = input_shape[3];
int multiplier = filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels); int multiplier =
int filter_in_channels = filter_shape == nullptr ? input_channels : filter_shape[1]; filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels);
int filter_in_channels =
filter_shape == nullptr ? input_channels : filter_shape[1];
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) { for (int b = 0; b < output_batch; ++b) {
for (int oc = 0; oc < output_channels; ++oc) { for (int oc = 0; oc < output_channels; ++oc) {
float *output_ptr_base = output + b * output_channels * output_height * output_width; float *output_ptr_base =
output + b * output_channels * output_height * output_width;
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize; const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize;
const float *input_ptr = input + b * input_channels * input_height * input_width; const float *input_ptr =
input + b * input_channels * input_height * input_width;
if (filter_shape != nullptr) { if (filter_shape != nullptr) {
input_ptr += (oc / multiplier) * input_height * input_width; input_ptr += (oc / multiplier) * input_height * input_width;
} }
float *output_ptr = output_ptr_base + oc * output_height * output_width; float *output_ptr = output_ptr_base + oc * output_height * output_width;
std::fill(output_ptr, output_ptr + output_height * output_width, bias ? bias[oc] : 0); std::fill(output_ptr, output_ptr + output_height * output_width,
bias ? bias[oc] : 0);
for (int ic = 0; ic < filter_in_channels; ++ic) { for (int ic = 0; ic < filter_in_channels; ++ic) {
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)}; float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr),
vld1q_f32(filter_ptr + 3),
vld1q_f32(filter_ptr + 6)};
const float *row_ptr_v[3] = { const float *row_ptr_v[3] = {input_ptr, input_ptr + input_width,
input_ptr, input_ptr + input_width, input_ptr + 2 * input_width input_ptr + 2 * input_width};
};
float *output_ptr_inner = output_ptr; float *output_ptr_inner = output_ptr;
...@@ -259,24 +269,33 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW ...@@ -259,24 +269,33 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
float32x4x2_t n_row_former = vld2q_f32(row_ptr_v[0]); float32x4x2_t n_row_former = vld2q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + 8); float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + 8);
float32x4_t n_row_ext = vextq_f32(n_row_former.val[0], n_row_latter, 1); float32x4_t n_row_ext =
vextq_f32(n_row_former.val[0], n_row_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[0], n_filter_v[0], 0); n_sum =
n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[1], n_filter_v[0], 1); vfmaq_laneq_f32(n_sum, n_row_former.val[0], n_filter_v[0], 0);
n_sum =
vfmaq_laneq_f32(n_sum, n_row_former.val[1], n_filter_v[0], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext, n_filter_v[0], 2); n_sum = vfmaq_laneq_f32(n_sum, n_row_ext, n_filter_v[0], 2);
float32x4x2_t n_row1_former = vld2q_f32(row_ptr_v[1]); float32x4x2_t n_row1_former = vld2q_f32(row_ptr_v[1]);
float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + 8); float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + 8);
float32x4_t n_row1_ext = vextq_f32(n_row1_former.val[0], n_row1_latter, 1); float32x4_t n_row1_ext =
n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[0], n_filter_v[1], 0); vextq_f32(n_row1_former.val[0], n_row1_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[1], n_filter_v[1], 1); n_sum =
vfmaq_laneq_f32(n_sum, n_row1_former.val[0], n_filter_v[1], 0);
n_sum =
vfmaq_laneq_f32(n_sum, n_row1_former.val[1], n_filter_v[1], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_ext, n_filter_v[1], 2); n_sum = vfmaq_laneq_f32(n_sum, n_row1_ext, n_filter_v[1], 2);
float32x4x2_t n_row2_former = vld2q_f32(row_ptr_v[2]); float32x4x2_t n_row2_former = vld2q_f32(row_ptr_v[2]);
float32x4_t n_row2_latter = vld1q_f32(row_ptr_v[2] + 8); float32x4_t n_row2_latter = vld1q_f32(row_ptr_v[2] + 8);
float32x4_t n_row2_ext = vextq_f32(n_row2_former.val[0], n_row2_latter, 1); float32x4_t n_row2_ext =
n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[0], n_filter_v[2], 0); vextq_f32(n_row2_former.val[0], n_row2_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[1], n_filter_v[2], 1); n_sum =
vfmaq_laneq_f32(n_sum, n_row2_former.val[0], n_filter_v[2], 0);
n_sum =
vfmaq_laneq_f32(n_sum, n_row2_former.val[1], n_filter_v[2], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_ext, n_filter_v[2], 2); n_sum = vfmaq_laneq_f32(n_sum, n_row2_ext, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_inner); float32x4_t n_output_row = vld1q_f32(output_ptr_inner);
...@@ -288,11 +307,9 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW ...@@ -288,11 +307,9 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
} }
} }
for (; remain_count > 0; --remain_count) { for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = { float32x4_t n_row_v[] = {vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]), vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2]) vld1q_f32(row_ptr_v[2])};
};
float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]); float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]); n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]);
n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]); n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]);
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/kernels/conv_2d.h" #include "mace/kernels/conv_2d.h"
#include "mace/kernels/depthwise_conv2d.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -24,8 +24,9 @@ extern void Conv2dNeonK3x3S2(const float *input, ...@@ -24,8 +24,9 @@ extern void Conv2dNeonK3x3S2(const float *input,
float *output, float *output,
const index_t *output_shape); const index_t *output_shape);
template<> template <>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *input, // NCHW void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
const float *input, // NCHW
const index_t *input_shape, const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape, const index_t *filter_shape,
...@@ -33,12 +34,8 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in ...@@ -33,12 +34,8 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in
float *output, // NCHW float *output, // NCHW
const index_t *output_shape) { const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)( typedef void (*Conv2dNeonFunction)(
const float *input, const float *input, const index_t *input_shape, const float *filter,
const index_t *input_shape, const index_t *filter_shape, const float *bias, float *output,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape); const index_t *output_shape);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = { static const Conv2dNeonFunction selector[5][2] = {
...@@ -57,7 +54,8 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in ...@@ -57,7 +54,8 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in
<< "filter" << kernel_h << "x" << kernel_w << "," << "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1] << " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version"; << " is not implemented yet, using slow version";
DepthwiseConv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)( DepthwiseConv2dFunctor<DeviceType::CPU, float>(strides_, paddings_,
dilations_)(
input, input_shape, filter, filter_shape, bias, output, output_shape); input, input_shape, filter, filter_shape, bias, output, output_shape);
return; return;
} }
...@@ -65,12 +63,14 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in ...@@ -65,12 +63,14 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in
// Keep this alive during kernel execution // Keep this alive during kernel execution
Tensor padded_input; Tensor padded_input;
if (paddings_[0] > 0 || paddings_[1] > 0) { if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input); ConstructInputWithPadding(input, input_shape, paddings_.data(),
&padded_input);
input = padded_input.data<float>(); input = padded_input.data<float>();
input_shape = padded_input.shape().data(); input_shape = padded_input.shape().data();
} }
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input, input_shape, filter, filter_shape, bias, output, output_shape); conv2d_neon_func(input, input_shape, filter, filter_shape, bias, output,
output_shape);
} }
} // namespace kernels } // namespace kernels
......
...@@ -8,11 +8,9 @@ ...@@ -8,11 +8,9 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template<> template <>
void GlobalAvgPoolingFunctor<DeviceType::NEON, float>::operator()( void GlobalAvgPoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input, const float *input, const index_t *input_shape, float *output) {
const index_t *input_shape,
float *output) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
index_t channels = input_shape[1]; index_t channels = input_shape[1];
index_t height = input_shape[2]; index_t height = input_shape[2];
......
...@@ -55,7 +55,7 @@ extern void PoolingAvgNeonK3x3S2x2Padded(const float *input, ...@@ -55,7 +55,7 @@ extern void PoolingAvgNeonK3x3S2x2Padded(const float *input,
const index_t *out_shape); const index_t *out_shape);
#endif #endif
template<> template <>
void PoolingFunctor<DeviceType::NEON, float>::operator()( void PoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input, const float *input,
const index_t *input_shape, const index_t *input_shape,
......
...@@ -13,18 +13,18 @@ namespace mace { ...@@ -13,18 +13,18 @@ namespace mace {
template <DeviceType D, class T> template <DeviceType D, class T>
class AddNOp : public Operator<D, T> { class AddNOp : public Operator<D, T> {
public: public:
AddNOp(const OperatorDef& operator_def, Workspace* ws) AddNOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {} : Operator<D, T>(operator_def, ws) {}
bool Run() override { bool Run() override {
Tensor* output_tensor = this->outputs_[0]; Tensor *output_tensor = this->outputs_[0];
output_tensor->ResizeLike(this->inputs_[0]); output_tensor->ResizeLike(this->inputs_[0]);
T* output = output_tensor->mutable_data<T>(); T *output = output_tensor->mutable_data<T>();
index_t size = this->inputs_[0]->size(); index_t size = this->inputs_[0]->size();
int n = this->inputs_.size(); int n = this->inputs_.size();
vector<const T*> inputs(n); vector<const T *> inputs(n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
const Tensor* input_tensor = this->inputs_[i]; const Tensor *input_tensor = this->inputs_[i];
inputs[i] = input_tensor->data<T>(); inputs[i] = input_tensor->data<T>();
} }
......
...@@ -39,7 +39,7 @@ static void AddNBenchmark(int iters, int n, int size) { ...@@ -39,7 +39,7 @@ static void AddNBenchmark(int iters, int n, int size) {
static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE(int iters) { \ static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \ const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \ AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \
} \ } \
BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE) BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE)
......
...@@ -11,7 +11,7 @@ class AddnOpTest : public OpsTestBase {}; ...@@ -11,7 +11,7 @@ class AddnOpTest : public OpsTestBase {};
TEST_F(AddnOpTest, AddnOp) { TEST_F(AddnOpTest, AddnOp) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("AddN", "AddNTest") OpDefBuilder("AddN", "AddNTest")
.Input("Input1") .Input("Input1")
.Input("Input2") .Input("Input2")
......
...@@ -13,17 +13,16 @@ namespace mace { ...@@ -13,17 +13,16 @@ namespace mace {
template <DeviceType D, class T> template <DeviceType D, class T>
class BatchNormOp : public Operator<D, T> { class BatchNormOp : public Operator<D, T> {
public: public:
BatchNormOp(const OperatorDef& operator_def, Workspace* ws) BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws), functor_() {}
functor_() {}
bool Run() override { bool Run() override {
const Tensor* input = this->Input(0); const Tensor *input = this->Input(0);
const Tensor* scale = this->Input(1); const Tensor *scale = this->Input(1);
const Tensor* offset = this->Input(2); const Tensor *offset = this->Input(2);
const Tensor* mean = this->Input(3); const Tensor *mean = this->Input(3);
const Tensor* var = this->Input(4); const Tensor *var = this->Input(4);
const Tensor* epsilon = this->Input(5); const Tensor *epsilon = this->Input(5);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size()); input->dim_size());
...@@ -38,23 +37,23 @@ class BatchNormOp : public Operator<D, T> { ...@@ -38,23 +37,23 @@ class BatchNormOp : public Operator<D, T> {
MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ", MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ",
epsilon->dim_size()); epsilon->dim_size());
Tensor* output = this->Output(0); Tensor *output = this->Output(0);
output->ResizeLike(input); output->ResizeLike(input);
const index_t n = input->dim(0); const index_t n = input->dim(0);
const index_t channel = input->dim(1); const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3); const index_t sample_size = input->dim(2) * input->dim(3);
const T* input_ptr = input->data<T>(); const T *input_ptr = input->data<T>();
const T* scale_ptr = scale->data<T>(); const T *scale_ptr = scale->data<T>();
const T* offset_ptr = offset->data<T>(); const T *offset_ptr = offset->data<T>();
const T* mean_ptr = mean->data<T>(); const T *mean_ptr = mean->data<T>();
const T* var_ptr = var->data<T>(); const T *var_ptr = var->data<T>();
const T* epsilon_ptr = epsilon->data<T>(); const T *epsilon_ptr = epsilon->data<T>();
T* output_ptr = output->mutable_data<T>(); T *output_ptr = output->mutable_data<T>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, *epsilon_ptr, n, channel, functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, *epsilon_ptr,
sample_size, output_ptr); n, channel, sample_size, output_ptr);
return true; return true;
} }
......
...@@ -47,7 +47,7 @@ static void BatchNorm( ...@@ -47,7 +47,7 @@ static void BatchNorm(
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \ BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \
} \ } \
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
......
...@@ -11,7 +11,7 @@ class BatchNormOpTest : public OpsTestBase {}; ...@@ -11,7 +11,7 @@ class BatchNormOpTest : public OpsTestBase {};
TEST_F(BatchNormOpTest, SimpleCPU) { TEST_F(BatchNormOpTest, SimpleCPU) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input") .Input("Input")
.Input("Scale") .Input("Scale")
...@@ -51,7 +51,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -51,7 +51,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
index_t height = 103; index_t height = 103;
index_t width = 113; index_t width = 113;
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input") .Input("Input")
.Input("Scale") .Input("Scale")
...@@ -74,7 +74,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -74,7 +74,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
net.RunOp(); net.RunOp();
// Check // Check
Tensor* expected = net.GetOutput("Output"); Tensor *expected = net.GetOutput("Output");
// Run NEON // Run NEON
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
namespace mace { namespace mace {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class ChannelShuffleOp : public Operator<D, T> { class ChannelShuffleOp : public Operator<D, T> {
public: public:
ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws) ChannelShuffleOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
group_(OperatorBase::GetSingleArgument<int>("group", 1)), group_(OperatorBase::GetSingleArgument<int>("group", 1)),
functor_(this->group_) {} functor_(this->group_) {}
......
...@@ -11,12 +11,8 @@ using namespace mace; ...@@ -11,12 +11,8 @@ using namespace mace;
using namespace mace::kernels; using namespace mace::kernels;
template <DeviceType D> template <DeviceType D>
static void ChannelShuffle(int iters, static void ChannelShuffle(
int batch, int iters, int batch, int channels, int height, int width, int group) {
int channels,
int height,
int width,
int group) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
...@@ -41,12 +37,11 @@ static void ChannelShuffle(int iters, ...@@ -41,12 +37,11 @@ static void ChannelShuffle(int iters,
} }
#define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, DEVICE) \ #define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, DEVICE) \
static void \ static void BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE( \
BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(float))); \ mace::testing::BytesProcessed(tot *(sizeof(float))); \
ChannelShuffle<DEVICE>(iters, N, C, H, W, G); \ ChannelShuffle<DEVICE>(iters, N, C, H, W, G); \
} \ } \
BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE) BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE)
......
...@@ -10,7 +10,7 @@ class ChannelShuffleOpTest : public OpsTestBase {}; ...@@ -10,7 +10,7 @@ class ChannelShuffleOpTest : public OpsTestBase {};
TEST_F(ChannelShuffleOpTest, C8G4) { TEST_F(ChannelShuffleOpTest, C8G4) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest") OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -21,17 +21,14 @@ TEST_F(ChannelShuffleOpTest, C8G4) { ...@@ -21,17 +21,14 @@ TEST_F(ChannelShuffleOpTest, C8G4) {
// Add input data // Add input data
net.AddInputFromArray<float>( net.AddInputFromArray<float>(
"Input", {1, 8, 1, 2}, "Input", {1, 8, 1, 2},
{0, 1, 2, 3, 4, 5, 6, 7, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
8, 9, 10, 11, 12, 13, 14, 15});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = auto expected = CreateTensor<float>(
CreateTensor<float>({1, 8, 1, 2}, {1, 8, 1, 2}, {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15});
{0, 1, 4, 5, 8, 9, 12, 13,
2, 3, 6, 7, 10, 11, 14, 15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
#ifndef MACE_OPS_CONCAT_H_ #ifndef MACE_OPS_CONCAT_H_
#define MACE_OPS_CONCAT_H_ #define MACE_OPS_CONCAT_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/concat.h" #include "mace/kernels/concat.h"
#include "mace/proto/mace.pb.h"
namespace mace { namespace mace {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class ConcatOp : public Operator<D, T> { class ConcatOp : public Operator<D, T> {
public: public:
ConcatOp(const OperatorDef &op_def, Workspace *ws) ConcatOp(const OperatorDef &op_def, Workspace *ws)
...@@ -25,9 +25,11 @@ class ConcatOp : public Operator<D, T> { ...@@ -25,9 +25,11 @@ class ConcatOp : public Operator<D, T> {
axis_tensor->dim_size()); axis_tensor->dim_size());
const int32_t concat_axis = *(axis_tensor->data<int32_t>()); const int32_t concat_axis = *(axis_tensor->data<int32_t>());
const int32_t input_dims = input0->dim_size(); const int32_t input_dims = input0->dim_size();
const int32_t axis = concat_axis < 0 ? concat_axis + input_dims : concat_axis; const int32_t axis =
MACE_CHECK((0 <= axis && axis < input_dims), "Expected concatenating axis in the range [", concat_axis < 0 ? concat_axis + input_dims : concat_axis;
-input_dims, ", ", input_dims, "], but got", concat_axis); MACE_CHECK((0 <= axis && axis < input_dims),
"Expected concatenating axis in the range [", -input_dims, ", ",
input_dims, "], but got", concat_axis);
std::vector<index_t> output_shape(input0->shape()); std::vector<index_t> output_shape(input0->shape());
index_t inner_size = 1; index_t inner_size = 1;
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
...@@ -40,10 +42,14 @@ class ConcatOp : public Operator<D, T> { ...@@ -40,10 +42,14 @@ class ConcatOp : public Operator<D, T> {
const Tensor *input = nullptr; const Tensor *input = nullptr;
for (int i = 1; i < values_count; ++i) { for (int i = 1; i < values_count; ++i) {
input = this->Input(i); input = this->Input(i);
MACE_CHECK(input->dim_size() == input0->dim_size(), "Ranks of all input tensors must be same."); MACE_CHECK(input->dim_size() == input0->dim_size(),
"Ranks of all input tensors must be same.");
for (int j = 0; j < axis_tensor->dim_size(); ++j) { for (int j = 0; j < axis_tensor->dim_size(); ++j) {
if (j == axis) { continue; } if (j == axis) {
MACE_CHECK(input->dim(j) == input0->dim(j), "Dimensions of inputs should equal except axis."); continue;
}
MACE_CHECK(input->dim(j) == input0->dim(j),
"Dimensions of inputs should equal except axis.");
} }
input_list[i] = input->data<T>(); input_list[i] = input->data<T>();
outer_sizes[i] = input->size() / inner_size; outer_sizes[i] = input->size() / inner_size;
...@@ -53,9 +59,11 @@ class ConcatOp : public Operator<D, T> { ...@@ -53,9 +59,11 @@ class ConcatOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
output->Resize(output_shape); output->Resize(output_shape);
functor_(input_list, inner_size, outer_sizes.data(), output->mutable_data<T>()); functor_(input_list, inner_size, outer_sizes.data(),
output->mutable_data<T>());
return true; return true;
} }
private: private:
kernels::ConcatFunctor<D, T> functor_; kernels::ConcatFunctor<D, T> functor_;
......
...@@ -7,9 +7,8 @@ ...@@ -7,9 +7,8 @@
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
namespace mace { namespace mace {
template<DeviceType D, typename T> template <DeviceType D, typename T>
static void ConcatHelper( static void ConcatHelper(int iters, int concat_dim, int dim1) {
int iters, int concat_dim, int dim1) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
// //
#include "mace/ops/concat.h" #include "mace/ops/concat.h"
#include "mace/ops/ops_test_util.h"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
using namespace mace; using namespace mace;
...@@ -99,9 +99,7 @@ TEST_F(ConcatOpTest, Random) { ...@@ -99,9 +99,7 @@ TEST_F(ConcatOpTest, Random) {
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
builder = builder.Input(("Input" + ToString(i)).c_str()); builder = builder.Input(("Input" + ToString(i)).c_str());
} }
builder.Input("Axis") builder.Input("Axis").Output("Output").Finalize(net.operator_def());
.Output("Output")
.Finalize(net.operator_def());
std::vector<index_t> shape_data; std::vector<index_t> shape_data;
GenerateRandomIntTypeData<index_t>({dim}, shape_data, 1, dim); GenerateRandomIntTypeData<index_t>({dim}, shape_data, 1, dim);
...@@ -114,7 +112,8 @@ TEST_F(ConcatOpTest, Random) { ...@@ -114,7 +112,8 @@ TEST_F(ConcatOpTest, Random) {
concat_axis_size += input_shapes[i][axis]; concat_axis_size += input_shapes[i][axis];
GenerateRandomRealTypeData(input_shapes[i], inputs[i]); GenerateRandomRealTypeData(input_shapes[i], inputs[i]);
input_ptrs[i] = inputs[i].data(); input_ptrs[i] = inputs[i].data();
net.AddInputFromArray<float>(("Input" + ToString(i)).c_str(), input_shapes[i], inputs[i]); net.AddInputFromArray<float>(("Input" + ToString(i)).c_str(),
input_shapes[i], inputs[i]);
} }
net.AddInputFromArray<int>("Axis", {}, {axis}); net.AddInputFromArray<int>("Axis", {}, {axis});
...@@ -131,9 +130,9 @@ TEST_F(ConcatOpTest, Random) { ...@@ -131,9 +130,9 @@ TEST_F(ConcatOpTest, Random) {
const float *output_ptr = output->data<float>(); const float *output_ptr = output->data<float>();
while (output_ptr != (output->data<float>() + output->size())) { while (output_ptr != (output->data<float>() + output->size())) {
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
index_t num_elements = std::accumulate(input_shapes[i].begin() + axis, index_t num_elements =
input_shapes[i].end(), 1, std::accumulate(input_shapes[i].begin() + axis, input_shapes[i].end(),
std::multiplies<index_t>()); 1, std::multiplies<index_t>());
for (int j = 0; j < num_elements; ++j) { for (int j = 0; j < num_elements; ++j) {
EXPECT_EQ(*input_ptrs[i]++, *output_ptr++); EXPECT_EQ(*input_ptrs[i]++, *output_ptr++);
} }
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace mace { namespace mace {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class Conv2dOp : public ConvPool2dOpBase<D, T> { class Conv2dOp : public ConvPool2dOpBase<D, T> {
public: public:
Conv2dOp(const OperatorDef &op_def, Workspace *ws) Conv2dOp(const OperatorDef &op_def, Workspace *ws)
...@@ -35,11 +35,10 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -35,11 +35,10 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize(input->shape().data(), kernels::CalcPaddingAndOutputSize(
filter->shape().data(), input->shape().data(), filter->shape().data(), this->dilations_.data(),
this->dilations_.data(), this->strides_.data(), this->padding_, output_shape.data(),
this->strides_.data(), this->padding_, paddings.data());
output_shape.data(), paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
functor_.paddings_ = paddings; functor_.paddings_ = paddings;
......
...@@ -60,8 +60,9 @@ static void Conv2d(int iters, ...@@ -60,8 +60,9 @@ static void Conv2d(int iters,
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \ Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, \
OC); \
} \ } \
BENCHMARK( \ BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
......
...@@ -12,7 +12,7 @@ class Conv2dOpTest : public OpsTestBase {}; ...@@ -12,7 +12,7 @@ class Conv2dOpTest : public OpsTestBase {};
TEST_F(Conv2dOpTest, Simple_VALID) { TEST_F(Conv2dOpTest, Simple_VALID) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -46,7 +46,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) { ...@@ -46,7 +46,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
TEST_F(Conv2dOpTest, Simple_SAME) { TEST_F(Conv2dOpTest, Simple_SAME) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -82,7 +82,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) { ...@@ -82,7 +82,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
TEST_F(Conv2dOpTest, Combined) { TEST_F(Conv2dOpTest, Combined) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -120,7 +120,7 @@ TEST_F(Conv2dOpTest, Combined) { ...@@ -120,7 +120,7 @@ TEST_F(Conv2dOpTest, Combined) {
TEST_F(Conv2dOpTest, Conv1x1) { TEST_F(Conv2dOpTest, Conv1x1) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -172,13 +172,13 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) { ...@@ -172,13 +172,13 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) {
srand(time(NULL)); srand(time(NULL));
// generate random input // generate random input
index_t batch = 3 ; index_t batch = 3;
index_t input_channels = 64; index_t input_channels = 64;
index_t height = 32; index_t height = 32;
index_t width = 32; index_t width = 32;
index_t output_channels = 128; index_t output_channels = 128;
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -229,7 +229,7 @@ TEST_F(Conv2dOpTest, DisgustConvNxNS12) { ...@@ -229,7 +229,7 @@ TEST_F(Conv2dOpTest, DisgustConvNxNS12) {
index_t width = 113; index_t width = 113;
index_t output_channels = 3 + rand() % 10; index_t output_channels = 3 + rand() % 10;
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
......
...@@ -13,12 +13,13 @@ namespace mace { ...@@ -13,12 +13,13 @@ namespace mace {
template <DeviceType D, class T> template <DeviceType D, class T>
class ConvPool2dOpBase : public Operator<D, T> { class ConvPool2dOpBase : public Operator<D, T> {
public: public:
ConvPool2dOpBase(const OperatorDef& op_def, Workspace* ws) ConvPool2dOpBase(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
strides_(OperatorBase::GetRepeatedArgument<int>("strides")), strides_(OperatorBase::GetRepeatedArgument<int>("strides")),
padding_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>( padding_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>(
"padding", static_cast<int>(SAME)))), "padding", static_cast<int>(SAME)))),
dilations_(OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {} dilations_(
OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {}
protected: protected:
std::vector<int> strides_; std::vector<int> strides_;
......
...@@ -6,10 +6,12 @@ ...@@ -6,10 +6,12 @@
namespace mace { namespace mace {
REGISTER_CPU_OPERATOR(DepthwiseConv2d, DepthwiseConv2dOp<DeviceType::CPU, float>); REGISTER_CPU_OPERATOR(DepthwiseConv2d,
DepthwiseConv2dOp<DeviceType::CPU, float>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(DepthwiseConv2d, DepthwiseConv2dOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(DepthwiseConv2d,
DepthwiseConv2dOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -9,12 +9,12 @@ ...@@ -9,12 +9,12 @@
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/conv_2d.h" #include "mace/kernels/conv_2d.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/kernels/depthwise_conv2d.h" #include "mace/kernels/depthwise_conv2d.h"
#include "mace/ops/conv_pool_2d_base.h"
namespace mace { namespace mace {
template<DeviceType D, typename T> template <DeviceType D, typename T>
class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
public: public:
DepthwiseConv2dOp(const OperatorDef &op_def, Workspace *ws) DepthwiseConv2dOp(const OperatorDef &op_def, Workspace *ws)
...@@ -34,16 +34,16 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -34,16 +34,16 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
// resize filter shape. // resize filter shape.
std::vector<index_t> filter_shape(filter->shape().begin(), filter->shape().end()); std::vector<index_t> filter_shape(filter->shape().begin(),
filter->shape().end());
filter_shape[0] *= filter_shape[1]; filter_shape[0] *= filter_shape[1];
filter_shape[1] = 1; filter_shape[1] = 1;
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize(input->shape().data(), kernels::CalcPaddingAndOutputSize(
filter_shape.data(), input->shape().data(), filter_shape.data(), this->dilations_.data(),
this->dilations_.data(), this->strides_.data(), this->padding_, output_shape.data(),
this->strides_.data(), this->padding_, paddings.data());
output_shape.data(), paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
functor_.paddings_ = paddings; functor_.paddings_ = paddings;
......
...@@ -12,7 +12,7 @@ class DepthwiseConv2dOpTest : public OpsTestBase {}; ...@@ -12,7 +12,7 @@ class DepthwiseConv2dOpTest : public OpsTestBase {};
TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -26,23 +26,20 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { ...@@ -26,23 +26,20 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>( net.AddInputFromArray<float>("Input", {1, 2, 2, 3},
"Input", {1, 2, 2, 3},
{1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12}); {1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12});
net.AddInputFromArray<float>( net.AddInputFromArray<float>(
"Filter", {2, 2, 2, 2}, "Filter", {2, 2, 2, 2},
{1.0f, 5.0f, 9.0f, 13.0f, {1.0f, 5.0f, 9.0f, 13.0f, 2.0f, 6.0f, 10.0f, 14.0f, 3.0f, 7.0f, 11.0f,
2.0f, 6.0f, 10.0f, 14.0f, 15.0f, 4.0f, 8.0f, 12.0f, 16.0f});
3.0f, 7.0f, 11.0f, 15.0f,
4.0f, 8.0f, 12.0f, 16.0f});
net.AddInputFromArray<float>("Bias", {4}, {.1f, .2f, .3f, .4f}); net.AddInputFromArray<float>("Bias", {4}, {.1f, .2f, .3f, .4f});
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = CreateTensor<float>({1, 4, 1, 2}, auto expected = CreateTensor<float>(
{196.1f, 252.1f, 216.2f, 280.2f, {1, 4, 1, 2},
272.3f, 344.3f, 296.4f, 376.4f}); {196.1f, 252.1f, 216.2f, 280.2f, 272.3f, 344.3f, 296.4f, 376.4f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} }
...@@ -60,7 +57,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { ...@@ -60,7 +57,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
index_t width = 113; index_t width = 113;
index_t multiplier = 3 + rand() % 10; index_t multiplier = 3 + rand() % 10;
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
.Input("Filter") .Input("Filter")
...@@ -75,8 +72,8 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { ...@@ -75,8 +72,8 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
// Add input data // Add input data
net.AddRandomInput<float>("Input", {batch, input_channels, height, width}); net.AddRandomInput<float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<float>( net.AddRandomInput<float>("Filter",
"Filter", {multiplier, input_channels, kernel_h, kernel_w}); {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<float>("Bias", {multiplier * input_channels}); net.AddRandomInput<float>("Bias", {multiplier * input_channels});
// run cpu // run cpu
net.RunOp(); net.RunOp();
......
...@@ -54,14 +54,16 @@ static void DepthwiseConv2d(int iters, ...@@ -54,14 +54,16 @@ static void DepthwiseConv2d(int iters,
} }
} }
#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \ #define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \
DEVICE) \
static void \ static void \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \ BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthwiseConv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \ DepthwiseConv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, \
mace::Padding::P, OC); \
} \ } \
BENCHMARK( \ BENCHMARK( \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace mace { namespace mace {
template<DeviceType D, class T> template <DeviceType D, class T>
class GlobalAvgPoolingOp : public Operator<D, T> { class GlobalAvgPoolingOp : public Operator<D, T> {
public: public:
GlobalAvgPoolingOp(const OperatorDef &operator_def, Workspace *ws) GlobalAvgPoolingOp(const OperatorDef &operator_def, Workspace *ws)
......
...@@ -11,11 +11,8 @@ using namespace mace; ...@@ -11,11 +11,8 @@ using namespace mace;
using namespace mace::kernels; using namespace mace::kernels;
template <DeviceType D> template <DeviceType D>
static void GlobalAvgPooling(int iters, static void GlobalAvgPooling(
int batch, int iters, int batch, int channels, int height, int width) {
int channels,
int height,
int width) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
...@@ -39,12 +36,11 @@ static void GlobalAvgPooling(int iters, ...@@ -39,12 +36,11 @@ static void GlobalAvgPooling(int iters,
} }
#define BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, DEVICE) \ #define BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, DEVICE) \
static void \ static void BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE( \
BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(float))); \ mace::testing::BytesProcessed(tot *(sizeof(float))); \
GlobalAvgPooling<DEVICE>(iters, N, C, H, W); \ GlobalAvgPooling<DEVICE>(iters, N, C, H, W); \
} \ } \
BENCHMARK(BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE) BENCHMARK(BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE)
......
...@@ -10,7 +10,7 @@ class GlobalAvgPoolingOpTest : public OpsTestBase {}; ...@@ -10,7 +10,7 @@ class GlobalAvgPoolingOpTest : public OpsTestBase {};
TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) { TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest") OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -19,24 +19,22 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) { ...@@ -19,24 +19,22 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) {
// Add input data // Add input data
std::vector<float> input(147); std::vector<float> input(147);
for (int i = 0; i < 147; ++i) { for (int i = 0; i < 147; ++i) {
input[i] = i/49 + 1; input[i] = i / 49 + 1;
} }
net.AddInputFromArray<float>( net.AddInputFromArray<float>("Input", {1, 3, 7, 7}, input);
"Input", {1, 3, 7, 7}, input);
// Run // Run
net.RunOp(); net.RunOp();
// Check // Check
auto expected = auto expected = CreateTensor<float>({1, 3, 1, 1}, {1, 2, 3});
CreateTensor<float>({1, 3, 1, 1}, {1, 2, 3});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) { TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest") OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -45,17 +43,15 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) { ...@@ -45,17 +43,15 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) {
// Add input data // Add input data
std::vector<float> input(147); std::vector<float> input(147);
for (int i = 0; i < 147; ++i) { for (int i = 0; i < 147; ++i) {
input[i] = i/49 + 1; input[i] = i / 49 + 1;
} }
net.AddInputFromArray<float>( net.AddInputFromArray<float>("Input", {1, 3, 7, 7}, input);
"Input", {1, 3, 7, 7}, input);
// Run // Run
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
auto expected = auto expected = CreateTensor<float>({1, 3, 1, 1}, {1, 2, 3});
CreateTensor<float>({1, 3, 1, 1}, {1, 2, 3});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -43,7 +43,7 @@ class OpsTestNet { ...@@ -43,7 +43,7 @@ class OpsTestNet {
public: public:
OpsTestNet() {} OpsTestNet() {}
template<typename T> template <typename T>
void AddInputFromArray(const char *name, void AddInputFromArray(const char *name,
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
const std::vector<T> &data) { const std::vector<T> &data) {
...@@ -55,7 +55,7 @@ class OpsTestNet { ...@@ -55,7 +55,7 @@ class OpsTestNet {
memcpy(input_data, data.data(), data.size() * sizeof(T)); memcpy(input_data, data.data(), data.size() * sizeof(T));
} }
template<typename T> template <typename T>
void AddRepeatedInput(const char *name, void AddRepeatedInput(const char *name,
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
const T data) { const T data) {
...@@ -66,7 +66,7 @@ class OpsTestNet { ...@@ -66,7 +66,7 @@ class OpsTestNet {
std::fill(input_data, input_data + input->size(), data); std::fill(input_data, input_data + input->size(), data);
} }
template<typename T> template <typename T>
void AddRandomInput(const char *name, void AddRandomInput(const char *name,
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
bool positive = false) { bool positive = false) {
...@@ -173,38 +173,37 @@ class OpsTestBase : public ::testing::Test { ...@@ -173,38 +173,37 @@ class OpsTestBase : public ::testing::Test {
OpsTestNet test_net_; OpsTestNet test_net_;
}; };
template<typename T> template <typename T>
void GenerateRandomRealTypeData(const std::vector<index_t> &shape, std::vector<T> &res) { void GenerateRandomRealTypeData(const std::vector<index_t> &shape,
std::vector<T> &res) {
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<T> nd(0, 1); std::normal_distribution<T> nd(0, 1);
index_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<index_t>()); index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
res.resize(size); res.resize(size);
std::generate(res.begin(), res.end(), std::generate(res.begin(), res.end(), [&gen, &nd] { return nd(gen); });
[&gen, &nd] {
return nd(gen);
});
} }
template<typename T> template <typename T>
void GenerateRandomIntTypeData(const std::vector<index_t> &shape, std::vector<T> &res, void GenerateRandomIntTypeData(const std::vector<index_t> &shape,
const T a = 0, const T b = std::numeric_limits<T>::max()) { std::vector<T> &res,
const T a = 0,
const T b = std::numeric_limits<T>::max()) {
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::uniform_int_distribution<> nd(a, b); std::uniform_int_distribution<> nd(a, b);
index_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<index_t>()); index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
res.resize(size); res.resize(size);
std::generate(res.begin(), res.end(), std::generate(res.begin(), res.end(), [&gen, &nd] { return nd(gen); });
[&gen, &nd] {
return nd(gen);
});
} }
template<typename T> template <typename T>
unique_ptr<Tensor> CreateTensor(const std::vector<index_t> &shape, unique_ptr<Tensor> CreateTensor(const std::vector<index_t> &shape,
const std::vector<T> &data) { const std::vector<T> &data) {
unique_ptr<Tensor> res(new Tensor(cpu_allocator(), DataTypeToEnum<T>::v())); unique_ptr<Tensor> res(new Tensor(cpu_allocator(), DataTypeToEnum<T>::v()));
...@@ -237,23 +236,23 @@ inline std::string ShapeToString(const Tensor &x) { ...@@ -237,23 +236,23 @@ inline std::string ShapeToString(const Tensor &x) {
return std::string(stream.str()); return std::string(stream.str());
} }
template<typename T> template <typename T>
struct is_floating_point_type { struct is_floating_point_type {
static const bool value = static const bool value =
std::is_same<T, float>::value || std::is_same<T, double>::value; std::is_same<T, float>::value || std::is_same<T, double>::value;
}; };
template<typename T> template <typename T>
inline void ExpectEqual(const T &a, const T &b) { inline void ExpectEqual(const T &a, const T &b) {
EXPECT_EQ(a, b); EXPECT_EQ(a, b);
} }
template<> template <>
inline void ExpectEqual<float>(const float &a, const float &b) { inline void ExpectEqual<float>(const float &a, const float &b) {
EXPECT_FLOAT_EQ(a, b); EXPECT_FLOAT_EQ(a, b);
} }
template<> template <>
inline void ExpectEqual<double>(const double &a, const double &b) { inline void ExpectEqual<double>(const double &a, const double &b) {
EXPECT_DOUBLE_EQ(a, b); EXPECT_DOUBLE_EQ(a, b);
} }
...@@ -264,11 +263,11 @@ inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) { ...@@ -264,11 +263,11 @@ inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) {
<< "y.shape [ " << ShapeToString(y) << "]"; << "y.shape [ " << ShapeToString(y) << "]";
} }
template<typename T, bool is_fp = is_floating_point_type<T>::value> template <typename T, bool is_fp = is_floating_point_type<T>::value>
struct Expector; struct Expector;
// Partial specialization for float and double. // Partial specialization for float and double.
template<typename T> template <typename T>
struct Expector<T, true> { struct Expector<T, true> {
static void Equal(const T &a, const T &b) { ExpectEqual(a, b); } static void Equal(const T &a, const T &b) { ExpectEqual(a, b); }
...@@ -294,17 +293,17 @@ struct Expector<T, true> { ...@@ -294,17 +293,17 @@ struct Expector<T, true> {
} }
}; };
template<typename T> template <typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) { void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) {
static_assert(is_floating_point_type<T>::value, static_assert(is_floating_point_type<T>::value,
"T is not a floating point type"); "T is not a floating point type");
Expector<T>::Near(x, y, abs_err); Expector<T>::Near(x, y, abs_err);
} }
template<typename T> template <typename T>
std::string ToString(const T& input) { std::string ToString(const T &input) {
std::stringstream ss; std::stringstream ss;
ss<<input; ss << input;
return ss.str(); return ss.str();
} }
......
...@@ -14,7 +14,7 @@ namespace mace { ...@@ -14,7 +14,7 @@ namespace mace {
template <DeviceType D, class T> template <DeviceType D, class T>
class PoolingOp : public ConvPool2dOpBase<D, T> { class PoolingOp : public ConvPool2dOpBase<D, T> {
public: public:
PoolingOp(const OperatorDef& op_def, Workspace* ws) PoolingOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws), : ConvPool2dOpBase<D, T>(op_def, ws),
kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")), kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")),
pooling_type_( pooling_type_(
...@@ -22,8 +22,8 @@ class PoolingOp : public ConvPool2dOpBase<D, T> { ...@@ -22,8 +22,8 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
"pooling_type", static_cast<int>(AVG)))){}; "pooling_type", static_cast<int>(AVG)))){};
bool Run() override { bool Run() override {
const Tensor* input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
Tensor* output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
...@@ -34,11 +34,10 @@ class PoolingOp : public ConvPool2dOpBase<D, T> { ...@@ -34,11 +34,10 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
filter_shape[2] = kernels_[0]; filter_shape[2] = kernels_[0];
filter_shape[3] = kernels_[1]; filter_shape[3] = kernels_[1];
kernels::CalcPaddingAndOutputSize(input->shape().data(), kernels::CalcPaddingAndOutputSize(
filter_shape.data(), input->shape().data(), filter_shape.data(), this->dilations_.data(),
this->dilations_.data(), this->strides_.data(), this->padding_, output_shape.data(),
this->strides_.data(), this->padding_, paddings.data());
output_shape.data(), paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
auto pooling_func = kernels::PoolingFunctor<D, T>( auto pooling_func = kernels::PoolingFunctor<D, T>(
......
...@@ -56,7 +56,7 @@ static void Pooling(int iters, ...@@ -56,7 +56,7 @@ static void Pooling(int iters,
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(float))); \ mace::testing::BytesProcessed(tot *(sizeof(float))); \
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \ Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \
PoolingType::PO); \ PoolingType::PO); \
} \ } \
......
...@@ -15,7 +15,7 @@ class PoolingOpTest : public OpsTestBase {}; ...@@ -15,7 +15,7 @@ class PoolingOpTest : public OpsTestBase {};
TEST_F(PoolingOpTest, MAX_VALID) { TEST_F(PoolingOpTest, MAX_VALID) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -46,7 +46,7 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -46,7 +46,7 @@ TEST_F(PoolingOpTest, MAX_VALID) {
TEST_F(PoolingOpTest, AVG_VALID) { TEST_F(PoolingOpTest, AVG_VALID) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -77,7 +77,7 @@ TEST_F(PoolingOpTest, AVG_VALID) { ...@@ -77,7 +77,7 @@ TEST_F(PoolingOpTest, AVG_VALID) {
TEST_F(PoolingOpTest, MAX_SAME) { TEST_F(PoolingOpTest, MAX_SAME) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -105,7 +105,7 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -105,7 +105,7 @@ TEST_F(PoolingOpTest, MAX_SAME) {
TEST_F(PoolingOpTest, MAX_VALID_DILATION) { TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -134,7 +134,7 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -134,7 +134,7 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
TEST_F(PoolingOpTest, MAX_k2x2s2x2) { TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -148,9 +148,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -148,9 +148,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 1, 2, 9}, net.AddInputFromArray<float>(
{0, 1, 2, 3, 4, 5, 6, 7, 8, "Input", {1, 1, 2, 9},
9, 10, 11, 12, 13, 14, 15, 16, 17}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
// Run // Run
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
...@@ -162,7 +162,7 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -162,7 +162,7 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
TEST_F(PoolingOpTest, MAX_k3x3s2x2) { TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -176,10 +176,10 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { ...@@ -176,10 +176,10 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddInputFromArray<float>("Input", {1, 1, 3, 9}, net.AddInputFromArray<float>(
{0, 1, 2, 3, 4, 5, 6, 7, 8, "Input", {1, 1, 3, 9},
9, 10, 11, 12, 13, 14, 15, 16, 17, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
18, 19, 20, 21, 22, 23, 24, 25, 26}); 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26});
// Run // Run
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
...@@ -191,7 +191,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { ...@@ -191,7 +191,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
TEST_F(PoolingOpTest, AVG_k2x2s2x2) { TEST_F(PoolingOpTest, AVG_k2x2s2x2) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -207,15 +207,12 @@ TEST_F(PoolingOpTest, AVG_k2x2s2x2) { ...@@ -207,15 +207,12 @@ TEST_F(PoolingOpTest, AVG_k2x2s2x2) {
// Add input data // Add input data
net.AddInputFromArray<float>( net.AddInputFromArray<float>(
"Input", {1, 1, 2, 8}, "Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
8, 9, 10, 11, 12, 13, 14, 15});
// Run // Run
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
auto expected = CreateTensor<float>({1, 1, 1, 4}, auto expected = CreateTensor<float>({1, 1, 1, 4}, {4.5, 6.5, 8.5, 10.5});
{4.5, 6.5, 8.5, 10.5});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
...@@ -13,17 +13,17 @@ namespace mace { ...@@ -13,17 +13,17 @@ namespace mace {
template <DeviceType D, class T> template <DeviceType D, class T>
class ReluOp : public Operator<D, T> { class ReluOp : public Operator<D, T> {
public: public:
ReluOp(const OperatorDef& operator_def, Workspace* ws) ReluOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) { : Operator<D, T>(operator_def, ws) {
functor_.max_limit_ = functor_.max_limit_ =
OperatorBase::GetSingleArgument<T>("max_limit", static_cast<T>(-1)); OperatorBase::GetSingleArgument<T>("max_limit", static_cast<T>(-1));
} }
bool Run() override { bool Run() override {
const Tensor* input_tensor = this->inputs_[0]; const Tensor *input_tensor = this->inputs_[0];
Tensor* output_tensor = this->outputs_[0]; Tensor *output_tensor = this->outputs_[0];
output_tensor->ResizeLike(input_tensor); output_tensor->ResizeLike(input_tensor);
const T* input = input_tensor->data<T>(); const T *input = input_tensor->data<T>();
T* output = output_tensor->mutable_data<T>(); T *output = output_tensor->mutable_data<T>();
index_t size = input_tensor->size(); index_t size = input_tensor->size();
functor_(input, output, size); functor_(input, output, size);
......
...@@ -36,7 +36,7 @@ static void ReluBenchmark(int iters, int size) { ...@@ -36,7 +36,7 @@ static void ReluBenchmark(int iters, int size) {
static void BM_RELU_##SIZE##_##TYPE##_##DEVICE(int iters) { \ static void BM_RELU_##SIZE##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * SIZE; \ const int64_t tot = static_cast<int64_t>(iters) * SIZE; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \ ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \
} \ } \
BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE) BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE)
......
...@@ -11,7 +11,7 @@ class ReluOpTest : public OpsTestBase {}; ...@@ -11,7 +11,7 @@ class ReluOpTest : public OpsTestBase {};
TEST_F(ReluOpTest, ReluOp) { TEST_F(ReluOpTest, ReluOp) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Relu", "ReluTest") OpDefBuilder("Relu", "ReluTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -34,7 +34,7 @@ TEST_F(ReluOpTest, ReluOp) { ...@@ -34,7 +34,7 @@ TEST_F(ReluOpTest, ReluOp) {
TEST_F(ReluOpTest, ReluOpWithMax) { TEST_F(ReluOpTest, ReluOpWithMax) {
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("Relu", "ReluTestWithMax") OpDefBuilder("Relu", "ReluTestWithMax")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
...@@ -56,5 +56,4 @@ TEST_F(ReluOpTest, ReluOpWithMax) { ...@@ -56,5 +56,4 @@ TEST_F(ReluOpTest, ReluOpWithMax) {
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01); ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01);
} }
} // namespace mace } // namespace mace
...@@ -13,21 +13,21 @@ namespace mace { ...@@ -13,21 +13,21 @@ namespace mace {
template <DeviceType D, class T> template <DeviceType D, class T>
class ResizeBilinearOp : public Operator<D, T> { class ResizeBilinearOp : public Operator<D, T> {
public: public:
ResizeBilinearOp(const OperatorDef& operator_def, Workspace* ws) ResizeBilinearOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_( functor_(
OperatorBase::GetSingleArgument<bool>("align_corners", false)) {} OperatorBase::GetSingleArgument<bool>("align_corners", false)) {}
bool Run() override { bool Run() override {
const Tensor* input = this->Input(0); const Tensor *input = this->Input(0);
const Tensor* resize_dims = this->Input(1); const Tensor *resize_dims = this->Input(1);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.", MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.",
input->dim_size()); input->dim_size());
MACE_CHECK(resize_dims->dim_size() == 1, MACE_CHECK(resize_dims->dim_size() == 1,
"resize dim must be 2-dimensional.", resize_dims->dim_size()); "resize dim must be 2-dimensional.", resize_dims->dim_size());
Tensor* output = this->Output(0); Tensor *output = this->Output(0);
index_t n = input->dim(0); index_t n = input->dim(0);
index_t channels = input->dim(1); index_t channels = input->dim(1);
...@@ -38,8 +38,8 @@ class ResizeBilinearOp : public Operator<D, T> { ...@@ -38,8 +38,8 @@ class ResizeBilinearOp : public Operator<D, T> {
vector<index_t> out_shape{n, channels, out_height, out_width}; vector<index_t> out_shape{n, channels, out_height, out_width};
output->Resize(out_shape); output->Resize(out_shape);
const T* input_ptr = input->data<T>(); const T *input_ptr = input->data<T>();
T* output_ptr = output->mutable_data<T>(); T *output_ptr = output->mutable_data<T>();
functor_(input_ptr, output_ptr, n, channels, in_height, in_width, functor_(input_ptr, output_ptr, n, channels, in_height, in_width,
out_height, out_width); out_height, out_width);
......
...@@ -13,7 +13,7 @@ class ResizeBilinearTest : public OpsTestBase {}; ...@@ -13,7 +13,7 @@ class ResizeBilinearTest : public OpsTestBase {};
TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) { TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input") .Input("Input")
.Input("OutSize") .Input("OutSize")
...@@ -38,7 +38,7 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) { ...@@ -38,7 +38,7 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) {
TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
// Construct graph // Construct graph
auto& net = test_net(); auto &net = test_net();
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input") .Input("Input")
.Input("OutSize") .Input("OutSize")
......
...@@ -33,8 +33,8 @@ cc_proto_library( ...@@ -33,8 +33,8 @@ cc_proto_library(
py_proto_library( py_proto_library(
name = "mace_py", name = "mace_py",
srcs = ["mace.proto"], srcs = ["mace.proto"],
default_runtime = "@com_google_protobuf//:protobuf_python",
protoc = "@com_google_protobuf//:protoc",
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = ["@com_google_protobuf//:protobuf_python"], deps = ["@com_google_protobuf//:protobuf_python"],
protoc = "@com_google_protobuf//:protoc",
default_runtime = "@com_google_protobuf//:protobuf_python",
) )
...@@ -16,4 +16,3 @@ py_binary( ...@@ -16,4 +16,3 @@ py_binary(
"@six_archive//:six", "@six_archive//:six",
], ],
) )
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
// //
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/utils/command_line_flags.h"
#include "mace/tools/benchmark/stat_summarizer.h" #include "mace/tools/benchmark/stat_summarizer.h"
#include "mace/utils/command_line_flags.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#include <fstream> #include <fstream>
#include <thread> #include <thread>
namespace mace { namespace mace {
namespace str_util { namespace str_util {
...@@ -29,8 +28,9 @@ std::vector<std::string> Split(const string &str, char delims) { ...@@ -29,8 +28,9 @@ std::vector<std::string> Split(const string &str, char delims) {
return result; return result;
} }
bool SplitAndParseToInts(const string &str,
bool SplitAndParseToInts(const string &str, char delims, std::vector<index_t>* result) { char delims,
std::vector<index_t> *result) {
string tmp = str; string tmp = str;
while (!tmp.empty()) { while (!tmp.empty()) {
index_t dim = atoi(tmp.data()); index_t dim = atoi(tmp.data());
...@@ -48,9 +48,11 @@ bool SplitAndParseToInts(const string &str, char delims, std::vector<index_t>* r ...@@ -48,9 +48,11 @@ bool SplitAndParseToInts(const string &str, char delims, std::vector<index_t>* r
namespace benchmark { namespace benchmark {
bool RunInference(NetBase* net, StatSummarizer* summarizer, int64_t* inference_time_us) { bool RunInference(NetBase *net,
StatSummarizer *summarizer,
int64_t *inference_time_us) {
RunMetadata run_metadata; RunMetadata run_metadata;
RunMetadata* run_metadata_ptr = nullptr; RunMetadata *run_metadata_ptr = nullptr;
if (summarizer) { if (summarizer) {
run_metadata_ptr = &run_metadata; run_metadata_ptr = &run_metadata;
} }
...@@ -71,9 +73,13 @@ bool RunInference(NetBase* net, StatSummarizer* summarizer, int64_t* inference_t ...@@ -71,9 +73,13 @@ bool RunInference(NetBase* net, StatSummarizer* summarizer, int64_t* inference_t
return true; return true;
} }
bool Run(NetBase* net, StatSummarizer* summarizer, bool Run(NetBase *net,
int num_runs, double max_time_sec, int64_t sleep_sec, StatSummarizer *summarizer,
int64_t* total_time_us, int64_t* actual_num_runs) { int num_runs,
double max_time_sec,
int64_t sleep_sec,
int64_t *total_time_us,
int64_t *actual_num_runs) {
*total_time_us = 0; *total_time_us = 0;
LOG(INFO) << "Running benchmark for max " << num_runs << " iterators, max " LOG(INFO) << "Running benchmark for max " << num_runs << " iterators, max "
...@@ -85,7 +91,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer, ...@@ -85,7 +91,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer,
Stat<int64_t> stat; Stat<int64_t> stat;
bool util_max_time = (num_runs <= 0); bool util_max_time = (num_runs <= 0);
for (int i = 0; util_max_time || i < num_runs ; ++i) { for (int i = 0; util_max_time || i < num_runs; ++i) {
int64_t inference_time_us = 0; int64_t inference_time_us = 0;
bool s = RunInference(net, summarizer, &inference_time_us); bool s = RunInference(net, summarizer, &inference_time_us);
stat.UpdateStat(inference_time_us); stat.UpdateStat(inference_time_us);
...@@ -113,7 +119,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer, ...@@ -113,7 +119,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer,
return true; return true;
} }
int Main(int argc, char** argv) { int Main(int argc, char **argv) {
std::string model_file = "/data/local/tmp/mobi_mace.pb"; std::string model_file = "/data/local/tmp/mobi_mace.pb";
std::string device = "CPU"; std::string device = "CPU";
std::string input_layer_string = "input:0"; std::string input_layer_string = "input:0";
...@@ -182,8 +188,10 @@ int Main(int argc, char** argv) { ...@@ -182,8 +188,10 @@ int Main(int argc, char** argv) {
return -1; return -1;
} }
std::vector<std::string> input_layers = str_util::Split(input_layer_string, ','); std::vector<std::string> input_layers =
std::vector<std::string> input_layer_shapes = str_util::Split(input_layer_shape_string, ':'); str_util::Split(input_layer_string, ',');
std::vector<std::string> input_layer_shapes =
str_util::Split(input_layer_shape_string, ':');
std::vector<string> input_layer_types = std::vector<string> input_layer_types =
str_util::Split(input_layer_type_string, ','); str_util::Split(input_layer_type_string, ',');
std::vector<string> input_layer_files = std::vector<string> input_layer_files =
...@@ -260,17 +268,17 @@ int Main(int argc, char** argv) { ...@@ -260,17 +268,17 @@ int Main(int argc, char** argv) {
ws.LoadModelTensor(net_def, DeviceType::CPU); ws.LoadModelTensor(net_def, DeviceType::CPU);
// Load inputs // Load inputs
for (size_t i = 0; i < inputs_count; ++i) { for (size_t i = 0; i < inputs_count; ++i) {
Tensor *input_tensor = ws.CreateTensor(input_layers[i], Tensor *input_tensor =
cpu_allocator(), DT_FLOAT); ws.CreateTensor(input_layers[i], cpu_allocator(), DT_FLOAT);
vector<index_t> shapes; vector<index_t> shapes;
str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes); str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes);
input_tensor->Resize(shapes); input_tensor->Resize(shapes);
float *input_data = input_tensor->mutable_data<float>(); float *input_data = input_tensor->mutable_data<float>();
// load input // load input
if (i < input_layer_files.size()) { if (i < input_layer_files.size()) {
std::ifstream in_file(input_layer_files[i], std::ios::in | std::ios::binary); std::ifstream in_file(input_layer_files[i],
std::ios::in | std::ios::binary);
in_file.read(reinterpret_cast<char *>(input_data), in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float)); input_tensor->size() * sizeof(float));
in_file.close(); in_file.close();
...@@ -285,31 +293,31 @@ int Main(int argc, char** argv) { ...@@ -285,31 +293,31 @@ int Main(int argc, char** argv) {
int64_t warmup_time_us = 0; int64_t warmup_time_us = 0;
int64_t num_warmup_runs = 0; int64_t num_warmup_runs = 0;
if (warmup_runs > 0) { if (warmup_runs > 0) {
bool status = Run(net.get(), nullptr, bool status =
warmup_runs, -1.0, inter_inference_sleep_seconds, Run(net.get(), nullptr, warmup_runs, -1.0,
&warmup_time_us, &num_warmup_runs); inter_inference_sleep_seconds, &warmup_time_us, &num_warmup_runs);
if (!status) { if (!status) {
LOG(ERROR) << "Failed at warm up run"; LOG(ERROR) << "Failed at warm up run";
} }
} }
if (inter_benchmark_sleep_seconds > 0) { if (inter_benchmark_sleep_seconds > 0) {
std::this_thread::sleep_for(std::chrono::seconds(inter_benchmark_sleep_seconds)); std::this_thread::sleep_for(
std::chrono::seconds(inter_benchmark_sleep_seconds));
} }
int64_t no_stat_time_us = 0; int64_t no_stat_time_us = 0;
int64_t no_stat_runs = 0; int64_t no_stat_runs = 0;
bool status = Run(net.get(), nullptr, bool status =
max_num_runs, max_benchmark_time_seconds, inter_inference_sleep_seconds, Run(net.get(), nullptr, max_num_runs, max_benchmark_time_seconds,
&no_stat_time_us, &no_stat_runs); inter_inference_sleep_seconds, &no_stat_time_us, &no_stat_runs);
if (!status) { if (!status) {
LOG(ERROR) << "Failed at normal no-stat run"; LOG(ERROR) << "Failed at normal no-stat run";
} }
int64_t stat_time_us = 0; int64_t stat_time_us = 0;
int64_t stat_runs = 0; int64_t stat_runs = 0;
status = Run(net.get(), stats.get(), status = Run(net.get(), stats.get(), max_num_runs, max_benchmark_time_seconds,
max_num_runs, max_benchmark_time_seconds, inter_inference_sleep_seconds, inter_inference_sleep_seconds, &stat_time_us, &stat_runs);
&stat_time_us, &stat_runs);
if (!status) { if (!status) {
LOG(ERROR) << "Failed at normal stat run"; LOG(ERROR) << "Failed at normal stat run";
} }
...@@ -328,6 +336,4 @@ int Main(int argc, char** argv) { ...@@ -328,6 +336,4 @@ int Main(int argc, char** argv) {
} // namespace benchmark } // namespace benchmark
} // namespace mace } // namespace mace
int main (int argc, char** argv) { int main(int argc, char **argv) { mace::benchmark::Main(argc, argv); }
mace::benchmark::Main(argc, argv);
}
...@@ -2,17 +2,16 @@ ...@@ -2,17 +2,16 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#include "mace/core/common.h"
#include "mace/tools/benchmark/stat_summarizer.h" #include "mace/tools/benchmark/stat_summarizer.h"
#include "mace/core/common.h"
#include "mace/proto/stats.pb.h" #include "mace/proto/stats.pb.h"
#include <iomanip> #include <iomanip>
#include <queue> #include <queue>
namespace mace { namespace mace {
StatSummarizer::StatSummarizer(const StatSummarizerOptions& options) StatSummarizer::StatSummarizer(const StatSummarizerOptions &options)
: options_(options) {} : options_(options) {}
StatSummarizer::~StatSummarizer() {} StatSummarizer::~StatSummarizer() {}
...@@ -23,17 +22,14 @@ void StatSummarizer::Reset() { ...@@ -23,17 +22,14 @@ void StatSummarizer::Reset() {
details_.clear(); details_.clear();
} }
void StatSummarizer::ProcessMetadata(const RunMetadata &run_metadata) { void StatSummarizer::ProcessMetadata(const RunMetadata &run_metadata) {
int64_t curr_total_us = 0; int64_t curr_total_us = 0;
int64_t mem_total = 0; int64_t mem_total = 0;
int64_t first_node_start_us = int64_t first_node_start_us = run_metadata.op_stats(0).all_start_micros();
run_metadata.op_stats(0).all_start_micros();
int node_num = 0; int node_num = 0;
for (const auto& ops : run_metadata.op_stats()) { for (const auto &ops : run_metadata.op_stats()) {
std::string name = ops.operator_name(); std::string name = ops.operator_name();
std::string op_type = ops.type(); std::string op_type = ops.type();
...@@ -41,7 +37,7 @@ void StatSummarizer::ProcessMetadata(const RunMetadata &run_metadata) { ...@@ -41,7 +37,7 @@ void StatSummarizer::ProcessMetadata(const RunMetadata &run_metadata) {
const int64_t curr_time = ops.all_end_rel_micros(); const int64_t curr_time = ops.all_end_rel_micros();
curr_total_us += curr_time; curr_total_us += curr_time;
auto result = details_.emplace(name, Detail()); auto result = details_.emplace(name, Detail());
Detail* detail = &(result.first->second); Detail *detail = &(result.first->second);
detail->start_us.UpdateStat(ops.all_start_micros() - first_node_start_us); detail->start_us.UpdateStat(ops.all_start_micros() - first_node_start_us);
detail->rel_end_us.UpdateStat(curr_time); detail->rel_end_us.UpdateStat(curr_time);
...@@ -77,13 +73,13 @@ std::string StatSummarizer::ShortSummary() const { ...@@ -77,13 +73,13 @@ std::string StatSummarizer::ShortSummary() const {
return stream.str(); return stream.str();
} }
std::ostream& InitField(std::ostream& stream, int width) { std::ostream &InitField(std::ostream &stream, int width) {
stream << "\t" << std::right << std::setw(width) << std::fixed stream << "\t" << std::right << std::setw(width) << std::fixed
<< std::setprecision(3); << std::setprecision(3);
return stream; return stream;
} }
std::string StatSummarizer::HeaderString(const std::string& title) const { std::string StatSummarizer::HeaderString(const std::string &title) const {
std::stringstream stream; std::stringstream stream;
stream << "============================== " << title stream << "============================== " << title
...@@ -102,9 +98,9 @@ std::string StatSummarizer::HeaderString(const std::string& title) const { ...@@ -102,9 +98,9 @@ std::string StatSummarizer::HeaderString(const std::string& title) const {
return stream.str(); return stream.str();
} }
std::string StatSummarizer::ColumnString(const StatSummarizer::Detail& detail, std::string StatSummarizer::ColumnString(const StatSummarizer::Detail &detail,
const int64_t cumulative_stat_on_node, const int64_t cumulative_stat_on_node,
const Stat<int64_t>& stat) const { const Stat<int64_t> &stat) const {
const double start_ms = detail.start_us.avg() / 1000.0; const double start_ms = detail.start_us.avg() / 1000.0;
const double first_time_ms = detail.rel_end_us.first() / 1000.0; const double first_time_ms = detail.rel_end_us.first() / 1000.0;
const double avg_time_ms = detail.rel_end_us.avg() / 1000.0; const double avg_time_ms = detail.rel_end_us.avg() / 1000.0;
...@@ -127,12 +123,12 @@ std::string StatSummarizer::ColumnString(const StatSummarizer::Detail& detail, ...@@ -127,12 +123,12 @@ std::string StatSummarizer::ColumnString(const StatSummarizer::Detail& detail,
} }
void StatSummarizer::OrderNodesByMetric( void StatSummarizer::OrderNodesByMetric(
SortingMetric metric, std::vector<const Detail*>* details) const { SortingMetric metric, std::vector<const Detail *> *details) const {
std::priority_queue<std::pair<std::string, const Detail*>> sorted_list; std::priority_queue<std::pair<std::string, const Detail *>> sorted_list;
const int num_nodes = details_.size(); const int num_nodes = details_.size();
for (const auto& det : details_) { for (const auto &det : details_) {
const Detail* detail = &(det.second); const Detail *detail = &(det.second);
std::stringstream stream; std::stringstream stream;
stream << std::setw(20) << std::right << std::setprecision(10) stream << std::setw(20) << std::right << std::setprecision(10)
<< std::fixed; << std::fixed;
...@@ -169,16 +165,16 @@ void StatSummarizer::OrderNodesByMetric( ...@@ -169,16 +165,16 @@ void StatSummarizer::OrderNodesByMetric(
} }
void StatSummarizer::ComputeStatsByType( void StatSummarizer::ComputeStatsByType(
std::map<std::string, int64_t>* node_type_map_count, std::map<std::string, int64_t> *node_type_map_count,
std::map<std::string, int64_t>* node_type_map_time, std::map<std::string, int64_t> *node_type_map_time,
std::map<std::string, int64_t>* node_type_map_memory, std::map<std::string, int64_t> *node_type_map_memory,
std::map<std::string, int64_t>* node_type_map_times_called, std::map<std::string, int64_t> *node_type_map_times_called,
int64_t* accumulated_us) const { int64_t *accumulated_us) const {
int64_t run_count = run_total_us_.count(); int64_t run_count = run_total_us_.count();
for (const auto& det : details_) { for (const auto &det : details_) {
const std::string node_name = det.first; const std::string node_name = det.first;
const Detail& detail = det.second; const Detail &detail = det.second;
int64_t curr_time_val = int64_t curr_time_val =
static_cast<int64_t>(detail.rel_end_us.sum() / run_count); static_cast<int64_t>(detail.rel_end_us.sum() / run_count);
...@@ -186,7 +182,7 @@ void StatSummarizer::ComputeStatsByType( ...@@ -186,7 +182,7 @@ void StatSummarizer::ComputeStatsByType(
int64_t curr_memory_val = detail.mem_used.newest(); int64_t curr_memory_val = detail.mem_used.newest();
const std::string& node_type = detail.type; const std::string &node_type = detail.type;
(*node_type_map_count)[node_type] += 1; (*node_type_map_count)[node_type] += 1;
(*node_type_map_time)[node_type] += curr_time_val; (*node_type_map_time)[node_type] += curr_time_val;
...@@ -215,8 +211,9 @@ std::string StatSummarizer::GetStatsByNodeType() const { ...@@ -215,8 +211,9 @@ std::string StatSummarizer::GetStatsByNodeType() const {
&accumulated_us); &accumulated_us);
// Sort them. // Sort them.
std::priority_queue<std::pair<int64_t, std::pair<std::string, int64_t>>> timings; std::priority_queue<std::pair<int64_t, std::pair<std::string, int64_t>>>
for (const auto& node_type : node_type_map_time) { timings;
for (const auto &node_type : node_type_map_time) {
const int64_t mem_used = node_type_map_memory[node_type.first]; const int64_t mem_used = node_type_map_memory[node_type.first];
timings.emplace(node_type.second, timings.emplace(node_type.second,
std::pair<std::string, int64_t>(node_type.first, mem_used)); std::pair<std::string, int64_t>(node_type.first, mem_used));
...@@ -259,10 +256,10 @@ std::string StatSummarizer::GetStatsByNodeType() const { ...@@ -259,10 +256,10 @@ std::string StatSummarizer::GetStatsByNodeType() const {
return stream.str(); return stream.str();
} }
std::string StatSummarizer::GetStatsByMetric(const std::string& title, std::string StatSummarizer::GetStatsByMetric(const std::string &title,
SortingMetric sorting_metric, SortingMetric sorting_metric,
int num_stats) const { int num_stats) const {
std::vector<const Detail*> details; std::vector<const Detail *> details;
OrderNodesByMetric(sorting_metric, &details); OrderNodesByMetric(sorting_metric, &details);
double cumulative_stat_on_node = 0; double cumulative_stat_on_node = 0;
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
namespace mace { namespace mace {
class RunMetadata; class RunMetadata;
...@@ -62,7 +61,7 @@ class Stat { ...@@ -62,7 +61,7 @@ class Stat {
return all_same() ? 0 : std::sqrt(squared_sum_ / count_ - avg() * avg()); return all_same() ? 0 : std::sqrt(squared_sum_ / count_ - avg() * avg());
} }
void OutputToStream(std::ostream* stream) const { void OutputToStream(std::ostream *stream) const {
if (empty()) { if (empty()) {
*stream << "count=0"; *stream << "count=0";
} else if (all_same()) { } else if (all_same()) {
...@@ -75,8 +74,8 @@ class Stat { ...@@ -75,8 +74,8 @@ class Stat {
} }
} }
friend std::ostream& operator<<(std::ostream& stream, friend std::ostream &operator<<(std::ostream &stream,
const Stat<ValueType>& stat) { const Stat<ValueType> &stat) {
stat.OutputToStream(&stream); stat.OutputToStream(&stream);
return stream; return stream;
} }
...@@ -131,12 +130,12 @@ class StatSummarizer { ...@@ -131,12 +130,12 @@ class StatSummarizer {
BY_TYPE, BY_TYPE,
}; };
explicit StatSummarizer(const StatSummarizerOptions& options); explicit StatSummarizer(const StatSummarizerOptions &options);
~StatSummarizer(); ~StatSummarizer();
// Adds another run's StepStats output to the aggregate counts. // Adds another run's StepStats output to the aggregate counts.
void ProcessMetadata(const RunMetadata& run_metadata); void ProcessMetadata(const RunMetadata &run_metadata);
// Returns a string detailing the accumulated runtime stats in a tab-separated // Returns a string detailing the accumulated runtime stats in a tab-separated
// format which can be pasted into a spreadsheet for further analysis. // format which can be pasted into a spreadsheet for further analysis.
...@@ -147,15 +146,16 @@ class StatSummarizer { ...@@ -147,15 +146,16 @@ class StatSummarizer {
// Prints the string returned by GetOutputString(). // Prints the string returned by GetOutputString().
void PrintOperatorStats() const; void PrintOperatorStats() const;
void ComputeStatsByType(std::map<std::string, int64_t>* node_type_map_count, void ComputeStatsByType(
std::map<std::string, int64_t>* node_type_map_time, std::map<std::string, int64_t> *node_type_map_count,
std::map<std::string, int64_t>* node_type_map_memory, std::map<std::string, int64_t> *node_type_map_time,
std::map<std::string, int64_t>* node_type_map_times_called, std::map<std::string, int64_t> *node_type_map_memory,
int64_t* accumulated_us) const; std::map<std::string, int64_t> *node_type_map_times_called,
int64_t *accumulated_us) const;
std::string GetStatsByNodeType() const; std::string GetStatsByNodeType() const;
std::string GetStatsByMetric(const std::string& title, std::string GetStatsByMetric(const std::string &title,
SortingMetric sorting_metric, SortingMetric sorting_metric,
int num_stats) const; int num_stats) const;
...@@ -165,7 +165,7 @@ class StatSummarizer { ...@@ -165,7 +165,7 @@ class StatSummarizer {
int num_runs() const { return run_total_us_.count(); } int num_runs() const { return run_total_us_.count(); }
// Returns stats of total microseconds spent by all nodes in each run. // Returns stats of total microseconds spent by all nodes in each run.
const Stat<int64_t>& run_total_us() const { return run_total_us_; } const Stat<int64_t> &run_total_us() const { return run_total_us_; }
private: private:
struct Detail { struct Detail {
...@@ -179,12 +179,12 @@ class StatSummarizer { ...@@ -179,12 +179,12 @@ class StatSummarizer {
}; };
void OrderNodesByMetric(SortingMetric sorting_metric, void OrderNodesByMetric(SortingMetric sorting_metric,
std::vector<const Detail*>* details) const; std::vector<const Detail *> *details) const;
std::string HeaderString(const std::string& title) const; std::string HeaderString(const std::string &title) const;
std::string ColumnString(const Detail& detail, std::string ColumnString(const Detail &detail,
const int64_t cumulative_stat_on_node, const int64_t cumulative_stat_on_node,
const Stat<int64_t>& stat) const; const Stat<int64_t> &stat) const;
Stat<int64_t> run_total_us_; Stat<int64_t> run_total_us_;
Stat<int64_t> memory_; Stat<int64_t> memory_;
......
...@@ -10,19 +10,21 @@ namespace mace { ...@@ -10,19 +10,21 @@ namespace mace {
namespace { namespace {
bool StringConsume(string &arg, const string &x) { bool StringConsume(string &arg, const string &x) {
if ((arg.size() >= x.size()) if ((arg.size() >= x.size()) &&
&& (memcmp(arg.data(), x.data(), x.size()) == 0)) { (memcmp(arg.data(), x.data(), x.size()) == 0)) {
arg = arg.substr(x.size()); arg = arg.substr(x.size());
return true; return true;
} }
return false; return false;
} }
bool ParseStringFlag(string arg, string flag, bool ParseStringFlag(string arg,
string *dst, bool *value_parsing_ok) { string flag,
string *dst,
bool *value_parsing_ok) {
*value_parsing_ok = true; *value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag) if (StringConsume(arg, "--") && StringConsume(arg, flag) &&
&& StringConsume(arg, "=")) { StringConsume(arg, "=")) {
*dst = arg; *dst = arg;
return true; return true;
} }
...@@ -30,11 +32,13 @@ bool ParseStringFlag(string arg, string flag, ...@@ -30,11 +32,13 @@ bool ParseStringFlag(string arg, string flag,
return false; return false;
} }
bool ParseInt32Flag(string arg, string flag, bool ParseInt32Flag(string arg,
int32_t *dst, bool *value_parsing_ok) { string flag,
int32_t *dst,
bool *value_parsing_ok) {
*value_parsing_ok = true; *value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag) if (StringConsume(arg, "--") && StringConsume(arg, flag) &&
&& StringConsume(arg, "=")) { StringConsume(arg, "=")) {
char extra; char extra;
if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) { if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
...@@ -47,11 +51,13 @@ bool ParseInt32Flag(string arg, string flag, ...@@ -47,11 +51,13 @@ bool ParseInt32Flag(string arg, string flag,
return false; return false;
} }
bool ParseInt64Flag(string arg, string flag, bool ParseInt64Flag(string arg,
long long *dst, bool *value_parsing_ok) { string flag,
long long *dst,
bool *value_parsing_ok) {
*value_parsing_ok = true; *value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag) if (StringConsume(arg, "--") && StringConsume(arg, flag) &&
&& StringConsume(arg, "=")) { StringConsume(arg, "=")) {
char extra; char extra;
if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) { if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
...@@ -64,8 +70,7 @@ bool ParseInt64Flag(string arg, string flag, ...@@ -64,8 +70,7 @@ bool ParseInt64Flag(string arg, string flag,
return false; return false;
} }
bool ParseBoolFlag(string arg, string flag, bool ParseBoolFlag(string arg, string flag, bool *dst, bool *value_parsing_ok) {
bool *dst, bool *value_parsing_ok) {
*value_parsing_ok = true; *value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag)) { if (StringConsume(arg, "--") && StringConsume(arg, flag)) {
if (arg.empty()) { if (arg.empty()) {
...@@ -90,11 +95,13 @@ bool ParseBoolFlag(string arg, string flag, ...@@ -90,11 +95,13 @@ bool ParseBoolFlag(string arg, string flag,
return false; return false;
} }
bool ParseFloatFlag(string arg, string flag, bool ParseFloatFlag(string arg,
float *dst, bool *value_parsing_ok) { string flag,
float *dst,
bool *value_parsing_ok) {
*value_parsing_ok = true; *value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag) if (StringConsume(arg, "--") && StringConsume(arg, flag) &&
&& StringConsume(arg, "=")) { StringConsume(arg, "=")) {
char extra; char extra;
if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) { if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
...@@ -152,7 +159,8 @@ bool Flag::Parse(string arg, bool *value_parsing_ok) const { ...@@ -152,7 +159,8 @@ bool Flag::Parse(string arg, bool *value_parsing_ok) const {
return result; return result;
} }
/*static*/ bool Flags::Parse(int *argc, char **argv, /*static*/ bool Flags::Parse(int *argc,
char **argv,
const std::vector<Flag> &flag_list) { const std::vector<Flag> &flag_list) {
bool result = true; bool result = true;
std::vector<char *> unknown_flags; std::vector<char *> unknown_flags;
......
...@@ -39,14 +39,12 @@ class Flags { ...@@ -39,14 +39,12 @@ class Flags {
// with matching flags, and remove the matching arguments from (*argc, argv). // with matching flags, and remove the matching arguments from (*argc, argv).
// Return true iff all recognized flag values were parsed correctly, and the // Return true iff all recognized flag values were parsed correctly, and the
// first remaining argument is not "--help". // first remaining argument is not "--help".
static bool Parse(int *argc, static bool Parse(int *argc, char **argv, const std::vector<Flag> &flag_list);
char **argv,
const std::vector<Flag> &flag_list);
// Return a usage message with command line cmdline, and the // Return a usage message with command line cmdline, and the
// usage_text strings in flag_list[]. // usage_text strings in flag_list[].
static string Usage(const string &cmdline, static string Usage(const string &cmdline,
const std::vector <Flag> &flag_list); const std::vector<Flag> &flag_list);
}; };
} // namespace mace } // namespace mace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册