Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
ebff986d
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ebff986d
编写于
10月 17, 2017
作者:
L
Liangliang He
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reformat with google coding style
上级
a23d053d
变更
80
隐藏空白更改
内联
并排
Showing
80 changed file
with
850 addition
and
838 deletion
+850
-838
mace/core/common.h
mace/core/common.h
+2
-2
mace/core/logging.cc
mace/core/logging.cc
+6
-6
mace/core/logging.h
mace/core/logging.h
+13
-13
mace/core/net.cc
mace/core/net.cc
+18
-15
mace/core/net.h
mace/core/net.h
+11
-11
mace/core/proto_utils.cc
mace/core/proto_utils.cc
+37
-37
mace/core/proto_utils.h
mace/core/proto_utils.h
+63
-63
mace/core/registry.h
mace/core/registry.h
+9
-9
mace/core/runtime/opencl/opencl_allocator.cc
mace/core/runtime/opencl/opencl_allocator.cc
+1
-1
mace/core/runtime/opencl/opencl_runtime.cc
mace/core/runtime/opencl/opencl_runtime.cc
+3
-1
mace/core/serializer.h
mace/core/serializer.h
+2
-2
mace/core/tensor.h
mace/core/tensor.h
+9
-9
mace/core/testing/test_benchmark.cc
mace/core/testing/test_benchmark.cc
+9
-9
mace/core/testing/test_benchmark.h
mace/core/testing/test_benchmark.h
+8
-8
mace/core/testing/test_benchmark_main.cc
mace/core/testing/test_benchmark_main.cc
+1
-1
mace/core/types.cc
mace/core/types.cc
+1
-1
mace/core/workspace.cc
mace/core/workspace.cc
+13
-14
mace/core/workspace.h
mace/core/workspace.h
+6
-6
mace/examples/benchmark_example.cc
mace/examples/benchmark_example.cc
+4
-4
mace/examples/mace_run.cc
mace/examples/mace_run.cc
+9
-7
mace/kernels/BUILD
mace/kernels/BUILD
+1
-1
mace/kernels/addn.h
mace/kernels/addn.h
+2
-2
mace/kernels/batch_norm.h
mace/kernels/batch_norm.h
+14
-15
mace/kernels/channel_shuffle.h
mace/kernels/channel_shuffle.h
+4
-5
mace/kernels/concat.h
mace/kernels/concat.h
+4
-4
mace/kernels/conv_2d.h
mace/kernels/conv_2d.h
+11
-13
mace/kernels/conv_pool_2d_util.cc
mace/kernels/conv_pool_2d_util.cc
+6
-7
mace/kernels/conv_pool_2d_util.h
mace/kernels/conv_pool_2d_util.h
+5
-5
mace/kernels/depthwise_conv2d.h
mace/kernels/depthwise_conv2d.h
+29
-31
mace/kernels/global_avg_pooling.h
mace/kernels/global_avg_pooling.h
+1
-3
mace/kernels/neon/avg_pooling_neon_2x2.cc
mace/kernels/neon/avg_pooling_neon_2x2.cc
+1
-1
mace/kernels/neon/avg_pooling_neon_3x3.cc
mace/kernels/neon/avg_pooling_neon_3x3.cc
+6
-5
mace/kernels/neon/batch_norm_neon.cc
mace/kernels/neon/batch_norm_neon.cc
+8
-8
mace/kernels/neon/conv_2d_neon.cc
mace/kernels/neon/conv_2d_neon.cc
+14
-15
mace/kernels/neon/conv_2d_neon_1x1.cc
mace/kernels/neon/conv_2d_neon_1x1.cc
+51
-44
mace/kernels/neon/conv_2d_neon_3x3.cc
mace/kernels/neon/conv_2d_neon_3x3.cc
+81
-64
mace/kernels/neon/conv_2d_neon_5x5.cc
mace/kernels/neon/conv_2d_neon_5x5.cc
+3
-3
mace/kernels/neon/depthwise_conv_neon.cc
mace/kernels/neon/depthwise_conv_neon.cc
+20
-20
mace/kernels/neon/global_avg_pooling_neon.cc
mace/kernels/neon/global_avg_pooling_neon.cc
+2
-4
mace/kernels/neon/pooling_neon.cc
mace/kernels/neon/pooling_neon.cc
+6
-6
mace/ops/addn.h
mace/ops/addn.h
+5
-5
mace/ops/addn_benchmark.cc
mace/ops/addn_benchmark.cc
+1
-1
mace/ops/addn_test.cc
mace/ops/addn_test.cc
+1
-1
mace/ops/batch_norm.h
mace/ops/batch_norm.h
+18
-19
mace/ops/batch_norm_benchmark.cc
mace/ops/batch_norm_benchmark.cc
+1
-1
mace/ops/batch_norm_test.cc
mace/ops/batch_norm_test.cc
+3
-3
mace/ops/channel_shuffle.h
mace/ops/channel_shuffle.h
+2
-2
mace/ops/channel_shuffle_benchmark.cc
mace/ops/channel_shuffle_benchmark.cc
+11
-16
mace/ops/channel_shuffle_test.cc
mace/ops/channel_shuffle_test.cc
+5
-8
mace/ops/concat.h
mace/ops/concat.h
+19
-11
mace/ops/concat_benchmark.cc
mace/ops/concat_benchmark.cc
+2
-3
mace/ops/concat_test.cc
mace/ops/concat_test.cc
+7
-8
mace/ops/conv_2d.h
mace/ops/conv_2d.h
+5
-6
mace/ops/conv_2d_benchmark.cc
mace/ops/conv_2d_benchmark.cc
+10
-9
mace/ops/conv_2d_test.cc
mace/ops/conv_2d_test.cc
+7
-7
mace/ops/conv_pool_2d_base.h
mace/ops/conv_pool_2d_base.h
+4
-3
mace/ops/depthwise_conv2d.cc
mace/ops/depthwise_conv2d.cc
+4
-2
mace/ops/depthwise_conv2d.h
mace/ops/depthwise_conv2d.h
+10
-10
mace/ops/depthwise_conv2d_test.cc
mace/ops/depthwise_conv2d_test.cc
+11
-14
mace/ops/depthwise_conv_2d_benchmark.cc
mace/ops/depthwise_conv_2d_benchmark.cc
+20
-18
mace/ops/global_avg_pooling.h
mace/ops/global_avg_pooling.h
+1
-1
mace/ops/global_avg_pooling_benchmark.cc
mace/ops/global_avg_pooling_benchmark.cc
+10
-14
mace/ops/global_avg_pooling_test.cc
mace/ops/global_avg_pooling_test.cc
+8
-12
mace/ops/ops_test_util.h
mace/ops/ops_test_util.h
+28
-29
mace/ops/pooling.h
mace/ops/pooling.h
+7
-8
mace/ops/pooling_benchmark.cc
mace/ops/pooling_benchmark.cc
+1
-1
mace/ops/pooling_test.cc
mace/ops/pooling_test.cc
+16
-19
mace/ops/relu.h
mace/ops/relu.h
+7
-7
mace/ops/relu_benchmark.cc
mace/ops/relu_benchmark.cc
+1
-1
mace/ops/relu_test.cc
mace/ops/relu_test.cc
+2
-3
mace/ops/resize_bilinear.h
mace/ops/resize_bilinear.h
+6
-6
mace/ops/resize_bilinear_test.cc
mace/ops/resize_bilinear_test.cc
+2
-2
mace/proto/BUILD
mace/proto/BUILD
+3
-3
mace/python/tools/BUILD
mace/python/tools/BUILD
+0
-1
mace/tools/benchmark/benchmark_model.cc
mace/tools/benchmark/benchmark_model.cc
+39
-33
mace/tools/benchmark/stat_summarizer.cc
mace/tools/benchmark/stat_summarizer.cc
+26
-29
mace/tools/benchmark/stat_summarizer.h
mace/tools/benchmark/stat_summarizer.h
+18
-18
mace/utils/command_line_flags.cc
mace/utils/command_line_flags.cc
+29
-21
mace/utils/command_line_flags.h
mace/utils/command_line_flags.h
+4
-6
mace/utils/utils.h
mace/utils/utils.h
+2
-2
未找到文件。
mace/core/common.h
浏览文件 @
ebff986d
...
@@ -26,8 +26,8 @@ typedef int64_t index_t;
...
@@ -26,8 +26,8 @@ typedef int64_t index_t;
#ifndef DISABLE_COPY_AND_ASSIGN
#ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \
#define DISABLE_COPY_AND_ASSIGN(classname) \
private: \
private: \
classname(const classname
&) = delete;
\
classname(const classname
&) = delete;
\
classname
& operator=(const classname
&) = delete
classname
&operator=(const classname
&) = delete
#endif
#endif
#define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented")
#define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented")
...
...
mace/core/logging.cc
浏览文件 @
ebff986d
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
namespace
mace
{
namespace
mace
{
namespace
internal
{
namespace
internal
{
LogMessage
::
LogMessage
(
const
char
*
fname
,
int
line
,
int
severity
)
LogMessage
::
LogMessage
(
const
char
*
fname
,
int
line
,
int
severity
)
:
fname_
(
fname
),
line_
(
line
),
severity_
(
severity
)
{}
:
fname_
(
fname
),
line_
(
line
),
severity_
(
severity
)
{}
#if defined(PLATFORM_POSIX_ANDROID)
#if defined(PLATFORM_POSIX_ANDROID)
...
@@ -43,7 +43,7 @@ void LogMessage::GenerateLogMessage() {
...
@@ -43,7 +43,7 @@ void LogMessage::GenerateLogMessage() {
}
}
std
::
stringstream
ss
;
std
::
stringstream
ss
;
const
char
*
const
partial_name
=
strrchr
(
fname_
,
'/'
);
const
char
*
const
partial_name
=
strrchr
(
fname_
,
'/'
);
ss
<<
(
partial_name
!=
nullptr
?
partial_name
+
1
:
fname_
)
<<
":"
<<
line_
ss
<<
(
partial_name
!=
nullptr
?
partial_name
+
1
:
fname_
)
<<
":"
<<
line_
<<
" "
<<
str
();
<<
" "
<<
str
();
__android_log_write
(
android_log_level
,
"native"
,
ss
.
str
().
c_str
());
__android_log_write
(
android_log_level
,
"native"
,
ss
.
str
().
c_str
());
...
@@ -69,7 +69,7 @@ void LogMessage::GenerateLogMessage() {
...
@@ -69,7 +69,7 @@ void LogMessage::GenerateLogMessage() {
namespace
{
namespace
{
// Parse log level (int64_t) from environment variable (char*)
// Parse log level (int64_t) from environment variable (char*)
int64_t
LogLevelStrToInt
(
const
char
*
mace_env_var_val
)
{
int64_t
LogLevelStrToInt
(
const
char
*
mace_env_var_val
)
{
if
(
mace_env_var_val
==
nullptr
)
{
if
(
mace_env_var_val
==
nullptr
)
{
return
0
;
return
0
;
}
}
...
@@ -89,12 +89,12 @@ int64_t LogLevelStrToInt(const char* mace_env_var_val) {
...
@@ -89,12 +89,12 @@ int64_t LogLevelStrToInt(const char* mace_env_var_val) {
}
}
int64_t
MinLogLevelFromEnv
()
{
int64_t
MinLogLevelFromEnv
()
{
const
char
*
mace_env_var_val
=
getenv
(
"MACE_CPP_MIN_LOG_LEVEL"
);
const
char
*
mace_env_var_val
=
getenv
(
"MACE_CPP_MIN_LOG_LEVEL"
);
return
LogLevelStrToInt
(
mace_env_var_val
);
return
LogLevelStrToInt
(
mace_env_var_val
);
}
}
int64_t
MinVLogLevelFromEnv
()
{
int64_t
MinVLogLevelFromEnv
()
{
const
char
*
mace_env_var_val
=
getenv
(
"MACE_CPP_MIN_VLOG_LEVEL"
);
const
char
*
mace_env_var_val
=
getenv
(
"MACE_CPP_MIN_VLOG_LEVEL"
);
return
LogLevelStrToInt
(
mace_env_var_val
);
return
LogLevelStrToInt
(
mace_env_var_val
);
}
}
...
@@ -111,7 +111,7 @@ int64_t LogMessage::MinVLogLevel() {
...
@@ -111,7 +111,7 @@ int64_t LogMessage::MinVLogLevel() {
return
min_vlog_level
;
return
min_vlog_level
;
}
}
LogMessageFatal
::
LogMessageFatal
(
const
char
*
file
,
int
line
)
LogMessageFatal
::
LogMessageFatal
(
const
char
*
file
,
int
line
)
:
LogMessage
(
file
,
line
,
FATAL
)
{}
:
LogMessage
(
file
,
line
,
FATAL
)
{}
LogMessageFatal
::~
LogMessageFatal
()
{
LogMessageFatal
::~
LogMessageFatal
()
{
// abort() ensures we don't return (we promised we would not via
// abort() ensures we don't return (we promised we would not via
...
...
mace/core/logging.h
浏览文件 @
ebff986d
...
@@ -23,23 +23,23 @@ namespace internal {
...
@@ -23,23 +23,23 @@ namespace internal {
using
std
::
string
;
using
std
::
string
;
inline
void
MakeStringInternal
(
std
::
stringstream
&
/*ss*/
)
{}
inline
void
MakeStringInternal
(
std
::
stringstream
&
/*ss*/
)
{}
template
<
typename
T
>
template
<
typename
T
>
inline
void
MakeStringInternal
(
std
::
stringstream
&
ss
,
const
T
&
t
)
{
inline
void
MakeStringInternal
(
std
::
stringstream
&
ss
,
const
T
&
t
)
{
ss
<<
t
;
ss
<<
t
;
}
}
template
<
typename
T
,
typename
...
Args
>
template
<
typename
T
,
typename
...
Args
>
inline
void
MakeStringInternal
(
std
::
stringstream
&
ss
,
inline
void
MakeStringInternal
(
std
::
stringstream
&
ss
,
const
T
&
t
,
const
T
&
t
,
const
Args
&
...
args
)
{
const
Args
&
...
args
)
{
MakeStringInternal
(
ss
,
t
);
MakeStringInternal
(
ss
,
t
);
MakeStringInternal
(
ss
,
args
...);
MakeStringInternal
(
ss
,
args
...);
}
}
template
<
typename
...
Args
>
template
<
typename
...
Args
>
string
MakeString
(
const
Args
&
...
args
)
{
string
MakeString
(
const
Args
&
...
args
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
MakeStringInternal
(
ss
,
args
...);
MakeStringInternal
(
ss
,
args
...);
return
ss
.
str
();
return
ss
.
str
();
...
@@ -48,7 +48,7 @@ string MakeString(const Args&... args) {
...
@@ -48,7 +48,7 @@ string MakeString(const Args&... args) {
template
<
typename
T
>
template
<
typename
T
>
string
MakeString
(
const
std
::
vector
<
T
>
&
args
)
{
string
MakeString
(
const
std
::
vector
<
T
>
&
args
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
for
(
const
T
&
arg
:
args
)
{
for
(
const
T
&
arg
:
args
)
{
ss
<<
arg
<<
", "
;
ss
<<
arg
<<
", "
;
}
}
return
ss
.
str
();
return
ss
.
str
();
...
@@ -56,14 +56,14 @@ string MakeString(const std::vector<T> &args) {
...
@@ -56,14 +56,14 @@ string MakeString(const std::vector<T> &args) {
// Specializations for already-a-string types.
// Specializations for already-a-string types.
template
<
>
template
<
>
inline
string
MakeString
(
const
string
&
str
)
{
inline
string
MakeString
(
const
string
&
str
)
{
return
str
;
return
str
;
}
}
inline
string
MakeString
(
const
char
*
c_str
)
{
return
string
(
c_str
);
}
inline
string
MakeString
(
const
char
*
c_str
)
{
return
string
(
c_str
);
}
class
LogMessage
:
public
std
::
basic_ostringstream
<
char
>
{
class
LogMessage
:
public
std
::
basic_ostringstream
<
char
>
{
public:
public:
LogMessage
(
const
char
*
fname
,
int
line
,
int
severity
);
LogMessage
(
const
char
*
fname
,
int
line
,
int
severity
);
~
LogMessage
();
~
LogMessage
();
// Returns the minimum log level for VLOG statements.
// Returns the minimum log level for VLOG statements.
...
@@ -75,7 +75,7 @@ class LogMessage : public std::basic_ostringstream<char> {
...
@@ -75,7 +75,7 @@ class LogMessage : public std::basic_ostringstream<char> {
void
GenerateLogMessage
();
void
GenerateLogMessage
();
private:
private:
const
char
*
fname_
;
const
char
*
fname_
;
int
line_
;
int
line_
;
int
severity_
;
int
severity_
;
};
};
...
@@ -84,7 +84,7 @@ class LogMessage : public std::basic_ostringstream<char> {
...
@@ -84,7 +84,7 @@ class LogMessage : public std::basic_ostringstream<char> {
// logging this message.
// logging this message.
class
LogMessageFatal
:
public
LogMessage
{
class
LogMessageFatal
:
public
LogMessage
{
public:
public:
LogMessageFatal
(
const
char
*
file
,
int
line
);
LogMessageFatal
(
const
char
*
file
,
int
line
);
~
LogMessageFatal
();
~
LogMessageFatal
();
};
};
...
@@ -136,7 +136,7 @@ class LogMessageFatal : public LogMessage {
...
@@ -136,7 +136,7 @@ class LogMessageFatal : public LogMessage {
#endif
#endif
template
<
typename
T
>
template
<
typename
T
>
T
&&
CheckNotNull
(
const
char
*
file
,
int
line
,
const
char
*
exprtext
,
T
&&
t
)
{
T
&&
CheckNotNull
(
const
char
*
file
,
int
line
,
const
char
*
exprtext
,
T
&&
t
)
{
if
(
t
==
nullptr
)
{
if
(
t
==
nullptr
)
{
LogMessageFatal
(
file
,
line
)
<<
string
(
exprtext
);
LogMessageFatal
(
file
,
line
)
<<
string
(
exprtext
);
}
}
...
...
mace/core/net.cc
浏览文件 @
ebff986d
...
@@ -7,18 +7,18 @@
...
@@ -7,18 +7,18 @@
namespace
mace
{
namespace
mace
{
NetBase
::
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
NetBase
::
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
)
DeviceType
type
)
:
name_
(
net_def
->
name
())
{}
:
name_
(
net_def
->
name
())
{}
SimpleNet
::
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
SimpleNet
::
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
)
DeviceType
type
)
:
NetBase
(
net_def
,
ws
,
type
)
{
:
NetBase
(
net_def
,
ws
,
type
)
{
VLOG
(
1
)
<<
"Constructing SimpleNet "
<<
net_def
->
name
();
VLOG
(
1
)
<<
"Constructing SimpleNet "
<<
net_def
->
name
();
for
(
int
idx
=
0
;
idx
<
net_def
->
op_size
();
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
net_def
->
op_size
();
++
idx
)
{
const
auto
&
operator_def
=
net_def
->
op
(
idx
);
const
auto
&
operator_def
=
net_def
->
op
(
idx
);
VLOG
(
1
)
<<
"Creating operator "
<<
operator_def
.
name
()
<<
":"
VLOG
(
1
)
<<
"Creating operator "
<<
operator_def
.
name
()
<<
":"
<<
operator_def
.
type
();
<<
operator_def
.
type
();
std
::
unique_ptr
<
OperatorBase
>
op
{
nullptr
};
std
::
unique_ptr
<
OperatorBase
>
op
{
nullptr
};
...
@@ -29,26 +29,29 @@ SimpleNet::SimpleNet(const std::shared_ptr<const NetDef>& net_def,
...
@@ -29,26 +29,29 @@ SimpleNet::SimpleNet(const std::shared_ptr<const NetDef>& net_def,
}
}
}
}
}
}
bool
SimpleNet
::
Run
(
RunMetadata
*
run_metadata
)
{
bool
SimpleNet
::
Run
(
RunMetadata
*
run_metadata
)
{
VLOG
(
1
)
<<
"Running net "
<<
name_
;
VLOG
(
1
)
<<
"Running net "
<<
name_
;
for
(
auto
&
op
:
operators_
)
{
for
(
auto
&
op
:
operators_
)
{
VLOG
(
1
)
<<
"Running operator "
<<
op
->
debug_def
().
name
()
<<
"("
VLOG
(
1
)
<<
"Running operator "
<<
op
->
debug_def
().
name
()
<<
"("
<<
op
->
debug_def
().
type
()
<<
")."
;
<<
op
->
debug_def
().
type
()
<<
")."
;
OperatorStats
*
op_stats
=
nullptr
;
OperatorStats
*
op_stats
=
nullptr
;
if
(
run_metadata
)
{
if
(
run_metadata
)
{
op_stats
=
run_metadata
->
add_op_stats
();
op_stats
=
run_metadata
->
add_op_stats
();
op_stats
->
set_operator_name
(
op
->
debug_def
().
name
());
op_stats
->
set_operator_name
(
op
->
debug_def
().
name
());
op_stats
->
set_type
(
op
->
debug_def
().
type
());
op_stats
->
set_type
(
op
->
debug_def
().
type
());
op_stats
->
set_all_start_micros
(
NowInMicroSec
());
op_stats
->
set_all_start_micros
(
NowInMicroSec
());
op_stats
->
set_op_start_rel_micros
(
NowInMicroSec
()
-
op_stats
->
all_start_micros
());
op_stats
->
set_op_start_rel_micros
(
NowInMicroSec
()
-
op_stats
->
all_start_micros
());
}
}
if
(
!
op
->
Run
())
{
if
(
!
op
->
Run
())
{
LOG
(
ERROR
)
<<
"Operator failed: "
<<
ProtoDebugString
(
op
->
debug_def
());
LOG
(
ERROR
)
<<
"Operator failed: "
<<
ProtoDebugString
(
op
->
debug_def
());
return
false
;
return
false
;
}
}
if
(
op_stats
)
{
if
(
op_stats
)
{
op_stats
->
set_op_end_rel_micros
(
NowInMicroSec
()
-
op_stats
->
all_start_micros
());
op_stats
->
set_op_end_rel_micros
(
NowInMicroSec
()
-
op_stats
->
set_all_end_rel_micros
(
NowInMicroSec
()
-
op_stats
->
all_start_micros
());
op_stats
->
all_start_micros
());
op_stats
->
set_all_end_rel_micros
(
NowInMicroSec
()
-
op_stats
->
all_start_micros
());
}
}
VLOG
(
1
)
<<
"Op "
<<
op
->
debug_def
().
name
()
VLOG
(
1
)
<<
"Op "
<<
op
->
debug_def
().
name
()
<<
" has shape: "
<<
internal
::
MakeString
(
op
->
Output
(
0
)
->
shape
());
<<
" has shape: "
<<
internal
::
MakeString
(
op
->
Output
(
0
)
->
shape
());
...
@@ -56,15 +59,15 @@ bool SimpleNet::Run(RunMetadata* run_metadata) {
...
@@ -56,15 +59,15 @@ bool SimpleNet::Run(RunMetadata* run_metadata) {
return
true
;
return
true
;
}
}
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
)
{
DeviceType
type
)
{
std
::
shared_ptr
<
NetDef
>
tmp_net_def
(
new
NetDef
(
net_def
));
std
::
shared_ptr
<
NetDef
>
tmp_net_def
(
new
NetDef
(
net_def
));
return
CreateNet
(
tmp_net_def
,
ws
,
type
);
return
CreateNet
(
tmp_net_def
,
ws
,
type
);
}
}
unique_ptr
<
NetBase
>
CreateNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
unique_ptr
<
NetBase
>
CreateNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
)
{
DeviceType
type
)
{
unique_ptr
<
NetBase
>
net
(
new
SimpleNet
(
net_def
,
ws
,
type
));
unique_ptr
<
NetBase
>
net
(
new
SimpleNet
(
net_def
,
ws
,
type
));
return
net
;
return
net
;
...
...
mace/core/net.h
浏览文件 @
ebff986d
...
@@ -15,14 +15,14 @@ namespace mace {
...
@@ -15,14 +15,14 @@ namespace mace {
class
NetBase
{
class
NetBase
{
public:
public:
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
);
DeviceType
type
);
virtual
~
NetBase
()
noexcept
{}
virtual
~
NetBase
()
noexcept
{}
virtual
bool
Run
(
RunMetadata
*
run_metadata
=
nullptr
)
=
0
;
virtual
bool
Run
(
RunMetadata
*
run_metadata
=
nullptr
)
=
0
;
const
string
&
Name
()
const
{
return
name_
;
}
const
string
&
Name
()
const
{
return
name_
;
}
protected:
protected:
string
name_
;
string
name_
;
...
@@ -32,11 +32,11 @@ class NetBase {
...
@@ -32,11 +32,11 @@ class NetBase {
class
SimpleNet
:
public
NetBase
{
class
SimpleNet
:
public
NetBase
{
public:
public:
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
);
DeviceType
type
);
bool
Run
(
RunMetadata
*
run_metadata
=
nullptr
)
override
;
bool
Run
(
RunMetadata
*
run_metadata
=
nullptr
)
override
;
protected:
protected:
vector
<
unique_ptr
<
OperatorBase
>
>
operators_
;
vector
<
unique_ptr
<
OperatorBase
>
>
operators_
;
...
@@ -44,11 +44,11 @@ class SimpleNet : public NetBase {
...
@@ -44,11 +44,11 @@ class SimpleNet : public NetBase {
DISABLE_COPY_AND_ASSIGN
(
SimpleNet
);
DISABLE_COPY_AND_ASSIGN
(
SimpleNet
);
};
};
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
);
DeviceType
type
);
unique_ptr
<
NetBase
>
CreateNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
unique_ptr
<
NetBase
>
CreateNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
);
DeviceType
type
);
}
// namespace mace
}
// namespace mace
...
...
mace/core/proto_utils.cc
浏览文件 @
ebff986d
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
namespace
mace
{
namespace
mace
{
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
)
{
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
)
{
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
in
);
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
in
);
if
(
!
ifs
)
{
if
(
!
ifs
)
{
VLOG
(
1
)
<<
"File cannot be opened: "
<<
filename
VLOG
(
1
)
<<
"File cannot be opened: "
<<
filename
...
@@ -33,7 +33,7 @@ bool ReadStringFromFile(const char* filename, string* str) {
...
@@ -33,7 +33,7 @@ bool ReadStringFromFile(const char* filename, string* str) {
return
true
;
return
true
;
}
}
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
)
{
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
)
{
std
::
ofstream
ofs
(
filename
,
std
::
ios
::
out
|
std
::
ios
::
trunc
);
std
::
ofstream
ofs
(
filename
,
std
::
ios
::
out
|
std
::
ios
::
trunc
);
if
(
!
ofs
.
is_open
())
{
if
(
!
ofs
.
is_open
())
{
VLOG
(
1
)
<<
"File cannot be created: "
<<
filename
VLOG
(
1
)
<<
"File cannot be created: "
<<
filename
...
@@ -54,15 +54,15 @@ bool WriteStringToFile(const string& str, const char* filename) {
...
@@ -54,15 +54,15 @@ bool WriteStringToFile(const string& str, const char* filename) {
namespace
{
namespace
{
class
IfstreamInputStream
:
public
::
google
::
protobuf
::
io
::
CopyingInputStream
{
class
IfstreamInputStream
:
public
::
google
::
protobuf
::
io
::
CopyingInputStream
{
public:
public:
explicit
IfstreamInputStream
(
const
string
&
filename
)
explicit
IfstreamInputStream
(
const
string
&
filename
)
:
ifs_
(
filename
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
)
{}
:
ifs_
(
filename
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
)
{}
~
IfstreamInputStream
()
{
ifs_
.
close
();
}
~
IfstreamInputStream
()
{
ifs_
.
close
();
}
int
Read
(
void
*
buffer
,
int
size
)
{
int
Read
(
void
*
buffer
,
int
size
)
{
if
(
!
ifs_
)
{
if
(
!
ifs_
)
{
return
-
1
;
return
-
1
;
}
}
ifs_
.
read
(
static_cast
<
char
*>
(
buffer
),
size
);
ifs_
.
read
(
static_cast
<
char
*>
(
buffer
),
size
);
return
ifs_
.
gcount
();
return
ifs_
.
gcount
();
}
}
...
@@ -71,7 +71,7 @@ class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
...
@@ -71,7 +71,7 @@ class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
};
};
}
// namespace
}
// namespace
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
::
google
::
protobuf
::
io
::
CopyingInputStreamAdaptor
stream
(
::
google
::
protobuf
::
io
::
CopyingInputStreamAdaptor
stream
(
new
IfstreamInputStream
(
filename
));
new
IfstreamInputStream
(
filename
));
stream
.
SetOwnsCopyingStream
(
true
);
stream
.
SetOwnsCopyingStream
(
true
);
...
@@ -82,8 +82,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
...
@@ -82,8 +82,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
return
proto
->
ParseFromCodedStream
(
&
coded_stream
);
return
proto
->
ParseFromCodedStream
(
&
coded_stream
);
}
}
void
WriteProtoToBinaryFile
(
const
MessageLite
&
/*proto*/
,
void
WriteProtoToBinaryFile
(
const
MessageLite
&
/*proto*/
,
const
char
*
/*filename*/
)
{
const
char
*
/*filename*/
)
{
LOG
(
FATAL
)
<<
"Not implemented yet."
;
LOG
(
FATAL
)
<<
"Not implemented yet."
;
}
}
...
@@ -98,25 +98,25 @@ using ::google::protobuf::io::CodedInputStream;
...
@@ -98,25 +98,25 @@ using ::google::protobuf::io::CodedInputStream;
using
::
google
::
protobuf
::
io
::
ZeroCopyOutputStream
;
using
::
google
::
protobuf
::
io
::
ZeroCopyOutputStream
;
using
::
google
::
protobuf
::
io
::
CodedOutputStream
;
using
::
google
::
protobuf
::
io
::
CodedOutputStream
;
bool
ReadProtoFromTextFile
(
const
char
*
filename
,
Message
*
proto
)
{
bool
ReadProtoFromTextFile
(
const
char
*
filename
,
Message
*
proto
)
{
int
fd
=
open
(
filename
,
O_RDONLY
);
int
fd
=
open
(
filename
,
O_RDONLY
);
MACE_CHECK
(
fd
!=
-
1
,
"File not found: "
,
filename
);
MACE_CHECK
(
fd
!=
-
1
,
"File not found: "
,
filename
);
FileInputStream
*
input
=
new
FileInputStream
(
fd
);
FileInputStream
*
input
=
new
FileInputStream
(
fd
);
bool
success
=
google
::
protobuf
::
TextFormat
::
Parse
(
input
,
proto
);
bool
success
=
google
::
protobuf
::
TextFormat
::
Parse
(
input
,
proto
);
delete
input
;
delete
input
;
close
(
fd
);
close
(
fd
);
return
success
;
return
success
;
}
}
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
char
*
filename
)
{
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
char
*
filename
)
{
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
FileOutputStream
*
output
=
new
FileOutputStream
(
fd
);
FileOutputStream
*
output
=
new
FileOutputStream
(
fd
);
MACE_CHECK
(
google
::
protobuf
::
TextFormat
::
Print
(
proto
,
output
));
MACE_CHECK
(
google
::
protobuf
::
TextFormat
::
Print
(
proto
,
output
));
delete
output
;
delete
output
;
close
(
fd
);
close
(
fd
);
}
}
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int
fd
=
open
(
filename
,
O_RDONLY
|
O_BINARY
);
int
fd
=
open
(
filename
,
O_RDONLY
|
O_BINARY
);
#else
#else
...
@@ -135,7 +135,7 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
...
@@ -135,7 +135,7 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
return
success
;
return
success
;
}
}
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
)
{
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
)
{
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
MACE_CHECK
(
fd
!=
-
1
,
"File cannot be created: "
,
filename
,
" error number: "
,
MACE_CHECK
(
fd
!=
-
1
,
"File cannot be created: "
,
filename
,
" error number: "
,
errno
);
errno
);
...
@@ -150,8 +150,8 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
...
@@ -150,8 +150,8 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
#endif // MACE_USE_LITE_PROTO
#endif // MACE_USE_LITE_PROTO
ArgumentHelper
::
ArgumentHelper
(
const
OperatorDef
&
def
)
{
ArgumentHelper
::
ArgumentHelper
(
const
OperatorDef
&
def
)
{
for
(
auto
&
arg
:
def
.
arg
())
{
for
(
auto
&
arg
:
def
.
arg
())
{
if
(
arg_map_
.
find
(
arg
.
name
())
!=
arg_map_
.
end
())
{
if
(
arg_map_
.
find
(
arg
.
name
())
!=
arg_map_
.
end
())
{
MACE_CHECK
(
MACE_CHECK
(
arg
.
SerializeAsString
()
==
arg_map_
[
arg
.
name
()].
SerializeAsString
(),
arg
.
SerializeAsString
()
==
arg_map_
[
arg
.
name
()].
SerializeAsString
(),
...
@@ -167,8 +167,8 @@ ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
...
@@ -167,8 +167,8 @@ ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
}
}
}
}
ArgumentHelper
::
ArgumentHelper
(
const
NetDef
&
netdef
)
{
ArgumentHelper
::
ArgumentHelper
(
const
NetDef
&
netdef
)
{
for
(
auto
&
arg
:
netdef
.
arg
())
{
for
(
auto
&
arg
:
netdef
.
arg
())
{
MACE_CHECK
(
arg_map_
.
count
(
arg
.
name
())
==
0
,
MACE_CHECK
(
arg_map_
.
count
(
arg
.
name
())
==
0
,
"Duplicated argument name found in net def: "
,
"Duplicated argument name found in net def: "
,
ProtoDebugString
(
netdef
));
ProtoDebugString
(
netdef
));
...
@@ -176,7 +176,7 @@ ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
...
@@ -176,7 +176,7 @@ ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
}
}
}
}
bool
ArgumentHelper
::
HasArgument
(
const
string
&
name
)
const
{
bool
ArgumentHelper
::
HasArgument
(
const
string
&
name
)
const
{
return
arg_map_
.
count
(
name
);
return
arg_map_
.
count
(
name
);
}
}
...
@@ -184,7 +184,7 @@ namespace {
...
@@ -184,7 +184,7 @@ namespace {
// Helper function to verify that conversion between types won't loose any
// Helper function to verify that conversion between types won't loose any
// significant bit.
// significant bit.
template
<
typename
InputType
,
typename
TargetType
>
template
<
typename
InputType
,
typename
TargetType
>
bool
SupportsLosslessConversion
(
const
InputType
&
value
)
{
bool
SupportsLosslessConversion
(
const
InputType
&
value
)
{
return
static_cast
<
InputType
>
(
static_cast
<
TargetType
>
(
value
))
==
value
;
return
static_cast
<
InputType
>
(
static_cast
<
TargetType
>
(
value
))
==
value
;
}
}
}
}
...
@@ -192,8 +192,8 @@ bool SupportsLosslessConversion(const InputType& value) {
...
@@ -192,8 +192,8 @@ bool SupportsLosslessConversion(const InputType& value) {
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \
enforce_lossless_conversion) \
enforce_lossless_conversion) \
template <> \
template <> \
T ArgumentHelper::GetSingleArgument<T>(const string
&
name, \
T ArgumentHelper::GetSingleArgument<T>(const string
&
name, \
const T
&
default_value) const { \
const T
&
default_value) const { \
if (arg_map_.count(name) == 0) { \
if (arg_map_.count(name) == 0) { \
VLOG(1) << "Using default parameter value " << default_value \
VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \
<< " for parameter " << name; \
...
@@ -211,7 +211,7 @@ bool SupportsLosslessConversion(const InputType& value) {
...
@@ -211,7 +211,7 @@ bool SupportsLosslessConversion(const InputType& value) {
return value; \
return value; \
} \
} \
template <> \
template <> \
bool ArgumentHelper::HasSingleArgumentOfType<T>(const string
&
name) const { \
bool ArgumentHelper::HasSingleArgumentOfType<T>(const string
&
name) const { \
if (arg_map_.count(name) == 0) { \
if (arg_map_.count(name) == 0) { \
return false; \
return false; \
} \
} \
...
@@ -235,12 +235,12 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
...
@@ -235,12 +235,12 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
enforce_lossless_conversion) \
enforce_lossless_conversion) \
template <> \
template <> \
vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
const string
& name, const std::vector<T>&
default_value) const { \
const string
&name, const std::vector<T> &
default_value) const { \
if (arg_map_.count(name) == 0) { \
if (arg_map_.count(name) == 0) { \
return default_value; \
return default_value; \
} \
} \
vector<T> values; \
vector<T> values; \
for (const auto
&
v : arg_map_.at(name).fieldname()) { \
for (const auto
&
v : arg_map_.at(name).fieldname()) { \
if (enforce_lossless_conversion) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \
SupportsLosslessConversion<decltype(v), T>(v); \
...
@@ -267,7 +267,7 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
...
@@ -267,7 +267,7 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <> \
template <> \
Argument MakeArgument(const string
& name, const T&
value) { \
Argument MakeArgument(const string
&name, const T &
value) { \
Argument arg; \
Argument arg; \
arg.set_name(name); \
arg.set_name(name); \
arg.set_##fieldname(value); \
arg.set_##fieldname(value); \
...
@@ -282,7 +282,7 @@ MACE_MAKE_SINGULAR_ARGUMENT(string, s)
...
@@ -282,7 +282,7 @@ MACE_MAKE_SINGULAR_ARGUMENT(string, s)
#undef MACE_MAKE_SINGULAR_ARGUMENT
#undef MACE_MAKE_SINGULAR_ARGUMENT
template
<
>
template
<
>
Argument
MakeArgument
(
const
string
&
name
,
const
MessageLite
&
value
)
{
Argument
MakeArgument
(
const
string
&
name
,
const
MessageLite
&
value
)
{
Argument
arg
;
Argument
arg
;
arg
.
set_name
(
name
);
arg
.
set_name
(
name
);
arg
.
set_s
(
value
.
SerializeAsString
());
arg
.
set_s
(
value
.
SerializeAsString
());
...
@@ -291,10 +291,10 @@ Argument MakeArgument(const string& name, const MessageLite& value) {
...
@@ -291,10 +291,10 @@ Argument MakeArgument(const string& name, const MessageLite& value) {
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <> \
template <> \
Argument MakeArgument(const string
& name, const vector<T>&
value) { \
Argument MakeArgument(const string
&name, const vector<T> &
value) { \
Argument arg; \
Argument arg; \
arg.set_name(name); \
arg.set_name(name); \
for (const auto
&
v : value) { \
for (const auto
&
v : value) { \
arg.add_##fieldname(v); \
arg.add_##fieldname(v); \
} \
} \
return arg; \
return arg; \
...
@@ -306,8 +306,8 @@ MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints)
...
@@ -306,8 +306,8 @@ MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints)
MACE_MAKE_REPEATED_ARGUMENT
(
string
,
strings
)
MACE_MAKE_REPEATED_ARGUMENT
(
string
,
strings
)
#undef MACE_MAKE_REPEATED_ARGUMENT
#undef MACE_MAKE_REPEATED_ARGUMENT
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
)
{
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
)
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
if
(
arg
.
name
()
==
name
)
{
if
(
arg
.
name
()
==
name
)
{
return
arg
;
return
arg
;
}
}
...
@@ -318,10 +318,10 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) {
...
@@ -318,10 +318,10 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) {
return
std
::
move
(
Argument
());
return
std
::
move
(
Argument
());
}
}
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
const
string
&
name
,
const
string
&
name
,
bool
def_value
)
{
bool
def_value
)
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
if
(
arg
.
name
()
==
name
)
{
if
(
arg
.
name
()
==
name
)
{
MACE_CHECK
(
arg
.
has_i
(),
"Can't parse argument as bool: "
,
MACE_CHECK
(
arg
.
has_i
(),
"Can't parse argument as bool: "
,
ProtoDebugString
(
arg
));
ProtoDebugString
(
arg
));
...
@@ -331,9 +331,9 @@ bool GetFlagArgument(const OperatorDef& def,
...
@@ -331,9 +331,9 @@ bool GetFlagArgument(const OperatorDef& def,
return
def_value
;
return
def_value
;
}
}
Argument
*
GetMutableArgument
(
const
string
&
name
,
Argument
*
GetMutableArgument
(
const
string
&
name
,
const
bool
create_if_missing
,
const
bool
create_if_missing
,
OperatorDef
*
def
)
{
OperatorDef
*
def
)
{
for
(
int
i
=
0
;
i
<
def
->
arg_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
def
->
arg_size
();
++
i
)
{
if
(
def
->
arg
(
i
).
name
()
==
name
)
{
if
(
def
->
arg
(
i
).
name
()
==
name
)
{
return
def
->
mutable_arg
(
i
);
return
def
->
mutable_arg
(
i
);
...
@@ -341,7 +341,7 @@ Argument* GetMutableArgument(const string& name,
...
@@ -341,7 +341,7 @@ Argument* GetMutableArgument(const string& name,
}
}
// If no argument of the right name is found...
// If no argument of the right name is found...
if
(
create_if_missing
)
{
if
(
create_if_missing
)
{
Argument
*
arg
=
def
->
add_arg
();
Argument
*
arg
=
def
->
add_arg
();
arg
->
set_name
(
name
);
arg
->
set_name
(
name
);
return
arg
;
return
arg
;
}
else
{
}
else
{
...
...
mace/core/proto_utils.h
浏览文件 @
ebff986d
...
@@ -21,56 +21,56 @@ using std::string;
...
@@ -21,56 +21,56 @@ using std::string;
using
::
google
::
protobuf
::
MessageLite
;
using
::
google
::
protobuf
::
MessageLite
;
// Common interfaces that reads file contents into a string.
// Common interfaces that reads file contents into a string.
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
);
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
);
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
);
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
);
// Common interfaces that are supported by both lite and full protobuf.
// Common interfaces that are supported by both lite and full protobuf.
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
);
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
);
inline
bool
ReadProtoFromBinaryFile
(
const
string
filename
,
MessageLite
*
proto
)
{
inline
bool
ReadProtoFromBinaryFile
(
const
string
filename
,
MessageLite
*
proto
)
{
return
ReadProtoFromBinaryFile
(
filename
.
c_str
(),
proto
);
return
ReadProtoFromBinaryFile
(
filename
.
c_str
(),
proto
);
}
}
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
);
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
);
inline
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
inline
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
string
&
filename
)
{
const
string
&
filename
)
{
return
WriteProtoToBinaryFile
(
proto
,
filename
.
c_str
());
return
WriteProtoToBinaryFile
(
proto
,
filename
.
c_str
());
}
}
#ifdef MACE_USE_LITE_PROTO
#ifdef MACE_USE_LITE_PROTO
inline
string
ProtoDebugString
(
const
MessageLite
&
proto
)
{
inline
string
ProtoDebugString
(
const
MessageLite
&
proto
)
{
return
proto
.
SerializeAsString
();
return
proto
.
SerializeAsString
();
}
}
// Text format MessageLite wrappers: these functions do nothing but just
// Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using
// allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support.
// MessageLite but still want text support.
inline
bool
ReadProtoFromTextFile
(
const
char
*
/*filename*/
,
inline
bool
ReadProtoFromTextFile
(
const
char
*
/*filename*/
,
MessageLite
*
/*proto*/
)
{
MessageLite
*
/*proto*/
)
{
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
<<
"calling any text-format protobuffers."
;
<<
"calling any text-format protobuffers."
;
return
false
;
// Just to suppress compiler warning.
return
false
;
// Just to suppress compiler warning.
}
}
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
MessageLite
*
proto
)
{
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
MessageLite
*
proto
)
{
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
}
}
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
/*proto*/
,
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
/*proto*/
,
const
char
*
/*filename*/
)
{
const
char
*
/*filename*/
)
{
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
<<
"calling any text-format protobuffers."
;
<<
"calling any text-format protobuffers."
;
}
}
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
proto
,
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
proto
,
const
string
&
filename
)
{
const
string
&
filename
)
{
return
WriteProtoToTextFile
(
proto
,
filename
.
c_str
());
return
WriteProtoToTextFile
(
proto
,
filename
.
c_str
());
}
}
inline
bool
ReadProtoFromFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
inline
bool
ReadProtoFromFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
return
(
ReadProtoFromBinaryFile
(
filename
,
proto
)
||
return
(
ReadProtoFromBinaryFile
(
filename
,
proto
)
||
ReadProtoFromTextFile
(
filename
,
proto
));
ReadProtoFromTextFile
(
filename
,
proto
));
}
}
inline
bool
ReadProtoFromFile
(
const
string
&
filename
,
MessageLite
*
proto
)
{
inline
bool
ReadProtoFromFile
(
const
string
&
filename
,
MessageLite
*
proto
)
{
return
ReadProtoFromFile
(
filename
.
c_str
(),
proto
);
return
ReadProtoFromFile
(
filename
.
c_str
(),
proto
);
}
}
...
@@ -78,27 +78,27 @@ inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
...
@@ -78,27 +78,27 @@ inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
using
::
google
::
protobuf
::
Message
;
using
::
google
::
protobuf
::
Message
;
inline
string
ProtoDebugString
(
const
Message
&
proto
)
{
inline
string
ProtoDebugString
(
const
Message
&
proto
)
{
return
proto
.
ShortDebugString
();
return
proto
.
ShortDebugString
();
}
}
bool
ReadProtoFromTextFile
(
const
char
*
filename
,
Message
*
proto
);
bool
ReadProtoFromTextFile
(
const
char
*
filename
,
Message
*
proto
);
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
Message
*
proto
)
{
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
Message
*
proto
)
{
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
}
}
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
char
*
filename
);
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
char
*
filename
);
inline
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
string
&
filename
)
{
inline
void
WriteProtoToTextFile
(
const
Message
&
proto
,
const
string
&
filename
)
{
return
WriteProtoToTextFile
(
proto
,
filename
.
c_str
());
return
WriteProtoToTextFile
(
proto
,
filename
.
c_str
());
}
}
// Read Proto from a file, letting the code figure out if it is text or binary.
// Read Proto from a file, letting the code figure out if it is text or binary.
inline
bool
ReadProtoFromFile
(
const
char
*
filename
,
Message
*
proto
)
{
inline
bool
ReadProtoFromFile
(
const
char
*
filename
,
Message
*
proto
)
{
return
(
ReadProtoFromBinaryFile
(
filename
,
proto
)
||
return
(
ReadProtoFromBinaryFile
(
filename
,
proto
)
||
ReadProtoFromTextFile
(
filename
,
proto
));
ReadProtoFromTextFile
(
filename
,
proto
));
}
}
inline
bool
ReadProtoFromFile
(
const
string
&
filename
,
Message
*
proto
)
{
inline
bool
ReadProtoFromFile
(
const
string
&
filename
,
Message
*
proto
)
{
return
ReadProtoFromFile
(
filename
.
c_str
(),
proto
);
return
ReadProtoFromFile
(
filename
.
c_str
(),
proto
);
}
}
...
@@ -107,21 +107,21 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) {
...
@@ -107,21 +107,21 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) {
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableOutputs
=
std
::
initializer_list
<
string
>
,
class
IterableOutputs
=
std
::
initializer_list
<
string
>
,
class
IterableArgs
=
std
::
initializer_list
<
Argument
>>
class
IterableArgs
=
std
::
initializer_list
<
Argument
>>
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
const
string
&
name
,
const
string
&
name
,
const
IterableInputs
&
inputs
,
const
IterableInputs
&
inputs
,
const
IterableOutputs
&
outputs
,
const
IterableOutputs
&
outputs
,
const
IterableArgs
&
args
)
{
const
IterableArgs
&
args
)
{
OperatorDef
def
;
OperatorDef
def
;
def
.
set_type
(
type
);
def
.
set_type
(
type
);
def
.
set_name
(
name
);
def
.
set_name
(
name
);
for
(
const
string
&
in
:
inputs
)
{
for
(
const
string
&
in
:
inputs
)
{
def
.
add_input
(
in
);
def
.
add_input
(
in
);
}
}
for
(
const
string
&
out
:
outputs
)
{
for
(
const
string
&
out
:
outputs
)
{
def
.
add_output
(
out
);
def
.
add_output
(
out
);
}
}
for
(
const
Argument
&
arg
:
args
)
{
for
(
const
Argument
&
arg
:
args
)
{
def
.
add_arg
()
->
CopyFrom
(
arg
);
def
.
add_arg
()
->
CopyFrom
(
arg
);
}
}
return
def
;
return
def
;
...
@@ -131,10 +131,10 @@ OperatorDef CreateOperatorDef(const string& type,
...
@@ -131,10 +131,10 @@ OperatorDef CreateOperatorDef(const string& type,
// to specify args.
// to specify args.
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableOutputs
=
std
::
initializer_list
<
string
>>
class
IterableOutputs
=
std
::
initializer_list
<
string
>>
inline
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
inline
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
const
string
&
name
,
const
string
&
name
,
const
IterableInputs
&
inputs
,
const
IterableInputs
&
inputs
,
const
IterableOutputs
&
outputs
)
{
const
IterableOutputs
&
outputs
)
{
return
CreateOperatorDef
(
type
,
name
,
inputs
,
outputs
,
return
CreateOperatorDef
(
type
,
name
,
inputs
,
outputs
,
std
::
vector
<
Argument
>
());
std
::
vector
<
Argument
>
());
}
}
...
@@ -150,56 +150,56 @@ inline OperatorDef CreateOperatorDef(const string& type,
...
@@ -150,56 +150,56 @@ inline OperatorDef CreateOperatorDef(const string& type,
class
ArgumentHelper
{
class
ArgumentHelper
{
public:
public:
template
<
typename
Def
>
template
<
typename
Def
>
static
bool
HasArgument
(
const
Def
&
def
,
const
string
&
name
)
{
static
bool
HasArgument
(
const
Def
&
def
,
const
string
&
name
)
{
return
ArgumentHelper
(
def
).
HasArgument
(
name
);
return
ArgumentHelper
(
def
).
HasArgument
(
name
);
}
}
template
<
typename
Def
,
typename
T
>
template
<
typename
Def
,
typename
T
>
static
T
GetSingleArgument
(
const
Def
&
def
,
static
T
GetSingleArgument
(
const
Def
&
def
,
const
string
&
name
,
const
string
&
name
,
const
T
&
default_value
)
{
const
T
&
default_value
)
{
return
ArgumentHelper
(
def
).
GetSingleArgument
<
T
>
(
name
,
default_value
);
return
ArgumentHelper
(
def
).
GetSingleArgument
<
T
>
(
name
,
default_value
);
}
}
template
<
typename
Def
,
typename
T
>
template
<
typename
Def
,
typename
T
>
static
bool
HasSingleArgumentOfType
(
const
Def
&
def
,
const
string
&
name
)
{
static
bool
HasSingleArgumentOfType
(
const
Def
&
def
,
const
string
&
name
)
{
return
ArgumentHelper
(
def
).
HasSingleArgumentOfType
<
T
>
(
name
);
return
ArgumentHelper
(
def
).
HasSingleArgumentOfType
<
T
>
(
name
);
}
}
template
<
typename
Def
,
typename
T
>
template
<
typename
Def
,
typename
T
>
static
vector
<
T
>
GetRepeatedArgument
(
static
vector
<
T
>
GetRepeatedArgument
(
const
Def
&
def
,
const
Def
&
def
,
const
string
&
name
,
const
string
&
name
,
const
std
::
vector
<
T
>
&
default_value
=
std
::
vector
<
T
>
())
{
const
std
::
vector
<
T
>
&
default_value
=
std
::
vector
<
T
>
())
{
return
ArgumentHelper
(
def
).
GetRepeatedArgument
<
T
>
(
name
,
default_value
);
return
ArgumentHelper
(
def
).
GetRepeatedArgument
<
T
>
(
name
,
default_value
);
}
}
template
<
typename
Def
,
typename
MessageType
>
template
<
typename
Def
,
typename
MessageType
>
static
MessageType
GetMessageArgument
(
const
Def
&
def
,
const
string
&
name
)
{
static
MessageType
GetMessageArgument
(
const
Def
&
def
,
const
string
&
name
)
{
return
ArgumentHelper
(
def
).
GetMessageArgument
<
MessageType
>
(
name
);
return
ArgumentHelper
(
def
).
GetMessageArgument
<
MessageType
>
(
name
);
}
}
template
<
typename
Def
,
typename
MessageType
>
template
<
typename
Def
,
typename
MessageType
>
static
vector
<
MessageType
>
GetRepeatedMessageArgument
(
const
Def
&
def
,
static
vector
<
MessageType
>
GetRepeatedMessageArgument
(
const
Def
&
def
,
const
string
&
name
)
{
const
string
&
name
)
{
return
ArgumentHelper
(
def
).
GetRepeatedMessageArgument
<
MessageType
>
(
name
);
return
ArgumentHelper
(
def
).
GetRepeatedMessageArgument
<
MessageType
>
(
name
);
}
}
explicit
ArgumentHelper
(
const
OperatorDef
&
def
);
explicit
ArgumentHelper
(
const
OperatorDef
&
def
);
explicit
ArgumentHelper
(
const
NetDef
&
netdef
);
explicit
ArgumentHelper
(
const
NetDef
&
netdef
);
bool
HasArgument
(
const
string
&
name
)
const
;
bool
HasArgument
(
const
string
&
name
)
const
;
template
<
typename
T
>
template
<
typename
T
>
T
GetSingleArgument
(
const
string
&
name
,
const
T
&
default_value
)
const
;
T
GetSingleArgument
(
const
string
&
name
,
const
T
&
default_value
)
const
;
template
<
typename
T
>
template
<
typename
T
>
bool
HasSingleArgumentOfType
(
const
string
&
name
)
const
;
bool
HasSingleArgumentOfType
(
const
string
&
name
)
const
;
template
<
typename
T
>
template
<
typename
T
>
vector
<
T
>
GetRepeatedArgument
(
vector
<
T
>
GetRepeatedArgument
(
const
string
&
name
,
const
string
&
name
,
const
std
::
vector
<
T
>
&
default_value
=
std
::
vector
<
T
>
())
const
;
const
std
::
vector
<
T
>
&
default_value
=
std
::
vector
<
T
>
())
const
;
template
<
typename
MessageType
>
template
<
typename
MessageType
>
MessageType
GetMessageArgument
(
const
string
&
name
)
const
{
MessageType
GetMessageArgument
(
const
string
&
name
)
const
{
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
MessageType
message
;
MessageType
message
;
if
(
arg_map_
.
at
(
name
).
has_s
())
{
if
(
arg_map_
.
at
(
name
).
has_s
())
{
...
@@ -212,7 +212,7 @@ class ArgumentHelper {
...
@@ -212,7 +212,7 @@ class ArgumentHelper {
}
}
template
<
typename
MessageType
>
template
<
typename
MessageType
>
vector
<
MessageType
>
GetRepeatedMessageArgument
(
const
string
&
name
)
const
{
vector
<
MessageType
>
GetRepeatedMessageArgument
(
const
string
&
name
)
const
{
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
vector
<
MessageType
>
messages
(
arg_map_
.
at
(
name
).
strings_size
());
vector
<
MessageType
>
messages
(
arg_map_
.
at
(
name
).
strings_size
());
for
(
int
i
=
0
;
i
<
messages
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
messages
.
size
();
++
i
)
{
...
@@ -226,20 +226,20 @@ class ArgumentHelper {
...
@@ -226,20 +226,20 @@ class ArgumentHelper {
std
::
map
<
string
,
Argument
>
arg_map_
;
std
::
map
<
string
,
Argument
>
arg_map_
;
};
};
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
);
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
);
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
const
string
&
name
,
const
string
&
name
,
bool
def_value
=
false
);
bool
def_value
=
false
);
Argument
*
GetMutableArgument
(
const
string
&
name
,
Argument
*
GetMutableArgument
(
const
string
&
name
,
const
bool
create_if_missing
,
const
bool
create_if_missing
,
OperatorDef
*
def
);
OperatorDef
*
def
);
template
<
typename
T
>
template
<
typename
T
>
Argument
MakeArgument
(
const
string
&
name
,
const
T
&
value
);
Argument
MakeArgument
(
const
string
&
name
,
const
T
&
value
);
template
<
typename
T
>
template
<
typename
T
>
inline
void
AddArgument
(
const
string
&
name
,
const
T
&
value
,
OperatorDef
*
def
)
{
inline
void
AddArgument
(
const
string
&
name
,
const
T
&
value
,
OperatorDef
*
def
)
{
GetMutableArgument
(
name
,
true
,
def
)
->
CopyFrom
(
MakeArgument
(
name
,
value
));
GetMutableArgument
(
name
,
true
,
def
)
->
CopyFrom
(
MakeArgument
(
name
,
value
));
}
}
...
...
mace/core/registry.h
浏览文件 @
ebff986d
...
@@ -16,15 +16,15 @@ class Registry {
...
@@ -16,15 +16,15 @@ class Registry {
Registry
()
:
registry_
()
{}
Registry
()
:
registry_
()
{}
void
Register
(
const
SrcType
&
key
,
Creator
creator
)
{
void
Register
(
const
SrcType
&
key
,
Creator
creator
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
register_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
register_mutex_
);
MACE_CHECK
(
registry_
.
count
(
key
)
==
0
,
"Key already registered."
);
MACE_CHECK
(
registry_
.
count
(
key
)
==
0
,
"Key already registered."
);
registry_
[
key
]
=
creator
;
registry_
[
key
]
=
creator
;
}
}
inline
bool
Has
(
const
SrcType
&
key
)
{
return
registry_
.
count
(
key
)
!=
0
;
}
inline
bool
Has
(
const
SrcType
&
key
)
{
return
registry_
.
count
(
key
)
!=
0
;
}
unique_ptr
<
ObjectType
>
Create
(
const
SrcType
&
key
,
Args
...
args
)
{
unique_ptr
<
ObjectType
>
Create
(
const
SrcType
&
key
,
Args
...
args
)
{
if
(
registry_
.
count
(
key
)
==
0
)
{
if
(
registry_
.
count
(
key
)
==
0
)
{
LOG
(
FATAL
)
<<
"Key not registered: "
<<
key
;
LOG
(
FATAL
)
<<
"Key not registered: "
<<
key
;
}
}
...
@@ -36,7 +36,7 @@ class Registry {
...
@@ -36,7 +36,7 @@ class Registry {
*/
*/
vector
<
SrcType
>
Keys
()
{
vector
<
SrcType
>
Keys
()
{
vector
<
SrcType
>
keys
;
vector
<
SrcType
>
keys
;
for
(
const
auto
&
it
:
registry_
)
{
for
(
const
auto
&
it
:
registry_
)
{
keys
.
push_back
(
it
.
first
);
keys
.
push_back
(
it
.
first
);
}
}
return
keys
;
return
keys
;
...
@@ -52,8 +52,8 @@ class Registry {
...
@@ -52,8 +52,8 @@ class Registry {
template
<
class
SrcType
,
class
ObjectType
,
class
...
Args
>
template
<
class
SrcType
,
class
ObjectType
,
class
...
Args
>
class
Registerer
{
class
Registerer
{
public:
public:
Registerer
(
const
SrcType
&
key
,
Registerer
(
const
SrcType
&
key
,
Registry
<
SrcType
,
ObjectType
,
Args
...
>
*
registry
,
Registry
<
SrcType
,
ObjectType
,
Args
...
>
*
registry
,
typename
Registry
<
SrcType
,
ObjectType
,
Args
...
>::
Creator
creator
)
{
typename
Registry
<
SrcType
,
ObjectType
,
Args
...
>::
Creator
creator
)
{
registry
->
Register
(
key
,
creator
);
registry
->
Register
(
key
,
creator
);
}
}
...
@@ -73,13 +73,13 @@ class Registerer {
...
@@ -73,13 +73,13 @@ class Registerer {
#endif
#endif
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>
*
RegistryName(); \
Registry<SrcType, ObjectType, ##__VA_ARGS__>
*
RegistryName(); \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName;
Registerer##RegistryName;
#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>
*
RegistryName() { \
Registry<SrcType, ObjectType, ##__VA_ARGS__>
*
RegistryName() { \
static Registry<SrcType, ObjectType, ##__VA_ARGS__>
*
registry = \
static Registry<SrcType, ObjectType, ##__VA_ARGS__>
*
registry = \
new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \
new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \
return registry; \
return registry; \
}
}
...
...
mace/core/runtime/opencl/opencl_allocator.cc
浏览文件 @
ebff986d
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
//
//
#include "mace/core/runtime/opencl/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/cl2.hpp"
#include "mace/core/runtime/opencl/cl2.hpp"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace
mace
{
namespace
mace
{
...
...
mace/core/runtime/opencl/opencl_runtime.cc
浏览文件 @
ebff986d
...
@@ -30,7 +30,9 @@ bool ReadSourceFile(const char *filename, std::string *content) {
...
@@ -30,7 +30,9 @@ bool ReadSourceFile(const char *filename, std::string *content) {
return
true
;
return
true
;
}
}
bool
BuildProgram
(
OpenCLRuntime
*
runtime
,
const
char
*
filename
,
cl
::
Program
*
program
)
{
bool
BuildProgram
(
OpenCLRuntime
*
runtime
,
const
char
*
filename
,
cl
::
Program
*
program
)
{
MACE_CHECK_NOTNULL
(
filename
);
MACE_CHECK_NOTNULL
(
filename
);
MACE_CHECK_NOTNULL
(
program
);
MACE_CHECK_NOTNULL
(
program
);
...
...
mace/core/serializer.h
浏览文件 @
ebff986d
...
@@ -16,9 +16,9 @@ class Serializer {
...
@@ -16,9 +16,9 @@ class Serializer {
Serializer
()
{}
Serializer
()
{}
~
Serializer
()
{}
~
Serializer
()
{}
unique_ptr
<
TensorProto
>
Serialize
(
const
Tensor
&
tensor
,
const
string
&
name
);
unique_ptr
<
TensorProto
>
Serialize
(
const
Tensor
&
tensor
,
const
string
&
name
);
unique_ptr
<
Tensor
>
Deserialize
(
const
TensorProto
&
proto
,
DeviceType
type
);
unique_ptr
<
Tensor
>
Deserialize
(
const
TensorProto
&
proto
,
DeviceType
type
);
DISABLE_COPY_AND_ASSIGN
(
Serializer
);
DISABLE_COPY_AND_ASSIGN
(
Serializer
);
};
};
...
...
mace/core/tensor.h
浏览文件 @
ebff986d
...
@@ -202,15 +202,15 @@ class Tensor {
...
@@ -202,15 +202,15 @@ class Tensor {
}
}
class
MappingGuard
{
class
MappingGuard
{
public:
public:
MappingGuard
(
Tensor
*
tensor
)
:
tensor_
(
tensor
)
{
MappingGuard
(
Tensor
*
tensor
)
:
tensor_
(
tensor
)
{
MACE_ASSERT
(
tensor_
!=
nullptr
);
MACE_ASSERT
(
tensor_
!=
nullptr
);
tensor_
->
Map
();
tensor_
->
Map
();
}
}
~
MappingGuard
()
{
tensor_
->
Unmap
();
}
~
MappingGuard
()
{
tensor_
->
Unmap
();
}
private:
private:
Tensor
*
tensor_
;
Tensor
*
tensor_
;
};
};
private:
private:
...
...
mace/core/testing/test_benchmark.cc
浏览文件 @
ebff986d
...
@@ -16,36 +16,36 @@
...
@@ -16,36 +16,36 @@
namespace
mace
{
namespace
mace
{
namespace
testing
{
namespace
testing
{
static
std
::
vector
<
Benchmark
*>*
all_benchmarks
=
nullptr
;
static
std
::
vector
<
Benchmark
*>
*
all_benchmarks
=
nullptr
;
static
std
::
string
label
;
static
std
::
string
label
;
static
int64_t
bytes_processed
;
static
int64_t
bytes_processed
;
static
int64_t
items_processed
;
static
int64_t
items_processed
;
static
int64_t
accum_time
=
0
;
static
int64_t
accum_time
=
0
;
static
int64_t
start_time
=
0
;
static
int64_t
start_time
=
0
;
Benchmark
::
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
))
Benchmark
::
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
))
:
name_
(
name
),
num_args_
(
0
),
fn0_
(
fn
)
{
:
name_
(
name
),
num_args_
(
0
),
fn0_
(
fn
)
{
args_
.
push_back
(
std
::
make_pair
(
-
1
,
-
1
));
args_
.
push_back
(
std
::
make_pair
(
-
1
,
-
1
));
Register
();
Register
();
}
}
Benchmark
::
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
))
Benchmark
::
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
))
:
name_
(
name
),
num_args_
(
1
),
fn1_
(
fn
)
{
:
name_
(
name
),
num_args_
(
1
),
fn1_
(
fn
)
{
Register
();
Register
();
}
}
Benchmark
::
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
,
int
))
Benchmark
::
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
,
int
))
:
name_
(
name
),
num_args_
(
2
),
fn2_
(
fn
)
{
:
name_
(
name
),
num_args_
(
2
),
fn2_
(
fn
)
{
Register
();
Register
();
}
}
Benchmark
*
Benchmark
::
Arg
(
int
x
)
{
Benchmark
*
Benchmark
::
Arg
(
int
x
)
{
MACE_CHECK
(
num_args_
==
1
);
MACE_CHECK
(
num_args_
==
1
);
args_
.
push_back
(
std
::
make_pair
(
x
,
-
1
));
args_
.
push_back
(
std
::
make_pair
(
x
,
-
1
));
return
this
;
return
this
;
}
}
Benchmark
*
Benchmark
::
ArgPair
(
int
x
,
int
y
)
{
Benchmark
*
Benchmark
::
ArgPair
(
int
x
,
int
y
)
{
MACE_CHECK
(
num_args_
==
2
);
MACE_CHECK
(
num_args_
==
2
);
args_
.
push_back
(
std
::
make_pair
(
x
,
y
));
args_
.
push_back
(
std
::
make_pair
(
x
,
y
));
return
this
;
return
this
;
...
@@ -54,7 +54,7 @@ Benchmark* Benchmark::ArgPair(int x, int y) {
...
@@ -54,7 +54,7 @@ Benchmark* Benchmark::ArgPair(int x, int y) {
// Run all benchmarks
// Run all benchmarks
void
Benchmark
::
Run
()
{
Run
(
"all"
);
}
void
Benchmark
::
Run
()
{
Run
(
"all"
);
}
void
Benchmark
::
Run
(
const
char
*
pattern
)
{
void
Benchmark
::
Run
(
const
char
*
pattern
)
{
if
(
!
all_benchmarks
)
return
;
if
(
!
all_benchmarks
)
return
;
if
(
std
::
string
(
pattern
)
==
"all"
)
{
if
(
std
::
string
(
pattern
)
==
"all"
)
{
...
@@ -117,11 +117,11 @@ void Benchmark::Run(const char* pattern) {
...
@@ -117,11 +117,11 @@ void Benchmark::Run(const char* pattern) {
}
}
void
Benchmark
::
Register
()
{
void
Benchmark
::
Register
()
{
if
(
!
all_benchmarks
)
all_benchmarks
=
new
std
::
vector
<
Benchmark
*>
;
if
(
!
all_benchmarks
)
all_benchmarks
=
new
std
::
vector
<
Benchmark
*>
;
all_benchmarks
->
push_back
(
this
);
all_benchmarks
->
push_back
(
this
);
}
}
void
Benchmark
::
Run
(
int
arg1
,
int
arg2
,
int
*
run_count
,
double
*
run_seconds
)
{
void
Benchmark
::
Run
(
int
arg1
,
int
arg2
,
int
*
run_count
,
double
*
run_seconds
)
{
static
const
int64_t
kMinIters
=
10
;
static
const
int64_t
kMinIters
=
10
;
static
const
int64_t
kMaxIters
=
1000000000
;
static
const
int64_t
kMaxIters
=
1000000000
;
static
const
double
kMinTime
=
0.5
;
static
const
double
kMinTime
=
0.5
;
...
...
mace/core/testing/test_benchmark.h
浏览文件 @
ebff986d
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c
#define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c
#define BENCHMARK(n) \
#define BENCHMARK(n) \
static ::mace::testing::Benchmark
*
MACE_BENCHMARK_CONCAT( \
static ::mace::testing::Benchmark
*
MACE_BENCHMARK_CONCAT( \
__benchmark_, n, __LINE__) = (new ::mace::testing::Benchmark(#n, (n)))
__benchmark_, n, __LINE__) = (new ::mace::testing::Benchmark(#n, (n)))
namespace
mace
{
namespace
mace
{
...
@@ -21,14 +21,14 @@ namespace testing {
...
@@ -21,14 +21,14 @@ namespace testing {
class
Benchmark
{
class
Benchmark
{
public:
public:
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
));
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
));
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
));
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
));
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
,
int
));
Benchmark
(
const
char
*
name
,
void
(
*
fn
)(
int
,
int
,
int
));
Benchmark
*
Arg
(
int
x
);
Benchmark
*
Arg
(
int
x
);
Benchmark
*
ArgPair
(
int
x
,
int
y
);
Benchmark
*
ArgPair
(
int
x
,
int
y
);
static
void
Run
();
static
void
Run
();
static
void
Run
(
const
char
*
pattern
);
static
void
Run
(
const
char
*
pattern
);
private:
private:
string
name_
;
string
name_
;
...
@@ -39,7 +39,7 @@ class Benchmark {
...
@@ -39,7 +39,7 @@ class Benchmark {
void
(
*
fn2_
)(
int
,
int
,
int
)
=
nullptr
;
void
(
*
fn2_
)(
int
,
int
,
int
)
=
nullptr
;
void
Register
();
void
Register
();
void
Run
(
int
arg1
,
int
arg2
,
int
*
run_count
,
double
*
run_seconds
);
void
Run
(
int
arg1
,
int
arg2
,
int
*
run_count
,
double
*
run_seconds
);
};
};
void
RunBenchmarks
();
void
RunBenchmarks
();
...
...
mace/core/testing/test_benchmark_main.cc
浏览文件 @
ebff986d
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "mace/core/testing/test_benchmark.h"
#include "mace/core/testing/test_benchmark.h"
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
std
::
cout
<<
"Running main() from test_main.cc
\n
"
;
std
::
cout
<<
"Running main() from test_main.cc
\n
"
;
// TODO Use gflags
// TODO Use gflags
...
...
mace/core/types.cc
浏览文件 @
ebff986d
...
@@ -23,4 +23,4 @@ bool DataTypeCanUseMemcpy(DataType dt) {
...
@@ -23,4 +23,4 @@ bool DataTypeCanUseMemcpy(DataType dt) {
}
}
}
}
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/core/workspace.cc
浏览文件 @
ebff986d
...
@@ -10,14 +10,14 @@ namespace mace {
...
@@ -10,14 +10,14 @@ namespace mace {
vector
<
string
>
Workspace
::
Tensors
()
const
{
vector
<
string
>
Workspace
::
Tensors
()
const
{
vector
<
string
>
names
;
vector
<
string
>
names
;
for
(
auto
&
entry
:
tensor_map_
)
{
for
(
auto
&
entry
:
tensor_map_
)
{
names
.
push_back
(
entry
.
first
);
names
.
push_back
(
entry
.
first
);
}
}
return
names
;
return
names
;
}
}
Tensor
*
Workspace
::
CreateTensor
(
const
string
&
name
,
Tensor
*
Workspace
::
CreateTensor
(
const
string
&
name
,
Allocator
*
alloc
,
Allocator
*
alloc
,
DataType
type
)
{
DataType
type
)
{
if
(
HasTensor
(
name
))
{
if
(
HasTensor
(
name
))
{
VLOG
(
1
)
<<
"Tensor "
<<
name
<<
" already exists. Skipping."
;
VLOG
(
1
)
<<
"Tensor "
<<
name
<<
" already exists. Skipping."
;
...
@@ -28,7 +28,7 @@ Tensor* Workspace::CreateTensor(const string& name,
...
@@ -28,7 +28,7 @@ Tensor* Workspace::CreateTensor(const string& name,
return
GetTensor
(
name
);
return
GetTensor
(
name
);
}
}
bool
Workspace
::
RemoveTensor
(
const
string
&
name
)
{
bool
Workspace
::
RemoveTensor
(
const
string
&
name
)
{
auto
it
=
tensor_map_
.
find
(
name
);
auto
it
=
tensor_map_
.
find
(
name
);
if
(
it
!=
tensor_map_
.
end
())
{
if
(
it
!=
tensor_map_
.
end
())
{
VLOG
(
1
)
<<
"Removing blob "
<<
name
<<
" from this workspace."
;
VLOG
(
1
)
<<
"Removing blob "
<<
name
<<
" from this workspace."
;
...
@@ -38,7 +38,7 @@ bool Workspace::RemoveTensor(const string& name) {
...
@@ -38,7 +38,7 @@ bool Workspace::RemoveTensor(const string& name) {
return
false
;
return
false
;
}
}
const
Tensor
*
Workspace
::
GetTensor
(
const
string
&
name
)
const
{
const
Tensor
*
Workspace
::
GetTensor
(
const
string
&
name
)
const
{
if
(
tensor_map_
.
count
(
name
))
{
if
(
tensor_map_
.
count
(
name
))
{
return
tensor_map_
.
at
(
name
).
get
();
return
tensor_map_
.
at
(
name
).
get
();
}
else
{
}
else
{
...
@@ -47,18 +47,17 @@ const Tensor* Workspace::GetTensor(const string& name) const {
...
@@ -47,18 +47,17 @@ const Tensor* Workspace::GetTensor(const string& name) const {
return
nullptr
;
return
nullptr
;
}
}
Tensor
*
Workspace
::
GetTensor
(
const
string
&
name
)
{
Tensor
*
Workspace
::
GetTensor
(
const
string
&
name
)
{
return
const_cast
<
Tensor
*>
(
return
const_cast
<
Tensor
*>
(
static_cast
<
const
Workspace
*>
(
this
)
->
GetTensor
(
name
));
static_cast
<
const
Workspace
*>
(
this
)
->
GetTensor
(
name
));
}
}
void
Workspace
::
LoadModelTensor
(
const
NetDef
&
net_def
,
DeviceType
type
)
{
void
Workspace
::
LoadModelTensor
(
const
NetDef
&
net_def
,
DeviceType
type
)
{
Serializer
serializer
;
Serializer
serializer
;
for
(
auto
&
tensor_proto
:
net_def
.
tensors
())
{
for
(
auto
&
tensor_proto
:
net_def
.
tensors
())
{
VLOG
(
1
)
<<
"Load tensor: "
<<
tensor_proto
.
name
()
<<
" has shape: "
VLOG
(
1
)
<<
"Load tensor: "
<<
tensor_proto
.
name
()
<<
internal
::
MakeString
(
vector
<
index_t
>
(
tensor_proto
.
dims
().
begin
(),
<<
" has shape: "
<<
internal
::
MakeString
(
vector
<
index_t
>
(
tensor_proto
.
dims
().
end
()));
tensor_proto
.
dims
().
begin
(),
tensor_proto
.
dims
().
end
()));
tensor_map_
[
tensor_proto
.
name
()]
=
tensor_map_
[
tensor_proto
.
name
()]
=
serializer
.
Deserialize
(
tensor_proto
,
type
);
serializer
.
Deserialize
(
tensor_proto
,
type
);
}
}
...
...
mace/core/workspace.h
浏览文件 @
ebff986d
...
@@ -19,19 +19,19 @@ class Workspace {
...
@@ -19,19 +19,19 @@ class Workspace {
vector
<
string
>
Tensors
()
const
;
vector
<
string
>
Tensors
()
const
;
Tensor
*
CreateTensor
(
const
string
&
name
,
Allocator
*
alloc
,
DataType
type
);
Tensor
*
CreateTensor
(
const
string
&
name
,
Allocator
*
alloc
,
DataType
type
);
bool
RemoveTensor
(
const
string
&
name
);
bool
RemoveTensor
(
const
string
&
name
);
inline
bool
HasTensor
(
const
string
&
name
)
const
{
inline
bool
HasTensor
(
const
string
&
name
)
const
{
return
tensor_map_
.
count
(
name
);
return
tensor_map_
.
count
(
name
);
}
}
const
Tensor
*
GetTensor
(
const
string
&
name
)
const
;
const
Tensor
*
GetTensor
(
const
string
&
name
)
const
;
Tensor
*
GetTensor
(
const
string
&
name
);
Tensor
*
GetTensor
(
const
string
&
name
);
void
LoadModelTensor
(
const
NetDef
&
net_def
,
DeviceType
type
);
void
LoadModelTensor
(
const
NetDef
&
net_def
,
DeviceType
type
);
private:
private:
TensorMap
tensor_map_
;
TensorMap
tensor_map_
;
...
...
mace/examples/benchmark_example.cc
浏览文件 @
ebff986d
...
@@ -10,8 +10,8 @@ static void foo(int iters) {
...
@@ -10,8 +10,8 @@ static void foo(int iters) {
mace
::
testing
::
ItemsProcessed
(
tot
);
mace
::
testing
::
ItemsProcessed
(
tot
);
mace
::
testing
::
BytesProcessed
(
tot
*
(
sizeof
(
float
)));
mace
::
testing
::
BytesProcessed
(
tot
*
(
sizeof
(
float
)));
float
*
inp
=
new
float
[
N
];
float
*
inp
=
new
float
[
N
];
float
*
out
=
new
float
[
N
];
float
*
out
=
new
float
[
N
];
while
(
iters
--
)
{
while
(
iters
--
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
...
@@ -29,8 +29,8 @@ static void bar(int iters, int n) {
...
@@ -29,8 +29,8 @@ static void bar(int iters, int n) {
mace
::
testing
::
ItemsProcessed
(
tot
);
mace
::
testing
::
ItemsProcessed
(
tot
);
mace
::
testing
::
BytesProcessed
(
tot
*
(
sizeof
(
float
)));
mace
::
testing
::
BytesProcessed
(
tot
*
(
sizeof
(
float
)));
float
*
inp
=
new
float
[
n
];
float
*
inp
=
new
float
[
n
];
float
*
out
=
new
float
[
n
];
float
*
out
=
new
float
[
n
];
while
(
iters
--
)
{
while
(
iters
--
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
...
...
mace/examples/mace_run.cc
浏览文件 @
ebff986d
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
* --output_file=mace.out \
* --output_file=mace.out \
* --device=NEON
* --device=NEON
*/
*/
#include <fstream>
#include <sys/time.h>
#include <sys/time.h>
#include <fstream>
#include "mace/core/net.h"
#include "mace/core/net.h"
#include "mace/utils/command_line_flags.h"
#include "mace/utils/command_line_flags.h"
...
@@ -83,12 +83,11 @@ int main(int argc, char **argv) {
...
@@ -83,12 +83,11 @@ int main(int argc, char **argv) {
Workspace
ws
;
Workspace
ws
;
ws
.
LoadModelTensor
(
net_def
,
DeviceType
::
CPU
);
ws
.
LoadModelTensor
(
net_def
,
DeviceType
::
CPU
);
Tensor
*
input_tensor
=
ws
.
CreateTensor
(
input_node
+
":0"
,
Tensor
*
input_tensor
=
cpu_allocator
(),
DT_FLOAT
);
ws
.
CreateTensor
(
input_node
+
":0"
,
cpu_allocator
(),
DT_FLOAT
);
input_tensor
->
Resize
(
shape
);
input_tensor
->
Resize
(
shape
);
float
*
input_data
=
input_tensor
->
mutable_data
<
float
>
();
float
*
input_data
=
input_tensor
->
mutable_data
<
float
>
();
// load input
// load input
ifstream
in_file
(
input_file
,
ios
::
in
|
ios
::
binary
);
ifstream
in_file
(
input_file
,
ios
::
in
|
ios
::
binary
);
in_file
.
read
(
reinterpret_cast
<
char
*>
(
input_data
),
in_file
.
read
(
reinterpret_cast
<
char
*>
(
input_data
),
...
@@ -112,14 +111,17 @@ int main(int argc, char **argv) {
...
@@ -112,14 +111,17 @@ int main(int argc, char **argv) {
net
->
Run
();
net
->
Run
();
}
}
gettimeofday
(
&
tv2
,
NULL
);
gettimeofday
(
&
tv2
,
NULL
);
cout
<<
"avg duration: "
<<
((
tv2
.
tv_sec
-
tv1
.
tv_sec
)
*
1000
cout
<<
"avg duration: "
+
(
tv2
.
tv_usec
-
tv1
.
tv_usec
)
/
1000
)
/
round
<<
endl
;
<<
((
tv2
.
tv_sec
-
tv1
.
tv_sec
)
*
1000
+
(
tv2
.
tv_usec
-
tv1
.
tv_usec
)
/
1000
)
/
round
<<
endl
;
// save output
// save output
const
Tensor
*
output
=
ws
.
GetTensor
(
output_node
+
":0"
);
const
Tensor
*
output
=
ws
.
GetTensor
(
output_node
+
":0"
);
ofstream
out_file
(
output_file
,
ios
::
binary
);
ofstream
out_file
(
output_file
,
ios
::
binary
);
out_file
.
write
((
const
char
*
)
(
output
->
data
<
float
>
()),
out_file
.
write
((
const
char
*
)(
output
->
data
<
float
>
()),
output
->
size
()
*
sizeof
(
float
));
output
->
size
()
*
sizeof
(
float
));
out_file
.
flush
();
out_file
.
flush
();
out_file
.
close
();
out_file
.
close
();
...
...
mace/kernels/BUILD
浏览文件 @
ebff986d
...
@@ -20,7 +20,7 @@ cc_library(
...
@@ -20,7 +20,7 @@ cc_library(
linkopts
=
if_android
([
"-lm"
]),
linkopts
=
if_android
([
"-lm"
]),
deps
=
[
deps
=
[
"//mace/core"
,
"//mace/core"
,
"//mace/utils
:utils
"
,
"//mace/utils"
,
],
],
)
)
...
...
mace/kernels/addn.h
浏览文件 @
ebff986d
...
@@ -12,7 +12,7 @@ namespace kernels {
...
@@ -12,7 +12,7 @@ namespace kernels {
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
AddNFunctor
{
struct
AddNFunctor
{
void
operator
()(
const
vector
<
const
T
*>&
inputs
,
T
*
output
,
index_t
size
)
{
void
operator
()(
const
vector
<
const
T
*>
&
inputs
,
T
*
output
,
index_t
size
)
{
memset
(
output
,
0
,
size
*
sizeof
(
T
));
memset
(
output
,
0
,
size
*
sizeof
(
T
));
int
n
=
inputs
.
size
();
int
n
=
inputs
.
size
();
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
@@ -25,7 +25,7 @@ struct AddNFunctor {
...
@@ -25,7 +25,7 @@ struct AddNFunctor {
template
<
>
template
<
>
void
AddNFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
AddNFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
vector
<
const
float
*>&
inputs
,
float
*
output
,
index_t
size
);
const
vector
<
const
float
*>
&
inputs
,
float
*
output
,
index_t
size
);
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
...
...
mace/kernels/batch_norm.h
浏览文件 @
ebff986d
...
@@ -13,17 +13,16 @@ namespace kernels {
...
@@ -13,17 +13,16 @@ namespace kernels {
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
BatchNormFunctor
{
struct
BatchNormFunctor
{
void
operator
()(
const
T
*
input
,
void
operator
()(
const
T
*
input
,
const
T
*
scale
,
const
T
*
scale
,
const
T
*
offset
,
const
T
*
offset
,
const
T
*
mean
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
var
,
const
float
variance_epsilon
,
const
float
variance_epsilon
,
const
index_t
n
,
const
index_t
n
,
const
index_t
channel
,
const
index_t
channel
,
const
index_t
sample_size
,
const
index_t
sample_size
,
T
*
output
)
{
T
*
output
)
{
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
...
@@ -40,8 +39,8 @@ struct BatchNormFunctor {
...
@@ -40,8 +39,8 @@ struct BatchNormFunctor {
index_t
pos
=
c
*
sample_size
;
index_t
pos
=
c
*
sample_size
;
for
(
index_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
n
;
++
i
)
{
const
T
*
input_sample_ptr
=
input
+
pos
;
const
T
*
input_sample_ptr
=
input
+
pos
;
T
*
output_sample_ptr
=
output
+
pos
;
T
*
output_sample_ptr
=
output
+
pos
;
for
(
index_t
j
=
0
;
j
<
sample_size
;
++
j
)
{
for
(
index_t
j
=
0
;
j
<
sample_size
;
++
j
)
{
output_sample_ptr
[
j
]
=
new_scale
*
input_sample_ptr
[
j
]
+
new_offset
;
output_sample_ptr
[
j
]
=
new_scale
*
input_sample_ptr
[
j
]
+
new_offset
;
}
}
...
@@ -53,16 +52,16 @@ struct BatchNormFunctor {
...
@@ -53,16 +52,16 @@ struct BatchNormFunctor {
template
<
>
template
<
>
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
float
*
scale
,
const
float
*
scale
,
const
float
*
offset
,
const
float
*
offset
,
const
float
*
mean
,
const
float
*
mean
,
const
float
*
var
,
const
float
*
var
,
const
float
variance_epsilon
,
const
float
variance_epsilon
,
const
index_t
n
,
const
index_t
n
,
const
index_t
channel
,
const
index_t
channel
,
const
index_t
sample_size
,
const
index_t
sample_size
,
float
*
output
);
float
*
output
);
}
// namepsace kernels
}
// namepsace kernels
}
// namespace mace
}
// namespace mace
...
...
mace/kernels/channel_shuffle.h
浏览文件 @
ebff986d
...
@@ -10,11 +10,10 @@
...
@@ -10,11 +10,10 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
ChannelShuffleFunctor
{
class
ChannelShuffleFunctor
{
public:
public:
ChannelShuffleFunctor
(
const
int
group
)
ChannelShuffleFunctor
(
const
int
group
)
:
group_
(
group
)
{}
:
group_
(
group
)
{}
void
operator
()(
const
T
*
input
,
const
index_t
*
input_shape
,
T
*
output
)
{
void
operator
()(
const
T
*
input
,
const
index_t
*
input_shape
,
T
*
output
)
{
index_t
batch
=
input_shape
[
0
];
index_t
batch
=
input_shape
[
0
];
...
@@ -28,8 +27,8 @@ class ChannelShuffleFunctor {
...
@@ -28,8 +27,8 @@ class ChannelShuffleFunctor {
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
c
=
0
;
c
<
channels_of_group
;
++
c
)
{
for
(
int
c
=
0
;
c
<
channels_of_group
;
++
c
)
{
for
(
int
g
=
0
;
g
<
group_
;
++
g
)
{
for
(
int
g
=
0
;
g
<
group_
;
++
g
)
{
index_t
input_offset
=
(
b
*
channels
+
g
*
channels_of_group
+
c
)
*
index_t
input_offset
=
image_size
;
(
b
*
channels
+
g
*
channels_of_group
+
c
)
*
image_size
;
index_t
output_offset
=
(
b
*
channels
+
c
*
group_
+
g
)
*
image_size
;
index_t
output_offset
=
(
b
*
channels
+
c
*
group_
+
g
)
*
image_size
;
memcpy
(
output
+
output_offset
,
input
+
input_offset
,
memcpy
(
output
+
output_offset
,
input
+
input_offset
,
image_size
*
sizeof
(
T
));
image_size
*
sizeof
(
T
));
...
...
mace/kernels/concat.h
浏览文件 @
ebff986d
...
@@ -5,13 +5,13 @@
...
@@ -5,13 +5,13 @@
#ifndef MACE_KERNELS_CONCAT_H_
#ifndef MACE_KERNELS_CONCAT_H_
#define MACE_KERNELS_CONCAT_H_
#define MACE_KERNELS_CONCAT_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
#include "mace/core/common.h"
#include "mace/core/types.h"
#include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
ConcatFunctor
{
struct
ConcatFunctor
{
void
operator
()(
std
::
vector
<
const
T
*>
&
input_list
,
void
operator
()(
std
::
vector
<
const
T
*>
&
input_list
,
const
index_t
inner_dim
,
const
index_t
inner_dim
,
...
@@ -35,6 +35,6 @@ struct ConcatFunctor {
...
@@ -35,6 +35,6 @@ struct ConcatFunctor {
};
};
}
// namepsace kernels
}
// namepsace kernels
}
// namespace mace
}
// namespace mace
#endif // MACE_KERNELS_CONCAT_H_
#endif
// MACE_KERNELS_CONCAT_H_
mace/kernels/conv_2d.h
浏览文件 @
ebff986d
...
@@ -11,15 +11,13 @@
...
@@ -11,15 +11,13 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
Conv2dFunctor
{
struct
Conv2dFunctor
{
Conv2dFunctor
()
{}
Conv2dFunctor
()
{}
Conv2dFunctor
(
const
int
*
strides
,
Conv2dFunctor
(
const
int
*
strides
,
const
std
::
vector
<
int
>
&
paddings
,
const
std
::
vector
<
int
>
&
paddings
,
const
int
*
dilations
)
:
const
int
*
dilations
)
strides_
(
strides
),
:
strides_
(
strides
),
paddings_
(
paddings
),
dilations_
(
dilations
)
{}
paddings_
(
paddings
),
dilations_
(
dilations
)
{}
void
operator
()(
const
T
*
input
,
// NCHW
void
operator
()(
const
T
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
...
@@ -66,7 +64,7 @@ struct Conv2dFunctor {
...
@@ -66,7 +64,7 @@ struct Conv2dFunctor {
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
index_t
offset
=
n
*
channels
*
height
*
width
+
index_t
offset
=
n
*
channels
*
height
*
width
+
c
*
height
*
width
+
h
*
width
+
w
;
c
*
height
*
width
+
h
*
width
+
w
;
output
[
offset
]
=
bias_channel
;
output
[
offset
]
=
bias_channel
;
T
sum
=
0
;
T
sum
=
0
;
const
T
*
filter_ptr
=
filter
+
c
*
kernel_size
;
const
T
*
filter_ptr
=
filter
+
c
*
kernel_size
;
...
@@ -78,7 +76,7 @@ struct Conv2dFunctor {
...
@@ -78,7 +76,7 @@ struct Conv2dFunctor {
if
(
inh
<
0
||
inh
>=
input_height
||
inw
<
0
||
if
(
inh
<
0
||
inh
>=
input_height
||
inw
<
0
||
inw
>=
input_width
)
{
inw
>=
input_width
)
{
MACE_CHECK
(
inh
>=
padded_h_start
&&
inh
<
padded_h_stop
&&
MACE_CHECK
(
inh
>=
padded_h_start
&&
inh
<
padded_h_stop
&&
inw
>=
padded_w_start
&&
inw
<
padded_w_stop
,
inw
>=
padded_w_start
&&
inw
<
padded_w_stop
,
"Out of range read from input: "
,
inh
,
", "
,
"Out of range read from input: "
,
inh
,
", "
,
inw
);
inw
);
// else padding with 0:
// else padding with 0:
...
@@ -86,8 +84,8 @@ struct Conv2dFunctor {
...
@@ -86,8 +84,8 @@ struct Conv2dFunctor {
}
else
{
}
else
{
index_t
input_offset
=
index_t
input_offset
=
n
*
input_channels
*
input_height
*
input_width
+
n
*
input_channels
*
input_height
*
input_width
+
inc
*
input_height
*
input_width
+
inh
*
input_width
+
inc
*
input_height
*
input_width
+
inh
*
input_width
+
inw
;
inw
;
sum
+=
input
[
input_offset
]
*
*
filter_ptr
;
sum
+=
input
[
input_offset
]
*
*
filter_ptr
;
}
}
++
filter_ptr
;
++
filter_ptr
;
...
@@ -101,12 +99,12 @@ struct Conv2dFunctor {
...
@@ -101,12 +99,12 @@ struct Conv2dFunctor {
}
}
}
}
const
int
*
strides_
;
// [stride_h, stride_w]
const
int
*
strides_
;
// [stride_h, stride_w]
std
::
vector
<
int
>
paddings_
;
// [padding_h, padding_w]
std
::
vector
<
int
>
paddings_
;
// [padding_h, padding_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
};
};
template
<
>
template
<
>
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
...
...
mace/kernels/conv_pool_2d_util.cc
浏览文件 @
ebff986d
...
@@ -72,16 +72,15 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
...
@@ -72,16 +72,15 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
}
}
void
CalPaddingSize
(
const
index_t
*
input_shape
,
// NCHW
void
CalPaddingSize
(
const
index_t
*
input_shape
,
// NCHW
const
index_t
*
filter_shape
,
// OIHW
const
index_t
*
filter_shape
,
// OIHW
const
int
*
dilations
,
const
int
*
dilations
,
const
int
*
strides
,
const
int
*
strides
,
Padding
padding
,
Padding
padding
,
int
*
padding_size
)
{
int
*
padding_size
)
{
MACE_CHECK
(
dilations
[
0
]
>
0
&&
dilations
[
1
]
>
0
,
MACE_CHECK
(
dilations
[
0
]
>
0
&&
dilations
[
1
]
>
0
,
"Invalid dilations, must >= 1"
);
"Invalid dilations, must >= 1"
);
MACE_CHECK
((
dilations
[
0
]
==
1
||
strides
[
0
]
==
1
)
&&
MACE_CHECK
((
dilations
[
0
]
==
1
||
strides
[
0
]
==
1
)
&&
(
dilations
[
1
]
==
1
||
strides
[
1
]
==
1
),
(
dilations
[
1
]
==
1
||
strides
[
1
]
==
1
),
"If dilations > 1, strides should be 1"
);
"If dilations > 1, strides should be 1"
);
MACE_CHECK_NOTNULL
(
padding_size
);
MACE_CHECK_NOTNULL
(
padding_size
);
...
...
mace/kernels/conv_pool_2d_util.h
浏览文件 @
ebff986d
...
@@ -26,11 +26,11 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
...
@@ -26,11 +26,11 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
int
*
padding_size
);
int
*
padding_size
);
void
CalPaddingSize
(
const
index_t
*
input_shape
,
// NCHW
void
CalPaddingSize
(
const
index_t
*
input_shape
,
// NCHW
const
index_t
*
filter_shape
,
// OIHW
const
index_t
*
filter_shape
,
// OIHW
const
int
*
dilations
,
const
int
*
dilations
,
const
int
*
strides
,
const
int
*
strides
,
Padding
padding
,
Padding
padding
,
int
*
padding_size
);
int
*
padding_size
);
void
ConstructInputWithPadding
(
const
float
*
input
,
void
ConstructInputWithPadding
(
const
float
*
input
,
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
...
...
mace/kernels/depthwise_conv2d.h
浏览文件 @
ebff986d
...
@@ -5,29 +5,27 @@
...
@@ -5,29 +5,27 @@
#ifndef MACE_KERNELS_DEPTHWISE_CONV_H_
#ifndef MACE_KERNELS_DEPTHWISE_CONV_H_
#define MACE_KERNELS_DEPTHWISE_CONV_H_
#define MACE_KERNELS_DEPTHWISE_CONV_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
#include "mace/core/common.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
DepthwiseConv2dFunctor
{
struct
DepthwiseConv2dFunctor
{
DepthwiseConv2dFunctor
()
{}
DepthwiseConv2dFunctor
()
{}
DepthwiseConv2dFunctor
(
const
int
*
strides
,
DepthwiseConv2dFunctor
(
const
int
*
strides
,
const
std
::
vector
<
int
>
&
paddings
,
const
std
::
vector
<
int
>
&
paddings
,
const
int
*
dilations
)
:
const
int
*
dilations
)
strides_
(
strides
),
:
strides_
(
strides
),
paddings_
(
paddings
),
dilations_
(
dilations
)
{}
paddings_
(
paddings
),
dilations_
(
dilations
)
{}
void
operator
()(
const
T
*
input
,
// NCHW
void
operator
()(
const
T
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
T
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
T
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
filter_shape
,
const
index_t
*
filter_shape
,
const
T
*
bias
,
// c_out
const
T
*
bias
,
// c_out
T
*
output
,
// NCHW
T
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
MACE_CHECK_NOTNULL
(
output
);
MACE_CHECK_NOTNULL
(
output
);
...
@@ -68,7 +66,7 @@ struct DepthwiseConv2dFunctor {
...
@@ -68,7 +66,7 @@ struct DepthwiseConv2dFunctor {
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
index_t
offset
=
n
*
channels
*
height
*
width
+
index_t
offset
=
n
*
channels
*
height
*
width
+
c
*
height
*
width
+
h
*
width
+
w
;
c
*
height
*
width
+
h
*
width
+
w
;
output
[
offset
]
=
bias_channel
;
output
[
offset
]
=
bias_channel
;
T
sum
=
0
;
T
sum
=
0
;
const
T
*
filter_ptr
=
filter
+
c
*
kernel_size
;
const
T
*
filter_ptr
=
filter
+
c
*
kernel_size
;
...
@@ -79,16 +77,15 @@ struct DepthwiseConv2dFunctor {
...
@@ -79,16 +77,15 @@ struct DepthwiseConv2dFunctor {
if
(
inh
<
0
||
inh
>=
input_height
||
inw
<
0
||
if
(
inh
<
0
||
inh
>=
input_height
||
inw
<
0
||
inw
>=
input_width
)
{
inw
>=
input_width
)
{
MACE_CHECK
(
inh
>=
padded_h_start
&&
inh
<
padded_h_stop
&&
MACE_CHECK
(
inh
>=
padded_h_start
&&
inh
<
padded_h_stop
&&
inw
>=
padded_w_start
&&
inw
<
padded_w_stop
,
inw
>=
padded_w_start
&&
inw
<
padded_w_stop
,
"Out of range read from input: "
,
inh
,
", "
,
"Out of range read from input: "
,
inh
,
", "
,
inw
);
inw
);
// else padding with 0:
// else padding with 0:
// sum += 0;
// sum += 0;
}
else
{
}
else
{
index_t
input_offset
=
index_t
input_offset
=
n
*
input_channels
*
input_height
*
input_width
+
n
*
input_channels
*
input_height
*
input_width
+
(
c
/
multiplier
)
*
input_height
*
input_width
+
inh
*
input_width
+
(
c
/
multiplier
)
*
input_height
*
input_width
+
inw
;
inh
*
input_width
+
inw
;
sum
+=
input
[
input_offset
]
*
*
filter_ptr
;
sum
+=
input
[
input_offset
]
*
*
filter_ptr
;
}
}
++
filter_ptr
;
++
filter_ptr
;
...
@@ -101,20 +98,21 @@ struct DepthwiseConv2dFunctor {
...
@@ -101,20 +98,21 @@ struct DepthwiseConv2dFunctor {
}
}
}
}
const
int
*
strides_
;
// [stride_h, stride_w]
const
int
*
strides_
;
// [stride_h, stride_w]
std
::
vector
<
int
>
paddings_
;
// [padding_h, padding_w]
std
::
vector
<
int
>
paddings_
;
// [padding_h, padding_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
};
};
template
<
>
template
<
>
void
DepthwiseConv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
void
DepthwiseConv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
index_t
*
input_shape
,
const
float
*
input
,
const
float
*
filter
,
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
float
*
filter
,
const
float
*
bias
,
const
index_t
*
filter_shape
,
float
*
output
,
const
float
*
bias
,
const
index_t
*
output_shape
);
float
*
output
,
}
// namespace kernels
const
index_t
*
output_shape
);
}
// namespace mace
}
// namespace kernels
}
// namespace mace
#endif // MACE_KERNELS_DEPTHWISE_CONV_H_
#endif // MACE_KERNELS_DEPTHWISE_CONV_H_
mace/kernels/global_avg_pooling.h
浏览文件 @
ebff986d
...
@@ -35,9 +35,7 @@ struct GlobalAvgPoolingFunctor {
...
@@ -35,9 +35,7 @@ struct GlobalAvgPoolingFunctor {
template
<
>
template
<
>
void
GlobalAvgPoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
GlobalAvgPoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
float
*
output
);
const
index_t
*
input_shape
,
float
*
output
);
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
...
...
mace/kernels/neon/avg_pooling_neon_2x2.cc
浏览文件 @
ebff986d
...
@@ -45,7 +45,7 @@ void PoolingAvgNeonK2x2S2x2(const float *input,
...
@@ -45,7 +45,7 @@ void PoolingAvgNeonK2x2S2x2(const float *input,
int
w
=
0
;
int
w
=
0
;
int
num_vectors
=
0
;
int
num_vectors
=
0
;
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r1
=
r0
+
in_width
;
r1
=
r0
+
in_width
;
if
(
padding_left
>
0
)
{
if
(
padding_left
>
0
)
{
...
...
mace/kernels/neon/avg_pooling_neon_3x3.cc
浏览文件 @
ebff986d
...
@@ -33,7 +33,7 @@ void PoolingAvgNeonK3x3S2x2(const float *input,
...
@@ -33,7 +33,7 @@ void PoolingAvgNeonK3x3S2x2(const float *input,
int
out_image_size
=
out_height
*
out_width
;
int
out_image_size
=
out_height
*
out_width
;
index_t
input_offset
=
0
;
index_t
input_offset
=
0
;
index_t
output_offset
=
0
;
index_t
output_offset
=
0
;
float
avg_factors
[
4
]
=
{
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
};
float
avg_factors
[
4
]
=
{
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
};
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
...
@@ -45,7 +45,7 @@ void PoolingAvgNeonK3x3S2x2(const float *input,
...
@@ -45,7 +45,7 @@ void PoolingAvgNeonK3x3S2x2(const float *input,
int
num_vectors
=
0
;
int
num_vectors
=
0
;
const
float
*
r0
,
*
r1
,
*
r2
;
const
float
*
r0
,
*
r1
,
*
r2
;
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r1
=
r0
+
in_width
;
r1
=
r0
+
in_width
;
r2
=
r1
+
in_width
;
r2
=
r1
+
in_width
;
...
@@ -147,7 +147,7 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input,
...
@@ -147,7 +147,7 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input,
int
out_image_size
=
out_height
*
out_width
;
int
out_image_size
=
out_height
*
out_width
;
index_t
input_offset
=
0
;
index_t
input_offset
=
0
;
index_t
output_offset
=
0
;
index_t
output_offset
=
0
;
float
avg_factors
[
4
]
=
{
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
};
float
avg_factors
[
4
]
=
{
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
,
1.0
/
9.0
};
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
...
@@ -200,8 +200,9 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input,
...
@@ -200,8 +200,9 @@ void PoolingAvgNeonK3x3S2x2Padded(const float *input,
}
}
for
(;
remain
>
0
;
remain
--
)
{
for
(;
remain
>
0
;
remain
--
)
{
*
outptr
=
(
r0
[
0
]
+
r0
[
1
]
+
r0
[
2
]
+
r1
[
0
]
+
r1
[
1
]
+
r1
[
2
]
+
*
outptr
=
(
r0
[
0
]
+
r0
[
1
]
+
r0
[
2
]
+
r1
[
0
]
+
r1
[
1
]
+
r1
[
2
]
+
r2
[
0
]
+
r2
[
0
]
+
r2
[
1
]
+
r2
[
2
])
/
9.0
;
r2
[
1
]
+
r2
[
2
])
/
9.0
;
r0
+=
2
;
r0
+=
2
;
r1
+=
2
;
r1
+=
2
;
...
...
mace/kernels/neon/batch_norm_neon.cc
浏览文件 @
ebff986d
...
@@ -10,16 +10,16 @@ namespace kernels {
...
@@ -10,16 +10,16 @@ namespace kernels {
template
<
>
template
<
>
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
float
*
scale
,
const
float
*
scale
,
const
float
*
offset
,
const
float
*
offset
,
const
float
*
mean
,
const
float
*
mean
,
const
float
*
var
,
const
float
*
var
,
const
float
variance_epsilon
,
const
float
variance_epsilon
,
const
index_t
n
,
const
index_t
n
,
const
index_t
channel
,
const
index_t
channel
,
const
index_t
sample_size
,
const
index_t
sample_size
,
float
*
output
)
{
float
*
output
)
{
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
...
@@ -40,8 +40,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
...
@@ -40,8 +40,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
float32x4_t
new_scale_f
=
vdupq_n_f32
(
new_scale
);
float32x4_t
new_scale_f
=
vdupq_n_f32
(
new_scale
);
float32x4_t
new_offset_f
=
vdupq_n_f32
(
new_offset
);
float32x4_t
new_offset_f
=
vdupq_n_f32
(
new_offset
);
for
(
index_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
n
;
++
i
)
{
const
float
*
input_sample_ptr
=
input
+
pos
;
const
float
*
input_sample_ptr
=
input
+
pos
;
float
*
output_sample_ptr
=
output
+
pos
;
float
*
output_sample_ptr
=
output
+
pos
;
for
(
index_t
j
=
0
;
j
<
count
;
++
j
)
{
for
(
index_t
j
=
0
;
j
<
count
;
++
j
)
{
float32x4_t
input_f
=
vld1q_f32
(
input_sample_ptr
);
float32x4_t
input_f
=
vld1q_f32
(
input_sample_ptr
);
...
...
mace/kernels/neon/conv_2d_neon.cc
浏览文件 @
ebff986d
...
@@ -41,20 +41,17 @@ extern void Conv2dNeonK5x5S1(const float *input,
...
@@ -41,20 +41,17 @@ extern void Conv2dNeonK5x5S1(const float *input,
const
index_t
*
output_shape
);
const
index_t
*
output_shape
);
template
<
>
template
<
>
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
index_t
*
input_shape
,
const
float
*
input
,
const
float
*
filter
,
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
float
*
filter
,
const
float
*
bias
,
const
index_t
*
filter_shape
,
float
*
output
,
const
float
*
bias
,
const
index_t
*
output_shape
)
{
float
*
output
,
const
index_t
*
output_shape
)
{
typedef
void
(
*
Conv2dNeonFunction
)(
typedef
void
(
*
Conv2dNeonFunction
)(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
const
float
*
filter
,
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
float
*
output
,
const
float
*
filter
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
float
*
output
,
const
index_t
*
output_shape
);
const
index_t
*
output_shape
);
// Selection matrix: kernel_size x stride_size
// Selection matrix: kernel_size x stride_size
static
const
Conv2dNeonFunction
selector
[
5
][
2
]
=
{
static
const
Conv2dNeonFunction
selector
[
5
][
2
]
=
{
...
@@ -81,12 +78,14 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
...
@@ -81,12 +78,14 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
// Keep this alive during kernel execution
// Keep this alive during kernel execution
Tensor
padded_input
;
Tensor
padded_input
;
if
(
paddings_
[
0
]
>
0
||
paddings_
[
1
]
>
0
)
{
if
(
paddings_
[
0
]
>
0
||
paddings_
[
1
]
>
0
)
{
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
.
data
(),
&
padded_input
);
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
.
data
(),
&
padded_input
);
input
=
padded_input
.
data
<
float
>
();
input
=
padded_input
.
data
<
float
>
();
input_shape
=
padded_input
.
shape
().
data
();
input_shape
=
padded_input
.
shape
().
data
();
}
}
auto
conv2d_neon_func
=
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
];
auto
conv2d_neon_func
=
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
];
conv2d_neon_func
(
input
,
input_shape
,
filter
,
nullptr
,
bias
,
output
,
output_shape
);
conv2d_neon_func
(
input
,
input_shape
,
filter
,
nullptr
,
bias
,
output
,
output_shape
);
}
}
}
// namespace kernels
}
// namespace kernels
...
...
mace/kernels/neon/conv_2d_neon_1x1.cc
浏览文件 @
ebff986d
...
@@ -10,9 +10,8 @@ namespace mace {
...
@@ -10,9 +10,8 @@ namespace mace {
namespace
kernels
{
namespace
kernels
{
static
constexpr
index_t
kInputChannelBlockSize
=
2
;
static
constexpr
index_t
kInputChannelBlockSize
=
2
;
static
constexpr
index_t
kOutputChannelBlockSize
=
4
;
static
constexpr
index_t
kOutputChannelBlockSize
=
4
;
static
__attribute__
((
__aligned__
(
64
)))
int32_t
mask_array
[
8
]
=
{
static
__attribute__
((
__aligned__
(
64
)))
0
,
0
,
0
,
0
,
-
1
,
-
1
,
-
1
,
-
1
int32_t
mask_array
[
8
]
=
{
0
,
0
,
0
,
0
,
-
1
,
-
1
,
-
1
,
-
1
};
};
static
inline
void
NeonConv2x4Kernel
(
index_t
input_channels
,
static
inline
void
NeonConv2x4Kernel
(
index_t
input_channels
,
index_t
pixel_size
,
index_t
pixel_size
,
...
@@ -77,15 +76,15 @@ static inline void NeonConv2x4Kernel(index_t input_channels,
...
@@ -77,15 +76,15 @@ static inline void NeonConv2x4Kernel(index_t input_channels,
output3
=
output3
+
pixel_size
-
4
;
output3
=
output3
+
pixel_size
-
4
;
float32x4_t
voutput3
=
vld1q_f32
(
output3
);
float32x4_t
voutput3
=
vld1q_f32
(
output3
);
const
float32x4_t
vinput0
=
vreinterpretq_f32_s32
(
const
float32x4_t
vinput0
=
vreinterpretq_f32_s32
(
vandq_s32
(
v
andq_s32
(
v
mask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input0
[
pixel_size
-
4
]))));
vmask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input0
[
pixel_size
-
4
]))));
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput0
,
vfilter0x
,
0
);
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput0
,
vfilter0x
,
0
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput0
,
vfilter1x
,
0
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput0
,
vfilter1x
,
0
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput0
,
vfilter2x
,
0
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput0
,
vfilter2x
,
0
);
voutput3
=
vfmaq_lane_f32
(
voutput3
,
vinput0
,
vfilter3x
,
0
);
voutput3
=
vfmaq_lane_f32
(
voutput3
,
vinput0
,
vfilter3x
,
0
);
const
float32x4_t
vinput1
=
vreinterpretq_f32_s32
(
const
float32x4_t
vinput1
=
vreinterpretq_f32_s32
(
vandq_s32
(
v
andq_s32
(
v
mask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input1
[
pixel_size
-
4
]))));
vmask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input1
[
pixel_size
-
4
]))));
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput1
,
vfilter0x
,
1
);
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput1
,
vfilter0x
,
1
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput1
,
vfilter1x
,
1
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput1
,
vfilter1x
,
1
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput1
,
vfilter2x
,
1
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput1
,
vfilter2x
,
1
);
...
@@ -98,13 +97,14 @@ static inline void NeonConv2x4Kernel(index_t input_channels,
...
@@ -98,13 +97,14 @@ static inline void NeonConv2x4Kernel(index_t input_channels,
}
}
}
}
static
inline
void
NeonConv2x4SubBlockKernel
(
index_t
input_channels_subblock_size
,
static
inline
void
NeonConv2x4SubBlockKernel
(
index_t
output_channels_subblock_size
,
index_t
input_channels_subblock_size
,
index_t
input_channels
,
index_t
output_channels_subblock_size
,
index_t
pixel_size
,
index_t
input_channels
,
const
float
*
input
,
index_t
pixel_size
,
const
float
*
filter
,
const
float
*
input
,
float
*
output
)
{
const
float
*
filter
,
float
*
output
)
{
const
float
*
input0
=
input
;
const
float
*
input0
=
input
;
const
float
*
input1
=
input
+
pixel_size
;
const
float
*
input1
=
input
+
pixel_size
;
...
@@ -204,16 +204,16 @@ static inline void NeonConv2x4SubBlockKernel(index_t input_channels_subblock_siz
...
@@ -204,16 +204,16 @@ static inline void NeonConv2x4SubBlockKernel(index_t input_channels_subblock_siz
}
}
}
}
const
float32x4_t
vinput0
=
vreinterpretq_f32_s32
(
const
float32x4_t
vinput0
=
vreinterpretq_f32_s32
(
vandq_s32
(
v
andq_s32
(
v
mask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input0
[
pixel_size
-
4
]))));
vmask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input0
[
pixel_size
-
4
]))));
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput0
,
vfilter0x
,
0
);
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput0
,
vfilter0x
,
0
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput0
,
vfilter1x
,
0
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput0
,
vfilter1x
,
0
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput0
,
vfilter2x
,
0
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput0
,
vfilter2x
,
0
);
voutput3
=
vfmaq_lane_f32
(
voutput3
,
vinput0
,
vfilter3x
,
0
);
voutput3
=
vfmaq_lane_f32
(
voutput3
,
vinput0
,
vfilter3x
,
0
);
if
(
input_channels_subblock_size
>
1
)
{
if
(
input_channels_subblock_size
>
1
)
{
const
float32x4_t
vinput1
=
vreinterpretq_f32_s32
(
const
float32x4_t
vinput1
=
vreinterpretq_f32_s32
(
vandq_s32
(
v
andq_s32
(
v
mask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input1
[
pixel_size
-
4
]))));
vmask
,
vreinterpretq_s32_f32
(
vld1q_f32
(
&
input1
[
pixel_size
-
4
]))));
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput1
,
vfilter0x
,
1
);
voutput0
=
vfmaq_lane_f32
(
voutput0
,
vinput1
,
vfilter0x
,
1
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput1
,
vfilter1x
,
1
);
voutput1
=
vfmaq_lane_f32
(
voutput1
,
vinput1
,
vfilter1x
,
1
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput1
,
vfilter2x
,
1
);
voutput2
=
vfmaq_lane_f32
(
voutput2
,
vinput1
,
vfilter2x
,
1
);
...
@@ -237,8 +237,8 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
...
@@ -237,8 +237,8 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, filter_h, filter_w
const
float
*
filter
,
// c_out, c_in, filter_h, filter_w
const
index_t
*
filter_shape
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
// c_out
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
const
index_t
batch
=
output_shape
[
0
];
const
index_t
batch
=
output_shape
[
0
];
const
index_t
channels
=
output_shape
[
1
];
const
index_t
channels
=
output_shape
[
1
];
...
@@ -251,7 +251,7 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
...
@@ -251,7 +251,7 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
const
index_t
input_width
=
input_shape
[
3
];
const
index_t
input_width
=
input_shape
[
3
];
MACE_CHECK
(
input_batch
==
batch
&&
input_height
==
height
&&
MACE_CHECK
(
input_batch
==
batch
&&
input_height
==
height
&&
input_width
==
width
);
input_width
==
width
);
const
index_t
total_pixels
=
height
*
width
;
const
index_t
total_pixels
=
height
*
width
;
const
index_t
round_up_channels
=
RoundUp
(
channels
,
kOutputChannelBlockSize
);
const
index_t
round_up_channels
=
RoundUp
(
channels
,
kOutputChannelBlockSize
);
...
@@ -259,22 +259,27 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
...
@@ -259,22 +259,27 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
int
i
=
0
;
i
<
channels
;
++
i
)
{
for
(
int
i
=
0
;
i
<
channels
;
++
i
)
{
float
*
output_ptr_base
=
output
+
n
*
channels
*
total_pixels
+
i
*
total_pixels
;
float
*
output_ptr_base
=
std
::
fill
(
output_ptr_base
,
output_ptr_base
+
total_pixels
,
bias
?
bias
[
i
]
:
0
);
output
+
n
*
channels
*
total_pixels
+
i
*
total_pixels
;
std
::
fill
(
output_ptr_base
,
output_ptr_base
+
total_pixels
,
bias
?
bias
[
i
]
:
0
);
}
}
}
}
// benchmark omp collapsed(2)
// benchmark omp collapsed(2)
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
c
=
0
;
c
<
round_up_channels
;
c
+=
kOutputChannelBlockSize
)
{
for
(
index_t
c
=
0
;
c
<
round_up_channels
;
c
+=
kOutputChannelBlockSize
)
{
const
float
*
input_ptr
=
input
+
n
*
input_channels
*
total_pixels
;
const
float
*
input_ptr
=
input
+
n
*
input_channels
*
total_pixels
;
const
float
*
filter_ptr
=
filter
+
c
*
input_channels
;
const
float
*
filter_ptr
=
filter
+
c
*
input_channels
;
float
*
output_ptr
=
output
+
n
*
channels
*
total_pixels
+
c
*
total_pixels
;
float
*
output_ptr
=
const
index_t
output_channel_block_size
=
std
::
min
(
channels
-
c
,
kOutputChannelBlockSize
);
output
+
n
*
channels
*
total_pixels
+
c
*
total_pixels
;
const
index_t
output_channel_block_size
=
std
::
min
(
channels
-
c
,
kOutputChannelBlockSize
);
index_t
remain_input_channels
=
input_channels
;
index_t
remain_input_channels
=
input_channels
;
if
(
c
+
kOutputChannelBlockSize
<=
channels
)
{
if
(
c
+
kOutputChannelBlockSize
<=
channels
)
{
while
(
remain_input_channels
>=
kInputChannelBlockSize
)
{
while
(
remain_input_channels
>=
kInputChannelBlockSize
)
{
NeonConv2x4Kernel
(
input_channels
,
total_pixels
,
input_ptr
,
filter_ptr
,
output_ptr
);
NeonConv2x4Kernel
(
input_channels
,
total_pixels
,
input_ptr
,
filter_ptr
,
output_ptr
);
input_ptr
+=
kInputChannelBlockSize
*
total_pixels
;
input_ptr
+=
kInputChannelBlockSize
*
total_pixels
;
filter_ptr
+=
kInputChannelBlockSize
;
filter_ptr
+=
kInputChannelBlockSize
;
...
@@ -282,25 +287,27 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
...
@@ -282,25 +287,27 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
}
}
}
}
while
(
remain_input_channels
!=
0
)
{
while
(
remain_input_channels
!=
0
)
{
const
index_t
input_channel_block_size
=
std
::
min
(
remain_input_channels
,
kInputChannelBlockSize
);
const
index_t
input_channel_block_size
=
NeonConv2x4SubBlockKernel
(
input_channel_block_size
,
output_channel_block_size
,
std
::
min
(
remain_input_channels
,
kInputChannelBlockSize
);
input_channels
,
total_pixels
,
input_ptr
,
filter_ptr
,
output_ptr
);
NeonConv2x4SubBlockKernel
(
input_channel_block_size
,
output_channel_block_size
,
input_channels
,
total_pixels
,
input_ptr
,
filter_ptr
,
output_ptr
);
input_ptr
+=
kInputChannelBlockSize
*
total_pixels
;
input_ptr
+=
kInputChannelBlockSize
*
total_pixels
;
filter_ptr
+=
kInputChannelBlockSize
;
filter_ptr
+=
kInputChannelBlockSize
;
remain_input_channels
-=
input_channel_block_size
;
remain_input_channels
-=
input_channel_block_size
;
}
}
}
}
}
}
};
};
void
Conv2dNeonPixelK1x1S1
(
const
float
*
input
,
// NCHW
void
Conv2dNeonPixelK1x1S1
(
const
index_t
*
input_shape
,
const
float
*
input
,
// NCHW
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
const
index_t
*
filter_shape
,
float
*
output
,
// NCHW
const
float
*
bias
,
// c_out
const
index_t
*
output_shape
)
{
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
batch
=
output_shape
[
0
];
const
index_t
batch
=
output_shape
[
0
];
const
index_t
channels
=
output_shape
[
1
];
const
index_t
channels
=
output_shape
[
1
];
const
index_t
height
=
output_shape
[
2
];
const
index_t
height
=
output_shape
[
2
];
...
@@ -312,7 +319,7 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
...
@@ -312,7 +319,7 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
const
index_t
input_width
=
input_shape
[
3
];
const
index_t
input_width
=
input_shape
[
3
];
MACE_CHECK
(
input_batch
==
batch
&&
input_height
==
height
&&
MACE_CHECK
(
input_batch
==
batch
&&
input_height
==
height
&&
input_width
==
width
);
input_width
==
width
);
const
index_t
total_pixels
=
height
*
width
;
const
index_t
total_pixels
=
height
*
width
;
// Process 4 * 2 = 8 pixels for each innermost loop
// Process 4 * 2 = 8 pixels for each innermost loop
...
@@ -320,7 +327,7 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
...
@@ -320,7 +327,7 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
const
index_t
total_loops
=
total_pixels
>>
3
;
const
index_t
total_loops
=
total_pixels
>>
3
;
const
index_t
loop_remaining
=
total_pixels
&
7
;
const
index_t
loop_remaining
=
total_pixels
&
7
;
// benchmark omp collapsed(2)
// benchmark omp collapsed(2)
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
...
@@ -341,8 +348,8 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
...
@@ -341,8 +348,8 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
float
*
output_ptr
=
channel_output_start
;
float
*
output_ptr
=
channel_output_start
;
// The begining of each input feature map channel
// The begining of each input feature map channel
MACE_ASSERT
(
input_ptr
==
MACE_ASSERT
(
input_ptr
==
input
+
n
*
input_channels
*
input_height
*
input_width
+
input
+
n
*
input_channels
*
input_height
*
input_width
+
inc
*
input_height
*
input_width
);
inc
*
input_height
*
input_width
);
const
float
*
input_ptr1
=
input_ptr
+
total_pixels
;
const
float
*
input_ptr1
=
input_ptr
+
total_pixels
;
const
float
*
input_ptr2
=
input_ptr1
+
total_pixels
;
const
float
*
input_ptr2
=
input_ptr1
+
total_pixels
;
...
@@ -426,8 +433,8 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
...
@@ -426,8 +433,8 @@ void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
for
(;
inc
<
input_channels
;
++
inc
)
{
for
(;
inc
<
input_channels
;
++
inc
)
{
float
*
output_ptr
=
channel_output_start
;
float
*
output_ptr
=
channel_output_start
;
MACE_ASSERT
(
input_ptr
==
MACE_ASSERT
(
input_ptr
==
input
+
n
*
input_channels
*
input_height
*
input_width
+
input
+
n
*
input_channels
*
input_height
*
input_width
+
inc
*
input_height
*
input_width
);
inc
*
input_height
*
input_width
);
MACE_ASSERT
(
filter_ptr
==
filter
+
c
*
input_channels
+
inc
);
MACE_ASSERT
(
filter_ptr
==
filter
+
c
*
input_channels
+
inc
);
const
float
k0
=
filter_ptr
[
0
];
const
float
k0
=
filter_ptr
[
0
];
...
...
mace/kernels/neon/conv_2d_neon_3x3.cc
浏览文件 @
ebff986d
...
@@ -11,44 +11,49 @@ namespace kernels {
...
@@ -11,44 +11,49 @@ namespace kernels {
static
const
int
kRegisterSize
=
4
;
static
const
int
kRegisterSize
=
4
;
static
const
int
kFilterSize
=
9
;
static
const
int
kFilterSize
=
9
;
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
filter_shape
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
// c_out
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
int
height_count
=
(
output_shape
[
2
]
>>
1
)
<<
1
;
int
height_count
=
(
output_shape
[
2
]
>>
1
)
<<
1
;
int
output_batch
=
output_shape
[
0
];
int
output_batch
=
output_shape
[
0
];
int
output_channels
=
output_shape
[
1
];
int
output_channels
=
output_shape
[
1
];
int
output_height
=
output_shape
[
2
];
int
output_height
=
output_shape
[
2
];
int
output_width
=
output_shape
[
3
];
int
output_width
=
output_shape
[
3
];
int
input_batch
=
input_shape
[
0
];
int
input_batch
=
input_shape
[
0
];
int
input_channels
=
input_shape
[
1
];
int
input_channels
=
input_shape
[
1
];
int
input_height
=
input_shape
[
2
];
int
input_height
=
input_shape
[
2
];
int
input_width
=
input_shape
[
3
];
int
input_width
=
input_shape
[
3
];
int
multiplier
=
filter_shape
==
nullptr
?
0
:
(
filter_shape
[
0
]
/
input_channels
);
int
multiplier
=
int
filter_in_channels
=
filter_shape
==
nullptr
?
input_channels
:
filter_shape
[
1
];
filter_shape
==
nullptr
?
0
:
(
filter_shape
[
0
]
/
input_channels
);
int
filter_in_channels
=
filter_shape
==
nullptr
?
input_channels
:
filter_shape
[
1
];
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
b
=
0
;
b
<
output_batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
output_batch
;
++
b
)
{
for
(
int
oc
=
0
;
oc
<
output_channels
;
++
oc
)
{
for
(
int
oc
=
0
;
oc
<
output_channels
;
++
oc
)
{
float
*
output_ptr_base
=
output
+
b
*
output_channels
*
output_height
*
output_width
;
float
*
output_ptr_base
=
output
+
b
*
output_channels
*
output_height
*
output_width
;
const
float
*
filter_ptr
=
filter
+
oc
*
filter_in_channels
*
kFilterSize
;
const
float
*
filter_ptr
=
filter
+
oc
*
filter_in_channels
*
kFilterSize
;
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
if
(
filter_shape
!=
nullptr
)
{
if
(
filter_shape
!=
nullptr
)
{
input_ptr
+=
(
oc
/
multiplier
)
*
input_height
*
input_width
;
input_ptr
+=
(
oc
/
multiplier
)
*
input_height
*
input_width
;
}
}
float
*
output_ptr
=
output_ptr_base
+
oc
*
output_height
*
output_width
;
float
*
output_ptr
=
output_ptr_base
+
oc
*
output_height
*
output_width
;
std
::
fill
(
output_ptr
,
output_ptr
+
output_height
*
output_width
,
bias
?
bias
[
oc
]
:
0
);
std
::
fill
(
output_ptr
,
output_ptr
+
output_height
*
output_width
,
bias
?
bias
[
oc
]
:
0
);
for
(
int
ic
=
0
;
ic
<
filter_in_channels
;
++
ic
)
{
for
(
int
ic
=
0
;
ic
<
filter_in_channels
;
++
ic
)
{
float32x4_t
n_filter_v
[
3
]
=
{
vld1q_f32
(
filter_ptr
),
vld1q_f32
(
filter_ptr
+
3
),
vld1q_f32
(
filter_ptr
+
6
)};
float32x4_t
n_filter_v
[
3
]
=
{
vld1q_f32
(
filter_ptr
),
vld1q_f32
(
filter_ptr
+
3
),
vld1q_f32
(
filter_ptr
+
6
)};
const
float
*
row_ptr_v
[
kRegisterSize
]
=
{
const
float
*
row_ptr_v
[
kRegisterSize
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
3
*
input_width
input_ptr
+
3
*
input_width
};
};
float
*
output_ptr_v
[]
=
{
output_ptr
,
output_ptr
+
output_width
};
float
*
output_ptr_v
[]
=
{
output_ptr
,
output_ptr
+
output_width
};
...
@@ -69,8 +74,10 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
...
@@ -69,8 +74,10 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
float32x4_t
n_row1_former
=
vld1q_f32
(
row_ptr_v
[
1
]);
float32x4_t
n_row1_former
=
vld1q_f32
(
row_ptr_v
[
1
]);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
kRegisterSize
);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
kRegisterSize
);
float32x4_t
n_row1_ext0
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
1
);
float32x4_t
n_row1_ext0
=
float32x4_t
n_row1_ext1
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
2
);
vextq_f32
(
n_row1_former
,
n_row1_latter
,
1
);
float32x4_t
n_row1_ext1
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_former
,
n_filter_v
[
1
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_former
,
n_filter_v
[
1
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext0
,
n_filter_v
[
1
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext0
,
n_filter_v
[
1
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext1
,
n_filter_v
[
1
],
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext1
,
n_filter_v
[
1
],
2
);
...
@@ -115,11 +122,9 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
...
@@ -115,11 +122,9 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
}
}
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
n_row_v
[]
=
{
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
])};
vld1q_f32
(
row_ptr_v
[
2
])
};
float32x4_t
n_sum0
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
float32x4_t
n_sum0
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
2
],
n_filter_v
[
2
]);
...
@@ -185,8 +190,7 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
...
@@ -185,8 +190,7 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
n_row_v
[]
=
{
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
]),
vld1q_f32
(
row_ptr_v
[
2
]),
};
};
...
@@ -210,43 +214,49 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
...
@@ -210,43 +214,49 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
}
}
}
}
void
Conv2dNeonK3x3S2
(
const
float
*
input
,
// NCHW
void
Conv2dNeonK3x3S2
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
filter_shape
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
// c_out
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
int
tail_step
=
2
*
(
input_shape
[
3
]
-
output_shape
[
3
]);
int
tail_step
=
2
*
(
input_shape
[
3
]
-
output_shape
[
3
]);
int
output_batch
=
output_shape
[
0
];
int
output_batch
=
output_shape
[
0
];
int
output_channels
=
output_shape
[
1
];
int
output_channels
=
output_shape
[
1
];
int
output_height
=
output_shape
[
2
];
int
output_height
=
output_shape
[
2
];
int
output_width
=
output_shape
[
3
];
int
output_width
=
output_shape
[
3
];
int
input_batch
=
input_shape
[
0
];
int
input_batch
=
input_shape
[
0
];
int
input_channels
=
input_shape
[
1
];
int
input_channels
=
input_shape
[
1
];
int
input_height
=
input_shape
[
2
];
int
input_height
=
input_shape
[
2
];
int
input_width
=
input_shape
[
3
];
int
input_width
=
input_shape
[
3
];
int
multiplier
=
filter_shape
==
nullptr
?
0
:
(
filter_shape
[
0
]
/
input_channels
);
int
multiplier
=
int
filter_in_channels
=
filter_shape
==
nullptr
?
input_channels
:
filter_shape
[
1
];
filter_shape
==
nullptr
?
0
:
(
filter_shape
[
0
]
/
input_channels
);
int
filter_in_channels
=
filter_shape
==
nullptr
?
input_channels
:
filter_shape
[
1
];
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
b
=
0
;
b
<
output_batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
output_batch
;
++
b
)
{
for
(
int
oc
=
0
;
oc
<
output_channels
;
++
oc
)
{
for
(
int
oc
=
0
;
oc
<
output_channels
;
++
oc
)
{
float
*
output_ptr_base
=
output
+
b
*
output_channels
*
output_height
*
output_width
;
float
*
output_ptr_base
=
output
+
b
*
output_channels
*
output_height
*
output_width
;
const
float
*
filter_ptr
=
filter
+
oc
*
filter_in_channels
*
kFilterSize
;
const
float
*
filter_ptr
=
filter
+
oc
*
filter_in_channels
*
kFilterSize
;
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
if
(
filter_shape
!=
nullptr
)
{
if
(
filter_shape
!=
nullptr
)
{
input_ptr
+=
(
oc
/
multiplier
)
*
input_height
*
input_width
;
input_ptr
+=
(
oc
/
multiplier
)
*
input_height
*
input_width
;
}
}
float
*
output_ptr
=
output_ptr_base
+
oc
*
output_height
*
output_width
;
float
*
output_ptr
=
output_ptr_base
+
oc
*
output_height
*
output_width
;
std
::
fill
(
output_ptr
,
output_ptr
+
output_height
*
output_width
,
bias
?
bias
[
oc
]
:
0
);
std
::
fill
(
output_ptr
,
output_ptr
+
output_height
*
output_width
,
bias
?
bias
[
oc
]
:
0
);
for
(
int
ic
=
0
;
ic
<
filter_in_channels
;
++
ic
)
{
for
(
int
ic
=
0
;
ic
<
filter_in_channels
;
++
ic
)
{
float32x4_t
n_filter_v
[
3
]
=
{
vld1q_f32
(
filter_ptr
),
vld1q_f32
(
filter_ptr
+
3
),
vld1q_f32
(
filter_ptr
+
6
)};
float32x4_t
n_filter_v
[
3
]
=
{
vld1q_f32
(
filter_ptr
),
vld1q_f32
(
filter_ptr
+
3
),
vld1q_f32
(
filter_ptr
+
6
)};
const
float
*
row_ptr_v
[
3
]
=
{
const
float
*
row_ptr_v
[
3
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
input_ptr
+
2
*
input_width
};
};
float
*
output_ptr_inner
=
output_ptr
;
float
*
output_ptr_inner
=
output_ptr
;
...
@@ -259,24 +269,33 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
...
@@ -259,24 +269,33 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
float32x4x2_t
n_row_former
=
vld2q_f32
(
row_ptr_v
[
0
]);
float32x4x2_t
n_row_former
=
vld2q_f32
(
row_ptr_v
[
0
]);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
8
);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
8
);
float32x4_t
n_row_ext
=
vextq_f32
(
n_row_former
.
val
[
0
],
n_row_latter
,
1
);
float32x4_t
n_row_ext
=
vextq_f32
(
n_row_former
.
val
[
0
],
n_row_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
0
],
n_filter_v
[
0
],
0
);
n_sum
=
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
1
],
n_filter_v
[
0
],
1
);
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
0
],
n_filter_v
[
0
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
1
],
n_filter_v
[
0
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext
,
n_filter_v
[
0
],
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext
,
n_filter_v
[
0
],
2
);
float32x4x2_t
n_row1_former
=
vld2q_f32
(
row_ptr_v
[
1
]);
float32x4x2_t
n_row1_former
=
vld2q_f32
(
row_ptr_v
[
1
]);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
8
);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
8
);
float32x4_t
n_row1_ext
=
vextq_f32
(
n_row1_former
.
val
[
0
],
n_row1_latter
,
1
);
float32x4_t
n_row1_ext
=
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
0
],
n_filter_v
[
1
],
0
);
vextq_f32
(
n_row1_former
.
val
[
0
],
n_row1_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
1
],
n_filter_v
[
1
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
0
],
n_filter_v
[
1
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
1
],
n_filter_v
[
1
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_ext
,
n_filter_v
[
1
],
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_ext
,
n_filter_v
[
1
],
2
);
float32x4x2_t
n_row2_former
=
vld2q_f32
(
row_ptr_v
[
2
]);
float32x4x2_t
n_row2_former
=
vld2q_f32
(
row_ptr_v
[
2
]);
float32x4_t
n_row2_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
8
);
float32x4_t
n_row2_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
8
);
float32x4_t
n_row2_ext
=
vextq_f32
(
n_row2_former
.
val
[
0
],
n_row2_latter
,
1
);
float32x4_t
n_row2_ext
=
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
0
],
n_filter_v
[
2
],
0
);
vextq_f32
(
n_row2_former
.
val
[
0
],
n_row2_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
1
],
n_filter_v
[
2
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
0
],
n_filter_v
[
2
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
1
],
n_filter_v
[
2
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_ext
,
n_filter_v
[
2
],
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_ext
,
n_filter_v
[
2
],
2
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_inner
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_inner
);
...
@@ -288,11 +307,9 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
...
@@ -288,11 +307,9 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
}
}
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
n_row_v
[]
=
{
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
])};
vld1q_f32
(
row_ptr_v
[
2
])
};
float32x4_t
n_sum
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
float32x4_t
n_sum
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
2
],
n_filter_v
[
2
]);
...
@@ -315,5 +332,5 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
...
@@ -315,5 +332,5 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
}
}
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/neon/conv_2d_neon_5x5.cc
浏览文件 @
ebff986d
...
@@ -14,8 +14,8 @@ void Conv2dNeonK5x5S1(const float *input, // NCHW
...
@@ -14,8 +14,8 @@ void Conv2dNeonK5x5S1(const float *input, // NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
filter_shape
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
// c_out
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
const
index_t
batch
=
output_shape
[
0
];
const
index_t
batch
=
output_shape
[
0
];
const
index_t
channels
=
output_shape
[
1
];
const
index_t
channels
=
output_shape
[
1
];
...
@@ -41,7 +41,7 @@ void Conv2dNeonK5x5S1(const float *input, // NCHW
...
@@ -41,7 +41,7 @@ void Conv2dNeonK5x5S1(const float *input, // NCHW
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
float
*
output_ptr
=
output
+
n
*
output_total_pixels_per_batch
+
float
*
output_ptr
=
output
+
n
*
output_total_pixels_per_batch
+
c
*
output_total_pixels_per_channel
;
c
*
output_total_pixels_per_channel
;
const
float
*
input_ptr
=
input
+
n
*
input_total_pixels_per_batch
;
const
float
*
input_ptr
=
input
+
n
*
input_total_pixels_per_batch
;
// Fill with bias
// Fill with bias
...
...
mace/kernels/neon/depthwise_conv_neon.cc
浏览文件 @
ebff986d
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/depthwise_conv2d.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
...
@@ -24,21 +24,18 @@ extern void Conv2dNeonK3x3S2(const float *input,
...
@@ -24,21 +24,18 @@ extern void Conv2dNeonK3x3S2(const float *input,
float
*
output
,
float
*
output
,
const
index_t
*
output_shape
);
const
index_t
*
output_shape
);
template
<
>
template
<
>
void
DepthwiseConv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
// NCHW
void
DepthwiseConv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
index_t
*
input_shape
,
const
float
*
input
,
// NCHW
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
const
index_t
*
filter_shape
,
float
*
output
,
// NCHW
const
float
*
bias
,
// c_out
const
index_t
*
output_shape
)
{
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
typedef
void
(
*
Conv2dNeonFunction
)(
typedef
void
(
*
Conv2dNeonFunction
)(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
const
float
*
filter
,
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
float
*
output
,
const
float
*
filter
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
float
*
output
,
const
index_t
*
output_shape
);
const
index_t
*
output_shape
);
// Selection matrix: kernel_size x stride_size
// Selection matrix: kernel_size x stride_size
static
const
Conv2dNeonFunction
selector
[
5
][
2
]
=
{
static
const
Conv2dNeonFunction
selector
[
5
][
2
]
=
{
...
@@ -57,7 +54,8 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in
...
@@ -57,7 +54,8 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in
<<
"filter"
<<
kernel_h
<<
"x"
<<
kernel_w
<<
","
<<
"filter"
<<
kernel_h
<<
"x"
<<
kernel_w
<<
","
<<
" stride "
<<
strides_
[
0
]
<<
"x"
<<
strides_
[
1
]
<<
" stride "
<<
strides_
[
0
]
<<
"x"
<<
strides_
[
1
]
<<
" is not implemented yet, using slow version"
;
<<
" is not implemented yet, using slow version"
;
DepthwiseConv2dFunctor
<
DeviceType
::
CPU
,
float
>
(
strides_
,
paddings_
,
dilations_
)(
DepthwiseConv2dFunctor
<
DeviceType
::
CPU
,
float
>
(
strides_
,
paddings_
,
dilations_
)(
input
,
input_shape
,
filter
,
filter_shape
,
bias
,
output
,
output_shape
);
input
,
input_shape
,
filter
,
filter_shape
,
bias
,
output
,
output_shape
);
return
;
return
;
}
}
...
@@ -65,13 +63,15 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in
...
@@ -65,13 +63,15 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *in
// Keep this alive during kernel execution
// Keep this alive during kernel execution
Tensor
padded_input
;
Tensor
padded_input
;
if
(
paddings_
[
0
]
>
0
||
paddings_
[
1
]
>
0
)
{
if
(
paddings_
[
0
]
>
0
||
paddings_
[
1
]
>
0
)
{
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
.
data
(),
&
padded_input
);
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
.
data
(),
&
padded_input
);
input
=
padded_input
.
data
<
float
>
();
input
=
padded_input
.
data
<
float
>
();
input_shape
=
padded_input
.
shape
().
data
();
input_shape
=
padded_input
.
shape
().
data
();
}
}
auto
conv2d_neon_func
=
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
];
auto
conv2d_neon_func
=
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
];
conv2d_neon_func
(
input
,
input_shape
,
filter
,
filter_shape
,
bias
,
output
,
output_shape
);
conv2d_neon_func
(
input
,
input_shape
,
filter
,
filter_shape
,
bias
,
output
,
output_shape
);
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/kernels/neon/global_avg_pooling_neon.cc
浏览文件 @
ebff986d
...
@@ -8,11 +8,9 @@
...
@@ -8,11 +8,9 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
>
template
<
>
void
GlobalAvgPoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
GlobalAvgPoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
float
*
output
)
{
const
index_t
*
input_shape
,
float
*
output
)
{
index_t
batch
=
input_shape
[
0
];
index_t
batch
=
input_shape
[
0
];
index_t
channels
=
input_shape
[
1
];
index_t
channels
=
input_shape
[
1
];
index_t
height
=
input_shape
[
2
];
index_t
height
=
input_shape
[
2
];
...
...
mace/kernels/neon/pooling_neon.cc
浏览文件 @
ebff986d
...
@@ -55,7 +55,7 @@ extern void PoolingAvgNeonK3x3S2x2Padded(const float *input,
...
@@ -55,7 +55,7 @@ extern void PoolingAvgNeonK3x3S2x2Padded(const float *input,
const
index_t
*
out_shape
);
const
index_t
*
out_shape
);
#endif
#endif
template
<
>
template
<
>
void
PoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
PoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
...
@@ -71,14 +71,14 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
...
@@ -71,14 +71,14 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
if
(
kernels_
[
0
]
==
2
&&
kernels_
[
1
]
==
2
&&
strides_
[
0
]
==
2
&&
if
(
kernels_
[
0
]
==
2
&&
kernels_
[
1
]
==
2
&&
strides_
[
0
]
==
2
&&
strides_
[
1
]
==
2
)
{
strides_
[
1
]
==
2
)
{
// kernel_size: 2x2, strides: 2x2
// kernel_size: 2x2, strides: 2x2
if
(
pooling_type_
==
MAX
)
{
// MAX_POOL_2x2s2x2
if
(
pooling_type_
==
MAX
)
{
// MAX_POOL_2x2s2x2
#ifdef __COPY_MAKE_PADDING
#ifdef __COPY_MAKE_PADDING
PoolingMaxNeonK2x2S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
PoolingMaxNeonK2x2S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
#else
#else
PoolingMaxNeonK2x2S2x2
(
input
,
input_shape
,
output
,
output_shape
,
PoolingMaxNeonK2x2S2x2
(
input
,
input_shape
,
output
,
output_shape
,
paddings_
);
paddings_
);
#endif
#endif
}
else
{
// AVG_POOL_2x2s2x2
}
else
{
// AVG_POOL_2x2s2x2
#ifdef __COPY_MAKE_PADDING
#ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK2x2S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
PoolingAvgNeonK2x2S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
#else
#else
...
@@ -87,16 +87,16 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
...
@@ -87,16 +87,16 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
#endif
#endif
}
}
}
else
if
(
kernels_
[
0
]
==
3
&&
kernels_
[
1
]
==
3
&&
strides_
[
0
]
==
2
&&
}
else
if
(
kernels_
[
0
]
==
3
&&
kernels_
[
1
]
==
3
&&
strides_
[
0
]
==
2
&&
strides_
[
1
]
==
2
)
{
strides_
[
1
]
==
2
)
{
// kernel_size: 3x3, strides: 2x2
// kernel_size: 3x3, strides: 2x2
if
(
pooling_type_
==
MAX
)
{
// MAX_POOL_3x3s2x2
if
(
pooling_type_
==
MAX
)
{
// MAX_POOL_3x3s2x2
#ifdef __COPY_MAKE_PADDING
#ifdef __COPY_MAKE_PADDING
PoolingMaxNeonK3x3S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
PoolingMaxNeonK3x3S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
#else
#else
PoolingMaxNeonK3x3S2x2
(
input
,
input_shape
,
output
,
output_shape
,
PoolingMaxNeonK3x3S2x2
(
input
,
input_shape
,
output
,
output_shape
,
paddings_
);
paddings_
);
#endif
#endif
}
else
{
// AVG_POOL_3x3s2x2
}
else
{
// AVG_POOL_3x3s2x2
#ifdef __COPY_MAKE_PADDING
#ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK3x3S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
PoolingAvgNeonK3x3S2x2Padded
(
input
,
input_shape
,
output
,
output_shape
);
#else
#else
...
...
mace/ops/addn.h
浏览文件 @
ebff986d
...
@@ -13,18 +13,18 @@ namespace mace {
...
@@ -13,18 +13,18 @@ namespace mace {
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
AddNOp
:
public
Operator
<
D
,
T
>
{
class
AddNOp
:
public
Operator
<
D
,
T
>
{
public:
public:
AddNOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
AddNOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{}
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{}
bool
Run
()
override
{
bool
Run
()
override
{
Tensor
*
output_tensor
=
this
->
outputs_
[
0
];
Tensor
*
output_tensor
=
this
->
outputs_
[
0
];
output_tensor
->
ResizeLike
(
this
->
inputs_
[
0
]);
output_tensor
->
ResizeLike
(
this
->
inputs_
[
0
]);
T
*
output
=
output_tensor
->
mutable_data
<
T
>
();
T
*
output
=
output_tensor
->
mutable_data
<
T
>
();
index_t
size
=
this
->
inputs_
[
0
]
->
size
();
index_t
size
=
this
->
inputs_
[
0
]
->
size
();
int
n
=
this
->
inputs_
.
size
();
int
n
=
this
->
inputs_
.
size
();
vector
<
const
T
*>
inputs
(
n
);
vector
<
const
T
*>
inputs
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
const
Tensor
*
input_tensor
=
this
->
inputs_
[
i
];
const
Tensor
*
input_tensor
=
this
->
inputs_
[
i
];
inputs
[
i
]
=
input_tensor
->
data
<
T
>
();
inputs
[
i
]
=
input_tensor
->
data
<
T
>
();
}
}
...
...
mace/ops/addn_benchmark.cc
浏览文件 @
ebff986d
...
@@ -39,7 +39,7 @@ static void AddNBenchmark(int iters, int n, int size) {
...
@@ -39,7 +39,7 @@ static void AddNBenchmark(int iters, int n, int size) {
static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE(int iters) { \
static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \
const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot
*(sizeof(TYPE)));
\
mace::testing::BytesProcessed(tot
*(sizeof(TYPE)));
\
AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \
AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \
} \
} \
BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE)
BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE)
...
...
mace/ops/addn_test.cc
浏览文件 @
ebff986d
...
@@ -11,7 +11,7 @@ class AddnOpTest : public OpsTestBase {};
...
@@ -11,7 +11,7 @@ class AddnOpTest : public OpsTestBase {};
TEST_F
(
AddnOpTest
,
AddnOp
)
{
TEST_F
(
AddnOpTest
,
AddnOp
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"AddN"
,
"AddNTest"
)
OpDefBuilder
(
"AddN"
,
"AddNTest"
)
.
Input
(
"Input1"
)
.
Input
(
"Input1"
)
.
Input
(
"Input2"
)
.
Input
(
"Input2"
)
...
...
mace/ops/batch_norm.h
浏览文件 @
ebff986d
...
@@ -13,17 +13,16 @@ namespace mace {
...
@@ -13,17 +13,16 @@ namespace mace {
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
BatchNormOp
:
public
Operator
<
D
,
T
>
{
class
BatchNormOp
:
public
Operator
<
D
,
T
>
{
public:
public:
BatchNormOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
BatchNormOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
functor_
()
{}
functor_
()
{}
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
0
);
const
Tensor
*
input
=
this
->
Input
(
0
);
const
Tensor
*
scale
=
this
->
Input
(
1
);
const
Tensor
*
scale
=
this
->
Input
(
1
);
const
Tensor
*
offset
=
this
->
Input
(
2
);
const
Tensor
*
offset
=
this
->
Input
(
2
);
const
Tensor
*
mean
=
this
->
Input
(
3
);
const
Tensor
*
mean
=
this
->
Input
(
3
);
const
Tensor
*
var
=
this
->
Input
(
4
);
const
Tensor
*
var
=
this
->
Input
(
4
);
const
Tensor
*
epsilon
=
this
->
Input
(
5
);
const
Tensor
*
epsilon
=
this
->
Input
(
5
);
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional. "
,
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional. "
,
input
->
dim_size
());
input
->
dim_size
());
...
@@ -38,23 +37,23 @@ class BatchNormOp : public Operator<D, T> {
...
@@ -38,23 +37,23 @@ class BatchNormOp : public Operator<D, T> {
MACE_CHECK
(
epsilon
->
dim_size
()
==
0
,
"epsilon must be 0-dimensional. "
,
MACE_CHECK
(
epsilon
->
dim_size
()
==
0
,
"epsilon must be 0-dimensional. "
,
epsilon
->
dim_size
());
epsilon
->
dim_size
());
Tensor
*
output
=
this
->
Output
(
0
);
Tensor
*
output
=
this
->
Output
(
0
);
output
->
ResizeLike
(
input
);
output
->
ResizeLike
(
input
);
const
index_t
n
=
input
->
dim
(
0
);
const
index_t
n
=
input
->
dim
(
0
);
const
index_t
channel
=
input
->
dim
(
1
);
const
index_t
channel
=
input
->
dim
(
1
);
const
index_t
sample_size
=
input
->
dim
(
2
)
*
input
->
dim
(
3
);
const
index_t
sample_size
=
input
->
dim
(
2
)
*
input
->
dim
(
3
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
scale_ptr
=
scale
->
data
<
T
>
();
const
T
*
scale_ptr
=
scale
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
->
data
<
T
>
();
const
T
*
mean_ptr
=
mean
->
data
<
T
>
();
const
T
*
mean_ptr
=
mean
->
data
<
T
>
();
const
T
*
var_ptr
=
var
->
data
<
T
>
();
const
T
*
var_ptr
=
var
->
data
<
T
>
();
const
T
*
epsilon_ptr
=
epsilon
->
data
<
T
>
();
const
T
*
epsilon_ptr
=
epsilon
->
data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
functor_
(
input_ptr
,
scale_ptr
,
offset_ptr
,
mean_ptr
,
var_ptr
,
*
epsilon_ptr
,
n
,
channel
,
functor_
(
input_ptr
,
scale_ptr
,
offset_ptr
,
mean_ptr
,
var_ptr
,
*
epsilon_ptr
,
sample_size
,
output_ptr
);
n
,
channel
,
sample_size
,
output_ptr
);
return
true
;
return
true
;
}
}
...
...
mace/ops/batch_norm_benchmark.cc
浏览文件 @
ebff986d
...
@@ -47,7 +47,7 @@ static void BatchNorm(
...
@@ -47,7 +47,7 @@ static void BatchNorm(
int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot
*(sizeof(TYPE)));
\
mace::testing::BytesProcessed(tot
*(sizeof(TYPE)));
\
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \
} \
} \
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
...
...
mace/ops/batch_norm_test.cc
浏览文件 @
ebff986d
...
@@ -11,7 +11,7 @@ class BatchNormOpTest : public OpsTestBase {};
...
@@ -11,7 +11,7 @@ class BatchNormOpTest : public OpsTestBase {};
TEST_F
(
BatchNormOpTest
,
SimpleCPU
)
{
TEST_F
(
BatchNormOpTest
,
SimpleCPU
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Scale"
)
.
Input
(
"Scale"
)
...
@@ -51,7 +51,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
...
@@ -51,7 +51,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
index_t
height
=
103
;
index_t
height
=
103
;
index_t
width
=
113
;
index_t
width
=
113
;
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Scale"
)
.
Input
(
"Scale"
)
...
@@ -74,7 +74,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
...
@@ -74,7 +74,7 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
Tensor
*
expected
=
net
.
GetOutput
(
"Output"
);
Tensor
*
expected
=
net
.
GetOutput
(
"Output"
);
// Run NEON
// Run NEON
net
.
RunOp
(
DeviceType
::
NEON
);
net
.
RunOp
(
DeviceType
::
NEON
);
...
...
mace/ops/channel_shuffle.h
浏览文件 @
ebff986d
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
ChannelShuffleOp
:
public
Operator
<
D
,
T
>
{
class
ChannelShuffleOp
:
public
Operator
<
D
,
T
>
{
public:
public:
ChannelShuffleOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
ChannelShuffleOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
group_
(
OperatorBase
::
GetSingleArgument
<
int
>
(
"group"
,
1
)),
group_
(
OperatorBase
::
GetSingleArgument
<
int
>
(
"group"
,
1
)),
functor_
(
this
->
group_
)
{}
functor_
(
this
->
group_
)
{}
...
...
mace/ops/channel_shuffle_benchmark.cc
浏览文件 @
ebff986d
...
@@ -11,12 +11,8 @@ using namespace mace;
...
@@ -11,12 +11,8 @@ using namespace mace;
using
namespace
mace
::
kernels
;
using
namespace
mace
::
kernels
;
template
<
DeviceType
D
>
template
<
DeviceType
D
>
static
void
ChannelShuffle
(
int
iters
,
static
void
ChannelShuffle
(
int
batch
,
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
,
int
group
)
{
int
channels
,
int
height
,
int
width
,
int
group
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
...
@@ -40,18 +36,17 @@ static void ChannelShuffle(int iters,
...
@@ -40,18 +36,17 @@ static void ChannelShuffle(int iters,
}
}
}
}
#define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, DEVICE) \
#define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, DEVICE) \
static void \
static void BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE( \
BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE( \
int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(float))); \
mace::testing::BytesProcessed(tot*(sizeof(float))); \
ChannelShuffle<DEVICE>(iters, N, C, H, W, G); \
ChannelShuffle<DEVICE>(iters, N, C, H, W, G); \
} \
} \
BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE)
BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE)
#define BM_CHANNEL_SHUFFLE(N, C, H, W, G)
\
#define BM_CHANNEL_SHUFFLE(N, C, H, W, G) \
BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, CPU);
BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, CPU);
BM_CHANNEL_SHUFFLE
(
1
,
64
,
64
,
64
,
8
);
BM_CHANNEL_SHUFFLE
(
1
,
64
,
64
,
64
,
8
);
...
...
mace/ops/channel_shuffle_test.cc
浏览文件 @
ebff986d
...
@@ -10,7 +10,7 @@ class ChannelShuffleOpTest : public OpsTestBase {};
...
@@ -10,7 +10,7 @@ class ChannelShuffleOpTest : public OpsTestBase {};
TEST_F
(
ChannelShuffleOpTest
,
C8G4
)
{
TEST_F
(
ChannelShuffleOpTest
,
C8G4
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"ChannelShuffle"
,
"ChannelShuffleTest"
)
OpDefBuilder
(
"ChannelShuffle"
,
"ChannelShuffleTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -20,18 +20,15 @@ TEST_F(ChannelShuffleOpTest, C8G4) {
...
@@ -20,18 +20,15 @@ TEST_F(ChannelShuffleOpTest, C8G4) {
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
8
,
1
,
2
},
"Input"
,
{
1
,
8
,
1
,
2
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
});
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
auto
expected
=
CreateTensor
<
float
>
(
CreateTensor
<
float
>
({
1
,
8
,
1
,
2
},
{
1
,
8
,
1
,
2
},
{
0
,
1
,
4
,
5
,
8
,
9
,
12
,
13
,
2
,
3
,
6
,
7
,
10
,
11
,
14
,
15
});
{
0
,
1
,
4
,
5
,
8
,
9
,
12
,
13
,
2
,
3
,
6
,
7
,
10
,
11
,
14
,
15
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
mace/ops/concat.h
浏览文件 @
ebff986d
...
@@ -5,12 +5,12 @@
...
@@ -5,12 +5,12 @@
#ifndef MACE_OPS_CONCAT_H_
#ifndef MACE_OPS_CONCAT_H_
#define MACE_OPS_CONCAT_H_
#define MACE_OPS_CONCAT_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/operator.h"
#include "mace/core/operator.h"
#include "mace/kernels/concat.h"
#include "mace/kernels/concat.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
ConcatOp
:
public
Operator
<
D
,
T
>
{
class
ConcatOp
:
public
Operator
<
D
,
T
>
{
public:
public:
ConcatOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
ConcatOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
...
@@ -25,9 +25,11 @@ class ConcatOp : public Operator<D, T> {
...
@@ -25,9 +25,11 @@ class ConcatOp : public Operator<D, T> {
axis_tensor
->
dim_size
());
axis_tensor
->
dim_size
());
const
int32_t
concat_axis
=
*
(
axis_tensor
->
data
<
int32_t
>
());
const
int32_t
concat_axis
=
*
(
axis_tensor
->
data
<
int32_t
>
());
const
int32_t
input_dims
=
input0
->
dim_size
();
const
int32_t
input_dims
=
input0
->
dim_size
();
const
int32_t
axis
=
concat_axis
<
0
?
concat_axis
+
input_dims
:
concat_axis
;
const
int32_t
axis
=
MACE_CHECK
((
0
<=
axis
&&
axis
<
input_dims
),
"Expected concatenating axis in the range ["
,
concat_axis
<
0
?
concat_axis
+
input_dims
:
concat_axis
;
-
input_dims
,
", "
,
input_dims
,
"], but got"
,
concat_axis
);
MACE_CHECK
((
0
<=
axis
&&
axis
<
input_dims
),
"Expected concatenating axis in the range ["
,
-
input_dims
,
", "
,
input_dims
,
"], but got"
,
concat_axis
);
std
::
vector
<
index_t
>
output_shape
(
input0
->
shape
());
std
::
vector
<
index_t
>
output_shape
(
input0
->
shape
());
index_t
inner_size
=
1
;
index_t
inner_size
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
...
@@ -40,10 +42,14 @@ class ConcatOp : public Operator<D, T> {
...
@@ -40,10 +42,14 @@ class ConcatOp : public Operator<D, T> {
const
Tensor
*
input
=
nullptr
;
const
Tensor
*
input
=
nullptr
;
for
(
int
i
=
1
;
i
<
values_count
;
++
i
)
{
for
(
int
i
=
1
;
i
<
values_count
;
++
i
)
{
input
=
this
->
Input
(
i
);
input
=
this
->
Input
(
i
);
MACE_CHECK
(
input
->
dim_size
()
==
input0
->
dim_size
(),
"Ranks of all input tensors must be same."
);
MACE_CHECK
(
input
->
dim_size
()
==
input0
->
dim_size
(),
"Ranks of all input tensors must be same."
);
for
(
int
j
=
0
;
j
<
axis_tensor
->
dim_size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
axis_tensor
->
dim_size
();
++
j
)
{
if
(
j
==
axis
)
{
continue
;
}
if
(
j
==
axis
)
{
MACE_CHECK
(
input
->
dim
(
j
)
==
input0
->
dim
(
j
),
"Dimensions of inputs should equal except axis."
);
continue
;
}
MACE_CHECK
(
input
->
dim
(
j
)
==
input0
->
dim
(
j
),
"Dimensions of inputs should equal except axis."
);
}
}
input_list
[
i
]
=
input
->
data
<
T
>
();
input_list
[
i
]
=
input
->
data
<
T
>
();
outer_sizes
[
i
]
=
input
->
size
()
/
inner_size
;
outer_sizes
[
i
]
=
input
->
size
()
/
inner_size
;
...
@@ -53,9 +59,11 @@ class ConcatOp : public Operator<D, T> {
...
@@ -53,9 +59,11 @@ class ConcatOp : public Operator<D, T> {
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
output
->
Resize
(
output_shape
);
output
->
Resize
(
output_shape
);
functor_
(
input_list
,
inner_size
,
outer_sizes
.
data
(),
output
->
mutable_data
<
T
>
());
functor_
(
input_list
,
inner_size
,
outer_sizes
.
data
(),
output
->
mutable_data
<
T
>
());
return
true
;
return
true
;
}
}
private:
private:
kernels
::
ConcatFunctor
<
D
,
T
>
functor_
;
kernels
::
ConcatFunctor
<
D
,
T
>
functor_
;
...
@@ -63,6 +71,6 @@ class ConcatOp : public Operator<D, T> {
...
@@ -63,6 +71,6 @@ class ConcatOp : public Operator<D, T> {
OP_OUTPUT_TAGS
(
OUTPUT
);
OP_OUTPUT_TAGS
(
OUTPUT
);
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_CONCAT_H_
#endif
// MACE_OPS_CONCAT_H_
mace/ops/concat_benchmark.cc
浏览文件 @
ebff986d
...
@@ -7,9 +7,8 @@
...
@@ -7,9 +7,8 @@
#include "mace/ops/ops_test_util.h"
#include "mace/ops/ops_test_util.h"
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
static
void
ConcatHelper
(
static
void
ConcatHelper
(
int
iters
,
int
concat_dim
,
int
dim1
)
{
int
iters
,
int
concat_dim
,
int
dim1
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
...
...
mace/ops/concat_test.cc
浏览文件 @
ebff986d
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
//
//
#include "mace/ops/concat.h"
#include "mace/ops/concat.h"
#include "mace/ops/ops_test_util.h"
#include "gmock/gmock.h"
#include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
using
namespace
mace
;
using
namespace
mace
;
...
@@ -99,9 +99,7 @@ TEST_F(ConcatOpTest, Random) {
...
@@ -99,9 +99,7 @@ TEST_F(ConcatOpTest, Random) {
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
builder
=
builder
.
Input
((
"Input"
+
ToString
(
i
)).
c_str
());
builder
=
builder
.
Input
((
"Input"
+
ToString
(
i
)).
c_str
());
}
}
builder
.
Input
(
"Axis"
)
builder
.
Input
(
"Axis"
).
Output
(
"Output"
).
Finalize
(
net
.
operator_def
());
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
std
::
vector
<
index_t
>
shape_data
;
std
::
vector
<
index_t
>
shape_data
;
GenerateRandomIntTypeData
<
index_t
>
({
dim
},
shape_data
,
1
,
dim
);
GenerateRandomIntTypeData
<
index_t
>
({
dim
},
shape_data
,
1
,
dim
);
...
@@ -114,7 +112,8 @@ TEST_F(ConcatOpTest, Random) {
...
@@ -114,7 +112,8 @@ TEST_F(ConcatOpTest, Random) {
concat_axis_size
+=
input_shapes
[
i
][
axis
];
concat_axis_size
+=
input_shapes
[
i
][
axis
];
GenerateRandomRealTypeData
(
input_shapes
[
i
],
inputs
[
i
]);
GenerateRandomRealTypeData
(
input_shapes
[
i
],
inputs
[
i
]);
input_ptrs
[
i
]
=
inputs
[
i
].
data
();
input_ptrs
[
i
]
=
inputs
[
i
].
data
();
net
.
AddInputFromArray
<
float
>
((
"Input"
+
ToString
(
i
)).
c_str
(),
input_shapes
[
i
],
inputs
[
i
]);
net
.
AddInputFromArray
<
float
>
((
"Input"
+
ToString
(
i
)).
c_str
(),
input_shapes
[
i
],
inputs
[
i
]);
}
}
net
.
AddInputFromArray
<
int
>
(
"Axis"
,
{},
{
axis
});
net
.
AddInputFromArray
<
int
>
(
"Axis"
,
{},
{
axis
});
...
@@ -131,9 +130,9 @@ TEST_F(ConcatOpTest, Random) {
...
@@ -131,9 +130,9 @@ TEST_F(ConcatOpTest, Random) {
const
float
*
output_ptr
=
output
->
data
<
float
>
();
const
float
*
output_ptr
=
output
->
data
<
float
>
();
while
(
output_ptr
!=
(
output
->
data
<
float
>
()
+
output
->
size
()))
{
while
(
output_ptr
!=
(
output
->
data
<
float
>
()
+
output
->
size
()))
{
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
index_t
num_elements
=
std
::
accumulate
(
input_shapes
[
i
].
begin
()
+
axis
,
index_t
num_elements
=
input_shapes
[
i
].
end
(),
1
,
std
::
accumulate
(
input_shapes
[
i
].
begin
()
+
axis
,
input_shapes
[
i
].
end
()
,
std
::
multiplies
<
index_t
>
());
1
,
std
::
multiplies
<
index_t
>
());
for
(
int
j
=
0
;
j
<
num_elements
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_elements
;
++
j
)
{
EXPECT_EQ
(
*
input_ptrs
[
i
]
++
,
*
output_ptr
++
);
EXPECT_EQ
(
*
input_ptrs
[
i
]
++
,
*
output_ptr
++
);
}
}
...
...
mace/ops/conv_2d.h
浏览文件 @
ebff986d
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
Conv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
class
Conv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
public:
public:
Conv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
Conv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
...
@@ -35,11 +35,10 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
...
@@ -35,11 +35,10 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
int
>
paddings
(
2
);
std
::
vector
<
int
>
paddings
(
2
);
kernels
::
CalcPaddingAndOutputSize
(
input
->
shape
().
data
(),
kernels
::
CalcPaddingAndOutputSize
(
filter
->
shape
().
data
(),
input
->
shape
().
data
(),
filter
->
shape
().
data
(),
this
->
dilations_
.
data
(),
this
->
dilations_
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
output_shape
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
paddings
.
data
());
output_shape
.
data
(),
paddings
.
data
());
output
->
Resize
(
output_shape
);
output
->
Resize
(
output_shape
);
functor_
.
paddings_
=
paddings
;
functor_
.
paddings_
=
paddings
;
...
...
mace/ops/conv_2d_benchmark.cc
浏览文件 @
ebff986d
...
@@ -54,16 +54,17 @@ static void Conv2d(int iters,
...
@@ -54,16 +54,17 @@ static void Conv2d(int iters,
}
}
}
}
#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE)
\
static void \
static void
\
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, \
} \
OC); \
BENCHMARK( \
} \
BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
...
...
mace/ops/conv_2d_test.cc
浏览文件 @
ebff986d
...
@@ -12,7 +12,7 @@ class Conv2dOpTest : public OpsTestBase {};
...
@@ -12,7 +12,7 @@ class Conv2dOpTest : public OpsTestBase {};
TEST_F
(
Conv2dOpTest
,
Simple_VALID
)
{
TEST_F
(
Conv2dOpTest
,
Simple_VALID
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -46,7 +46,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
...
@@ -46,7 +46,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
TEST_F
(
Conv2dOpTest
,
Simple_SAME
)
{
TEST_F
(
Conv2dOpTest
,
Simple_SAME
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -82,7 +82,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
...
@@ -82,7 +82,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
TEST_F
(
Conv2dOpTest
,
Combined
)
{
TEST_F
(
Conv2dOpTest
,
Combined
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2D"
,
"Conv2DTest"
)
OpDefBuilder
(
"Conv2D"
,
"Conv2DTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -120,7 +120,7 @@ TEST_F(Conv2dOpTest, Combined) {
...
@@ -120,7 +120,7 @@ TEST_F(Conv2dOpTest, Combined) {
TEST_F
(
Conv2dOpTest
,
Conv1x1
)
{
TEST_F
(
Conv2dOpTest
,
Conv1x1
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2D"
,
"Conv2DTest"
)
OpDefBuilder
(
"Conv2D"
,
"Conv2DTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -172,13 +172,13 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) {
...
@@ -172,13 +172,13 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) {
srand
(
time
(
NULL
));
srand
(
time
(
NULL
));
// generate random input
// generate random input
index_t
batch
=
3
;
index_t
batch
=
3
;
index_t
input_channels
=
64
;
index_t
input_channels
=
64
;
index_t
height
=
32
;
index_t
height
=
32
;
index_t
width
=
32
;
index_t
width
=
32
;
index_t
output_channels
=
128
;
index_t
output_channels
=
128
;
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -229,7 +229,7 @@ TEST_F(Conv2dOpTest, DisgustConvNxNS12) {
...
@@ -229,7 +229,7 @@ TEST_F(Conv2dOpTest, DisgustConvNxNS12) {
index_t
width
=
113
;
index_t
width
=
113
;
index_t
output_channels
=
3
+
rand
()
%
10
;
index_t
output_channels
=
3
+
rand
()
%
10
;
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2D"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
...
mace/ops/conv_pool_2d_base.h
浏览文件 @
ebff986d
...
@@ -13,13 +13,14 @@ namespace mace {
...
@@ -13,13 +13,14 @@ namespace mace {
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
ConvPool2dOpBase
:
public
Operator
<
D
,
T
>
{
class
ConvPool2dOpBase
:
public
Operator
<
D
,
T
>
{
public:
public:
ConvPool2dOpBase
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
ConvPool2dOpBase
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
op_def
,
ws
),
:
Operator
<
D
,
T
>
(
op_def
,
ws
),
strides_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"strides"
)),
strides_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"strides"
)),
padding_
(
static_cast
<
Padding
>
(
OperatorBase
::
GetSingleArgument
<
int
>
(
padding_
(
static_cast
<
Padding
>
(
OperatorBase
::
GetSingleArgument
<
int
>
(
"padding"
,
static_cast
<
int
>
(
SAME
)))),
"padding"
,
static_cast
<
int
>
(
SAME
)))),
dilations_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"dilations"
,
{
1
,
1
}))
{}
dilations_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"dilations"
,
{
1
,
1
}))
{}
protected:
protected:
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
strides_
;
Padding
padding_
;
Padding
padding_
;
...
...
mace/ops/depthwise_conv2d.cc
浏览文件 @
ebff986d
...
@@ -6,10 +6,12 @@
...
@@ -6,10 +6,12 @@
namespace
mace
{
namespace
mace
{
REGISTER_CPU_OPERATOR
(
DepthwiseConv2d
,
DepthwiseConv2dOp
<
DeviceType
::
CPU
,
float
>
);
REGISTER_CPU_OPERATOR
(
DepthwiseConv2d
,
DepthwiseConv2dOp
<
DeviceType
::
CPU
,
float
>
);
#if __ARM_NEON
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
DepthwiseConv2d
,
DepthwiseConv2dOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
DepthwiseConv2d
,
DepthwiseConv2dOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
#endif // __ARM_NEON
}
// namespace mace
}
// namespace mace
mace/ops/depthwise_conv2d.h
浏览文件 @
ebff986d
...
@@ -9,12 +9,12 @@
...
@@ -9,12 +9,12 @@
#include "mace/core/operator.h"
#include "mace/core/operator.h"
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/conv_2d.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/ops/conv_pool_2d_base.h"
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
DepthwiseConv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
class
DepthwiseConv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
public:
public:
DepthwiseConv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
DepthwiseConv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
...
@@ -34,16 +34,16 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
...
@@ -34,16 +34,16 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
// resize filter shape.
// resize filter shape.
std
::
vector
<
index_t
>
filter_shape
(
filter
->
shape
().
begin
(),
filter
->
shape
().
end
());
std
::
vector
<
index_t
>
filter_shape
(
filter
->
shape
().
begin
(),
filter
->
shape
().
end
());
filter_shape
[
0
]
*=
filter_shape
[
1
];
filter_shape
[
0
]
*=
filter_shape
[
1
];
filter_shape
[
1
]
=
1
;
filter_shape
[
1
]
=
1
;
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
int
>
paddings
(
2
);
std
::
vector
<
int
>
paddings
(
2
);
kernels
::
CalcPaddingAndOutputSize
(
input
->
shape
().
data
(),
kernels
::
CalcPaddingAndOutputSize
(
filter_shape
.
data
(),
input
->
shape
().
data
(),
filter_shape
.
data
(),
this
->
dilations_
.
data
(),
this
->
dilations_
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
output_shape
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
paddings
.
data
());
output_shape
.
data
(),
paddings
.
data
());
output
->
Resize
(
output_shape
);
output
->
Resize
(
output_shape
);
functor_
.
paddings_
=
paddings
;
functor_
.
paddings_
=
paddings
;
...
@@ -62,6 +62,6 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
...
@@ -62,6 +62,6 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
OP_OUTPUT_TAGS
(
OUTPUT
);
OP_OUTPUT_TAGS
(
OUTPUT
);
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_DEPTHWISE_CONV_H_
#endif
// MACE_OPS_DEPTHWISE_CONV_H_
mace/ops/depthwise_conv2d_test.cc
浏览文件 @
ebff986d
...
@@ -12,7 +12,7 @@ class DepthwiseConv2dOpTest : public OpsTestBase {};
...
@@ -12,7 +12,7 @@ class DepthwiseConv2dOpTest : public OpsTestBase {};
TEST_F
(
DepthwiseConv2dOpTest
,
Simple_VALID
)
{
TEST_F
(
DepthwiseConv2dOpTest
,
Simple_VALID
)
{
testing
::
internal
::
LogToStderr
();
testing
::
internal
::
LogToStderr
();
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2DTest"
)
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2DTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -26,23 +26,20 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
...
@@ -26,23 +26,20 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
2
,
2
,
3
},
"Input"
,
{
1
,
2
,
2
,
3
},
{
1
,
3
,
5
,
7
,
9
,
11
,
2
,
4
,
6
,
8
,
10
,
12
});
{
1
,
3
,
5
,
7
,
9
,
11
,
2
,
4
,
6
,
8
,
10
,
12
});
net
.
AddInputFromArray
<
float
>
(
net
.
AddInputFromArray
<
float
>
(
"Filter"
,
{
2
,
2
,
2
,
2
},
"Filter"
,
{
2
,
2
,
2
,
2
},
{
1.0
f
,
5.0
f
,
9.0
f
,
13.0
f
,
{
1.0
f
,
5.0
f
,
9.0
f
,
13.0
f
,
2.0
f
,
6.0
f
,
10.0
f
,
14.0
f
,
3.0
f
,
7.0
f
,
11.0
f
,
2.0
f
,
6.0
f
,
10.0
f
,
14.0
f
,
15.0
f
,
4.0
f
,
8.0
f
,
12.0
f
,
16.0
f
});
3.0
f
,
7.0
f
,
11.0
f
,
15.0
f
,
4.0
f
,
8.0
f
,
12.0
f
,
16.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
4
},
{
.1
f
,
.2
f
,
.3
f
,
.4
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
4
},
{
.1
f
,
.2
f
,
.3
f
,
.4
f
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
(
{
1
,
4
,
1
,
2
},
auto
expected
=
CreateTensor
<
float
>
(
{
196.1
f
,
252.1
f
,
216.2
f
,
280.2
f
,
{
1
,
4
,
1
,
2
}
,
272.3
f
,
344.3
f
,
296.4
f
,
376.4
f
});
{
196.1
f
,
252.1
f
,
216.2
f
,
280.2
f
,
272.3
f
,
344.3
f
,
296.4
f
,
376.4
f
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
}
}
...
@@ -60,7 +57,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
...
@@ -60,7 +57,7 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
index_t
width
=
113
;
index_t
width
=
113
;
index_t
multiplier
=
3
+
rand
()
%
10
;
index_t
multiplier
=
3
+
rand
()
%
10
;
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2DTest"
)
OpDefBuilder
(
"DepthwiseConv2d"
,
"DepthwiseConv2DTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -75,8 +72,8 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
...
@@ -75,8 +72,8 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
// Add input data
// Add input data
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
input_channels
,
height
,
width
});
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
input_channels
,
height
,
width
});
net
.
AddRandomInput
<
float
>
(
net
.
AddRandomInput
<
float
>
(
"Filter"
,
"Filter"
,
{
multiplier
,
input_channels
,
kernel_h
,
kernel_w
});
{
multiplier
,
input_channels
,
kernel_h
,
kernel_w
});
net
.
AddRandomInput
<
float
>
(
"Bias"
,
{
multiplier
*
input_channels
});
net
.
AddRandomInput
<
float
>
(
"Bias"
,
{
multiplier
*
input_channels
});
// run cpu
// run cpu
net
.
RunOp
();
net
.
RunOp
();
...
...
mace/ops/depthwise_conv_2d_benchmark.cc
浏览文件 @
ebff986d
...
@@ -13,15 +13,15 @@ namespace mace {
...
@@ -13,15 +13,15 @@ namespace mace {
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
static
void
DepthwiseConv2d
(
int
iters
,
static
void
DepthwiseConv2d
(
int
iters
,
int
batch
,
int
batch
,
int
channels
,
int
channels
,
int
height
,
int
height
,
int
width
,
int
width
,
int
kernel_h
,
int
kernel_h
,
int
kernel_w
,
int
kernel_w
,
int
stride
,
int
stride
,
Padding
padding
,
Padding
padding
,
int
output_channels
)
{
int
output_channels
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
...
@@ -54,16 +54,18 @@ static void DepthwiseConv2d(int iters,
...
@@ -54,16 +54,18 @@ static void DepthwiseConv2d(int iters,
}
}
}
}
#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \
static void \
DEVICE) \
static void \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthwiseConv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
DepthwiseConv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, \
} \
mace::Padding::P, OC); \
BENCHMARK( \
} \
BENCHMARK( \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
...
...
mace/ops/global_avg_pooling.h
浏览文件 @
ebff986d
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
GlobalAvgPoolingOp
:
public
Operator
<
D
,
T
>
{
class
GlobalAvgPoolingOp
:
public
Operator
<
D
,
T
>
{
public:
public:
GlobalAvgPoolingOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
GlobalAvgPoolingOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
...
...
mace/ops/global_avg_pooling_benchmark.cc
浏览文件 @
ebff986d
...
@@ -11,11 +11,8 @@ using namespace mace;
...
@@ -11,11 +11,8 @@ using namespace mace;
using
namespace
mace
::
kernels
;
using
namespace
mace
::
kernels
;
template
<
DeviceType
D
>
template
<
DeviceType
D
>
static
void
GlobalAvgPooling
(
int
iters
,
static
void
GlobalAvgPooling
(
int
batch
,
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
)
{
int
channels
,
int
height
,
int
width
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
...
@@ -38,15 +35,14 @@ static void GlobalAvgPooling(int iters,
...
@@ -38,15 +35,14 @@ static void GlobalAvgPooling(int iters,
}
}
}
}
#define BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, DEVICE) \
#define BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, DEVICE) \
static void \
static void BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE( \
BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE( \
int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(float))); \
mace::testing::BytesProcessed(tot*(sizeof(float))); \
GlobalAvgPooling<DEVICE>(iters, N, C, H, W); \
GlobalAvgPooling<DEVICE>(iters, N, C, H, W); \
} \
} \
BENCHMARK(BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE)
BENCHMARK(BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE)
#define BM_GLOBAL_AVG_POOLING(N, C, H, W) \
#define BM_GLOBAL_AVG_POOLING(N, C, H, W) \
...
...
mace/ops/global_avg_pooling_test.cc
浏览文件 @
ebff986d
...
@@ -10,7 +10,7 @@ class GlobalAvgPoolingOpTest : public OpsTestBase {};
...
@@ -10,7 +10,7 @@ class GlobalAvgPoolingOpTest : public OpsTestBase {};
TEST_F
(
GlobalAvgPoolingOpTest
,
3
x7x7_CPU
)
{
TEST_F
(
GlobalAvgPoolingOpTest
,
3
x7x7_CPU
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"GlobalAvgPooling"
,
"GlobalAvgPoolingTest"
)
OpDefBuilder
(
"GlobalAvgPooling"
,
"GlobalAvgPoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -19,24 +19,22 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) {
...
@@ -19,24 +19,22 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) {
// Add input data
// Add input data
std
::
vector
<
float
>
input
(
147
);
std
::
vector
<
float
>
input
(
147
);
for
(
int
i
=
0
;
i
<
147
;
++
i
)
{
for
(
int
i
=
0
;
i
<
147
;
++
i
)
{
input
[
i
]
=
i
/
49
+
1
;
input
[
i
]
=
i
/
49
+
1
;
}
}
net
.
AddInputFromArray
<
float
>
(
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
3
,
7
,
7
},
input
);
"Input"
,
{
1
,
3
,
7
,
7
},
input
);
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
auto
expected
=
CreateTensor
<
float
>
({
1
,
3
,
1
,
1
},
{
1
,
2
,
3
});
CreateTensor
<
float
>
({
1
,
3
,
1
,
1
},
{
1
,
2
,
3
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
TEST_F
(
GlobalAvgPoolingOpTest
,
3
x7x7_NEON
)
{
TEST_F
(
GlobalAvgPoolingOpTest
,
3
x7x7_NEON
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"GlobalAvgPooling"
,
"GlobalAvgPoolingTest"
)
OpDefBuilder
(
"GlobalAvgPooling"
,
"GlobalAvgPoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -45,17 +43,15 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) {
...
@@ -45,17 +43,15 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) {
// Add input data
// Add input data
std
::
vector
<
float
>
input
(
147
);
std
::
vector
<
float
>
input
(
147
);
for
(
int
i
=
0
;
i
<
147
;
++
i
)
{
for
(
int
i
=
0
;
i
<
147
;
++
i
)
{
input
[
i
]
=
i
/
49
+
1
;
input
[
i
]
=
i
/
49
+
1
;
}
}
net
.
AddInputFromArray
<
float
>
(
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
3
,
7
,
7
},
input
);
"Input"
,
{
1
,
3
,
7
,
7
},
input
);
// Run
// Run
net
.
RunOp
(
DeviceType
::
NEON
);
net
.
RunOp
(
DeviceType
::
NEON
);
// Check
// Check
auto
expected
=
auto
expected
=
CreateTensor
<
float
>
({
1
,
3
,
1
,
1
},
{
1
,
2
,
3
});
CreateTensor
<
float
>
({
1
,
3
,
1
,
1
},
{
1
,
2
,
3
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
mace/ops/ops_test_util.h
浏览文件 @
ebff986d
...
@@ -43,7 +43,7 @@ class OpsTestNet {
...
@@ -43,7 +43,7 @@ class OpsTestNet {
public:
public:
OpsTestNet
()
{}
OpsTestNet
()
{}
template
<
typename
T
>
template
<
typename
T
>
void
AddInputFromArray
(
const
char
*
name
,
void
AddInputFromArray
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
T
>
&
data
)
{
const
std
::
vector
<
T
>
&
data
)
{
...
@@ -55,7 +55,7 @@ class OpsTestNet {
...
@@ -55,7 +55,7 @@ class OpsTestNet {
memcpy
(
input_data
,
data
.
data
(),
data
.
size
()
*
sizeof
(
T
));
memcpy
(
input_data
,
data
.
data
(),
data
.
size
()
*
sizeof
(
T
));
}
}
template
<
typename
T
>
template
<
typename
T
>
void
AddRepeatedInput
(
const
char
*
name
,
void
AddRepeatedInput
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
index_t
>
&
shape
,
const
T
data
)
{
const
T
data
)
{
...
@@ -66,7 +66,7 @@ class OpsTestNet {
...
@@ -66,7 +66,7 @@ class OpsTestNet {
std
::
fill
(
input_data
,
input_data
+
input
->
size
(),
data
);
std
::
fill
(
input_data
,
input_data
+
input
->
size
(),
data
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
AddRandomInput
(
const
char
*
name
,
void
AddRandomInput
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
index_t
>
&
shape
,
bool
positive
=
false
)
{
bool
positive
=
false
)
{
...
@@ -173,38 +173,37 @@ class OpsTestBase : public ::testing::Test {
...
@@ -173,38 +173,37 @@ class OpsTestBase : public ::testing::Test {
OpsTestNet
test_net_
;
OpsTestNet
test_net_
;
};
};
template
<
typename
T
>
template
<
typename
T
>
void
GenerateRandomRealTypeData
(
const
std
::
vector
<
index_t
>
&
shape
,
std
::
vector
<
T
>
&
res
)
{
void
GenerateRandomRealTypeData
(
const
std
::
vector
<
index_t
>
&
shape
,
std
::
vector
<
T
>
&
res
)
{
std
::
random_device
rd
;
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
std
::
normal_distribution
<
T
>
nd
(
0
,
1
);
std
::
normal_distribution
<
T
>
nd
(
0
,
1
);
index_t
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
index_t
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
res
.
resize
(
size
);
res
.
resize
(
size
);
std
::
generate
(
res
.
begin
(),
res
.
end
(),
std
::
generate
(
res
.
begin
(),
res
.
end
(),
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
}
}
template
<
typename
T
>
template
<
typename
T
>
void
GenerateRandomIntTypeData
(
const
std
::
vector
<
index_t
>
&
shape
,
std
::
vector
<
T
>
&
res
,
void
GenerateRandomIntTypeData
(
const
std
::
vector
<
index_t
>
&
shape
,
const
T
a
=
0
,
const
T
b
=
std
::
numeric_limits
<
T
>::
max
())
{
std
::
vector
<
T
>
&
res
,
const
T
a
=
0
,
const
T
b
=
std
::
numeric_limits
<
T
>::
max
())
{
std
::
random_device
rd
;
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
std
::
uniform_int_distribution
<>
nd
(
a
,
b
);
std
::
uniform_int_distribution
<>
nd
(
a
,
b
);
index_t
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
index_t
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
res
.
resize
(
size
);
res
.
resize
(
size
);
std
::
generate
(
res
.
begin
(),
res
.
end
(),
std
::
generate
(
res
.
begin
(),
res
.
end
(),
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
}
}
template
<
typename
T
>
template
<
typename
T
>
unique_ptr
<
Tensor
>
CreateTensor
(
const
std
::
vector
<
index_t
>
&
shape
,
unique_ptr
<
Tensor
>
CreateTensor
(
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
T
>
&
data
)
{
const
std
::
vector
<
T
>
&
data
)
{
unique_ptr
<
Tensor
>
res
(
new
Tensor
(
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
()));
unique_ptr
<
Tensor
>
res
(
new
Tensor
(
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
()));
...
@@ -237,23 +236,23 @@ inline std::string ShapeToString(const Tensor &x) {
...
@@ -237,23 +236,23 @@ inline std::string ShapeToString(const Tensor &x) {
return
std
::
string
(
stream
.
str
());
return
std
::
string
(
stream
.
str
());
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
is_floating_point_type
{
struct
is_floating_point_type
{
static
const
bool
value
=
static
const
bool
value
=
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
;
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
;
};
};
template
<
typename
T
>
template
<
typename
T
>
inline
void
ExpectEqual
(
const
T
&
a
,
const
T
&
b
)
{
inline
void
ExpectEqual
(
const
T
&
a
,
const
T
&
b
)
{
EXPECT_EQ
(
a
,
b
);
EXPECT_EQ
(
a
,
b
);
}
}
template
<
>
template
<
>
inline
void
ExpectEqual
<
float
>
(
const
float
&
a
,
const
float
&
b
)
{
inline
void
ExpectEqual
<
float
>
(
const
float
&
a
,
const
float
&
b
)
{
EXPECT_FLOAT_EQ
(
a
,
b
);
EXPECT_FLOAT_EQ
(
a
,
b
);
}
}
template
<
>
template
<
>
inline
void
ExpectEqual
<
double
>
(
const
double
&
a
,
const
double
&
b
)
{
inline
void
ExpectEqual
<
double
>
(
const
double
&
a
,
const
double
&
b
)
{
EXPECT_DOUBLE_EQ
(
a
,
b
);
EXPECT_DOUBLE_EQ
(
a
,
b
);
}
}
...
@@ -264,11 +263,11 @@ inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) {
...
@@ -264,11 +263,11 @@ inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) {
<<
"y.shape [ "
<<
ShapeToString
(
y
)
<<
"]"
;
<<
"y.shape [ "
<<
ShapeToString
(
y
)
<<
"]"
;
}
}
template
<
typename
T
,
bool
is_fp
=
is_floating_point_type
<
T
>
::
value
>
template
<
typename
T
,
bool
is_fp
=
is_floating_point_type
<
T
>
::
value
>
struct
Expector
;
struct
Expector
;
// Partial specialization for float and double.
// Partial specialization for float and double.
template
<
typename
T
>
template
<
typename
T
>
struct
Expector
<
T
,
true
>
{
struct
Expector
<
T
,
true
>
{
static
void
Equal
(
const
T
&
a
,
const
T
&
b
)
{
ExpectEqual
(
a
,
b
);
}
static
void
Equal
(
const
T
&
a
,
const
T
&
b
)
{
ExpectEqual
(
a
,
b
);
}
...
@@ -294,17 +293,17 @@ struct Expector<T, true> {
...
@@ -294,17 +293,17 @@ struct Expector<T, true> {
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
void
ExpectTensorNear
(
const
Tensor
&
x
,
const
Tensor
&
y
,
const
double
abs_err
)
{
void
ExpectTensorNear
(
const
Tensor
&
x
,
const
Tensor
&
y
,
const
double
abs_err
)
{
static_assert
(
is_floating_point_type
<
T
>::
value
,
static_assert
(
is_floating_point_type
<
T
>::
value
,
"T is not a floating point type"
);
"T is not a floating point type"
);
Expector
<
T
>::
Near
(
x
,
y
,
abs_err
);
Expector
<
T
>::
Near
(
x
,
y
,
abs_err
);
}
}
template
<
typename
T
>
template
<
typename
T
>
std
::
string
ToString
(
const
T
&
input
)
{
std
::
string
ToString
(
const
T
&
input
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
input
;
ss
<<
input
;
return
ss
.
str
();
return
ss
.
str
();
}
}
...
...
mace/ops/pooling.h
浏览文件 @
ebff986d
...
@@ -14,7 +14,7 @@ namespace mace {
...
@@ -14,7 +14,7 @@ namespace mace {
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
PoolingOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
class
PoolingOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
public:
public:
PoolingOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
PoolingOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
),
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
),
kernels_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"kernels"
)),
kernels_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"kernels"
)),
pooling_type_
(
pooling_type_
(
...
@@ -22,8 +22,8 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
...
@@ -22,8 +22,8 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
"pooling_type"
,
static_cast
<
int
>
(
AVG
)))){};
"pooling_type"
,
static_cast
<
int
>
(
AVG
)))){};
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
int
>
paddings
(
2
);
std
::
vector
<
int
>
paddings
(
2
);
...
@@ -34,11 +34,10 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
...
@@ -34,11 +34,10 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
filter_shape
[
2
]
=
kernels_
[
0
];
filter_shape
[
2
]
=
kernels_
[
0
];
filter_shape
[
3
]
=
kernels_
[
1
];
filter_shape
[
3
]
=
kernels_
[
1
];
kernels
::
CalcPaddingAndOutputSize
(
input
->
shape
().
data
(),
kernels
::
CalcPaddingAndOutputSize
(
filter_shape
.
data
(),
input
->
shape
().
data
(),
filter_shape
.
data
(),
this
->
dilations_
.
data
(),
this
->
dilations_
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
output_shape
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
paddings
.
data
());
output_shape
.
data
(),
paddings
.
data
());
output
->
Resize
(
output_shape
);
output
->
Resize
(
output_shape
);
auto
pooling_func
=
kernels
::
PoolingFunctor
<
D
,
T
>
(
auto
pooling_func
=
kernels
::
PoolingFunctor
<
D
,
T
>
(
...
...
mace/ops/pooling_benchmark.cc
浏览文件 @
ebff986d
...
@@ -56,7 +56,7 @@ static void Pooling(int iters,
...
@@ -56,7 +56,7 @@ static void Pooling(int iters,
int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot
*(sizeof(float)));
\
mace::testing::BytesProcessed(tot
*(sizeof(float)));
\
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \
PoolingType::PO); \
PoolingType::PO); \
} \
} \
...
...
mace/ops/pooling_test.cc
浏览文件 @
ebff986d
...
@@ -15,7 +15,7 @@ class PoolingOpTest : public OpsTestBase {};
...
@@ -15,7 +15,7 @@ class PoolingOpTest : public OpsTestBase {};
TEST_F
(
PoolingOpTest
,
MAX_VALID
)
{
TEST_F
(
PoolingOpTest
,
MAX_VALID
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -46,7 +46,7 @@ TEST_F(PoolingOpTest, MAX_VALID) {
...
@@ -46,7 +46,7 @@ TEST_F(PoolingOpTest, MAX_VALID) {
TEST_F
(
PoolingOpTest
,
AVG_VALID
)
{
TEST_F
(
PoolingOpTest
,
AVG_VALID
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -77,7 +77,7 @@ TEST_F(PoolingOpTest, AVG_VALID) {
...
@@ -77,7 +77,7 @@ TEST_F(PoolingOpTest, AVG_VALID) {
TEST_F
(
PoolingOpTest
,
MAX_SAME
)
{
TEST_F
(
PoolingOpTest
,
MAX_SAME
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -105,7 +105,7 @@ TEST_F(PoolingOpTest, MAX_SAME) {
...
@@ -105,7 +105,7 @@ TEST_F(PoolingOpTest, MAX_SAME) {
TEST_F
(
PoolingOpTest
,
MAX_VALID_DILATION
)
{
TEST_F
(
PoolingOpTest
,
MAX_VALID_DILATION
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -134,7 +134,7 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
...
@@ -134,7 +134,7 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
TEST_F
(
PoolingOpTest
,
MAX_k2x2s2x2
)
{
TEST_F
(
PoolingOpTest
,
MAX_k2x2s2x2
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -148,9 +148,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
...
@@ -148,9 +148,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
2
,
9
},
net
.
AddInputFromArray
<
float
>
(
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
"Input"
,
{
1
,
1
,
2
,
9
}
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
});
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
});
// Run
// Run
net
.
RunOp
(
DeviceType
::
NEON
);
net
.
RunOp
(
DeviceType
::
NEON
);
...
@@ -162,7 +162,7 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
...
@@ -162,7 +162,7 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
TEST_F
(
PoolingOpTest
,
MAX_k3x3s2x2
)
{
TEST_F
(
PoolingOpTest
,
MAX_k3x3s2x2
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -176,10 +176,10 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
...
@@ -176,10 +176,10 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
3
,
9
},
net
.
AddInputFromArray
<
float
>
(
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
"Input"
,
{
1
,
1
,
3
,
9
}
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
});
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
});
// Run
// Run
net
.
RunOp
(
DeviceType
::
NEON
);
net
.
RunOp
(
DeviceType
::
NEON
);
...
@@ -191,7 +191,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
...
@@ -191,7 +191,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
TEST_F
(
PoolingOpTest
,
AVG_k2x2s2x2
)
{
TEST_F
(
PoolingOpTest
,
AVG_k2x2s2x2
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -207,15 +207,12 @@ TEST_F(PoolingOpTest, AVG_k2x2s2x2) {
...
@@ -207,15 +207,12 @@ TEST_F(PoolingOpTest, AVG_k2x2s2x2) {
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
2
,
8
},
"Input"
,
{
1
,
1
,
2
,
8
},
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
});
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
});
// Run
// Run
net
.
RunOp
(
DeviceType
::
NEON
);
net
.
RunOp
(
DeviceType
::
NEON
);
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
1
,
4
},
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
1
,
4
},
{
4.5
,
6.5
,
8.5
,
10.5
});
{
4.5
,
6.5
,
8.5
,
10.5
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
mace/ops/relu.h
浏览文件 @
ebff986d
...
@@ -13,17 +13,17 @@ namespace mace {
...
@@ -13,17 +13,17 @@ namespace mace {
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
ReluOp
:
public
Operator
<
D
,
T
>
{
class
ReluOp
:
public
Operator
<
D
,
T
>
{
public:
public:
ReluOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
ReluOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{
functor_
.
max_limit_
=
functor_
.
max_limit_
=
OperatorBase
::
GetSingleArgument
<
T
>
(
"max_limit"
,
static_cast
<
T
>
(
-
1
));
OperatorBase
::
GetSingleArgument
<
T
>
(
"max_limit"
,
static_cast
<
T
>
(
-
1
));
}
}
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input_tensor
=
this
->
inputs_
[
0
];
const
Tensor
*
input_tensor
=
this
->
inputs_
[
0
];
Tensor
*
output_tensor
=
this
->
outputs_
[
0
];
Tensor
*
output_tensor
=
this
->
outputs_
[
0
];
output_tensor
->
ResizeLike
(
input_tensor
);
output_tensor
->
ResizeLike
(
input_tensor
);
const
T
*
input
=
input_tensor
->
data
<
T
>
();
const
T
*
input
=
input_tensor
->
data
<
T
>
();
T
*
output
=
output_tensor
->
mutable_data
<
T
>
();
T
*
output
=
output_tensor
->
mutable_data
<
T
>
();
index_t
size
=
input_tensor
->
size
();
index_t
size
=
input_tensor
->
size
();
functor_
(
input
,
output
,
size
);
functor_
(
input
,
output
,
size
);
...
...
mace/ops/relu_benchmark.cc
浏览文件 @
ebff986d
...
@@ -36,7 +36,7 @@ static void ReluBenchmark(int iters, int size) {
...
@@ -36,7 +36,7 @@ static void ReluBenchmark(int iters, int size) {
static void BM_RELU_##SIZE##_##TYPE##_##DEVICE(int iters) { \
static void BM_RELU_##SIZE##_##TYPE##_##DEVICE(int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * SIZE; \
const int64_t tot = static_cast<int64_t>(iters) * SIZE; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot
*(sizeof(TYPE)));
\
mace::testing::BytesProcessed(tot
*(sizeof(TYPE)));
\
ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \
ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \
} \
} \
BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE)
BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE)
...
...
mace/ops/relu_test.cc
浏览文件 @
ebff986d
...
@@ -11,7 +11,7 @@ class ReluOpTest : public OpsTestBase {};
...
@@ -11,7 +11,7 @@ class ReluOpTest : public OpsTestBase {};
TEST_F
(
ReluOpTest
,
ReluOp
)
{
TEST_F
(
ReluOpTest
,
ReluOp
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Relu"
,
"ReluTest"
)
OpDefBuilder
(
"Relu"
,
"ReluTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -34,7 +34,7 @@ TEST_F(ReluOpTest, ReluOp) {
...
@@ -34,7 +34,7 @@ TEST_F(ReluOpTest, ReluOp) {
TEST_F
(
ReluOpTest
,
ReluOpWithMax
)
{
TEST_F
(
ReluOpTest
,
ReluOpWithMax
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Relu"
,
"ReluTestWithMax"
)
OpDefBuilder
(
"Relu"
,
"ReluTestWithMax"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
...
@@ -56,5 +56,4 @@ TEST_F(ReluOpTest, ReluOpWithMax) {
...
@@ -56,5 +56,4 @@ TEST_F(ReluOpTest, ReluOpWithMax) {
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
}
}
}
// namespace mace
}
// namespace mace
mace/ops/resize_bilinear.h
浏览文件 @
ebff986d
...
@@ -13,21 +13,21 @@ namespace mace {
...
@@ -13,21 +13,21 @@ namespace mace {
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
ResizeBilinearOp
:
public
Operator
<
D
,
T
>
{
class
ResizeBilinearOp
:
public
Operator
<
D
,
T
>
{
public:
public:
ResizeBilinearOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
ResizeBilinearOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
functor_
(
functor_
(
OperatorBase
::
GetSingleArgument
<
bool
>
(
"align_corners"
,
false
))
{}
OperatorBase
::
GetSingleArgument
<
bool
>
(
"align_corners"
,
false
))
{}
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
0
);
const
Tensor
*
input
=
this
->
Input
(
0
);
const
Tensor
*
resize_dims
=
this
->
Input
(
1
);
const
Tensor
*
resize_dims
=
this
->
Input
(
1
);
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional."
,
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional."
,
input
->
dim_size
());
input
->
dim_size
());
MACE_CHECK
(
resize_dims
->
dim_size
()
==
1
,
MACE_CHECK
(
resize_dims
->
dim_size
()
==
1
,
"resize dim must be 2-dimensional."
,
resize_dims
->
dim_size
());
"resize dim must be 2-dimensional."
,
resize_dims
->
dim_size
());
Tensor
*
output
=
this
->
Output
(
0
);
Tensor
*
output
=
this
->
Output
(
0
);
index_t
n
=
input
->
dim
(
0
);
index_t
n
=
input
->
dim
(
0
);
index_t
channels
=
input
->
dim
(
1
);
index_t
channels
=
input
->
dim
(
1
);
...
@@ -38,8 +38,8 @@ class ResizeBilinearOp : public Operator<D, T> {
...
@@ -38,8 +38,8 @@ class ResizeBilinearOp : public Operator<D, T> {
vector
<
index_t
>
out_shape
{
n
,
channels
,
out_height
,
out_width
};
vector
<
index_t
>
out_shape
{
n
,
channels
,
out_height
,
out_width
};
output
->
Resize
(
out_shape
);
output
->
Resize
(
out_shape
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
input_ptr
=
input
->
data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
functor_
(
input_ptr
,
output_ptr
,
n
,
channels
,
in_height
,
in_width
,
functor_
(
input_ptr
,
output_ptr
,
n
,
channels
,
in_height
,
in_width
,
out_height
,
out_width
);
out_height
,
out_width
);
...
...
mace/ops/resize_bilinear_test.cc
浏览文件 @
ebff986d
...
@@ -13,7 +13,7 @@ class ResizeBilinearTest : public OpsTestBase {};
...
@@ -13,7 +13,7 @@ class ResizeBilinearTest : public OpsTestBase {};
TEST_F
(
ResizeBilinearTest
,
ResizeBilinearWOAlignCorners
)
{
TEST_F
(
ResizeBilinearTest
,
ResizeBilinearWOAlignCorners
)
{
testing
::
internal
::
LogToStderr
();
testing
::
internal
::
LogToStderr
();
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"ResizeBilinear"
,
"ResizeBilinearTest"
)
OpDefBuilder
(
"ResizeBilinear"
,
"ResizeBilinearTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"OutSize"
)
.
Input
(
"OutSize"
)
...
@@ -38,7 +38,7 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) {
...
@@ -38,7 +38,7 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWOAlignCorners) {
TEST_F
(
ResizeBilinearTest
,
ResizeBilinearWAlignCorners
)
{
TEST_F
(
ResizeBilinearTest
,
ResizeBilinearWAlignCorners
)
{
testing
::
internal
::
LogToStderr
();
testing
::
internal
::
LogToStderr
();
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"ResizeBilinear"
,
"ResizeBilinearTest"
)
OpDefBuilder
(
"ResizeBilinear"
,
"ResizeBilinearTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"OutSize"
)
.
Input
(
"OutSize"
)
...
...
mace/proto/BUILD
浏览文件 @
ebff986d
...
@@ -33,8 +33,8 @@ cc_proto_library(
...
@@ -33,8 +33,8 @@ cc_proto_library(
py_proto_library
(
py_proto_library
(
name
=
"mace_py"
,
name
=
"mace_py"
,
srcs
=
[
"mace.proto"
],
srcs
=
[
"mace.proto"
],
default_runtime
=
"@com_google_protobuf//:protobuf_python"
,
protoc
=
"@com_google_protobuf//:protoc"
,
srcs_version
=
"PY2AND3"
,
srcs_version
=
"PY2AND3"
,
deps
=
[
"@com_google_protobuf//:protobuf_python"
],
deps
=
[
"@com_google_protobuf//:protobuf_python"
],
protoc
=
"@com_google_protobuf//:protoc"
,
)
default_runtime
=
"@com_google_protobuf//:protobuf_python"
,
)
\ No newline at end of file
mace/python/tools/BUILD
浏览文件 @
ebff986d
...
@@ -16,4 +16,3 @@ py_binary(
...
@@ -16,4 +16,3 @@ py_binary(
"@six_archive//:six"
,
"@six_archive//:six"
,
],
],
)
)
mace/tools/benchmark/benchmark_model.cc
浏览文件 @
ebff986d
...
@@ -3,14 +3,13 @@
...
@@ -3,14 +3,13 @@
//
//
#include "mace/core/net.h"
#include "mace/core/net.h"
#include "mace/utils/command_line_flags.h"
#include "mace/tools/benchmark/stat_summarizer.h"
#include "mace/tools/benchmark/stat_summarizer.h"
#include "mace/utils/command_line_flags.h"
#include "mace/utils/utils.h"
#include "mace/utils/utils.h"
#include <fstream>
#include <fstream>
#include <thread>
#include <thread>
namespace
mace
{
namespace
mace
{
namespace
str_util
{
namespace
str_util
{
...
@@ -29,8 +28,9 @@ std::vector<std::string> Split(const string &str, char delims) {
...
@@ -29,8 +28,9 @@ std::vector<std::string> Split(const string &str, char delims) {
return
result
;
return
result
;
}
}
bool
SplitAndParseToInts
(
const
string
&
str
,
bool
SplitAndParseToInts
(
const
string
&
str
,
char
delims
,
std
::
vector
<
index_t
>*
result
)
{
char
delims
,
std
::
vector
<
index_t
>
*
result
)
{
string
tmp
=
str
;
string
tmp
=
str
;
while
(
!
tmp
.
empty
())
{
while
(
!
tmp
.
empty
())
{
index_t
dim
=
atoi
(
tmp
.
data
());
index_t
dim
=
atoi
(
tmp
.
data
());
...
@@ -44,13 +44,15 @@ bool SplitAndParseToInts(const string &str, char delims, std::vector<index_t>* r
...
@@ -44,13 +44,15 @@ bool SplitAndParseToInts(const string &str, char delims, std::vector<index_t>* r
}
}
}
}
}
// namespace str_util
}
// namespace str_util
namespace
benchmark
{
namespace
benchmark
{
bool
RunInference
(
NetBase
*
net
,
StatSummarizer
*
summarizer
,
int64_t
*
inference_time_us
)
{
bool
RunInference
(
NetBase
*
net
,
StatSummarizer
*
summarizer
,
int64_t
*
inference_time_us
)
{
RunMetadata
run_metadata
;
RunMetadata
run_metadata
;
RunMetadata
*
run_metadata_ptr
=
nullptr
;
RunMetadata
*
run_metadata_ptr
=
nullptr
;
if
(
summarizer
)
{
if
(
summarizer
)
{
run_metadata_ptr
=
&
run_metadata
;
run_metadata_ptr
=
&
run_metadata
;
}
}
...
@@ -71,9 +73,13 @@ bool RunInference(NetBase* net, StatSummarizer* summarizer, int64_t* inference_t
...
@@ -71,9 +73,13 @@ bool RunInference(NetBase* net, StatSummarizer* summarizer, int64_t* inference_t
return
true
;
return
true
;
}
}
bool
Run
(
NetBase
*
net
,
StatSummarizer
*
summarizer
,
bool
Run
(
NetBase
*
net
,
int
num_runs
,
double
max_time_sec
,
int64_t
sleep_sec
,
StatSummarizer
*
summarizer
,
int64_t
*
total_time_us
,
int64_t
*
actual_num_runs
)
{
int
num_runs
,
double
max_time_sec
,
int64_t
sleep_sec
,
int64_t
*
total_time_us
,
int64_t
*
actual_num_runs
)
{
*
total_time_us
=
0
;
*
total_time_us
=
0
;
LOG
(
INFO
)
<<
"Running benchmark for max "
<<
num_runs
<<
" iterators, max "
LOG
(
INFO
)
<<
"Running benchmark for max "
<<
num_runs
<<
" iterators, max "
...
@@ -85,7 +91,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer,
...
@@ -85,7 +91,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer,
Stat
<
int64_t
>
stat
;
Stat
<
int64_t
>
stat
;
bool
util_max_time
=
(
num_runs
<=
0
);
bool
util_max_time
=
(
num_runs
<=
0
);
for
(
int
i
=
0
;
util_max_time
||
i
<
num_runs
;
++
i
)
{
for
(
int
i
=
0
;
util_max_time
||
i
<
num_runs
;
++
i
)
{
int64_t
inference_time_us
=
0
;
int64_t
inference_time_us
=
0
;
bool
s
=
RunInference
(
net
,
summarizer
,
&
inference_time_us
);
bool
s
=
RunInference
(
net
,
summarizer
,
&
inference_time_us
);
stat
.
UpdateStat
(
inference_time_us
);
stat
.
UpdateStat
(
inference_time_us
);
...
@@ -113,7 +119,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer,
...
@@ -113,7 +119,7 @@ bool Run(NetBase* net, StatSummarizer* summarizer,
return
true
;
return
true
;
}
}
int
Main
(
int
argc
,
char
**
argv
)
{
int
Main
(
int
argc
,
char
**
argv
)
{
std
::
string
model_file
=
"/data/local/tmp/mobi_mace.pb"
;
std
::
string
model_file
=
"/data/local/tmp/mobi_mace.pb"
;
std
::
string
device
=
"CPU"
;
std
::
string
device
=
"CPU"
;
std
::
string
input_layer_string
=
"input:0"
;
std
::
string
input_layer_string
=
"input:0"
;
...
@@ -182,8 +188,10 @@ int Main(int argc, char** argv) {
...
@@ -182,8 +188,10 @@ int Main(int argc, char** argv) {
return
-
1
;
return
-
1
;
}
}
std
::
vector
<
std
::
string
>
input_layers
=
str_util
::
Split
(
input_layer_string
,
','
);
std
::
vector
<
std
::
string
>
input_layers
=
std
::
vector
<
std
::
string
>
input_layer_shapes
=
str_util
::
Split
(
input_layer_shape_string
,
':'
);
str_util
::
Split
(
input_layer_string
,
','
);
std
::
vector
<
std
::
string
>
input_layer_shapes
=
str_util
::
Split
(
input_layer_shape_string
,
':'
);
std
::
vector
<
string
>
input_layer_types
=
std
::
vector
<
string
>
input_layer_types
=
str_util
::
Split
(
input_layer_type_string
,
','
);
str_util
::
Split
(
input_layer_type_string
,
','
);
std
::
vector
<
string
>
input_layer_files
=
std
::
vector
<
string
>
input_layer_files
=
...
@@ -260,17 +268,17 @@ int Main(int argc, char** argv) {
...
@@ -260,17 +268,17 @@ int Main(int argc, char** argv) {
ws
.
LoadModelTensor
(
net_def
,
DeviceType
::
CPU
);
ws
.
LoadModelTensor
(
net_def
,
DeviceType
::
CPU
);
// Load inputs
// Load inputs
for
(
size_t
i
=
0
;
i
<
inputs_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs_count
;
++
i
)
{
Tensor
*
input_tensor
=
ws
.
CreateTensor
(
input_layers
[
i
],
Tensor
*
input_tensor
=
cpu_allocator
(),
DT_FLOAT
);
ws
.
CreateTensor
(
input_layers
[
i
],
cpu_allocator
(),
DT_FLOAT
);
vector
<
index_t
>
shapes
;
vector
<
index_t
>
shapes
;
str_util
::
SplitAndParseToInts
(
input_layer_shapes
[
i
],
','
,
&
shapes
);
str_util
::
SplitAndParseToInts
(
input_layer_shapes
[
i
],
','
,
&
shapes
);
input_tensor
->
Resize
(
shapes
);
input_tensor
->
Resize
(
shapes
);
float
*
input_data
=
input_tensor
->
mutable_data
<
float
>
();
float
*
input_data
=
input_tensor
->
mutable_data
<
float
>
();
// load input
// load input
if
(
i
<
input_layer_files
.
size
())
{
if
(
i
<
input_layer_files
.
size
())
{
std
::
ifstream
in_file
(
input_layer_files
[
i
],
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
in_file
(
input_layer_files
[
i
],
std
::
ios
::
in
|
std
::
ios
::
binary
);
in_file
.
read
(
reinterpret_cast
<
char
*>
(
input_data
),
in_file
.
read
(
reinterpret_cast
<
char
*>
(
input_data
),
input_tensor
->
size
()
*
sizeof
(
float
));
input_tensor
->
size
()
*
sizeof
(
float
));
in_file
.
close
();
in_file
.
close
();
...
@@ -285,31 +293,31 @@ int Main(int argc, char** argv) {
...
@@ -285,31 +293,31 @@ int Main(int argc, char** argv) {
int64_t
warmup_time_us
=
0
;
int64_t
warmup_time_us
=
0
;
int64_t
num_warmup_runs
=
0
;
int64_t
num_warmup_runs
=
0
;
if
(
warmup_runs
>
0
)
{
if
(
warmup_runs
>
0
)
{
bool
status
=
Run
(
net
.
get
(),
nullptr
,
bool
status
=
warmup_runs
,
-
1.0
,
inter_inference_sleep_seconds
,
Run
(
net
.
get
(),
nullptr
,
warmup_runs
,
-
1.0
,
&
warmup_time_us
,
&
num_warmup_runs
);
inter_inference_sleep_seconds
,
&
warmup_time_us
,
&
num_warmup_runs
);
if
(
!
status
)
{
if
(
!
status
)
{
LOG
(
ERROR
)
<<
"Failed at warm up run"
;
LOG
(
ERROR
)
<<
"Failed at warm up run"
;
}
}
}
}
if
(
inter_benchmark_sleep_seconds
>
0
)
{
if
(
inter_benchmark_sleep_seconds
>
0
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
inter_benchmark_sleep_seconds
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
inter_benchmark_sleep_seconds
));
}
}
int64_t
no_stat_time_us
=
0
;
int64_t
no_stat_time_us
=
0
;
int64_t
no_stat_runs
=
0
;
int64_t
no_stat_runs
=
0
;
bool
status
=
Run
(
net
.
get
(),
nullptr
,
bool
status
=
max_num_runs
,
max_benchmark_time_seconds
,
inter_inference_sleep
_seconds
,
Run
(
net
.
get
(),
nullptr
,
max_num_runs
,
max_benchmark_time
_seconds
,
&
no_stat_time_us
,
&
no_stat_runs
);
inter_inference_sleep_seconds
,
&
no_stat_time_us
,
&
no_stat_runs
);
if
(
!
status
)
{
if
(
!
status
)
{
LOG
(
ERROR
)
<<
"Failed at normal no-stat run"
;
LOG
(
ERROR
)
<<
"Failed at normal no-stat run"
;
}
}
int64_t
stat_time_us
=
0
;
int64_t
stat_time_us
=
0
;
int64_t
stat_runs
=
0
;
int64_t
stat_runs
=
0
;
status
=
Run
(
net
.
get
(),
stats
.
get
(),
status
=
Run
(
net
.
get
(),
stats
.
get
(),
max_num_runs
,
max_benchmark_time_seconds
,
max_num_runs
,
max_benchmark_time_seconds
,
inter_inference_sleep_seconds
,
inter_inference_sleep_seconds
,
&
stat_time_us
,
&
stat_runs
);
&
stat_time_us
,
&
stat_runs
);
if
(
!
status
)
{
if
(
!
status
)
{
LOG
(
ERROR
)
<<
"Failed at normal stat run"
;
LOG
(
ERROR
)
<<
"Failed at normal stat run"
;
}
}
...
@@ -325,9 +333,7 @@ int Main(int argc, char** argv) {
...
@@ -325,9 +333,7 @@ int Main(int argc, char** argv) {
return
0
;
return
0
;
}
}
}
// namespace benchmark
}
// namespace benchmark
}
// namespace mace
}
// namespace mace
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
mace
::
benchmark
::
Main
(
argc
,
argv
);
}
mace
::
benchmark
::
Main
(
argc
,
argv
);
}
mace/tools/benchmark/stat_summarizer.cc
浏览文件 @
ebff986d
...
@@ -2,17 +2,16 @@
...
@@ -2,17 +2,16 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/core/common.h"
#include "mace/tools/benchmark/stat_summarizer.h"
#include "mace/tools/benchmark/stat_summarizer.h"
#include "mace/core/common.h"
#include "mace/proto/stats.pb.h"
#include "mace/proto/stats.pb.h"
#include <iomanip>
#include <iomanip>
#include <queue>
#include <queue>
namespace
mace
{
namespace
mace
{
StatSummarizer
::
StatSummarizer
(
const
StatSummarizerOptions
&
options
)
StatSummarizer
::
StatSummarizer
(
const
StatSummarizerOptions
&
options
)
:
options_
(
options
)
{}
:
options_
(
options
)
{}
StatSummarizer
::~
StatSummarizer
()
{}
StatSummarizer
::~
StatSummarizer
()
{}
...
@@ -23,17 +22,14 @@ void StatSummarizer::Reset() {
...
@@ -23,17 +22,14 @@ void StatSummarizer::Reset() {
details_
.
clear
();
details_
.
clear
();
}
}
void
StatSummarizer
::
ProcessMetadata
(
const
RunMetadata
&
run_metadata
)
{
void
StatSummarizer
::
ProcessMetadata
(
const
RunMetadata
&
run_metadata
)
{
int64_t
curr_total_us
=
0
;
int64_t
curr_total_us
=
0
;
int64_t
mem_total
=
0
;
int64_t
mem_total
=
0
;
int64_t
first_node_start_us
=
int64_t
first_node_start_us
=
run_metadata
.
op_stats
(
0
).
all_start_micros
();
run_metadata
.
op_stats
(
0
).
all_start_micros
();
int
node_num
=
0
;
int
node_num
=
0
;
for
(
const
auto
&
ops
:
run_metadata
.
op_stats
())
{
for
(
const
auto
&
ops
:
run_metadata
.
op_stats
())
{
std
::
string
name
=
ops
.
operator_name
();
std
::
string
name
=
ops
.
operator_name
();
std
::
string
op_type
=
ops
.
type
();
std
::
string
op_type
=
ops
.
type
();
...
@@ -41,7 +37,7 @@ void StatSummarizer::ProcessMetadata(const RunMetadata &run_metadata) {
...
@@ -41,7 +37,7 @@ void StatSummarizer::ProcessMetadata(const RunMetadata &run_metadata) {
const
int64_t
curr_time
=
ops
.
all_end_rel_micros
();
const
int64_t
curr_time
=
ops
.
all_end_rel_micros
();
curr_total_us
+=
curr_time
;
curr_total_us
+=
curr_time
;
auto
result
=
details_
.
emplace
(
name
,
Detail
());
auto
result
=
details_
.
emplace
(
name
,
Detail
());
Detail
*
detail
=
&
(
result
.
first
->
second
);
Detail
*
detail
=
&
(
result
.
first
->
second
);
detail
->
start_us
.
UpdateStat
(
ops
.
all_start_micros
()
-
first_node_start_us
);
detail
->
start_us
.
UpdateStat
(
ops
.
all_start_micros
()
-
first_node_start_us
);
detail
->
rel_end_us
.
UpdateStat
(
curr_time
);
detail
->
rel_end_us
.
UpdateStat
(
curr_time
);
...
@@ -77,13 +73,13 @@ std::string StatSummarizer::ShortSummary() const {
...
@@ -77,13 +73,13 @@ std::string StatSummarizer::ShortSummary() const {
return
stream
.
str
();
return
stream
.
str
();
}
}
std
::
ostream
&
InitField
(
std
::
ostream
&
stream
,
int
width
)
{
std
::
ostream
&
InitField
(
std
::
ostream
&
stream
,
int
width
)
{
stream
<<
"
\t
"
<<
std
::
right
<<
std
::
setw
(
width
)
<<
std
::
fixed
stream
<<
"
\t
"
<<
std
::
right
<<
std
::
setw
(
width
)
<<
std
::
fixed
<<
std
::
setprecision
(
3
);
<<
std
::
setprecision
(
3
);
return
stream
;
return
stream
;
}
}
std
::
string
StatSummarizer
::
HeaderString
(
const
std
::
string
&
title
)
const
{
std
::
string
StatSummarizer
::
HeaderString
(
const
std
::
string
&
title
)
const
{
std
::
stringstream
stream
;
std
::
stringstream
stream
;
stream
<<
"============================== "
<<
title
stream
<<
"============================== "
<<
title
...
@@ -102,9 +98,9 @@ std::string StatSummarizer::HeaderString(const std::string& title) const {
...
@@ -102,9 +98,9 @@ std::string StatSummarizer::HeaderString(const std::string& title) const {
return
stream
.
str
();
return
stream
.
str
();
}
}
std
::
string
StatSummarizer
::
ColumnString
(
const
StatSummarizer
::
Detail
&
detail
,
std
::
string
StatSummarizer
::
ColumnString
(
const
StatSummarizer
::
Detail
&
detail
,
const
int64_t
cumulative_stat_on_node
,
const
int64_t
cumulative_stat_on_node
,
const
Stat
<
int64_t
>
&
stat
)
const
{
const
Stat
<
int64_t
>
&
stat
)
const
{
const
double
start_ms
=
detail
.
start_us
.
avg
()
/
1000.0
;
const
double
start_ms
=
detail
.
start_us
.
avg
()
/
1000.0
;
const
double
first_time_ms
=
detail
.
rel_end_us
.
first
()
/
1000.0
;
const
double
first_time_ms
=
detail
.
rel_end_us
.
first
()
/
1000.0
;
const
double
avg_time_ms
=
detail
.
rel_end_us
.
avg
()
/
1000.0
;
const
double
avg_time_ms
=
detail
.
rel_end_us
.
avg
()
/
1000.0
;
...
@@ -127,12 +123,12 @@ std::string StatSummarizer::ColumnString(const StatSummarizer::Detail& detail,
...
@@ -127,12 +123,12 @@ std::string StatSummarizer::ColumnString(const StatSummarizer::Detail& detail,
}
}
void
StatSummarizer
::
OrderNodesByMetric
(
void
StatSummarizer
::
OrderNodesByMetric
(
SortingMetric
metric
,
std
::
vector
<
const
Detail
*>*
details
)
const
{
SortingMetric
metric
,
std
::
vector
<
const
Detail
*>
*
details
)
const
{
std
::
priority_queue
<
std
::
pair
<
std
::
string
,
const
Detail
*>>
sorted_list
;
std
::
priority_queue
<
std
::
pair
<
std
::
string
,
const
Detail
*>>
sorted_list
;
const
int
num_nodes
=
details_
.
size
();
const
int
num_nodes
=
details_
.
size
();
for
(
const
auto
&
det
:
details_
)
{
for
(
const
auto
&
det
:
details_
)
{
const
Detail
*
detail
=
&
(
det
.
second
);
const
Detail
*
detail
=
&
(
det
.
second
);
std
::
stringstream
stream
;
std
::
stringstream
stream
;
stream
<<
std
::
setw
(
20
)
<<
std
::
right
<<
std
::
setprecision
(
10
)
stream
<<
std
::
setw
(
20
)
<<
std
::
right
<<
std
::
setprecision
(
10
)
<<
std
::
fixed
;
<<
std
::
fixed
;
...
@@ -169,16 +165,16 @@ void StatSummarizer::OrderNodesByMetric(
...
@@ -169,16 +165,16 @@ void StatSummarizer::OrderNodesByMetric(
}
}
void
StatSummarizer
::
ComputeStatsByType
(
void
StatSummarizer
::
ComputeStatsByType
(
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_count
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_count
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_time
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_time
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_memory
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_memory
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_times_called
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_times_called
,
int64_t
*
accumulated_us
)
const
{
int64_t
*
accumulated_us
)
const
{
int64_t
run_count
=
run_total_us_
.
count
();
int64_t
run_count
=
run_total_us_
.
count
();
for
(
const
auto
&
det
:
details_
)
{
for
(
const
auto
&
det
:
details_
)
{
const
std
::
string
node_name
=
det
.
first
;
const
std
::
string
node_name
=
det
.
first
;
const
Detail
&
detail
=
det
.
second
;
const
Detail
&
detail
=
det
.
second
;
int64_t
curr_time_val
=
int64_t
curr_time_val
=
static_cast
<
int64_t
>
(
detail
.
rel_end_us
.
sum
()
/
run_count
);
static_cast
<
int64_t
>
(
detail
.
rel_end_us
.
sum
()
/
run_count
);
...
@@ -186,7 +182,7 @@ void StatSummarizer::ComputeStatsByType(
...
@@ -186,7 +182,7 @@ void StatSummarizer::ComputeStatsByType(
int64_t
curr_memory_val
=
detail
.
mem_used
.
newest
();
int64_t
curr_memory_val
=
detail
.
mem_used
.
newest
();
const
std
::
string
&
node_type
=
detail
.
type
;
const
std
::
string
&
node_type
=
detail
.
type
;
(
*
node_type_map_count
)[
node_type
]
+=
1
;
(
*
node_type_map_count
)[
node_type
]
+=
1
;
(
*
node_type_map_time
)[
node_type
]
+=
curr_time_val
;
(
*
node_type_map_time
)[
node_type
]
+=
curr_time_val
;
...
@@ -215,8 +211,9 @@ std::string StatSummarizer::GetStatsByNodeType() const {
...
@@ -215,8 +211,9 @@ std::string StatSummarizer::GetStatsByNodeType() const {
&
accumulated_us
);
&
accumulated_us
);
// Sort them.
// Sort them.
std
::
priority_queue
<
std
::
pair
<
int64_t
,
std
::
pair
<
std
::
string
,
int64_t
>>>
timings
;
std
::
priority_queue
<
std
::
pair
<
int64_t
,
std
::
pair
<
std
::
string
,
int64_t
>>>
for
(
const
auto
&
node_type
:
node_type_map_time
)
{
timings
;
for
(
const
auto
&
node_type
:
node_type_map_time
)
{
const
int64_t
mem_used
=
node_type_map_memory
[
node_type
.
first
];
const
int64_t
mem_used
=
node_type_map_memory
[
node_type
.
first
];
timings
.
emplace
(
node_type
.
second
,
timings
.
emplace
(
node_type
.
second
,
std
::
pair
<
std
::
string
,
int64_t
>
(
node_type
.
first
,
mem_used
));
std
::
pair
<
std
::
string
,
int64_t
>
(
node_type
.
first
,
mem_used
));
...
@@ -259,10 +256,10 @@ std::string StatSummarizer::GetStatsByNodeType() const {
...
@@ -259,10 +256,10 @@ std::string StatSummarizer::GetStatsByNodeType() const {
return
stream
.
str
();
return
stream
.
str
();
}
}
std
::
string
StatSummarizer
::
GetStatsByMetric
(
const
std
::
string
&
title
,
std
::
string
StatSummarizer
::
GetStatsByMetric
(
const
std
::
string
&
title
,
SortingMetric
sorting_metric
,
SortingMetric
sorting_metric
,
int
num_stats
)
const
{
int
num_stats
)
const
{
std
::
vector
<
const
Detail
*>
details
;
std
::
vector
<
const
Detail
*>
details
;
OrderNodesByMetric
(
sorting_metric
,
&
details
);
OrderNodesByMetric
(
sorting_metric
,
&
details
);
double
cumulative_stat_on_node
=
0
;
double
cumulative_stat_on_node
=
0
;
...
...
mace/tools/benchmark/stat_summarizer.h
浏览文件 @
ebff986d
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
#include <sstream>
#include <sstream>
#include <string>
#include <string>
namespace
mace
{
namespace
mace
{
class
RunMetadata
;
class
RunMetadata
;
...
@@ -62,7 +61,7 @@ class Stat {
...
@@ -62,7 +61,7 @@ class Stat {
return
all_same
()
?
0
:
std
::
sqrt
(
squared_sum_
/
count_
-
avg
()
*
avg
());
return
all_same
()
?
0
:
std
::
sqrt
(
squared_sum_
/
count_
-
avg
()
*
avg
());
}
}
void
OutputToStream
(
std
::
ostream
*
stream
)
const
{
void
OutputToStream
(
std
::
ostream
*
stream
)
const
{
if
(
empty
())
{
if
(
empty
())
{
*
stream
<<
"count=0"
;
*
stream
<<
"count=0"
;
}
else
if
(
all_same
())
{
}
else
if
(
all_same
())
{
...
@@ -75,8 +74,8 @@ class Stat {
...
@@ -75,8 +74,8 @@ class Stat {
}
}
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
stream
,
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
stream
,
const
Stat
<
ValueType
>
&
stat
)
{
const
Stat
<
ValueType
>
&
stat
)
{
stat
.
OutputToStream
(
&
stream
);
stat
.
OutputToStream
(
&
stream
);
return
stream
;
return
stream
;
}
}
...
@@ -131,12 +130,12 @@ class StatSummarizer {
...
@@ -131,12 +130,12 @@ class StatSummarizer {
BY_TYPE
,
BY_TYPE
,
};
};
explicit
StatSummarizer
(
const
StatSummarizerOptions
&
options
);
explicit
StatSummarizer
(
const
StatSummarizerOptions
&
options
);
~
StatSummarizer
();
~
StatSummarizer
();
// Adds another run's StepStats output to the aggregate counts.
// Adds another run's StepStats output to the aggregate counts.
void
ProcessMetadata
(
const
RunMetadata
&
run_metadata
);
void
ProcessMetadata
(
const
RunMetadata
&
run_metadata
);
// Returns a string detailing the accumulated runtime stats in a tab-separated
// Returns a string detailing the accumulated runtime stats in a tab-separated
// format which can be pasted into a spreadsheet for further analysis.
// format which can be pasted into a spreadsheet for further analysis.
...
@@ -147,15 +146,16 @@ class StatSummarizer {
...
@@ -147,15 +146,16 @@ class StatSummarizer {
// Prints the string returned by GetOutputString().
// Prints the string returned by GetOutputString().
void
PrintOperatorStats
()
const
;
void
PrintOperatorStats
()
const
;
void
ComputeStatsByType
(
std
::
map
<
std
::
string
,
int64_t
>*
node_type_map_count
,
void
ComputeStatsByType
(
std
::
map
<
std
::
string
,
int64_t
>*
node_type_map_time
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_count
,
std
::
map
<
std
::
string
,
int64_t
>*
node_type_map_memory
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_time
,
std
::
map
<
std
::
string
,
int64_t
>*
node_type_map_times_called
,
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_memory
,
int64_t
*
accumulated_us
)
const
;
std
::
map
<
std
::
string
,
int64_t
>
*
node_type_map_times_called
,
int64_t
*
accumulated_us
)
const
;
std
::
string
GetStatsByNodeType
()
const
;
std
::
string
GetStatsByNodeType
()
const
;
std
::
string
GetStatsByMetric
(
const
std
::
string
&
title
,
std
::
string
GetStatsByMetric
(
const
std
::
string
&
title
,
SortingMetric
sorting_metric
,
SortingMetric
sorting_metric
,
int
num_stats
)
const
;
int
num_stats
)
const
;
...
@@ -165,7 +165,7 @@ class StatSummarizer {
...
@@ -165,7 +165,7 @@ class StatSummarizer {
int
num_runs
()
const
{
return
run_total_us_
.
count
();
}
int
num_runs
()
const
{
return
run_total_us_
.
count
();
}
// Returns stats of total microseconds spent by all nodes in each run.
// Returns stats of total microseconds spent by all nodes in each run.
const
Stat
<
int64_t
>
&
run_total_us
()
const
{
return
run_total_us_
;
}
const
Stat
<
int64_t
>
&
run_total_us
()
const
{
return
run_total_us_
;
}
private:
private:
struct
Detail
{
struct
Detail
{
...
@@ -179,12 +179,12 @@ class StatSummarizer {
...
@@ -179,12 +179,12 @@ class StatSummarizer {
};
};
void
OrderNodesByMetric
(
SortingMetric
sorting_metric
,
void
OrderNodesByMetric
(
SortingMetric
sorting_metric
,
std
::
vector
<
const
Detail
*>*
details
)
const
;
std
::
vector
<
const
Detail
*>
*
details
)
const
;
std
::
string
HeaderString
(
const
std
::
string
&
title
)
const
;
std
::
string
HeaderString
(
const
std
::
string
&
title
)
const
;
std
::
string
ColumnString
(
const
Detail
&
detail
,
std
::
string
ColumnString
(
const
Detail
&
detail
,
const
int64_t
cumulative_stat_on_node
,
const
int64_t
cumulative_stat_on_node
,
const
Stat
<
int64_t
>
&
stat
)
const
;
const
Stat
<
int64_t
>
&
stat
)
const
;
Stat
<
int64_t
>
run_total_us_
;
Stat
<
int64_t
>
run_total_us_
;
Stat
<
int64_t
>
memory_
;
Stat
<
int64_t
>
memory_
;
...
@@ -193,6 +193,6 @@ class StatSummarizer {
...
@@ -193,6 +193,6 @@ class StatSummarizer {
StatSummarizerOptions
options_
;
StatSummarizerOptions
options_
;
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_TOOLS_BENCHMARK_STAT_SUMMARIZER_H_
#endif // MACE_TOOLS_BENCHMARK_STAT_SUMMARIZER_H_
mace/utils/command_line_flags.cc
浏览文件 @
ebff986d
...
@@ -10,19 +10,21 @@ namespace mace {
...
@@ -10,19 +10,21 @@ namespace mace {
namespace
{
namespace
{
bool
StringConsume
(
string
&
arg
,
const
string
&
x
)
{
bool
StringConsume
(
string
&
arg
,
const
string
&
x
)
{
if
((
arg
.
size
()
>=
x
.
size
())
if
((
arg
.
size
()
>=
x
.
size
())
&&
&&
(
memcmp
(
arg
.
data
(),
x
.
data
(),
x
.
size
())
==
0
))
{
(
memcmp
(
arg
.
data
(),
x
.
data
(),
x
.
size
())
==
0
))
{
arg
=
arg
.
substr
(
x
.
size
());
arg
=
arg
.
substr
(
x
.
size
());
return
true
;
return
true
;
}
}
return
false
;
return
false
;
}
}
bool
ParseStringFlag
(
string
arg
,
string
flag
,
bool
ParseStringFlag
(
string
arg
,
string
*
dst
,
bool
*
value_parsing_ok
)
{
string
flag
,
string
*
dst
,
bool
*
value_parsing_ok
)
{
*
value_parsing_ok
=
true
;
*
value_parsing_ok
=
true
;
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
&&
&&
StringConsume
(
arg
,
"="
))
{
StringConsume
(
arg
,
"="
))
{
*
dst
=
arg
;
*
dst
=
arg
;
return
true
;
return
true
;
}
}
...
@@ -30,11 +32,13 @@ bool ParseStringFlag(string arg, string flag,
...
@@ -30,11 +32,13 @@ bool ParseStringFlag(string arg, string flag,
return
false
;
return
false
;
}
}
bool
ParseInt32Flag
(
string
arg
,
string
flag
,
bool
ParseInt32Flag
(
string
arg
,
int32_t
*
dst
,
bool
*
value_parsing_ok
)
{
string
flag
,
int32_t
*
dst
,
bool
*
value_parsing_ok
)
{
*
value_parsing_ok
=
true
;
*
value_parsing_ok
=
true
;
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
&&
&&
StringConsume
(
arg
,
"="
))
{
StringConsume
(
arg
,
"="
))
{
char
extra
;
char
extra
;
if
(
sscanf
(
arg
.
data
(),
"%d%c"
,
dst
,
&
extra
)
!=
1
)
{
if
(
sscanf
(
arg
.
data
(),
"%d%c"
,
dst
,
&
extra
)
!=
1
)
{
LOG
(
ERROR
)
<<
"Couldn't interpret value "
<<
arg
<<
" for flag "
<<
flag
LOG
(
ERROR
)
<<
"Couldn't interpret value "
<<
arg
<<
" for flag "
<<
flag
...
@@ -47,11 +51,13 @@ bool ParseInt32Flag(string arg, string flag,
...
@@ -47,11 +51,13 @@ bool ParseInt32Flag(string arg, string flag,
return
false
;
return
false
;
}
}
bool
ParseInt64Flag
(
string
arg
,
string
flag
,
bool
ParseInt64Flag
(
string
arg
,
long
long
*
dst
,
bool
*
value_parsing_ok
)
{
string
flag
,
long
long
*
dst
,
bool
*
value_parsing_ok
)
{
*
value_parsing_ok
=
true
;
*
value_parsing_ok
=
true
;
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
&&
&&
StringConsume
(
arg
,
"="
))
{
StringConsume
(
arg
,
"="
))
{
char
extra
;
char
extra
;
if
(
sscanf
(
arg
.
data
(),
"%lld%c"
,
dst
,
&
extra
)
!=
1
)
{
if
(
sscanf
(
arg
.
data
(),
"%lld%c"
,
dst
,
&
extra
)
!=
1
)
{
LOG
(
ERROR
)
<<
"Couldn't interpret value "
<<
arg
<<
" for flag "
<<
flag
LOG
(
ERROR
)
<<
"Couldn't interpret value "
<<
arg
<<
" for flag "
<<
flag
...
@@ -64,8 +70,7 @@ bool ParseInt64Flag(string arg, string flag,
...
@@ -64,8 +70,7 @@ bool ParseInt64Flag(string arg, string flag,
return
false
;
return
false
;
}
}
bool
ParseBoolFlag
(
string
arg
,
string
flag
,
bool
ParseBoolFlag
(
string
arg
,
string
flag
,
bool
*
dst
,
bool
*
value_parsing_ok
)
{
bool
*
dst
,
bool
*
value_parsing_ok
)
{
*
value_parsing_ok
=
true
;
*
value_parsing_ok
=
true
;
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
))
{
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
))
{
if
(
arg
.
empty
())
{
if
(
arg
.
empty
())
{
...
@@ -90,11 +95,13 @@ bool ParseBoolFlag(string arg, string flag,
...
@@ -90,11 +95,13 @@ bool ParseBoolFlag(string arg, string flag,
return
false
;
return
false
;
}
}
bool
ParseFloatFlag
(
string
arg
,
string
flag
,
bool
ParseFloatFlag
(
string
arg
,
float
*
dst
,
bool
*
value_parsing_ok
)
{
string
flag
,
float
*
dst
,
bool
*
value_parsing_ok
)
{
*
value_parsing_ok
=
true
;
*
value_parsing_ok
=
true
;
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
if
(
StringConsume
(
arg
,
"--"
)
&&
StringConsume
(
arg
,
flag
)
&&
&&
StringConsume
(
arg
,
"="
))
{
StringConsume
(
arg
,
"="
))
{
char
extra
;
char
extra
;
if
(
sscanf
(
arg
.
data
(),
"%f%c"
,
dst
,
&
extra
)
!=
1
)
{
if
(
sscanf
(
arg
.
data
(),
"%f%c"
,
dst
,
&
extra
)
!=
1
)
{
LOG
(
ERROR
)
<<
"Couldn't interpret value "
<<
arg
<<
" for flag "
<<
flag
LOG
(
ERROR
)
<<
"Couldn't interpret value "
<<
arg
<<
" for flag "
<<
flag
...
@@ -152,7 +159,8 @@ bool Flag::Parse(string arg, bool *value_parsing_ok) const {
...
@@ -152,7 +159,8 @@ bool Flag::Parse(string arg, bool *value_parsing_ok) const {
return
result
;
return
result
;
}
}
/*static*/
bool
Flags
::
Parse
(
int
*
argc
,
char
**
argv
,
/*static*/
bool
Flags
::
Parse
(
int
*
argc
,
char
**
argv
,
const
std
::
vector
<
Flag
>
&
flag_list
)
{
const
std
::
vector
<
Flag
>
&
flag_list
)
{
bool
result
=
true
;
bool
result
=
true
;
std
::
vector
<
char
*>
unknown_flags
;
std
::
vector
<
char
*>
unknown_flags
;
...
...
mace/utils/command_line_flags.h
浏览文件 @
ebff986d
...
@@ -39,16 +39,14 @@ class Flags {
...
@@ -39,16 +39,14 @@ class Flags {
// with matching flags, and remove the matching arguments from (*argc, argv).
// with matching flags, and remove the matching arguments from (*argc, argv).
// Return true iff all recognized flag values were parsed correctly, and the
// Return true iff all recognized flag values were parsed correctly, and the
// first remaining argument is not "--help".
// first remaining argument is not "--help".
static
bool
Parse
(
int
*
argc
,
static
bool
Parse
(
int
*
argc
,
char
**
argv
,
const
std
::
vector
<
Flag
>
&
flag_list
);
char
**
argv
,
const
std
::
vector
<
Flag
>
&
flag_list
);
// Return a usage message with command line cmdline, and the
// Return a usage message with command line cmdline, and the
// usage_text strings in flag_list[].
// usage_text strings in flag_list[].
static
string
Usage
(
const
string
&
cmdline
,
static
string
Usage
(
const
string
&
cmdline
,
const
std
::
vector
<
Flag
>
&
flag_list
);
const
std
::
vector
<
Flag
>
&
flag_list
);
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_CORE_COMMAND_LINE_FLAGS_H
#endif
// MACE_CORE_COMMAND_LINE_FLAGS_H
mace/utils/utils.h
浏览文件 @
ebff986d
...
@@ -24,5 +24,5 @@ inline int64_t NowInMicroSec() {
...
@@ -24,5 +24,5 @@ inline int64_t NowInMicroSec() {
return
static_cast
<
int64_t
>
(
tv
.
tv_sec
*
1000000
+
tv
.
tv_usec
);
return
static_cast
<
int64_t
>
(
tv
.
tv_sec
*
1000000
+
tv
.
tv_usec
);
}
}
}
// namespace mace
}
// namespace mace
#endif // MACE_UTILS_UTILS_H_
#endif
// MACE_UTILS_UTILS_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录