Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
578b382a
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
578b382a
编写于
9月 15, 2017
作者:
吴
吴承辉
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'style' into 'master'
Fix Google Style See merge request !43
上级
9c9af68e
8ae8f575
变更
66
隐藏空白更改
内联
并排
Showing
66 changed file
with
1096 addition
and
1361 deletion
+1096
-1361
mace/core/allocator.cc
mace/core/allocator.cc
+3
-7
mace/core/allocator.h
mace/core/allocator.h
+4
-6
mace/core/common.h
mace/core/common.h
+7
-7
mace/core/logging.cc
mace/core/logging.cc
+2
-3
mace/core/logging.h
mace/core/logging.h
+18
-21
mace/core/macros.h
mace/core/macros.h
+1
-2
mace/core/net.cc
mace/core/net.cc
+10
-17
mace/core/net.h
mace/core/net.h
+11
-18
mace/core/operator.cc
mace/core/operator.cc
+9
-20
mace/core/operator.h
mace/core/operator.h
+33
-50
mace/core/proto_utils.cc
mace/core/proto_utils.cc
+69
-92
mace/core/proto_utils.h
mace/core/proto_utils.h
+35
-58
mace/core/registry.h
mace/core/registry.h
+21
-22
mace/core/serializer.cc
mace/core/serializer.cc
+13
-16
mace/core/serializer.h
mace/core/serializer.h
+4
-4
mace/core/tensor.h
mace/core/tensor.h
+25
-30
mace/core/testing/test_benchmark.cc
mace/core/testing/test_benchmark.cc
+3
-6
mace/core/testing/test_benchmark.h
mace/core/testing/test_benchmark.h
+3
-3
mace/core/testing/test_benchmark_main.cc
mace/core/testing/test_benchmark_main.cc
+0
-1
mace/core/types.h
mace/core/types.h
+15
-16
mace/core/workspace.cc
mace/core/workspace.cc
+9
-8
mace/core/workspace.h
mace/core/workspace.h
+3
-5
mace/examples/benchmark_example.cc
mace/examples/benchmark_example.cc
+2
-3
mace/kernels/addn.h
mace/kernels/addn.h
+7
-9
mace/kernels/batch_norm.h
mace/kernels/batch_norm.h
+14
-25
mace/kernels/conv_2d.h
mace/kernels/conv_2d.h
+84
-95
mace/kernels/conv_pool_2d_util.cc
mace/kernels/conv_pool_2d_util.cc
+17
-21
mace/kernels/conv_pool_2d_util.h
mace/kernels/conv_pool_2d_util.h
+11
-15
mace/kernels/neon/addn_neon.cc
mace/kernels/neon/addn_neon.cc
+6
-7
mace/kernels/neon/batch_norm_neon.cc
mace/kernels/neon/batch_norm_neon.cc
+17
-21
mace/kernels/neon/conv_2d_neon.cc
mace/kernels/neon/conv_2d_neon.cc
+27
-55
mace/kernels/neon/conv_2d_neon_1x1.cc
mace/kernels/neon/conv_2d_neon_1x1.cc
+30
-31
mace/kernels/neon/conv_2d_neon_3x3.cc
mace/kernels/neon/conv_2d_neon_3x3.cc
+73
-66
mace/kernels/neon/conv_2d_neon_5x5.cc
mace/kernels/neon/conv_2d_neon_5x5.cc
+16
-16
mace/kernels/neon/max_pooling_neon_2x2.cc
mace/kernels/neon/max_pooling_neon_2x2.cc
+7
-12
mace/kernels/neon/max_pooling_neon_3x3.cc
mace/kernels/neon/max_pooling_neon_3x3.cc
+9
-14
mace/kernels/neon/pooling_neon.cc
mace/kernels/neon/pooling_neon.cc
+20
-34
mace/kernels/neon/relu_neon.cc
mace/kernels/neon/relu_neon.cc
+6
-7
mace/kernels/pooling.h
mace/kernels/pooling.h
+28
-36
mace/kernels/relu.h
mace/kernels/relu.h
+4
-4
mace/kernels/resize_bilinear.h
mace/kernels/resize_bilinear.h
+28
-31
mace/ops/addn.cc
mace/ops/addn.cc
+2
-2
mace/ops/addn.h
mace/ops/addn.h
+4
-4
mace/ops/addn_benchmark.cc
mace/ops/addn_benchmark.cc
+12
-15
mace/ops/addn_test.cc
mace/ops/addn_test.cc
+1
-1
mace/ops/batch_norm.cc
mace/ops/batch_norm.cc
+2
-2
mace/ops/batch_norm.h
mace/ops/batch_norm.h
+47
-42
mace/ops/batch_norm_benchmark.cc
mace/ops/batch_norm_benchmark.cc
+21
-21
mace/ops/batch_norm_test.cc
mace/ops/batch_norm_test.cc
+17
-18
mace/ops/conv_2d.cc
mace/ops/conv_2d.cc
+2
-2
mace/ops/conv_2d.h
mace/ops/conv_2d.h
+12
-17
mace/ops/conv_2d_benchmark.cc
mace/ops/conv_2d_benchmark.cc
+26
-22
mace/ops/conv_2d_test.cc
mace/ops/conv_2d_test.cc
+74
-102
mace/ops/conv_pool_2d_base.h
mace/ops/conv_pool_2d_base.h
+8
-9
mace/ops/ops_test_util.h
mace/ops/ops_test_util.h
+45
-40
mace/ops/pooling.cc
mace/ops/pooling.cc
+2
-3
mace/ops/pooling.h
mace/ops/pooling.h
+19
-26
mace/ops/pooling_benchmark.cc
mace/ops/pooling_benchmark.cc
+19
-17
mace/ops/pooling_test.cc
mace/ops/pooling_test.cc
+46
-71
mace/ops/relu.cc
mace/ops/relu.cc
+2
-2
mace/ops/relu.h
mace/ops/relu.h
+4
-4
mace/ops/relu_benchmark.cc
mace/ops/relu_benchmark.cc
+11
-13
mace/ops/relu_test.cc
mace/ops/relu_test.cc
+1
-1
mace/ops/resize_bilinear.cc
mace/ops/resize_bilinear.cc
+4
-3
mace/ops/resize_bilinear.h
mace/ops/resize_bilinear.h
+10
-9
mace/ops/resize_bilinear_test.cc
mace/ops/resize_bilinear_test.cc
+1
-1
未找到文件。
mace/core/allocator.cc
浏览文件 @
578b382a
...
@@ -7,13 +7,9 @@
...
@@ -7,13 +7,9 @@
namespace
mace
{
namespace
mace
{
static
std
::
unique_ptr
<
CPUAllocator
>
g_cpu_allocator
(
new
CPUAllocator
());
static
std
::
unique_ptr
<
CPUAllocator
>
g_cpu_allocator
(
new
CPUAllocator
());
CPUAllocator
*
cpu_allocator
()
{
CPUAllocator
*
cpu_allocator
()
{
return
g_cpu_allocator
.
get
();
}
return
g_cpu_allocator
.
get
();
}
void
SetCPUAllocator
(
CPUAllocator
*
alloc
)
{
void
SetCPUAllocator
(
CPUAllocator
*
alloc
)
{
g_cpu_allocator
.
reset
(
alloc
);
}
g_cpu_allocator
.
reset
(
alloc
);
}
Allocator
*
GetDeviceAllocator
(
DeviceType
type
)
{
Allocator
*
GetDeviceAllocator
(
DeviceType
type
)
{
switch
(
type
)
{
switch
(
type
)
{
...
@@ -26,4 +22,4 @@ Allocator* GetDeviceAllocator(DeviceType type) {
...
@@ -26,4 +22,4 @@ Allocator* GetDeviceAllocator(DeviceType type) {
return
nullptr
;
return
nullptr
;
}
}
}
// namespace mace
}
// namespace mace
mace/core/allocator.h
浏览文件 @
578b382a
...
@@ -39,7 +39,7 @@ class Allocator {
...
@@ -39,7 +39,7 @@ class Allocator {
}
}
};
};
class
CPUAllocator
:
public
Allocator
{
class
CPUAllocator
:
public
Allocator
{
public:
public:
~
CPUAllocator
()
override
{}
~
CPUAllocator
()
override
{}
void
*
New
(
size_t
nbytes
)
override
{
void
*
New
(
size_t
nbytes
)
override
{
...
@@ -55,9 +55,7 @@ class CPUAllocator: public Allocator {
...
@@ -55,9 +55,7 @@ class CPUAllocator: public Allocator {
return
data
;
return
data
;
}
}
void
Delete
(
void
*
data
)
override
{
void
Delete
(
void
*
data
)
override
{
free
(
data
);
}
free
(
data
);
}
void
CopyBytes
(
void
*
dst
,
const
void
*
src
,
size_t
size
)
override
{
void
CopyBytes
(
void
*
dst
,
const
void
*
src
,
size_t
size
)
override
{
memcpy
(
dst
,
src
,
size
);
memcpy
(
dst
,
src
,
size
);
...
@@ -85,6 +83,6 @@ struct DeviceContext<DeviceType::NEON> {
...
@@ -85,6 +83,6 @@ struct DeviceContext<DeviceType::NEON> {
Allocator
*
GetDeviceAllocator
(
DeviceType
type
);
Allocator
*
GetDeviceAllocator
(
DeviceType
type
);
}
// namespace mace
}
// namespace mace
#endif // MACE_CORE_ALLOCATOR_H_
#endif
// MACE_CORE_ALLOCATOR_H_
mace/core/common.h
浏览文件 @
578b382a
...
@@ -5,12 +5,12 @@
...
@@ -5,12 +5,12 @@
#ifndef MACE_CORE_COMMON_H_
#ifndef MACE_CORE_COMMON_H_
#define MACE_CORE_COMMON_H_
#define MACE_CORE_COMMON_H_
#include <
set
>
#include <
algorithm
>
#include <map>
#include <map>
#include <string>
#include <memory>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <vector>
#include <algorithm>
#include "mace/core/logging.h"
#include "mace/core/logging.h"
...
@@ -24,9 +24,9 @@ typedef int64_t index_t;
...
@@ -24,9 +24,9 @@ typedef int64_t index_t;
// Disable the copy and assignment operator for a class.
// Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN
#ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname)
\
#define DISABLE_COPY_AND_ASSIGN(classname) \
private:
\
private:
\
classname(const classname&) = delete;
\
classname(const classname&) = delete; \
classname& operator=(const classname&) = delete
classname& operator=(const classname&) = delete
#endif
#endif
...
@@ -35,4 +35,4 @@ private: \
...
@@ -35,4 +35,4 @@ private: \
// TODO: need to fine tune this
// TODO: need to fine tune this
#define kCostPerGroup 1024000000
#define kCostPerGroup 1024000000
#endif // MACE_CORE_COMMON_H_
#endif
// MACE_CORE_COMMON_H_
mace/core/logging.cc
浏览文件 @
578b382a
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/core/logging.h"
#include "mace/core/logging.h"
#include <stdlib.h>
#include <stdlib.h>
...
@@ -62,11 +61,11 @@ void LogMessage::GenerateLogMessage() {
...
@@ -62,11 +61,11 @@ void LogMessage::GenerateLogMessage() {
#else
#else
void
LogMessage
::
GenerateLogMessage
()
{
void
LogMessage
::
GenerateLogMessage
()
{
fprintf
(
stderr
,
"%c %s:%d] %s
\n
"
,
"IWEF"
[
severity_
],
fname_
,
line_
,
str
().
c_str
());
fprintf
(
stderr
,
"%c %s:%d] %s
\n
"
,
"IWEF"
[
severity_
],
fname_
,
line_
,
str
().
c_str
());
}
}
#endif
#endif
namespace
{
namespace
{
// Parse log level (int64_t) from environment variable (char*)
// Parse log level (int64_t) from environment variable (char*)
...
...
mace/core/logging.h
浏览文件 @
578b382a
...
@@ -5,8 +5,8 @@
...
@@ -5,8 +5,8 @@
#ifndef MACE_CORE_LOGGING_H_
#ifndef MACE_CORE_LOGGING_H_
#define MACE_CORE_LOGGING_H_
#define MACE_CORE_LOGGING_H_
#include <sstream>
#include <limits>
#include <limits>
#include <sstream>
#include <string>
#include <string>
#undef ERROR
#undef ERROR
...
@@ -30,8 +30,8 @@ inline void MakeStringInternal(std::stringstream& ss, const T& t) {
...
@@ -30,8 +30,8 @@ inline void MakeStringInternal(std::stringstream& ss, const T& t) {
}
}
template
<
typename
T
,
typename
...
Args
>
template
<
typename
T
,
typename
...
Args
>
inline
void
inline
void
MakeStringInternal
(
std
::
stringstream
&
ss
,
const
T
&
t
,
MakeStringInternal
(
std
::
stringstream
&
ss
,
const
T
&
t
,
const
Args
&
...
args
)
{
const
Args
&
...
args
)
{
MakeStringInternal
(
ss
,
t
);
MakeStringInternal
(
ss
,
t
);
MakeStringInternal
(
ss
,
args
...);
MakeStringInternal
(
ss
,
args
...);
}
}
...
@@ -48,9 +48,7 @@ template <>
...
@@ -48,9 +48,7 @@ template <>
inline
string
MakeString
(
const
string
&
str
)
{
inline
string
MakeString
(
const
string
&
str
)
{
return
str
;
return
str
;
}
}
inline
string
MakeString
(
const
char
*
c_str
)
{
inline
string
MakeString
(
const
char
*
c_str
)
{
return
string
(
c_str
);
}
return
string
(
c_str
);
}
class
LogMessage
:
public
std
::
basic_ostringstream
<
char
>
{
class
LogMessage
:
public
std
::
basic_ostringstream
<
char
>
{
public:
public:
...
@@ -85,8 +83,7 @@ class LogMessageFatal : public LogMessage {
...
@@ -85,8 +83,7 @@ class LogMessageFatal : public LogMessage {
::mace::internal::LogMessage(__FILE__, __LINE__, mace::WARNING)
::mace::internal::LogMessage(__FILE__, __LINE__, mace::WARNING)
#define _MACE_LOG_ERROR \
#define _MACE_LOG_ERROR \
::mace::internal::LogMessage(__FILE__, __LINE__, mace::ERROR)
::mace::internal::LogMessage(__FILE__, __LINE__, mace::ERROR)
#define _MACE_LOG_FATAL \
#define _MACE_LOG_FATAL ::mace::internal::LogMessageFatal(__FILE__, __LINE__)
::mace::internal::LogMessageFatal(__FILE__, __LINE__)
#define _MACE_LOG_QFATAL _MACE_LOG_FATAL
#define _MACE_LOG_QFATAL _MACE_LOG_FATAL
...
@@ -96,10 +93,10 @@ class LogMessageFatal : public LogMessage {
...
@@ -96,10 +93,10 @@ class LogMessageFatal : public LogMessage {
// Turn VLOG off when under mobile devices for considerations of binary size.
// Turn VLOG off when under mobile devices for considerations of binary size.
#define VLOG_IS_ON(lvl) ((lvl) <= 0)
#define VLOG_IS_ON(lvl) ((lvl) <= 0)
#else
#else
// Otherwise, Set MACE_CPP_MIN_VLOG_LEVEL environment to update minimum log level
// Otherwise, Set MACE_CPP_MIN_VLOG_LEVEL environment to update minimum log
// level
// of VLOG
// of VLOG
#define VLOG_IS_ON(lvl) \
#define VLOG_IS_ON(lvl) ((lvl) <= ::mace::internal::LogMessage::MinVLogLevel())
((lvl) <= ::mace::internal::LogMessage::MinVLogLevel())
#endif
#endif
#define VLOG(lvl) \
#define VLOG(lvl) \
...
@@ -113,16 +110,16 @@ class LogMessageFatal : public LogMessage {
...
@@ -113,16 +110,16 @@ class LogMessageFatal : public LogMessage {
// MACE_CHECK(fp->Write(x) == 4)
// MACE_CHECK(fp->Write(x) == 4)
// MACE_CHECK(fp->Write(x) == 4, "Write failed")
// MACE_CHECK(fp->Write(x) == 4, "Write failed")
// which are not correct for MACE_ASSERT.
// which are not correct for MACE_ASSERT.
#define MACE_CHECK(condition, ...) \
#define MACE_CHECK(condition, ...)
\
if (!(condition)) \
if (!(condition))
\
LOG(FATAL) << "Check failed: " #condition " " \
LOG(FATAL) << "Check failed: " #condition " " \
<< ::mace::internal::MakeString(__VA_ARGS__)
<< ::mace::internal::MakeString(__VA_ARGS__)
#ifndef NDEBUG
#ifndef NDEBUG
#define MACE_ASSERT(condition, ...) \
#define MACE_ASSERT(condition, ...)
\
if (!(condition)) \
if (!(condition))
\
LOG(FATAL) << "Assert failed: " #condition " " \
LOG(FATAL) << "Assert failed: " #condition " " \
<< ::mace::internal::MakeString(__VA_ARGS__)
<< ::mace::internal::MakeString(__VA_ARGS__)
#else
#else
#define MACE_ASSERT(condition, ...) ((void)0)
#define MACE_ASSERT(condition, ...) ((void)0)
#endif
#endif
...
@@ -135,9 +132,9 @@ T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) {
...
@@ -135,9 +132,9 @@ T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) {
return
std
::
forward
<
T
>
(
t
);
return
std
::
forward
<
T
>
(
t
);
}
}
#define MACE_CHECK_NOTNULL(val)
\
#define MACE_CHECK_NOTNULL(val) \
::mace::internal::CheckNotNull(__FILE__, __LINE__, \
::mace::internal::CheckNotNull(__FILE__, __LINE__, \
"'" #val "' Must be non NULL", (val))
"'" #val "' Must be non NULL", (val))
}
// namespace internal
}
// namespace internal
}
// namespace mace
}
// namespace mace
...
...
mace/core/macros.h
浏览文件 @
578b382a
...
@@ -17,5 +17,4 @@
...
@@ -17,5 +17,4 @@
#define MACE_PREDICT_TRUE(x) (x)
#define MACE_PREDICT_TRUE(x) (x)
#endif
#endif
#endif // MACE_CORE_MACROS_H_
#endif //MACE_CORE_MACROS_H_
mace/core/net.cc
浏览文件 @
578b382a
...
@@ -6,22 +6,19 @@
...
@@ -6,22 +6,19 @@
namespace
mace
{
namespace
mace
{
NetBase
::
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
NetBase
::
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
)
DeviceType
type
)
:
name_
(
net_def
->
name
())
{
:
name_
(
net_def
->
name
())
{}
}
SimpleNet
::
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
SimpleNet
::
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
)
DeviceType
type
)
:
NetBase
(
net_def
,
ws
,
type
)
{
:
NetBase
(
net_def
,
ws
,
type
)
{
VLOG
(
1
)
<<
"Constructing SimpleNet "
<<
net_def
->
name
();
VLOG
(
1
)
<<
"Constructing SimpleNet "
<<
net_def
->
name
();
for
(
int
idx
=
0
;
idx
<
net_def
->
op_size
();
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
net_def
->
op_size
();
++
idx
)
{
const
auto
&
operator_def
=
net_def
->
op
(
idx
);
const
auto
&
operator_def
=
net_def
->
op
(
idx
);
VLOG
(
1
)
<<
"Creating operator "
<<
operator_def
.
name
()
<<
":"
VLOG
(
1
)
<<
"Creating operator "
<<
operator_def
.
name
()
<<
":"
<<
operator_def
.
type
();
<<
operator_def
.
type
();
std
::
unique_ptr
<
OperatorBase
>
op
{
nullptr
};
std
::
unique_ptr
<
OperatorBase
>
op
{
nullptr
};
OperatorDef
temp_def
(
operator_def
);
OperatorDef
temp_def
(
operator_def
);
op
=
CreateOperator
(
temp_def
,
ws
,
type
);
op
=
CreateOperator
(
temp_def
,
ws
,
type
);
operators_
.
emplace_back
(
std
::
move
(
op
));
operators_
.
emplace_back
(
std
::
move
(
op
));
...
@@ -40,20 +37,16 @@ bool SimpleNet::Run() {
...
@@ -40,20 +37,16 @@ bool SimpleNet::Run() {
return
true
;
return
true
;
}
}
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
)
{
DeviceType
type
)
{
std
::
shared_ptr
<
NetDef
>
tmp_net_def
(
new
NetDef
(
net_def
));
std
::
shared_ptr
<
NetDef
>
tmp_net_def
(
new
NetDef
(
net_def
));
return
CreateNet
(
tmp_net_def
,
ws
,
type
);
return
CreateNet
(
tmp_net_def
,
ws
,
type
);
}
}
unique_ptr
<
NetBase
>
CreateNet
(
unique_ptr
<
NetBase
>
CreateNet
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
DeviceType
type
)
{
Workspace
*
ws
,
DeviceType
type
)
{
unique_ptr
<
NetBase
>
net
(
new
SimpleNet
(
net_def
,
ws
,
type
));
unique_ptr
<
NetBase
>
net
(
new
SimpleNet
(
net_def
,
ws
,
type
));
return
net
;
return
net
;
}
}
}
// namespace mace
}
// namespace mace
mace/core/net.h
浏览文件 @
578b382a
...
@@ -6,35 +6,31 @@
...
@@ -6,35 +6,31 @@
#define MACE_CORE_NET_H_
#define MACE_CORE_NET_H_
#include "mace/core/common.h"
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/operator.h"
#include "mace/core/operator.h"
#include "mace/core/workspace.h"
#include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
class
NetBase
{
class
NetBase
{
public:
public:
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
);
DeviceType
type
);
virtual
~
NetBase
()
noexcept
{}
virtual
~
NetBase
()
noexcept
{}
virtual
bool
Run
()
=
0
;
virtual
bool
Run
()
=
0
;
const
string
&
Name
()
const
{
const
string
&
Name
()
const
{
return
name_
;
}
return
name_
;
}
protected:
protected:
string
name_
;
string
name_
;
DISABLE_COPY_AND_ASSIGN
(
NetBase
);
DISABLE_COPY_AND_ASSIGN
(
NetBase
);
};
};
class
SimpleNet
:
public
NetBase
{
class
SimpleNet
:
public
NetBase
{
public:
public:
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
);
DeviceType
type
);
bool
Run
()
override
;
bool
Run
()
override
;
...
@@ -42,17 +38,14 @@ class SimpleNet : public NetBase {
...
@@ -42,17 +38,14 @@ class SimpleNet : public NetBase {
protected:
protected:
vector
<
unique_ptr
<
OperatorBase
>
>
operators_
;
vector
<
unique_ptr
<
OperatorBase
>
>
operators_
;
DISABLE_COPY_AND_ASSIGN
(
SimpleNet
);
DISABLE_COPY_AND_ASSIGN
(
SimpleNet
);
};
};
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
Workspace
*
ws
,
Workspace
*
ws
,
DeviceType
type
);
DeviceType
type
);
unique_ptr
<
NetBase
>
CreateNet
(
unique_ptr
<
NetBase
>
CreateNet
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
DeviceType
type
);
Workspace
*
ws
,
DeviceType
type
);
}
// namespace mace
}
// namespace mace
#endif // MACE_CORE_NET_H_
#endif
// MACE_CORE_NET_H_
mace/core/operator.cc
浏览文件 @
578b382a
...
@@ -11,33 +11,22 @@ std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
...
@@ -11,33 +11,22 @@ std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
return
&
g_device_type_registry
;
return
&
g_device_type_registry
;
}
}
MACE_DEFINE_REGISTRY
(
MACE_DEFINE_REGISTRY
(
CPUOperatorRegistry
,
OperatorBase
,
const
OperatorDef
&
,
CPUOperatorRegistry
,
Workspace
*
);
OperatorBase
,
const
OperatorDef
&
,
Workspace
*
);
MACE_REGISTER_DEVICE_TYPE
(
DeviceType
::
CPU
,
CPUOperatorRegistry
);
MACE_REGISTER_DEVICE_TYPE
(
DeviceType
::
CPU
,
CPUOperatorRegistry
);
MACE_DEFINE_REGISTRY
(
MACE_DEFINE_REGISTRY
(
NEONOperatorRegistry
,
OperatorBase
,
const
OperatorDef
&
,
NEONOperatorRegistry
,
Workspace
*
);
OperatorBase
,
const
OperatorDef
&
,
Workspace
*
);
MACE_REGISTER_DEVICE_TYPE
(
DeviceType
::
NEON
,
NEONOperatorRegistry
);
MACE_REGISTER_DEVICE_TYPE
(
DeviceType
::
NEON
,
NEONOperatorRegistry
);
unique_ptr
<
OperatorBase
>
CreateOperator
(
unique_ptr
<
OperatorBase
>
CreateOperator
(
const
OperatorDef
&
operator_def
,
const
OperatorDef
&
operator_def
,
Workspace
*
ws
,
DeviceType
type
)
{
Workspace
*
ws
,
DeviceType
type
)
{
OperatorRegistry
*
registry
=
gDeviceTypeRegistry
()
->
at
(
type
);
OperatorRegistry
*
registry
=
gDeviceTypeRegistry
()
->
at
(
type
);
return
registry
->
Create
(
operator_def
.
type
(),
operator_def
,
ws
);
return
registry
->
Create
(
operator_def
.
type
(),
operator_def
,
ws
);
}
}
OperatorBase
::
OperatorBase
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
OperatorBase
::
OperatorBase
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
operator_ws_
(
ws
),
:
operator_ws_
(
ws
),
operator_def_
(
std
::
make_shared
<
OperatorDef
>
(
operator_def
))
{
operator_def_
(
std
::
make_shared
<
OperatorDef
>
(
operator_def
))
{}
}
}
// namespace mace
}
// namespace mace
mace/core/operator.h
浏览文件 @
578b382a
...
@@ -5,12 +5,12 @@
...
@@ -5,12 +5,12 @@
#ifndef MACE_CORE_OPERATOR_H
#ifndef MACE_CORE_OPERATOR_H
#define MACE_CORE_OPERATOR_H
#define MACE_CORE_OPERATOR_H
#include "mace/core/proto_utils.h"
#include "mace/core/common.h"
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/proto_utils.h"
#include "mace/core/tensor.h"
#include "mace/core/registry.h"
#include "mace/core/registry.h"
#include "mace/core/tensor.h"
#include "mace/core/workspace.h"
#include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
...
@@ -23,22 +23,21 @@ class OperatorBase {
...
@@ -23,22 +23,21 @@ class OperatorBase {
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
return
ArgumentHelper
::
HasArgument
(
*
operator_def_
,
name
);
return
ArgumentHelper
::
HasArgument
(
*
operator_def_
,
name
);
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
T
GetSingleArgument
(
const
string
&
name
,
const
T
&
default_value
)
const
{
inline
T
GetSingleArgument
(
const
string
&
name
,
const
T
&
default_value
)
const
{
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
return
ArgumentHelper
::
GetSingleArgument
<
OperatorDef
,
T
>
(
return
ArgumentHelper
::
GetSingleArgument
<
OperatorDef
,
T
>
(
*
operator_def_
,
name
,
default_value
);
*
operator_def_
,
name
,
default_value
);
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
bool
HasSingleArgumentOfType
(
const
string
&
name
)
const
{
inline
bool
HasSingleArgumentOfType
(
const
string
&
name
)
const
{
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
return
ArgumentHelper
::
HasSingleArgumentOfType
<
OperatorDef
,
T
>
(
return
ArgumentHelper
::
HasSingleArgumentOfType
<
OperatorDef
,
T
>
(
*
operator_def_
,
name
);
*
operator_def_
,
name
);
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
vector
<
T
>
GetRepeatedArgument
(
inline
vector
<
T
>
GetRepeatedArgument
(
const
string
&
name
,
const
string
&
name
,
const
vector
<
T
>
&
default_value
=
{})
const
{
const
vector
<
T
>
&
default_value
=
{})
const
{
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
MACE_CHECK
(
operator_def_
,
"operator_def was null!"
);
return
ArgumentHelper
::
GetRepeatedArgument
<
OperatorDef
,
T
>
(
return
ArgumentHelper
::
GetRepeatedArgument
<
OperatorDef
,
T
>
(
*
operator_def_
,
name
,
default_value
);
*
operator_def_
,
name
,
default_value
);
...
@@ -49,9 +48,7 @@ class OperatorBase {
...
@@ -49,9 +48,7 @@ class OperatorBase {
return
inputs_
[
idx
];
return
inputs_
[
idx
];
}
}
inline
Tensor
*
Output
(
int
idx
)
{
inline
Tensor
*
Output
(
int
idx
)
{
return
outputs_
[
idx
];
}
return
outputs_
[
idx
];
}
inline
int
InputSize
()
{
return
inputs_
.
size
();
}
inline
int
InputSize
()
{
return
inputs_
.
size
();
}
inline
int
OutputSize
()
{
return
outputs_
.
size
();
}
inline
int
OutputSize
()
{
return
outputs_
.
size
();
}
...
@@ -70,9 +67,7 @@ class OperatorBase {
...
@@ -70,9 +67,7 @@ class OperatorBase {
operator_def_
=
operator_def
;
operator_def_
=
operator_def
;
}
}
inline
bool
has_debug_def
()
const
{
inline
bool
has_debug_def
()
const
{
return
operator_def_
!=
nullptr
;
}
return
operator_def_
!=
nullptr
;
}
protected:
protected:
Workspace
*
operator_ws_
;
Workspace
*
operator_ws_
;
...
@@ -80,7 +75,7 @@ class OperatorBase {
...
@@ -80,7 +75,7 @@ class OperatorBase {
vector
<
const
Tensor
*>
inputs_
;
vector
<
const
Tensor
*>
inputs_
;
vector
<
Tensor
*>
outputs_
;
vector
<
Tensor
*>
outputs_
;
DISABLE_COPY_AND_ASSIGN
(
OperatorBase
);
DISABLE_COPY_AND_ASSIGN
(
OperatorBase
);
};
};
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
...
@@ -90,26 +85,22 @@ class Operator : public OperatorBase {
...
@@ -90,26 +85,22 @@ class Operator : public OperatorBase {
:
OperatorBase
(
operator_def
,
ws
)
{
:
OperatorBase
(
operator_def
,
ws
)
{
for
(
const
string
&
input_str
:
operator_def
.
input
())
{
for
(
const
string
&
input_str
:
operator_def
.
input
())
{
const
Tensor
*
tensor
=
ws
->
GetTensor
(
input_str
);
const
Tensor
*
tensor
=
ws
->
GetTensor
(
input_str
);
MACE_CHECK
(
MACE_CHECK
(
tensor
!=
nullptr
,
"op "
,
operator_def
.
type
(),
tensor
!=
nullptr
,
": Encountered a non-existing input tensor: "
,
input_str
);
"op "
,
operator_def
.
type
(),
": Encountered a non-existing input tensor: "
,
input_str
);
inputs_
.
push_back
(
tensor
);
inputs_
.
push_back
(
tensor
);
}
}
for
(
const
string
&
output_str
:
operator_def
.
output
())
{
for
(
const
string
&
output_str
:
operator_def
.
output
())
{
outputs_
.
push_back
(
MACE_CHECK_NOTNULL
(
ws
->
CreateTensor
(
output_str
,
outputs_
.
push_back
(
MACE_CHECK_NOTNULL
(
ws
->
CreateTensor
(
DeviceContext
<
D
>::
allocator
(),
output_str
,
DeviceContext
<
D
>::
allocator
(),
DataTypeToEnum
<
T
>::
v
())));
DataTypeToEnum
<
T
>::
v
())));
}
}
}
}
virtual
bool
Run
()
override
=
0
;
virtual
bool
Run
()
override
=
0
;
~
Operator
()
noexcept
override
{}
~
Operator
()
noexcept
override
{}
};
};
// OP_INPUT_TAGS and OP_OUTPUT_TAGS are optional features to name the indices of the
// OP_INPUT_TAGS and OP_OUTPUT_TAGS are optional features to name the indices of
// the
// operator's inputs and outputs, in order to avoid confusion. For example, for
// operator's inputs and outputs, in order to avoid confusion. For example, for
// a fully convolution layer that has input, weight and bias, you can define its
// a fully convolution layer that has input, weight and bias, you can define its
// input tags as:
// input tags as:
...
@@ -119,9 +110,9 @@ class Operator : public OperatorBase {
...
@@ -119,9 +110,9 @@ class Operator : public OperatorBase {
// you can now do
// you can now do
// auto& weight = Input(WEIGHT);
// auto& weight = Input(WEIGHT);
// to make it more clear.
// to make it more clear.
#define OP_INPUT_TAGS(first_input, ...)
\
#define OP_INPUT_TAGS(first_input, ...) \
enum _InputTags { first_input = 0, __VA_ARGS__ }
enum _InputTags { first_input = 0, __VA_ARGS__ }
#define OP_OUTPUT_TAGS(first_input, ...)
\
#define OP_OUTPUT_TAGS(first_input, ...) \
enum _OutputTags { first_input = 0, __VA_ARGS__ }
enum _OutputTags { first_input = 0, __VA_ARGS__ }
typedef
Registry
<
std
::
string
,
OperatorBase
,
const
OperatorDef
&
,
Workspace
*>
typedef
Registry
<
std
::
string
,
OperatorBase
,
const
OperatorDef
&
,
Workspace
*>
...
@@ -135,7 +126,7 @@ struct DeviceTypeRegisterer {
...
@@ -135,7 +126,7 @@ struct DeviceTypeRegisterer {
if
(
gDeviceTypeRegistry
()
->
count
(
type
))
{
if
(
gDeviceTypeRegistry
()
->
count
(
type
))
{
LOG
(
ERROR
)
<<
"Device type "
<<
type
LOG
(
ERROR
)
<<
"Device type "
<<
type
<<
"registered twice. This should not happen. Did you have "
<<
"registered twice. This should not happen. Did you have "
"duplicated numbers assigned to different devices?"
;
"duplicated numbers assigned to different devices?"
;
std
::
exit
(
1
);
std
::
exit
(
1
);
}
}
// Calling the registry function to get the actual registry pointer.
// Calling the registry function to get the actual registry pointer.
...
@@ -143,39 +134,31 @@ struct DeviceTypeRegisterer {
...
@@ -143,39 +134,31 @@ struct DeviceTypeRegisterer {
}
}
};
};
#define MACE_REGISTER_DEVICE_TYPE(type, registry_function) \
#define MACE_REGISTER_DEVICE_TYPE(type, registry_function)
\
namespace { \
namespace {
\
static DeviceTypeRegisterer MACE_ANONYMOUS_VARIABLE(
\
static DeviceTypeRegisterer MACE_ANONYMOUS_VARIABLE(
DeviceType)(
\
DeviceType)(type, ®istry_function);
\
type, ®istry_function);
\
}
}
MACE_DECLARE_REGISTRY
(
MACE_DECLARE_REGISTRY
(
CPUOperatorRegistry
,
OperatorBase
,
const
OperatorDef
&
,
CPUOperatorRegistry
,
Workspace
*
);
OperatorBase
,
const
OperatorDef
&
,
Workspace
*
);
#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
MACE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
MACE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_CPU_OPERATOR(name, ...)
\
#define REGISTER_CPU_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
MACE_DECLARE_REGISTRY
(
MACE_DECLARE_REGISTRY
(
NEONOperatorRegistry
,
OperatorBase
,
const
OperatorDef
&
,
NEONOperatorRegistry
,
Workspace
*
);
OperatorBase
,
const
OperatorDef
&
,
Workspace
*
);
#define REGISTER_NEON_OPERATOR_CREATOR(key, ...) \
#define REGISTER_NEON_OPERATOR_CREATOR(key, ...) \
MACE_REGISTER_CREATOR(NEONOperatorRegistry, key, __VA_ARGS__)
MACE_REGISTER_CREATOR(NEONOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_NEON_OPERATOR(name, ...)
\
#define REGISTER_NEON_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(NEONOperatorRegistry, name, __VA_ARGS__)
MACE_REGISTER_CLASS(NEONOperatorRegistry, name, __VA_ARGS__)
unique_ptr
<
OperatorBase
>
CreateOperator
(
unique_ptr
<
OperatorBase
>
CreateOperator
(
const
OperatorDef
&
operator_def
,
const
OperatorDef
&
operator_def
,
Workspace
*
ws
,
DeviceType
type
);
Workspace
*
ws
,
DeviceType
type
);
}
// namespace mace
}
// namespace mace
#endif
//
MACE_CORE_OPERATOR_H
#endif
//
MACE_CORE_OPERATOR_H
mace/core/proto_utils.cc
浏览文件 @
578b382a
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#include "mace/core/proto_utils.h"
#include "mace/core/proto_utils.h"
#include <fcntl.h>
#include <fcntl.h>
#include <unistd.h>
#include <cerrno>
#include <cerrno>
#include <fstream>
#include <fstream>
#include <unistd.h>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
...
@@ -82,13 +82,12 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
...
@@ -82,13 +82,12 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
return
proto
->
ParseFromCodedStream
(
&
coded_stream
);
return
proto
->
ParseFromCodedStream
(
&
coded_stream
);
}
}
void
WriteProtoToBinaryFile
(
void
WriteProtoToBinaryFile
(
const
MessageLite
&
/*proto*/
,
const
MessageLite
&
/*proto*/
,
const
char
*
/*filename*/
)
{
const
char
*
/*filename*/
)
{
LOG
(
FATAL
)
<<
"Not implemented yet."
;
LOG
(
FATAL
)
<<
"Not implemented yet."
;
}
}
#else // MACE_USE_LITE_PROTO
#else
// MACE_USE_LITE_PROTO
// Full protocol buffer.
// Full protocol buffer.
...
@@ -118,7 +117,7 @@ void WriteProtoToTextFile(const Message& proto, const char* filename) {
...
@@ -118,7 +117,7 @@ void WriteProtoToTextFile(const Message& proto, const char* filename) {
}
}
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
bool
ReadProtoFromBinaryFile
(
const
char
*
filename
,
MessageLite
*
proto
)
{
#if defined
(_MSC_VER) // for MSC compiler binary flag needs to be specified
#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int
fd
=
open
(
filename
,
O_RDONLY
|
O_BINARY
);
int
fd
=
open
(
filename
,
O_RDONLY
|
O_BINARY
);
#else
#else
int
fd
=
open
(
filename
,
O_RDONLY
);
int
fd
=
open
(
filename
,
O_RDONLY
);
...
@@ -138,8 +137,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
...
@@ -138,8 +137,8 @@ bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
)
{
void
WriteProtoToBinaryFile
(
const
MessageLite
&
proto
,
const
char
*
filename
)
{
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
int
fd
=
open
(
filename
,
O_WRONLY
|
O_CREAT
|
O_TRUNC
,
0644
);
MACE_CHECK
(
MACE_CHECK
(
fd
!=
-
1
,
"File cannot be created: "
,
filename
,
" error number: "
,
fd
!=
-
1
,
"File cannot be created: "
,
filename
,
" error number: "
,
errno
);
errno
);
std
::
unique_ptr
<
ZeroCopyOutputStream
>
raw_output
(
new
FileOutputStream
(
fd
));
std
::
unique_ptr
<
ZeroCopyOutputStream
>
raw_output
(
new
FileOutputStream
(
fd
));
std
::
unique_ptr
<
CodedOutputStream
>
coded_output
(
std
::
unique_ptr
<
CodedOutputStream
>
coded_output
(
new
CodedOutputStream
(
raw_output
.
get
()));
new
CodedOutputStream
(
raw_output
.
get
()));
...
@@ -151,18 +150,17 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
...
@@ -151,18 +150,17 @@ void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
#endif // MACE_USE_LITE_PROTO
#endif // MACE_USE_LITE_PROTO
ArgumentHelper
::
ArgumentHelper
(
const
OperatorDef
&
def
)
{
ArgumentHelper
::
ArgumentHelper
(
const
OperatorDef
&
def
)
{
for
(
auto
&
arg
:
def
.
arg
())
{
for
(
auto
&
arg
:
def
.
arg
())
{
if
(
arg_map_
.
find
(
arg
.
name
())
!=
arg_map_
.
end
())
{
if
(
arg_map_
.
find
(
arg
.
name
())
!=
arg_map_
.
end
())
{
MACE_CHECK
(
MACE_CHECK
(
arg
.
SerializeAsString
()
==
arg_map_
[
arg
.
name
()].
SerializeAsString
(),
arg
.
SerializeAsString
()
==
arg_map_
[
arg
.
name
()].
SerializeAsString
(),
"Found argument of the same name '"
,
"Found argument of the same name '"
,
arg
.
name
(),
arg
.
name
(),
"' but with different contents: "
,
ProtoDebugString
(
def
));
"' but with different contents: "
,
ProtoDebugString
(
def
));
LOG
(
WARNING
)
<<
"Duplicated argument name found in operator def: "
LOG
(
WARNING
)
<<
"Duplicated argument name found in operator def: "
<<
ProtoDebugString
(
def
)
<<
", arg: "
<<
ProtoDebugString
(
arg
);
<<
ProtoDebugString
(
def
)
<<
", arg: "
<<
ProtoDebugString
(
arg
);
}
}
arg_map_
[
arg
.
name
()]
=
arg
;
arg_map_
[
arg
.
name
()]
=
arg
;
...
@@ -171,10 +169,9 @@ ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
...
@@ -171,10 +169,9 @@ ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
ArgumentHelper
::
ArgumentHelper
(
const
NetDef
&
netdef
)
{
ArgumentHelper
::
ArgumentHelper
(
const
NetDef
&
netdef
)
{
for
(
auto
&
arg
:
netdef
.
arg
())
{
for
(
auto
&
arg
:
netdef
.
arg
())
{
MACE_CHECK
(
MACE_CHECK
(
arg_map_
.
count
(
arg
.
name
())
==
0
,
arg_map_
.
count
(
arg
.
name
())
==
0
,
"Duplicated argument name found in net def: "
,
"Duplicated argument name found in net def: "
,
ProtoDebugString
(
netdef
));
ProtoDebugString
(
netdef
));
arg_map_
[
arg
.
name
()]
=
arg
;
arg_map_
[
arg
.
name
()]
=
arg
;
}
}
}
}
...
@@ -192,32 +189,24 @@ bool SupportsLosslessConversion(const InputType& value) {
...
@@ -192,32 +189,24 @@ bool SupportsLosslessConversion(const InputType& value) {
}
}
}
}
#define INSTANTIATE_GET_SINGLE_ARGUMENT(
\
#define INSTANTIATE_GET_SINGLE_ARGUMENT(
T, fieldname,
\
T, fieldname, enforce_lossless_conversion)
\
enforce_lossless_conversion)
\
template <> \
template <> \
T ArgumentHelper::GetSingleArgument<T>(
\
T ArgumentHelper::GetSingleArgument<T>(
const string& name,
\
const string& name, const T& default_value) const {
\
const T& default_value) const {
\
if (arg_map_.count(name) == 0) { \
if (arg_map_.count(name) == 0) { \
VLOG(1) << "Using default parameter value " << default_value \
VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \
<< " for parameter " << name; \
return default_value; \
return default_value; \
} \
} \
MACE_CHECK( \
MACE_CHECK(arg_map_.at(name).has_##fieldname(), "Argument ", name, \
arg_map_.at(name).has_##fieldname(), \
" does not have the right field: expected field " #fieldname); \
"Argument ", \
name, \
" does not have the right field: expected field " #fieldname); \
auto value = arg_map_.at(name).fieldname(); \
auto value = arg_map_.at(name).fieldname(); \
if (enforce_lossless_conversion) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
auto supportsConversion = \
SupportsLosslessConversion<decltype(value), T>(value); \
SupportsLosslessConversion<decltype(value), T>(value); \
MACE_CHECK( \
MACE_CHECK(supportsConversion, "Value", value, " of argument ", name, \
supportsConversion, \
"cannot be represented correctly in a target type"); \
"Value", \
value, \
" of argument ", \
name, \
"cannot be represented correctly in a target type"); \
} \
} \
return value; \
return value; \
} \
} \
...
@@ -242,30 +231,25 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
...
@@ -242,30 +231,25 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT
(
string
,
s
,
false
)
INSTANTIATE_GET_SINGLE_ARGUMENT
(
string
,
s
,
false
)
#undef INSTANTIATE_GET_SINGLE_ARGUMENT
#undef INSTANTIATE_GET_SINGLE_ARGUMENT
#define INSTANTIATE_GET_REPEATED_ARGUMENT( \
#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname, \
T, fieldname, enforce_lossless_conversion) \
enforce_lossless_conversion) \
template <> \
template <> \
vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
const string& name, const std::vector<T>& default_value) const { \
const string& name, const std::vector<T>& default_value) const { \
if (arg_map_.count(name) == 0) { \
if (arg_map_.count(name) == 0) { \
return default_value; \
return default_value; \
} \
} \
vector<T> values; \
vector<T> values; \
for (const auto& v : arg_map_.at(name).fieldname()) { \
for (const auto& v : arg_map_.at(name).fieldname()) { \
if (enforce_lossless_conversion) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \
SupportsLosslessConversion<decltype(v), T>(v); \
MACE_CHECK( \
MACE_CHECK(supportsConversion, "Value", v, " of argument ", name, \
supportsConversion, \
"cannot be represented correctly in a target type"); \
"Value", \
} \
v, \
values.push_back(v); \
" of argument ", \
} \
name, \
return values; \
"cannot be represented correctly in a target type"); \
} \
values.push_back(v); \
} \
return values; \
}
}
INSTANTIATE_GET_REPEATED_ARGUMENT
(
float
,
floats
,
false
)
INSTANTIATE_GET_REPEATED_ARGUMENT
(
float
,
floats
,
false
)
...
@@ -281,14 +265,14 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
...
@@ -281,14 +265,14 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT
(
string
,
strings
,
false
)
INSTANTIATE_GET_REPEATED_ARGUMENT
(
string
,
strings
,
false
)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname)
\
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <>
\
template <>
\
Argument MakeArgument(const string& name, const T& value) {
\
Argument MakeArgument(const string& name, const T& value) {
\
Argument arg;
\
Argument arg;
\
arg.set_name(name);
\
arg.set_name(name);
\
arg.set_##fieldname(value);
\
arg.set_##fieldname(value);
\
return arg;
\
return arg;
\
}
}
MACE_MAKE_SINGULAR_ARGUMENT
(
bool
,
i
)
MACE_MAKE_SINGULAR_ARGUMENT
(
bool
,
i
)
MACE_MAKE_SINGULAR_ARGUMENT
(
float
,
f
)
MACE_MAKE_SINGULAR_ARGUMENT
(
float
,
f
)
...
@@ -305,16 +289,16 @@ Argument MakeArgument(const string& name, const MessageLite& value) {
...
@@ -305,16 +289,16 @@ Argument MakeArgument(const string& name, const MessageLite& value) {
return
arg
;
return
arg
;
}
}
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname)
\
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <>
\
template <>
\
Argument MakeArgument(const string& name, const vector<T>& value) {
\
Argument MakeArgument(const string& name, const vector<T>& value) {
\
Argument arg;
\
Argument arg;
\
arg.set_name(name);
\
arg.set_name(name);
\
for (const auto& v : value) {
\
for (const auto& v : value) {
\
arg.add_##fieldname(v);
\
arg.add_##fieldname(v);
\
}
\
}
\
return arg;
\
return arg;
\
}
}
MACE_MAKE_REPEATED_ARGUMENT
(
float
,
floats
)
MACE_MAKE_REPEATED_ARGUMENT
(
float
,
floats
)
MACE_MAKE_REPEATED_ARGUMENT
(
int
,
ints
)
MACE_MAKE_REPEATED_ARGUMENT
(
int
,
ints
)
...
@@ -328,31 +312,24 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) {
...
@@ -328,31 +312,24 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) {
return
arg
;
return
arg
;
}
}
}
}
MACE_CHECK
(
false
,
MACE_CHECK
(
false
,
"Argument named "
,
name
,
"does not exist in operator "
,
"Argument named "
,
ProtoDebugString
(
def
));
name
,
"does not exist in operator "
,
ProtoDebugString
(
def
));
}
}
bool
GetFlagArgument
(
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
const
string
&
name
,
const
OperatorDef
&
def
,
bool
def_value
)
{
const
string
&
name
,
bool
def_value
)
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
for
(
const
Argument
&
arg
:
def
.
arg
())
{
if
(
arg
.
name
()
==
name
)
{
if
(
arg
.
name
()
==
name
)
{
MACE_CHECK
(
MACE_CHECK
(
arg
.
has_i
(),
"Can't parse argument as bool: "
,
arg
.
has_i
(),
"Can't parse argument as bool: "
,
ProtoDebugString
(
arg
));
ProtoDebugString
(
arg
));
return
arg
.
i
();
return
arg
.
i
();
}
}
}
}
return
def_value
;
return
def_value
;
}
}
Argument
*
GetMutableArgument
(
Argument
*
GetMutableArgument
(
const
string
&
name
,
const
bool
create_if_missing
,
const
string
&
name
,
OperatorDef
*
def
)
{
const
bool
create_if_missing
,
OperatorDef
*
def
)
{
for
(
int
i
=
0
;
i
<
def
->
arg_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
def
->
arg_size
();
++
i
)
{
if
(
def
->
arg
(
i
).
name
()
==
name
)
{
if
(
def
->
arg
(
i
).
name
()
==
name
)
{
return
def
->
mutable_arg
(
i
);
return
def
->
mutable_arg
(
i
);
...
...
mace/core/proto_utils.h
浏览文件 @
578b382a
...
@@ -12,15 +12,14 @@
...
@@ -12,15 +12,14 @@
#include "google/protobuf/message.h"
#include "google/protobuf/message.h"
#endif // !MACE_USE_LITE_PROTO
#endif // !MACE_USE_LITE_PROTO
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
using
std
::
string
;
using
std
::
string
;
using
::
google
::
protobuf
::
MessageLite
;
using
::
google
::
protobuf
::
MessageLite
;
// Common interfaces that reads file contents into a string.
// Common interfaces that reads file contents into a string.
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
);
bool
ReadStringFromFile
(
const
char
*
filename
,
string
*
str
);
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
);
bool
WriteStringToFile
(
const
string
&
str
,
const
char
*
filename
);
...
@@ -46,22 +45,20 @@ inline string ProtoDebugString(const MessageLite& proto) {
...
@@ -46,22 +45,20 @@ inline string ProtoDebugString(const MessageLite& proto) {
// Text format MessageLite wrappers: these functions do nothing but just
// Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using
// allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support.
// MessageLite but still want text support.
inline
bool
ReadProtoFromTextFile
(
inline
bool
ReadProtoFromTextFile
(
const
char
*
/*filename*/
,
const
char
*
/*filename*/
,
MessageLite
*
/*proto*/
)
{
MessageLite
*
/*proto*/
)
{
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
<<
"calling any text-format protobuffers."
;
<<
"calling any text-format protobuffers."
;
return
false
;
// Just to suppress compiler warning.
return
false
;
// Just to suppress compiler warning.
}
}
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
MessageLite
*
proto
)
{
inline
bool
ReadProtoFromTextFile
(
const
string
filename
,
MessageLite
*
proto
)
{
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
return
ReadProtoFromTextFile
(
filename
.
c_str
(),
proto
);
}
}
inline
void
WriteProtoToTextFile
(
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
/*proto*/
,
const
MessageLite
&
/*proto*/
,
const
char
*
/*filename*/
)
{
const
char
*
/*filename*/
)
{
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
LOG
(
FATAL
)
<<
"If you are running lite version, you should not be "
<<
"calling any text-format protobuffers."
;
<<
"calling any text-format protobuffers."
;
}
}
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
proto
,
inline
void
WriteProtoToTextFile
(
const
MessageLite
&
proto
,
const
string
&
filename
)
{
const
string
&
filename
)
{
...
@@ -107,16 +104,13 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) {
...
@@ -107,16 +104,13 @@ inline bool ReadProtoFromFile(const string& filename, Message* proto) {
#endif // MACE_USE_LITE_PROTO
#endif // MACE_USE_LITE_PROTO
template
<
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableOutputs
=
std
::
initializer_list
<
string
>
,
class
IterableOutputs
=
std
::
initializer_list
<
string
>
,
class
IterableArgs
=
std
::
initializer_list
<
Argument
>>
class
IterableArgs
=
std
::
initializer_list
<
Argument
>>
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
const
string
&
name
,
OperatorDef
CreateOperatorDef
(
const
IterableInputs
&
inputs
,
const
string
&
type
,
const
IterableOutputs
&
outputs
,
const
string
&
name
,
const
IterableArgs
&
args
)
{
const
IterableInputs
&
inputs
,
const
IterableOutputs
&
outputs
,
const
IterableArgs
&
args
)
{
OperatorDef
def
;
OperatorDef
def
;
def
.
set_type
(
type
);
def
.
set_type
(
type
);
def
.
set_name
(
name
);
def
.
set_name
(
name
);
...
@@ -134,20 +128,13 @@ OperatorDef CreateOperatorDef(
...
@@ -134,20 +128,13 @@ OperatorDef CreateOperatorDef(
// A simplified version compared to the full CreateOperator, if you do not need
// A simplified version compared to the full CreateOperator, if you do not need
// to specify args.
// to specify args.
template
<
template
<
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableInputs
=
std
::
initializer_list
<
string
>,
class
IterableOutputs
=
std
::
initializer_list
<
string
>>
class
IterableOutputs
=
std
::
initializer_list
<
string
>>
inline
OperatorDef
CreateOperatorDef
(
const
string
&
type
,
const
string
&
name
,
inline
OperatorDef
CreateOperatorDef
(
const
IterableInputs
&
inputs
,
const
string
&
type
,
const
IterableOutputs
&
outputs
)
{
const
string
&
name
,
return
CreateOperatorDef
(
type
,
name
,
inputs
,
outputs
,
const
IterableInputs
&
inputs
,
std
::
vector
<
Argument
>
());
const
IterableOutputs
&
outputs
)
{
return
CreateOperatorDef
(
type
,
name
,
inputs
,
outputs
,
std
::
vector
<
Argument
>
());
}
}
/**
/**
...
@@ -166,10 +153,8 @@ class ArgumentHelper {
...
@@ -166,10 +153,8 @@ class ArgumentHelper {
}
}
template
<
typename
Def
,
typename
T
>
template
<
typename
Def
,
typename
T
>
static
T
GetSingleArgument
(
static
T
GetSingleArgument
(
const
Def
&
def
,
const
string
&
name
,
const
Def
&
def
,
const
T
&
default_value
)
{
const
string
&
name
,
const
T
&
default_value
)
{
return
ArgumentHelper
(
def
).
GetSingleArgument
<
T
>
(
name
,
default_value
);
return
ArgumentHelper
(
def
).
GetSingleArgument
<
T
>
(
name
,
default_value
);
}
}
...
@@ -180,8 +165,7 @@ class ArgumentHelper {
...
@@ -180,8 +165,7 @@ class ArgumentHelper {
template
<
typename
Def
,
typename
T
>
template
<
typename
Def
,
typename
T
>
static
vector
<
T
>
GetRepeatedArgument
(
static
vector
<
T
>
GetRepeatedArgument
(
const
Def
&
def
,
const
Def
&
def
,
const
string
&
name
,
const
string
&
name
,
const
std
::
vector
<
T
>&
default_value
=
std
::
vector
<
T
>
())
{
const
std
::
vector
<
T
>&
default_value
=
std
::
vector
<
T
>
())
{
return
ArgumentHelper
(
def
).
GetRepeatedArgument
<
T
>
(
name
,
default_value
);
return
ArgumentHelper
(
def
).
GetRepeatedArgument
<
T
>
(
name
,
default_value
);
}
}
...
@@ -192,9 +176,8 @@ class ArgumentHelper {
...
@@ -192,9 +176,8 @@ class ArgumentHelper {
}
}
template
<
typename
Def
,
typename
MessageType
>
template
<
typename
Def
,
typename
MessageType
>
static
vector
<
MessageType
>
GetRepeatedMessageArgument
(
static
vector
<
MessageType
>
GetRepeatedMessageArgument
(
const
Def
&
def
,
const
Def
&
def
,
const
string
&
name
)
{
const
string
&
name
)
{
return
ArgumentHelper
(
def
).
GetRepeatedMessageArgument
<
MessageType
>
(
name
);
return
ArgumentHelper
(
def
).
GetRepeatedMessageArgument
<
MessageType
>
(
name
);
}
}
...
@@ -216,9 +199,8 @@ class ArgumentHelper {
...
@@ -216,9 +199,8 @@ class ArgumentHelper {
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
MessageType
message
;
MessageType
message
;
if
(
arg_map_
.
at
(
name
).
has_s
())
{
if
(
arg_map_
.
at
(
name
).
has_s
())
{
MACE_CHECK
(
MACE_CHECK
(
message
.
ParseFromString
(
arg_map_
.
at
(
name
).
s
()),
message
.
ParseFromString
(
arg_map_
.
at
(
name
).
s
()),
"Faild to parse content from the string"
);
"Faild to parse content from the string"
);
}
else
{
}
else
{
VLOG
(
1
)
<<
"Return empty message for parameter "
<<
name
;
VLOG
(
1
)
<<
"Return empty message for parameter "
<<
name
;
}
}
...
@@ -230,9 +212,8 @@ class ArgumentHelper {
...
@@ -230,9 +212,8 @@ class ArgumentHelper {
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
MACE_CHECK
(
arg_map_
.
count
(
name
),
"Cannot find parameter named "
+
name
);
vector
<
MessageType
>
messages
(
arg_map_
.
at
(
name
).
strings_size
());
vector
<
MessageType
>
messages
(
arg_map_
.
at
(
name
).
strings_size
());
for
(
int
i
=
0
;
i
<
messages
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
messages
.
size
();
++
i
)
{
MACE_CHECK
(
MACE_CHECK
(
messages
[
i
].
ParseFromString
(
arg_map_
.
at
(
name
).
strings
(
i
)),
messages
[
i
].
ParseFromString
(
arg_map_
.
at
(
name
).
strings
(
i
)),
"Faild to parse content from the string"
);
"Faild to parse content from the string"
);
}
}
return
messages
;
return
messages
;
}
}
...
@@ -242,15 +223,11 @@ class ArgumentHelper {
...
@@ -242,15 +223,11 @@ class ArgumentHelper {
};
};
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
);
const
Argument
&
GetArgument
(
const
OperatorDef
&
def
,
const
string
&
name
);
bool
GetFlagArgument
(
bool
GetFlagArgument
(
const
OperatorDef
&
def
,
const
string
&
name
,
const
OperatorDef
&
def
,
bool
def_value
=
false
);
const
string
&
name
,
bool
def_value
=
false
);
Argument
*
GetMutableArgument
(
const
string
&
name
,
const
bool
create_if_missing
,
OperatorDef
*
def
);
Argument
*
GetMutableArgument
(
const
string
&
name
,
const
bool
create_if_missing
,
OperatorDef
*
def
);
template
<
typename
T
>
template
<
typename
T
>
Argument
MakeArgument
(
const
string
&
name
,
const
T
&
value
);
Argument
MakeArgument
(
const
string
&
name
,
const
T
&
value
);
...
...
mace/core/registry.h
浏览文件 @
578b382a
...
@@ -12,7 +12,7 @@ namespace mace {
...
@@ -12,7 +12,7 @@ namespace mace {
template
<
class
SrcType
,
class
ObjectType
,
class
...
Args
>
template
<
class
SrcType
,
class
ObjectType
,
class
...
Args
>
class
Registry
{
class
Registry
{
public:
public:
typedef
std
::
function
<
std
::
unique_ptr
<
ObjectType
>
(
Args
...)
>
Creator
;
typedef
std
::
function
<
std
::
unique_ptr
<
ObjectType
>
(
Args
...)
>
Creator
;
Registry
()
:
registry_
()
{}
Registry
()
:
registry_
()
{}
...
@@ -24,7 +24,7 @@ class Registry {
...
@@ -24,7 +24,7 @@ class Registry {
inline
bool
Has
(
const
SrcType
&
key
)
{
return
registry_
.
count
(
key
)
!=
0
;
}
inline
bool
Has
(
const
SrcType
&
key
)
{
return
registry_
.
count
(
key
)
!=
0
;
}
unique_ptr
<
ObjectType
>
Create
(
const
SrcType
&
key
,
Args
...
args
)
{
unique_ptr
<
ObjectType
>
Create
(
const
SrcType
&
key
,
Args
...
args
)
{
if
(
registry_
.
count
(
key
)
==
0
)
{
if
(
registry_
.
count
(
key
)
==
0
)
{
VLOG
(
2
)
<<
"Key not registered: "
<<
key
;
VLOG
(
2
)
<<
"Key not registered: "
<<
key
;
return
nullptr
;
return
nullptr
;
...
@@ -60,7 +60,7 @@ class Registerer {
...
@@ -60,7 +60,7 @@ class Registerer {
}
}
template
<
class
DerivedType
>
template
<
class
DerivedType
>
static
unique_ptr
<
ObjectType
>
DefaultCreator
(
Args
...
args
)
{
static
unique_ptr
<
ObjectType
>
DefaultCreator
(
Args
...
args
)
{
return
std
::
unique_ptr
<
ObjectType
>
(
new
DerivedType
(
args
...));
return
std
::
unique_ptr
<
ObjectType
>
(
new
DerivedType
(
args
...));
}
}
};
};
...
@@ -74,36 +74,35 @@ class Registerer {
...
@@ -74,36 +74,35 @@ class Registerer {
#endif
#endif
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName();
\
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName(); \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__>
\
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName;
Registerer##RegistryName;
#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
#define MACE_DEFINE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName() {
\
Registry<SrcType, ObjectType, ##__VA_ARGS__>* RegistryName() { \
static Registry<SrcType, ObjectType, ##__VA_ARGS__>* registry =
\
static Registry<SrcType, ObjectType, ##__VA_ARGS__>* registry = \
new Registry<SrcType, ObjectType, ##__VA_ARGS__>();
\
new Registry<SrcType, ObjectType, ##__VA_ARGS__>(); \
return registry;
\
return registry; \
}
}
#define MACE_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
#define MACE_DECLARE_REGISTRY(RegistryName, ObjectType, ...)
\
MACE_DECLARE_TYPED_REGISTRY(
\
MACE_DECLARE_TYPED_REGISTRY(
RegistryName, std::string, ObjectType,
\
RegistryName, std::string, ObjectType,
##__VA_ARGS__)
##__VA_ARGS__)
#define MACE_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
#define MACE_DEFINE_REGISTRY(RegistryName, ObjectType, ...)
\
MACE_DEFINE_TYPED_REGISTRY(
\
MACE_DEFINE_TYPED_REGISTRY(
RegistryName, std::string, ObjectType,
\
RegistryName, std::string, ObjectType,
##__VA_ARGS__)
##__VA_ARGS__)
#define MACE_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
#define MACE_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
namespace {
\
namespace { \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, RegistryName(), __VA_ARGS__);
key, RegistryName(), __VA_ARGS__);
#define MACE_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
#define MACE_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
namespace {
\
namespace { \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
static Registerer##RegistryName MACE_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \
key, RegistryName(), \
RegistryName(), \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>); \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>); \
}
}
#define MACE_REGISTER_CREATOR(RegistryName, key, ...) \
#define MACE_REGISTER_CREATOR(RegistryName, key, ...) \
...
@@ -112,6 +111,6 @@ class Registerer {
...
@@ -112,6 +111,6 @@ class Registerer {
#define MACE_REGISTER_CLASS(RegistryName, key, ...) \
#define MACE_REGISTER_CLASS(RegistryName, key, ...) \
MACE_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
MACE_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
}
// namespace mace
}
// namespace mace
#endif // MACE_CORE_REGISTRY_H_
#endif
// MACE_CORE_REGISTRY_H_
mace/core/serializer.cc
浏览文件 @
578b382a
...
@@ -4,19 +4,18 @@
...
@@ -4,19 +4,18 @@
#include "mace/core/serializer.h"
#include "mace/core/serializer.h"
namespace
mace
{
namespace
mace
{
unique_ptr
<
TensorProto
>
Serializer
::
Serialize
(
const
Tensor
&
tensor
,
unique_ptr
<
TensorProto
>
Serializer
::
Serialize
(
const
Tensor
&
tensor
,
const
string
&
name
)
{
const
string
&
name
)
{
MACE_NOT_IMPLEMENTED
;
MACE_NOT_IMPLEMENTED
;
return
nullptr
;
return
nullptr
;
}
}
unique_ptr
<
Tensor
>
Serializer
::
Deserialize
(
const
TensorProto
&
proto
,
unique_ptr
<
Tensor
>
Serializer
::
Deserialize
(
const
TensorProto
&
proto
,
DeviceType
type
)
{
DeviceType
type
)
{
unique_ptr
<
Tensor
>
tensor
(
new
Tensor
(
GetDeviceAllocator
(
type
),
unique_ptr
<
Tensor
>
tensor
(
proto
.
data_type
()));
new
Tensor
(
GetDeviceAllocator
(
type
),
proto
.
data_type
()));
vector
<
index_t
>
dims
;
vector
<
index_t
>
dims
;
for
(
const
index_t
d
:
proto
.
dims
())
{
for
(
const
index_t
d
:
proto
.
dims
())
{
dims
.
push_back
(
d
);
dims
.
push_back
(
d
);
...
@@ -25,8 +24,7 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
...
@@ -25,8 +24,7 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
switch
(
proto
.
data_type
())
{
switch
(
proto
.
data_type
())
{
case
DT_FLOAT
:
case
DT_FLOAT
:
tensor
->
Copy
<
float
>
(
proto
.
float_data
().
data
(),
tensor
->
Copy
<
float
>
(
proto
.
float_data
().
data
(),
proto
.
float_data
().
size
());
proto
.
float_data
().
size
());
break
;
break
;
case
DT_DOUBLE
:
case
DT_DOUBLE
:
tensor
->
Copy
<
double
>
(
proto
.
double_data
().
data
(),
tensor
->
Copy
<
double
>
(
proto
.
double_data
().
data
(),
...
@@ -34,39 +32,38 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
...
@@ -34,39 +32,38 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
break
;
break
;
case
DT_INT32
:
case
DT_INT32
:
tensor
->
template
Copy
<
int32_t
>(
proto
.
int32_data
().
data
(),
tensor
->
template
Copy
<
int32_t
>(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
proto
.
int32_data
().
size
());
break
;
break
;
case
DT_UINT8
:
case
DT_UINT8
:
tensor
->
CopyWithCast
<
int32_t
,
uint8_t
>
(
proto
.
int32_data
().
data
(),
tensor
->
CopyWithCast
<
int32_t
,
uint8_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
proto
.
int32_data
().
size
());
break
;
break
;
case
DT_INT16
:
case
DT_INT16
:
tensor
->
CopyWithCast
<
int32_t
,
int16_t
>
(
proto
.
int32_data
().
data
(),
tensor
->
CopyWithCast
<
int32_t
,
int16_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
proto
.
int32_data
().
size
());
break
;
break
;
case
DT_INT8
:
case
DT_INT8
:
tensor
->
CopyWithCast
<
int32_t
,
int8_t
>
(
proto
.
int32_data
().
data
(),
tensor
->
CopyWithCast
<
int32_t
,
int8_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
proto
.
int32_data
().
size
());
break
;
break
;
case
DT_INT64
:
case
DT_INT64
:
tensor
->
Copy
<
int64_t
>
(
proto
.
int64_data
().
data
(),
tensor
->
Copy
<
int64_t
>
(
proto
.
int64_data
().
data
(),
proto
.
int64_data
().
size
());
proto
.
int64_data
().
size
());
break
;
break
;
case
DT_UINT16
:
case
DT_UINT16
:
tensor
->
CopyWithCast
<
int32_t
,
uint16_t
>
(
proto
.
int32_data
().
data
(),
tensor
->
CopyWithCast
<
int32_t
,
uint16_t
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
proto
.
int32_data
().
size
());
break
;
break
;
case
DT_BOOL
:
case
DT_BOOL
:
tensor
->
CopyWithCast
<
int32_t
,
bool
>
(
proto
.
int32_data
().
data
(),
tensor
->
CopyWithCast
<
int32_t
,
bool
>
(
proto
.
int32_data
().
data
(),
proto
.
int32_data
().
size
());
proto
.
int32_data
().
size
());
break
;
break
;
case
DT_STRING
:
{
case
DT_STRING
:
{
string
*
content
=
tensor
->
mutable_data
<
string
>
();
string
*
content
=
tensor
->
mutable_data
<
string
>
();
for
(
int
i
=
0
;
i
<
proto
.
string_data
().
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
proto
.
string_data
().
size
();
++
i
)
{
content
[
i
]
=
proto
.
string_data
(
i
);
content
[
i
]
=
proto
.
string_data
(
i
);
}
}
}
}
break
;
break
;
default:
default:
MACE_NOT_IMPLEMENTED
;
MACE_NOT_IMPLEMENTED
;
break
;
break
;
...
@@ -75,4 +72,4 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
...
@@ -75,4 +72,4 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
return
tensor
;
return
tensor
;
}
}
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/core/serializer.h
浏览文件 @
578b382a
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#ifndef MACE_CORE_SERIALIZER_H_
#ifndef MACE_CORE_SERIALIZER_H_
#define MACE_CORE_SERIALIZER_H_
#define MACE_CORE_SERIALIZER_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
...
@@ -20,9 +20,9 @@ class Serializer {
...
@@ -20,9 +20,9 @@ class Serializer {
unique_ptr
<
Tensor
>
Deserialize
(
const
TensorProto
&
proto
,
DeviceType
type
);
unique_ptr
<
Tensor
>
Deserialize
(
const
TensorProto
&
proto
,
DeviceType
type
);
DISABLE_COPY_AND_ASSIGN
(
Serializer
);
DISABLE_COPY_AND_ASSIGN
(
Serializer
);
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_CORE_SERIALIZER_H_
#endif
// MACE_CORE_SERIALIZER_H_
mace/core/tensor.h
浏览文件 @
578b382a
...
@@ -5,11 +5,11 @@
...
@@ -5,11 +5,11 @@
#ifndef MACE_CORE_TENSOR_H_
#ifndef MACE_CORE_TENSOR_H_
#define MACE_CORE_TENSOR_H_
#define MACE_CORE_TENSOR_H_
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/allocator.h"
#include "mace/core/allocator.h"
#include "mace/core/
types
.h"
#include "mace/core/
common
.h"
#include "mace/core/logging.h"
#include "mace/core/logging.h"
#include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
namespace
mace
{
...
@@ -25,13 +25,13 @@ namespace mace {
...
@@ -25,13 +25,13 @@ namespace mace {
switch (TYPE_ENUM) { \
switch (TYPE_ENUM) { \
CASE(float, SINGLE_ARG(STMTS)) \
CASE(float, SINGLE_ARG(STMTS)) \
CASE(double, SINGLE_ARG(STMTS)) \
CASE(double, SINGLE_ARG(STMTS)) \
CASE(int32_t, SINGLE_ARG(STMTS))
\
CASE(int32_t, SINGLE_ARG(STMTS)) \
CASE(uint8_t, SINGLE_ARG(STMTS))
\
CASE(uint8_t, SINGLE_ARG(STMTS)) \
CASE(uint16_t, SINGLE_ARG(STMTS))
\
CASE(uint16_t, SINGLE_ARG(STMTS)) \
CASE(int16_t, SINGLE_ARG(STMTS))
\
CASE(int16_t, SINGLE_ARG(STMTS)) \
CASE(int8_t, SINGLE_ARG(STMTS))
\
CASE(int8_t, SINGLE_ARG(STMTS)) \
CASE(string, SINGLE_ARG(STMTS)) \
CASE(string, SINGLE_ARG(STMTS)) \
CASE(int64_t, SINGLE_ARG(STMTS))
\
CASE(int64_t, SINGLE_ARG(STMTS)) \
CASE(bool, SINGLE_ARG(STMTS)) \
CASE(bool, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
case DT_INVALID: \
INVALID; \
INVALID; \
...
@@ -41,20 +41,17 @@ namespace mace {
...
@@ -41,20 +41,17 @@ namespace mace {
break; \
break; \
}
}
#define CASES(TYPE_ENUM, STMTS) \
#define CASES(TYPE_ENUM, STMTS) \
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
class
Tensor
{
class
Tensor
{
public:
public:
Tensor
()
Tensor
()
:
alloc_
(
cpu_allocator
()),
:
alloc_
(
cpu_allocator
()),
size_
(
0
),
dtype_
(
DT_FLOAT
),
data_
(
nullptr
){};
size_
(
0
),
dtype_
(
DT_FLOAT
),
data_
(
nullptr
)
{};
Tensor
(
Allocator
*
a
,
DataType
type
)
Tensor
(
Allocator
*
a
,
DataType
type
)
:
alloc_
(
a
),
size_
(
0
),
dtype_
(
type
),
data_
(
nullptr
)
{};
:
alloc_
(
a
),
size_
(
0
),
dtype_
(
type
),
data_
(
nullptr
){};
~
Tensor
()
{
~
Tensor
()
{
if
(
alloc_
&&
data_
.
get
())
{
if
(
alloc_
&&
data_
.
get
())
{
...
@@ -92,9 +89,8 @@ class Tensor {
...
@@ -92,9 +89,8 @@ class Tensor {
if
(
data_
.
get
()
||
size_
==
0
)
{
if
(
data_
.
get
()
||
size_
==
0
)
{
return
data_
.
get
();
return
data_
.
get
();
}
else
{
}
else
{
CASES
(
dtype_
,
data_
.
reset
(
alloc_
->
New
(
size_
*
sizeof
(
T
)),
[
this
](
void
*
ptr
)
{
CASES
(
dtype_
,
data_
.
reset
(
alloc_
->
New
(
size_
*
sizeof
(
T
)),
alloc_
->
Delete
(
ptr
);
[
this
](
void
*
ptr
)
{
alloc_
->
Delete
(
ptr
);
}));
}));
return
data_
.
get
();
return
data_
.
get
();
}
}
}
}
...
@@ -116,13 +112,9 @@ class Tensor {
...
@@ -116,13 +112,9 @@ class Tensor {
}
}
}
}
inline
void
ResizeLike
(
const
Tensor
&
other
)
{
inline
void
ResizeLike
(
const
Tensor
&
other
)
{
Resize
(
other
.
shape
());
}
Resize
(
other
.
shape
());
}
inline
void
ResizeLike
(
const
Tensor
*
other
)
{
inline
void
ResizeLike
(
const
Tensor
*
other
)
{
Resize
(
other
->
shape
());
}
Resize
(
other
->
shape
());
}
template
<
typename
T
>
template
<
typename
T
>
inline
void
Copy
(
const
T
*
src
,
index_t
size
)
{
inline
void
Copy
(
const
T
*
src
,
index_t
size
)
{
...
@@ -132,7 +124,8 @@ class Tensor {
...
@@ -132,7 +124,8 @@ class Tensor {
template
<
typename
SrcType
,
typename
DstType
>
template
<
typename
SrcType
,
typename
DstType
>
inline
void
CopyWithCast
(
const
SrcType
*
src
,
size_t
size
)
{
inline
void
CopyWithCast
(
const
SrcType
*
src
,
size_t
size
)
{
MACE_CHECK
(
static_cast
<
index_t
>
(
size
)
==
size_
,
"copy src and dst with different size."
);
MACE_CHECK
(
static_cast
<
index_t
>
(
size
)
==
size_
,
"copy src and dst with different size."
);
unique_ptr
<
DstType
[]
>
buffer
(
new
DstType
[
size
]);
unique_ptr
<
DstType
[]
>
buffer
(
new
DstType
[
size
]);
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
buffer
[
i
]
=
static_cast
<
DstType
>
(
src
[
i
]);
buffer
[
i
]
=
static_cast
<
DstType
>
(
src
[
i
]);
...
@@ -146,10 +139,11 @@ class Tensor {
...
@@ -146,10 +139,11 @@ class Tensor {
inline
void
DebugPrint
()
{
inline
void
DebugPrint
()
{
std
::
stringstream
os
;
std
::
stringstream
os
;
for
(
int
i
:
shape_
)
{
for
(
int
i
:
shape_
)
{
os
<<
i
<<
", "
;
os
<<
i
<<
", "
;
}
}
LOG
(
INFO
)
<<
"Tensor shape: "
<<
os
.
str
()
<<
" type: "
<<
DataType_Name
(
dtype_
);
LOG
(
INFO
)
<<
"Tensor shape: "
<<
os
.
str
()
<<
" type: "
<<
DataType_Name
(
dtype_
);
os
.
str
(
""
);
os
.
str
(
""
);
os
.
clear
();
os
.
clear
();
...
@@ -175,7 +169,8 @@ class Tensor {
...
@@ -175,7 +169,8 @@ class Tensor {
private:
private:
inline
int64_t
NumElements
()
const
{
inline
int64_t
NumElements
()
const
{
return
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
return
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
}
}
Allocator
*
alloc_
;
Allocator
*
alloc_
;
...
@@ -184,9 +179,9 @@ class Tensor {
...
@@ -184,9 +179,9 @@ class Tensor {
std
::
shared_ptr
<
void
>
data_
;
std
::
shared_ptr
<
void
>
data_
;
vector
<
index_t
>
shape_
;
vector
<
index_t
>
shape_
;
DISABLE_COPY_AND_ASSIGN
(
Tensor
);
DISABLE_COPY_AND_ASSIGN
(
Tensor
);
};
};
}
// namespace tensor
}
// namespace tensor
#endif
//
MACE_CORE_TENSOR_H_
#endif
//
MACE_CORE_TENSOR_H_
mace/core/testing/test_benchmark.cc
浏览文件 @
578b382a
...
@@ -51,11 +51,8 @@ Benchmark* Benchmark::ArgPair(int x, int y) {
...
@@ -51,11 +51,8 @@ Benchmark* Benchmark::ArgPair(int x, int y) {
return
this
;
return
this
;
}
}
// Run all benchmarks
// Run all benchmarks
void
Benchmark
::
Run
()
{
void
Benchmark
::
Run
()
{
Run
(
"all"
);
}
Run
(
"all"
);
}
void
Benchmark
::
Run
(
const
char
*
pattern
)
{
void
Benchmark
::
Run
(
const
char
*
pattern
)
{
if
(
!
all_benchmarks
)
return
;
if
(
!
all_benchmarks
)
return
;
...
@@ -113,8 +110,8 @@ void Benchmark::Run(const char* pattern) {
...
@@ -113,8 +110,8 @@ void Benchmark::Run(const char* pattern) {
(
items_processed
*
1e-6
)
/
seconds
);
(
items_processed
*
1e-6
)
/
seconds
);
full_label
+=
buf
;
full_label
+=
buf
;
}
}
printf
(
"%-*s %10.0f %10d
\t
%s
\n
"
,
width
,
name
,
printf
(
"%-*s %10.0f %10d
\t
%s
\n
"
,
width
,
name
,
seconds
*
1e9
/
iters
,
seconds
*
1e9
/
iters
,
iters
,
full_label
.
c_str
());
iters
,
full_label
.
c_str
());
}
}
}
}
}
}
...
...
mace/core/testing/test_benchmark.h
浏览文件 @
578b382a
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
#include "mace/core/types.h"
#include "mace/core/types.h"
#define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c
#define MACE_BENCHMARK_CONCAT(a, b, c) a##b##c
#define BENCHMARK(n)
\
#define BENCHMARK(n) \
static ::mace::testing::Benchmark* MACE_BENCHMARK_CONCAT(
__benchmark_, n, __LINE__) =
\
static ::mace::testing::Benchmark* MACE_BENCHMARK_CONCAT( \
(new ::mace::testing::Benchmark(#n, (n)))
__benchmark_, n, __LINE__) =
(new ::mace::testing::Benchmark(#n, (n)))
namespace
mace
{
namespace
mace
{
namespace
testing
{
namespace
testing
{
...
...
mace/core/testing/test_benchmark_main.cc
浏览文件 @
578b382a
...
@@ -17,4 +17,3 @@ int main(int argc, char** argv) {
...
@@ -17,4 +17,3 @@ int main(int argc, char** argv) {
}
}
return
0
;
return
0
;
}
}
mace/core/types.h
浏览文件 @
578b382a
...
@@ -18,26 +18,25 @@ struct DataTypeToEnum {
...
@@ -18,26 +18,25 @@ struct DataTypeToEnum {
static_assert
(
IsValidDataType
<
T
>::
value
,
"Specified Data Type not supported"
);
static_assert
(
IsValidDataType
<
T
>::
value
,
"Specified Data Type not supported"
);
};
};
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
// EnumToDataType<DT_FLOAT>::Type is float.
// EnumToDataType<DT_FLOAT>::Type is float.
template
<
DataType
VALUE
>
template
<
DataType
VALUE
>
struct
EnumToDataType
{};
// Specializations below
struct
EnumToDataType
{};
// Specializations below
// Template specialization for both DataTypeToEnum and EnumToDataType.
// Template specialization for both DataTypeToEnum and EnumToDataType.
#define MATCH_TYPE_AND_ENUM(TYPE, ENUM)
\
#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \
template <>
\
template <> \
struct DataTypeToEnum<TYPE> {
\
struct DataTypeToEnum<TYPE> { \
static DataType v() { return ENUM; }
\
static DataType v() { return ENUM; } \
static constexpr DataType value = ENUM;
\
static constexpr DataType value = ENUM; \
};
\
}; \
template <>
\
template <> \
struct IsValidDataType<TYPE> {
\
struct IsValidDataType<TYPE> { \
static constexpr bool value = true;
\
static constexpr bool value = true; \
};
\
}; \
template <>
\
template <> \
struct EnumToDataType<ENUM> {
\
struct EnumToDataType<ENUM> { \
typedef TYPE Type;
\
typedef TYPE Type; \
}
}
MATCH_TYPE_AND_ENUM
(
float
,
DT_FLOAT
);
MATCH_TYPE_AND_ENUM
(
float
,
DT_FLOAT
);
...
@@ -53,6 +52,6 @@ MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
...
@@ -53,6 +52,6 @@ MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
static
const
int32_t
kint32_tmax
=
((
int32_t
)
0x7FFFFFFF
);
static
const
int32_t
kint32_tmax
=
((
int32_t
)
0x7FFFFFFF
);
}
// namespace mace
}
// namespace mace
#endif // MACE_CORE_TYPES_H_
#endif
// MACE_CORE_TYPES_H_
mace/core/workspace.cc
浏览文件 @
578b382a
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/core/common.h"
#include "mace/core/workspace.h"
#include "mace/core/workspace.h"
#include "mace/core/common.h"
#include "mace/core/serializer.h"
#include "mace/core/serializer.h"
namespace
mace
{
namespace
mace
{
...
@@ -16,8 +16,7 @@ vector<string> Workspace::Tensors() const {
...
@@ -16,8 +16,7 @@ vector<string> Workspace::Tensors() const {
return
names
;
return
names
;
}
}
Tensor
*
Workspace
::
CreateTensor
(
const
string
&
name
,
Tensor
*
Workspace
::
CreateTensor
(
const
string
&
name
,
Allocator
*
alloc
,
Allocator
*
alloc
,
DataType
type
)
{
DataType
type
)
{
if
(
HasTensor
(
name
))
{
if
(
HasTensor
(
name
))
{
VLOG
(
1
)
<<
"Tensor "
<<
name
<<
" already exists. Skipping."
;
VLOG
(
1
)
<<
"Tensor "
<<
name
<<
" already exists. Skipping."
;
...
@@ -46,14 +45,16 @@ const Tensor* Workspace::GetTensor(const string& name) const {
...
@@ -46,14 +45,16 @@ const Tensor* Workspace::GetTensor(const string& name) const {
}
}
Tensor
*
Workspace
::
GetTensor
(
const
string
&
name
)
{
Tensor
*
Workspace
::
GetTensor
(
const
string
&
name
)
{
return
const_cast
<
Tensor
*>
(
static_cast
<
const
Workspace
*>
(
this
)
->
GetTensor
(
name
));
return
const_cast
<
Tensor
*>
(
static_cast
<
const
Workspace
*>
(
this
)
->
GetTensor
(
name
));
}
}
void
Workspace
::
LoadModelTensor
(
const
NetDef
&
net_def
,
DeviceType
type
)
{
void
Workspace
::
LoadModelTensor
(
const
NetDef
&
net_def
,
DeviceType
type
)
{
Serializer
serializer
;
Serializer
serializer
;
for
(
auto
&
tensor_proto
:
net_def
.
tensors
())
{
for
(
auto
&
tensor_proto
:
net_def
.
tensors
())
{
tensor_map_
[
tensor_proto
.
name
()]
=
serializer
.
Deserialize
(
tensor_proto
,
type
);
tensor_map_
[
tensor_proto
.
name
()]
=
serializer
.
Deserialize
(
tensor_proto
,
type
);
}
}
}
}
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/core/workspace.h
浏览文件 @
578b382a
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
#ifndef MACE_CORE_WORKSPACE_H_
#ifndef MACE_CORE_WORKSPACE_H_
#define MACE_CORE_WORKSPACE_H_
#define MACE_CORE_WORKSPACE_H_
#include "mace/core/common.h"
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#include "mace/proto/mace.pb.h"
...
@@ -37,10 +36,9 @@ class Workspace {
...
@@ -37,10 +36,9 @@ class Workspace {
private:
private:
TensorMap
tensor_map_
;
TensorMap
tensor_map_
;
DISABLE_COPY_AND_ASSIGN
(
Workspace
);
DISABLE_COPY_AND_ASSIGN
(
Workspace
);
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_CORE_WORKSPACE_H_
#endif
// MACE_CORE_WORKSPACE_H_
mace/examples/benchmark_example.cc
浏览文件 @
578b382a
...
@@ -14,7 +14,7 @@ static void foo(int iters) {
...
@@ -14,7 +14,7 @@ static void foo(int iters) {
float
*
out
=
new
float
[
N
];
float
*
out
=
new
float
[
N
];
while
(
iters
--
)
{
while
(
iters
--
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
out
[
i
]
=
inp
[
i
]
*
2.0
;
out
[
i
]
=
inp
[
i
]
*
2.0
;
}
}
}
}
...
@@ -24,7 +24,6 @@ static void foo(int iters) {
...
@@ -24,7 +24,6 @@ static void foo(int iters) {
BENCHMARK
(
foo
);
BENCHMARK
(
foo
);
static
void
bar
(
int
iters
,
int
n
)
{
static
void
bar
(
int
iters
,
int
n
)
{
const
int64_t
tot
=
static_cast
<
int64_t
>
(
iters
)
*
n
;
const
int64_t
tot
=
static_cast
<
int64_t
>
(
iters
)
*
n
;
mace
::
testing
::
ItemsProcessed
(
tot
);
mace
::
testing
::
ItemsProcessed
(
tot
);
...
@@ -34,7 +33,7 @@ static void bar(int iters, int n) {
...
@@ -34,7 +33,7 @@ static void bar(int iters, int n) {
float
*
out
=
new
float
[
n
];
float
*
out
=
new
float
[
n
];
while
(
iters
--
)
{
while
(
iters
--
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
out
[
i
]
=
inp
[
i
]
*
2.0
;
out
[
i
]
=
inp
[
i
]
*
2.0
;
}
}
}
}
...
...
mace/kernels/addn.h
浏览文件 @
578b382a
...
@@ -10,10 +10,9 @@
...
@@ -10,10 +10,9 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
AddNFunctor
{
struct
AddNFunctor
{
void
operator
()(
const
vector
<
const
T
*>&
inputs
,
void
operator
()(
const
vector
<
const
T
*>&
inputs
,
T
*
output
,
index_t
size
)
{
T
*
output
,
index_t
size
)
{
memset
(
output
,
0
,
size
*
sizeof
(
T
));
memset
(
output
,
0
,
size
*
sizeof
(
T
));
int
n
=
inputs
.
size
();
int
n
=
inputs
.
size
();
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
@@ -25,11 +24,10 @@ struct AddNFunctor {
...
@@ -25,11 +24,10 @@ struct AddNFunctor {
};
};
template
<
>
template
<
>
void
AddNFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
vector
<
const
float
*>&
inputs
,
void
AddNFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
float
*
output
,
const
vector
<
const
float
*>&
inputs
,
float
*
output
,
index_t
size
);
index_t
size
);
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
#endif // MACE_KERNELS_ADDN_H_
#endif // MACE_KERNELS_ADDN_H_
\ No newline at end of file
\ No newline at end of file
mace/kernels/batch_norm.h
浏览文件 @
578b382a
...
@@ -11,26 +11,21 @@
...
@@ -11,26 +11,21 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
BatchNormFunctor
{
struct
BatchNormFunctor
{
float
variance_epsilon_
;
float
variance_epsilon_
;
BatchNormFunctor
(
const
float
variance_epsilon
)
BatchNormFunctor
(
const
float
variance_epsilon
)
:
variance_epsilon_
(
variance_epsilon
)
{}
:
variance_epsilon_
(
variance_epsilon
)
{}
void
operator
()(
const
T
*
input
,
void
operator
()(
const
T
*
input
,
const
T
*
scale
,
const
T
*
offset
,
const
T
*
scale
,
const
T
*
mean
,
const
T
*
var
,
const
index_t
n
,
const
T
*
offset
,
const
index_t
channel
,
const
index_t
sample_size
,
T
*
output
)
{
const
T
*
mean
,
const
T
*
var
,
const
index_t
n
,
const
index_t
channel
,
const
index_t
sample_size
,
T
*
output
)
{
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} }
// ( \offset - \frac { \scale * mean } {
// \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
// Y = new_scale * X + new_offset;
...
@@ -53,18 +48,12 @@ struct BatchNormFunctor {
...
@@ -53,18 +48,12 @@ struct BatchNormFunctor {
};
};
template
<
>
template
<
>
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
scale
,
const
float
*
input
,
const
float
*
scale
,
const
float
*
offset
,
const
float
*
offset
,
const
float
*
mean
,
const
float
*
var
,
const
index_t
n
,
const
index_t
channel
,
const
float
*
mean
,
const
index_t
sample_size
,
float
*
output
);
const
float
*
var
,
const
index_t
n
,
const
index_t
channel
,
const
index_t
sample_size
,
float
*
output
);
}
// namepsace kernels
}
// namespace mace
}
// namepsace kernels
#endif // MACE_KERNELS_BATCH_NORM_H_
}
// namespace mace
#endif // MACE_KERNELS_BATCH_NORM_H_
mace/kernels/conv_2d.h
浏览文件 @
578b382a
...
@@ -10,114 +10,103 @@
...
@@ -10,114 +10,103 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
Conv2dFunctor
{
class
Conv2dFunctor
{
public:
public:
Conv2dFunctor
(
const
int
*
strides
,
Conv2dFunctor
(
const
int
*
strides
,
const
int
*
paddings
,
const
int
*
dilations
)
const
int
*
paddings
,
:
strides_
(
strides
),
paddings_
(
paddings
),
dilations_
(
dilations
)
{}
const
int
*
dilations
)
:
strides_
(
strides
),
void
operator
()(
const
T
*
input
,
// NCHW
paddings_
(
paddings
),
const
index_t
*
input_shape
,
dilations_
(
dilations
)
{}
const
T
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
filter_shape
,
void
operator
()(
const
T
*
input
,
// NCHW
const
T
*
bias
,
// c_out
const
index_t
*
input_shape
,
T
*
output
,
// NCHW
const
T
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
output_shape
)
{
const
index_t
*
filter_shape
,
MACE_CHECK_NOTNULL
(
output
);
const
T
*
bias
,
// c_out
T
*
output
,
// NCHW
index_t
batch
=
output_shape
[
0
];
const
index_t
*
output_shape
)
{
index_t
channels
=
output_shape
[
1
];
MACE_CHECK_NOTNULL
(
output
);
index_t
height
=
output_shape
[
2
];
index_t
width
=
output_shape
[
3
];
index_t
batch
=
output_shape
[
0
];
index_t
channels
=
output_shape
[
1
];
index_t
input_batch
=
input_shape
[
0
];
index_t
height
=
output_shape
[
2
];
index_t
input_channels
=
input_shape
[
1
];
index_t
width
=
output_shape
[
3
];
index_t
input_height
=
input_shape
[
2
];
index_t
input_width
=
input_shape
[
3
];
index_t
input_batch
=
input_shape
[
0
];
index_t
input_channels
=
input_shape
[
1
];
index_t
kernel_h
=
filter_shape
[
2
];
index_t
input_height
=
input_shape
[
2
];
index_t
kernel_w
=
filter_shape
[
3
];
index_t
input_width
=
input_shape
[
3
];
int
stride_h
=
strides_
[
0
];
index_t
kernel_h
=
filter_shape
[
2
];
int
stride_w
=
strides_
[
1
];
index_t
kernel_w
=
filter_shape
[
3
];
int
dilation_h
=
dilations_
[
0
];
int
stride_h
=
strides_
[
0
];
int
dilation_w
=
dilations_
[
1
];
int
stride_w
=
strides_
[
1
];
MACE_CHECK
(
batch
==
input_batch
,
"Input/Output batch size mismatch"
);
int
dilation_h
=
dilations_
[
0
];
int
dilation_w
=
dilations_
[
1
];
// The left-upper most offset of the padded input
int
padded_h_start
=
0
-
paddings_
[
0
]
/
2
;
MACE_CHECK
(
batch
==
input_batch
,
"Input/Output batch size mismatch"
);
int
padded_w_start
=
0
-
paddings_
[
1
]
/
2
;
index_t
padded_h_stop
=
input_height
+
paddings_
[
0
]
-
paddings_
[
0
]
/
2
;
// The left-upper most offset of the padded input
index_t
padded_w_stop
=
input_width
+
paddings_
[
1
]
-
paddings_
[
1
]
/
2
;
int
padded_h_start
=
0
-
paddings_
[
0
]
/
2
;
int
padded_w_start
=
0
-
paddings_
[
1
]
/
2
;
index_t
kernel_size
=
input_channels
*
kernel_h
*
kernel_w
;
index_t
padded_h_stop
=
input_height
+
paddings_
[
0
]
-
paddings_
[
0
]
/
2
;
index_t
padded_w_stop
=
input_width
+
paddings_
[
1
]
-
paddings_
[
1
]
/
2
;
index_t
kernel_size
=
input_channels
*
kernel_h
*
kernel_w
;
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
int
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
index_t
offset
=
n
*
channels
*
height
*
width
+
index_t
offset
=
n
*
channels
*
height
*
width
+
c
*
height
*
width
+
c
*
height
*
width
+
h
*
width
+
w
;
h
*
width
+
w
;
T
sum
=
0
;
T
sum
=
0
;
const
T
*
filter_ptr
=
filter
+
c
*
kernel_size
;
const
T
*
filter_ptr
=
filter
+
c
*
kernel_size
;
for
(
int
inc
=
0
;
inc
<
input_channels
;
++
inc
)
{
for
(
int
inc
=
0
;
inc
<
input_channels
;
++
inc
)
{
for
(
int
kh
=
0
;
kh
<
kernel_h
;
++
kh
)
{
for
(
int
kh
=
0
;
kh
<
kernel_h
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
kernel_w
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
kernel_w
;
++
kw
)
{
int
inh
=
padded_h_start
+
h
*
stride_h
+
dilation_h
*
kh
;
int
inw
=
padded_w_start
+
w
*
stride_w
+
dilation_w
*
kw
;
int
inh
=
padded_h_start
+
h
*
stride_h
+
dilation_h
*
kh
;
if
(
inh
<
0
||
inh
>=
input_height
||
inw
<
0
||
int
inw
=
padded_w_start
+
w
*
stride_w
+
dilation_w
*
kw
;
inw
>=
input_width
)
{
if
(
inh
<
0
||
inh
>=
input_height
||
MACE_CHECK
(
inh
>=
padded_h_start
&&
inh
<
padded_h_stop
&&
inw
<
0
||
inw
>=
input_width
)
{
inw
>=
padded_w_start
&&
inw
<
padded_w_stop
,
MACE_CHECK
(
inh
>=
padded_h_start
&&
"Out of range read from input: "
,
inh
,
", "
,
inh
<
padded_h_stop
&&
inw
);
inw
>=
padded_w_start
&&
// else padding with 0:
inw
<
padded_w_stop
,
// sum += 0;
"Out of range read from input: "
,
}
else
{
inh
,
", "
,
inw
);
index_t
input_offset
=
// else padding with 0:
// sum += 0;
}
else
{
index_t
input_offset
=
n
*
input_channels
*
input_height
*
input_width
+
n
*
input_channels
*
input_height
*
input_width
+
inc
*
input_height
*
input_width
+
inc
*
input_height
*
input_width
+
inh
*
input_width
+
inh
*
input_width
+
inw
;
inw
;
sum
+=
input
[
input_offset
]
*
*
filter_ptr
;
sum
+=
input
[
input_offset
]
*
*
filter_ptr
;
}
++
filter_ptr
;
}
}
++
filter_ptr
;
}
}
output
[
offset
]
=
sum
+
bias
[
c
];
}
}
output
[
offset
]
=
sum
+
bias
[
c
];
}
}
}
}
}
}
}
}
}
}
}
private:
private:
const
int
*
strides_
;
// [stride_h, stride_w]
const
int
*
strides_
;
// [stride_h, stride_w]
const
int
*
paddings_
;
// [padding_h, padding_w]
const
int
*
paddings_
;
// [padding_h, padding_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
const
int
*
dilations_
;
// [dilation_h, dilation_w]
};
};
template
<
>
template
<
>
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
index_t
*
input_shape
,
const
float
*
input
,
const
index_t
*
input_shape
,
const
float
*
filter
,
const
float
*
filter
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
float
*
output
,
const
index_t
*
filter_shape
,
const
index_t
*
output_shape
);
const
float
*
bias
,
float
*
output
,
}
// namespace kernels
const
index_t
*
output_shape
);
}
// namespace mace
}
// namespace kernels
#endif // MACE_KERNELS_CONV_2D_H_
}
// namespace mace
#endif // MACE_KERNELS_CONV_2D_H_
mace/kernels/conv_pool_2d_util.cc
浏览文件 @
578b382a
...
@@ -7,12 +7,10 @@
...
@@ -7,12 +7,10 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
void
CalcPaddingAndOutputSize
(
const
index_t
*
input_shape
,
// NCHW
void
CalcPaddingAndOutputSize
(
const
index_t
*
input_shape
,
// NCHW
const
index_t
*
filter_shape
,
// OIHW
const
index_t
*
filter_shape
,
// OIHW
const
int
*
dilations
,
const
int
*
dilations
,
const
int
*
strides
,
const
int
*
strides
,
Padding
padding
,
index_t
*
output_shape
,
Padding
padding
,
index_t
*
output_shape
,
int
*
padding_size
)
{
int
*
padding_size
)
{
MACE_CHECK
(
dilations
[
0
]
>
0
&&
dilations
[
1
]
>
0
,
MACE_CHECK
(
dilations
[
0
]
>
0
&&
dilations
[
1
]
>
0
,
"Invalid dilations, must >= 1"
);
"Invalid dilations, must >= 1"
);
...
@@ -43,14 +41,16 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
...
@@ -43,14 +41,16 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
output_height
=
(
input_shape
[
2
]
-
k_extent_height
)
/
strides
[
0
]
+
1
;
output_height
=
(
input_shape
[
2
]
-
k_extent_height
)
/
strides
[
0
]
+
1
;
output_width
=
(
input_shape
[
3
]
-
k_extent_width
)
/
strides
[
1
]
+
1
;
output_width
=
(
input_shape
[
3
]
-
k_extent_width
)
/
strides
[
1
]
+
1
;
break
;
break
;
case
SAME
:
output_height
=
(
input_shape
[
2
]
-
1
)
/
strides
[
0
]
+
1
;
case
SAME
:
output_height
=
(
input_shape
[
2
]
-
1
)
/
strides
[
0
]
+
1
;
output_width
=
(
input_shape
[
3
]
-
1
)
/
strides
[
1
]
+
1
;
output_width
=
(
input_shape
[
3
]
-
1
)
/
strides
[
1
]
+
1
;
break
;
break
;
case
FULL
:
case
FULL
:
output_height
=
(
input_shape
[
2
]
+
k_extent_height
-
2
)
/
strides
[
0
]
+
1
;
output_height
=
(
input_shape
[
2
]
+
k_extent_height
-
2
)
/
strides
[
0
]
+
1
;
output_width
=
(
input_shape
[
3
]
+
k_extent_width
-
2
)
/
strides
[
1
]
+
1
;
output_width
=
(
input_shape
[
3
]
+
k_extent_width
-
2
)
/
strides
[
1
]
+
1
;
break
;
break
;
default:
MACE_CHECK
(
false
,
"Unsupported padding type: "
,
padding
);
default:
MACE_CHECK
(
false
,
"Unsupported padding type: "
,
padding
);
}
}
// Note: TensorFlow may padded one more on the right/bottom side
// Note: TensorFlow may padded one more on the right/bottom side
...
@@ -58,10 +58,10 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
...
@@ -58,10 +58,10 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
// utilize the more centered features. We need to benchmark
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
// based on the model accuracy.
padding_size
[
0
]
=
(
output_height
-
1
)
*
strides
[
0
]
+
padding_size
[
0
]
=
k_extent_height
-
input_shape
[
2
];
(
output_height
-
1
)
*
strides
[
0
]
+
k_extent_height
-
input_shape
[
2
];
padding_size
[
1
]
=
(
output_width
-
1
)
*
strides
[
1
]
+
padding_size
[
1
]
=
k_extent_width
-
input_shape
[
3
];
(
output_width
-
1
)
*
strides
[
1
]
+
k_extent_width
-
input_shape
[
3
];
output_shape
[
0
]
=
input_shape
[
0
];
output_shape
[
0
]
=
input_shape
[
0
];
output_shape
[
1
]
=
output_channels
;
output_shape
[
1
]
=
output_channels
;
...
@@ -69,19 +69,15 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
...
@@ -69,19 +69,15 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
output_shape
[
3
]
=
output_width
;
output_shape
[
3
]
=
output_width
;
}
}
void
ConstructInputWithPadding
(
const
float
*
input
,
void
ConstructInputWithPadding
(
const
float
*
input
,
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
int
*
paddings
,
Tensor
*
output_tensor
)
{
const
int
*
paddings
,
Tensor
*
output_tensor
)
{
index_t
batch
=
input_shape
[
0
];
index_t
batch
=
input_shape
[
0
];
index_t
channels
=
input_shape
[
1
];
index_t
channels
=
input_shape
[
1
];
index_t
height
=
input_shape
[
2
];
index_t
height
=
input_shape
[
2
];
index_t
width
=
input_shape
[
3
];
index_t
width
=
input_shape
[
3
];
std
::
vector
<
index_t
>
output_shape
({
batch
,
std
::
vector
<
index_t
>
output_shape
(
channels
,
{
batch
,
channels
,
paddings
[
0
]
+
height
,
paddings
[
1
]
+
width
});
paddings
[
0
]
+
height
,
paddings
[
1
]
+
width
});
const
index_t
output_width
=
output_shape
[
3
];
const
index_t
output_width
=
output_shape
[
3
];
const
int
padded_top
=
paddings
[
0
]
/
2
;
const
int
padded_top
=
paddings
[
0
]
/
2
;
...
@@ -105,5 +101,5 @@ void ConstructInputWithPadding(const float *input,
...
@@ -105,5 +101,5 @@ void ConstructInputWithPadding(const float *input,
}
}
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/conv_pool_2d_util.h
浏览文件 @
578b382a
...
@@ -10,26 +10,22 @@
...
@@ -10,26 +10,22 @@
namespace
mace
{
namespace
mace
{
enum
Padding
{
enum
Padding
{
VALID
=
0
,
// No padding
VALID
=
0
,
// No padding
SAME
=
1
,
// Pads with half the filter size (rounded down) on both sides
SAME
=
1
,
// Pads with half the filter size (rounded down) on both sides
FULL
=
2
,
// Pads with one less than the filter size on both sides
FULL
=
2
,
// Pads with one less than the filter size on both sides
};
};
namespace
kernels
{
namespace
kernels
{
void
CalcPaddingAndOutputSize
(
const
index_t
*
input_shape
,
// NCHW
void
CalcPaddingAndOutputSize
(
const
index_t
*
input_shape
,
// NCHW
const
index_t
*
filter_shape
,
// OIHW
const
index_t
*
filter_shape
,
// OIHW
const
int
*
dilations
,
const
int
*
dilations
,
const
int
*
strides
,
const
int
*
strides
,
Padding
padding
,
index_t
*
output_shape
,
Padding
padding
,
index_t
*
output_shape
,
int
*
padding_size
);
int
*
padding_size
);
void
ConstructInputWithPadding
(
const
float
*
input
,
void
ConstructInputWithPadding
(
const
float
*
input
,
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
int
*
paddings
,
Tensor
*
output_tensor
);
const
int
*
paddings
,
}
// namespace kernels
Tensor
*
output_tensor
);
}
// namespace mace
}
// namespace kernels
}
// namespace mace
#endif // MACE_KERNELS_CONV_POOL_2D_UTIL_H_
#endif
// MACE_KERNELS_CONV_POOL_2D_UTIL_H_
mace/kernels/neon/addn_neon.cc
浏览文件 @
578b382a
...
@@ -2,16 +2,15 @@
...
@@ -2,16 +2,15 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include <arm_neon.h>
#include "mace/kernels/addn.h"
#include "mace/kernels/addn.h"
#include <arm_neon.h>
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
>
template
<
>
void
AddNFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
vector
<
const
float
*>&
inputs
,
void
AddNFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
float
*
output
,
const
vector
<
const
float
*>
&
inputs
,
float
*
output
,
index_t
size
)
{
index_t
size
)
{
// TODO: neon mem copy
// TODO: neon mem copy
memset
(
output
,
0
,
size
*
sizeof
(
float
));
memset
(
output
,
0
,
size
*
sizeof
(
float
));
int
n
=
inputs
.
size
();
int
n
=
inputs
.
size
();
...
@@ -22,7 +21,7 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>
...
@@ -22,7 +21,7 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>
}
}
int64_t
element_per_group
=
size
/
groups
;
int64_t
element_per_group
=
size
/
groups
;
#pragma omp parallel for num_threads(1) // no significant performance improve
#pragma omp parallel for num_threads(1)
// no significant performance improve
for
(
int64_t
i
=
0
;
i
<
size
;
i
+=
element_per_group
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
i
+=
element_per_group
)
{
int64_t
count
=
std
::
min
(
element_per_group
,
size
-
i
);
int64_t
count
=
std
::
min
(
element_per_group
,
size
-
i
);
int
nn
=
count
>>
2
;
int
nn
=
count
>>
2
;
...
@@ -48,5 +47,5 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>
...
@@ -48,5 +47,5 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>
}
}
};
};
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/kernels/neon/batch_norm_neon.cc
浏览文件 @
578b382a
...
@@ -2,29 +2,25 @@
...
@@ -2,29 +2,25 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include <arm_neon.h>
#include "mace/kernels/batch_norm.h"
#include "mace/kernels/batch_norm.h"
#include <arm_neon.h>
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
>
template
<
>
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
void
BatchNormFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
scale
,
const
float
*
input
,
const
float
*
scale
,
const
float
*
offset
,
const
float
*
offset
,
const
float
*
mean
,
const
float
*
var
,
const
index_t
n
,
const
index_t
channel
,
const
float
*
mean
,
const
index_t
sample_size
,
float
*
output
)
{
const
float
*
var
,
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
const
index_t
n
,
// The calculation formula for inference is
const
index_t
channel
,
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
const
index_t
sample_size
,
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon}
float
*
output
)
{
// }
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// The calculation formula for inference is
// new_offset = \offset - mean * common_val;
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// Y = new_scale * X + new_offset;
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
float
new_scale
,
new_offset
;
float
new_scale
,
new_offset
;
index_t
count
=
sample_size
>>
2
;
index_t
count
=
sample_size
>>
2
;
index_t
remain_count
=
sample_size
-
(
count
<<
2
);
index_t
remain_count
=
sample_size
-
(
count
<<
2
);
...
@@ -36,8 +32,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
...
@@ -36,8 +32,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
float32x4_t
new_scale_f
=
vdupq_n_f32
(
new_scale
);
float32x4_t
new_scale_f
=
vdupq_n_f32
(
new_scale
);
float32x4_t
new_offset_f
=
vdupq_n_f32
(
new_offset
);
float32x4_t
new_offset_f
=
vdupq_n_f32
(
new_offset
);
for
(
index_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
n
;
++
i
)
{
const
float
*
input_sample_ptr
=
input
+
pos
;
const
float
*
input_sample_ptr
=
input
+
pos
;
float
*
output_sample_ptr
=
output
+
pos
;
float
*
output_sample_ptr
=
output
+
pos
;
for
(
index_t
j
=
0
;
j
<
count
;
++
j
)
{
for
(
index_t
j
=
0
;
j
<
count
;
++
j
)
{
float32x4_t
input_f
=
vld1q_f32
(
input_sample_ptr
);
float32x4_t
input_f
=
vld1q_f32
(
input_sample_ptr
);
...
@@ -57,5 +53,5 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
...
@@ -57,5 +53,5 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
}
}
};
};
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/kernels/neon/conv_2d_neon.cc
浏览文件 @
578b382a
...
@@ -20,62 +20,39 @@ extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape,
...
@@ -20,62 +20,39 @@ extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape,
const
float
*
filter
,
const
float
*
bias
,
const
float
*
filter
,
const
float
*
bias
,
float
*
output
,
const
index_t
*
output_shape
);
float
*
output
,
const
index_t
*
output_shape
);
template
<
>
template
<
>
void
Conv2dFunctor
<
DeviceType
::
NEON
,
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
// NCHW
float
>::
const
index_t
*
input_shape
,
operator
()(
const
float
*
input
,
// NCHW
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
input_shape
,
const
index_t
*
filter_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
const
index_t
*
filter_shape
,
float
*
output
,
// NCHW
const
float
*
bias
,
// c_out
const
index_t
*
output_shape
)
{
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
typedef
void
(
*
Conv2dNeonFunction
)(
const
float
*
input
,
// NCHW
typedef
void
(
*
Conv2dNeonFunction
)(
const
index_t
*
input_shape
,
const
float
*
input
,
// NCHW
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
index_t
*
input_shape
,
const
float
*
bias
,
// c_out
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
float
*
output
,
// NCHW
const
float
*
bias
,
// c_out
const
index_t
*
output_shape
);
float
*
output
,
// NCHW
const
index_t
*
output_shape
);
// Selection matrix: kernel_size x stride_size
// Selection matrix: kernel_size x stride_size
static
const
Conv2dNeonFunction
selector
[
5
][
2
]
=
{
static
const
Conv2dNeonFunction
selector
[
5
][
2
]
=
{
{
{
Conv2dNeonK1x1S1
,
nullptr
},
Conv2dNeonK1x1S1
,
{
nullptr
,
nullptr
},
nullptr
{
Conv2dNeonK3x3S1
,
nullptr
},
},
{
nullptr
,
nullptr
},
{
{
Conv2dNeonK5x5S1
,
nullptr
}};
nullptr
,
nullptr
},
{
Conv2dNeonK3x3S1
,
nullptr
},
{
nullptr
,
nullptr
},
{
Conv2dNeonK5x5S1
,
nullptr
}
};
// not implement yet
// not implement yet
index_t
kernel_h
=
filter_shape
[
2
];
index_t
kernel_h
=
filter_shape
[
2
];
index_t
kernel_w
=
filter_shape
[
3
];
index_t
kernel_w
=
filter_shape
[
3
];
if
(
kernel_h
!=
kernel_w
||
kernel_h
>
5
||
if
(
kernel_h
!=
kernel_w
||
kernel_h
>
5
||
strides_
[
0
]
!=
strides_
[
1
]
||
strides_
[
0
]
!=
strides_
[
1
]
||
strides_
[
0
]
>
2
||
strides_
[
0
]
>
2
||
dilations_
[
0
]
!=
1
||
dilations_
[
1
]
!=
1
||
dilations_
[
0
]
!=
1
||
dilations_
[
1
]
!=
1
||
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
]
==
nullptr
)
{
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
]
==
nullptr
)
{
LOG
(
WARNING
)
<<
"NEON conv2d kernel not implementated, using slow vesion"
;
LOG
(
WARNING
)
<<
"NEON conv2d kernel not implementated, using slow vesion"
;
Conv2dFunctor
<
DeviceType
::
CPU
,
float
>
(
strides_
,
paddings_
,
dilations_
)(
Conv2dFunctor
<
DeviceType
::
CPU
,
float
>
(
strides_
,
paddings_
,
dilations_
)(
input
,
input
,
input_shape
,
filter
,
filter_shape
,
bias
,
output
,
output_shape
);
input_shape
,
filter
,
filter_shape
,
bias
,
output
,
output_shape
);
return
;
return
;
}
}
...
@@ -87,13 +64,8 @@ void Conv2dFunctor<DeviceType::NEON,
...
@@ -87,13 +64,8 @@ void Conv2dFunctor<DeviceType::NEON,
input_shape
=
padded_input
.
shape
().
data
();
input_shape
=
padded_input
.
shape
().
data
();
}
}
auto
conv2d_neon_func
=
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
];
auto
conv2d_neon_func
=
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
];
conv2d_neon_func
(
input
,
conv2d_neon_func
(
input
,
input_shape
,
filter
,
bias
,
output
,
output_shape
);
input_shape
,
filter
,
bias
,
output
,
output_shape
);
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/neon/conv_2d_neon_1x1.cc
浏览文件 @
578b382a
...
@@ -8,25 +8,24 @@
...
@@ -8,25 +8,24 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
void
Conv2dNeonK1x1S1
(
const
float
*
input
,
// NCHW
void
Conv2dNeonK1x1S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
const
index_t
batch
=
output_shape
[
0
];
const
index_t
batch
=
output_shape
[
0
];
const
index_t
channels
=
output_shape
[
1
];
const
index_t
channels
=
output_shape
[
1
];
const
index_t
height
=
output_shape
[
2
];
const
index_t
height
=
output_shape
[
2
];
const
index_t
width
=
output_shape
[
3
];
const
index_t
width
=
output_shape
[
3
];
const
index_t
input_batch
=
input_shape
[
0
];
const
index_t
input_batch
=
input_shape
[
0
];
const
index_t
input_channels
=
input_shape
[
1
];
const
index_t
input_channels
=
input_shape
[
1
];
const
index_t
input_height
=
input_shape
[
2
];
const
index_t
input_height
=
input_shape
[
2
];
const
index_t
input_width
=
input_shape
[
3
];
const
index_t
input_width
=
input_shape
[
3
];
MACE_CHECK
(
input_batch
==
batch
&&
MACE_CHECK
(
input_batch
==
batch
&&
input_height
==
height
&&
input_height
==
height
&&
input_width
==
width
);
input_width
==
width
);
const
index_t
total_pixels
=
height
*
width
;
const
index_t
total_pixels
=
height
*
width
;
// Process 4 * 2 = 8 pixels for each innermost loop
// Process 4 * 2 = 8 pixels for each innermost loop
...
@@ -37,17 +36,18 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
...
@@ -37,17 +36,18 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
// benchmark omp collapsed(2)
// benchmark omp collapsed(2)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
const
float
*
filter_ptr
=
filter
;
const
float
*
filter_ptr
=
filter
;
#pragma omp parallel for
#pragma omp parallel for
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
// TODO Will GCC opt these out?
// TODO Will GCC opt these out?
float
*
channel_output_start
=
float
*
channel_output_start
=
output
+
n
*
channels
*
height
*
width
+
c
*
height
*
width
;
output
+
n
*
channels
*
height
*
width
+
c
*
height
*
width
;
const
float
*
input_ptr
=
input
+
n
*
input_channels
*
input_height
*
input_width
;
const
float
*
input_ptr
=
input
+
n
*
input_channels
*
input_height
*
input_width
;
// Fill with bias
// Fill with bias
float
*
output_ptr
=
channel_output_start
;
float
*
output_ptr
=
channel_output_start
;
for
(
index_t
ptr
=
0
;
ptr
<
total_pixels
;
++
ptr
)
{
for
(
index_t
ptr
=
0
;
ptr
<
total_pixels
;
++
ptr
)
{
output_ptr
[
ptr
]
=
bias
[
c
];
// TODO can we avoid this?
output_ptr
[
ptr
]
=
bias
[
c
];
// TODO can we avoid this?
}
}
index_t
inc
=
0
;
index_t
inc
=
0
;
...
@@ -55,15 +55,14 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
...
@@ -55,15 +55,14 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
for
(;
inc
+
3
<
input_channels
;
inc
+=
4
)
{
for
(;
inc
+
3
<
input_channels
;
inc
+=
4
)
{
float
*
output_ptr
=
channel_output_start
;
float
*
output_ptr
=
channel_output_start
;
// The begining of each input feature map channel
// The begining of each input feature map channel
MACE_ASSERT
(
input_ptr
==
input
+
n
*
input_channels
*
MACE_ASSERT
(
input_ptr
==
input_height
*
input_width
+
input
+
n
*
input_channels
*
input_height
*
input_width
+
inc
*
input_height
*
input_width
);
inc
*
input_height
*
input_width
);
const
float
*
input_ptr1
=
input_ptr
+
total_pixels
;
const
float
*
input_ptr1
=
input_ptr
+
total_pixels
;
const
float
*
input_ptr2
=
input_ptr1
+
total_pixels
;
const
float
*
input_ptr2
=
input_ptr1
+
total_pixels
;
const
float
*
input_ptr3
=
input_ptr2
+
total_pixels
;
const
float
*
input_ptr3
=
input_ptr2
+
total_pixels
;
// filter is in c_out, c_in, 1, 1 order
// filter is in c_out, c_in, 1, 1 order
MACE_ASSERT
(
filter_ptr
==
filter
+
c
*
input_channels
+
inc
);
MACE_ASSERT
(
filter_ptr
==
filter
+
c
*
input_channels
+
inc
);
const
float
k0
=
filter_ptr
[
0
];
const
float
k0
=
filter_ptr
[
0
];
...
@@ -113,7 +112,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
...
@@ -113,7 +112,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
vst1q_f32
(
output_ptr
+
4
,
out4
);
vst1q_f32
(
output_ptr
+
4
,
out4
);
output_ptr
+=
8
;
output_ptr
+=
8
;
input_ptr
+=
8
;
input_ptr
+=
8
;
input_ptr1
+=
8
;
input_ptr1
+=
8
;
input_ptr2
+=
8
;
input_ptr2
+=
8
;
input_ptr3
+=
8
;
input_ptr3
+=
8
;
...
@@ -121,7 +120,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
...
@@ -121,7 +120,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
// Process the remaining pixels
// Process the remaining pixels
index_t
remaining_pixels
=
loop_remaining
;
index_t
remaining_pixels
=
loop_remaining
;
for
(;
remaining_pixels
>
0
;
--
remaining_pixels
)
{
for
(;
remaining_pixels
>
0
;
--
remaining_pixels
)
{
const
float
mul
=
*
input_ptr
*
k0
;
const
float
mul
=
*
input_ptr
*
k0
;
const
float
mul1
=
*
input_ptr1
*
k1
;
const
float
mul1
=
*
input_ptr1
*
k1
;
const
float
mul2
=
*
input_ptr2
*
k2
;
const
float
mul2
=
*
input_ptr2
*
k2
;
const
float
mul3
=
*
input_ptr3
*
k3
;
const
float
mul3
=
*
input_ptr3
*
k3
;
...
@@ -141,9 +140,9 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
...
@@ -141,9 +140,9 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
// Process the remaining channels
// Process the remaining channels
for
(;
inc
<
input_channels
;
++
inc
)
{
for
(;
inc
<
input_channels
;
++
inc
)
{
float
*
output_ptr
=
channel_output_start
;
float
*
output_ptr
=
channel_output_start
;
MACE_ASSERT
(
input_ptr
==
input
+
n
*
input_channels
*
MACE_ASSERT
(
input_ptr
==
input_height
*
input_width
+
input
+
n
*
input_channels
*
input_height
*
input_width
+
inc
*
input_height
*
input_width
);
inc
*
input_height
*
input_width
);
MACE_ASSERT
(
filter_ptr
==
filter
+
c
*
input_channels
+
inc
);
MACE_ASSERT
(
filter_ptr
==
filter
+
c
*
input_channels
+
inc
);
const
float
k0
=
filter_ptr
[
0
];
const
float
k0
=
filter_ptr
[
0
];
...
@@ -166,13 +165,13 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
...
@@ -166,13 +165,13 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
vst1q_f32
(
output_ptr
+
4
,
out4
);
vst1q_f32
(
output_ptr
+
4
,
out4
);
output_ptr
+=
8
;
output_ptr
+=
8
;
input_ptr
+=
8
;
input_ptr
+=
8
;
}
}
// Process the remaining pixels
// Process the remaining pixels
index_t
remaining_pixels
=
loop_remaining
;
index_t
remaining_pixels
=
loop_remaining
;
for
(;
remaining_pixels
>
0
;
--
remaining_pixels
)
{
for
(;
remaining_pixels
>
0
;
--
remaining_pixels
)
{
const
float
mul
=
*
input_ptr
*
k0
;
const
float
mul
=
*
input_ptr
*
k0
;
*
output_ptr
+=
mul
;
*
output_ptr
+=
mul
;
++
output_ptr
;
++
output_ptr
;
...
@@ -183,5 +182,5 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
...
@@ -183,5 +182,5 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
}
}
};
};
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/neon/conv_2d_neon_3x3.cc
浏览文件 @
578b382a
...
@@ -10,78 +10,81 @@ namespace kernels {
...
@@ -10,78 +10,81 @@ namespace kernels {
static
const
int
kRegisterSize
=
4
;
static
const
int
kRegisterSize
=
4
;
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
int
batch
=
output_shape
[
0
];
int
batch
=
output_shape
[
0
];
int
channels
=
output_shape
[
1
];
int
channels
=
output_shape
[
1
];
int
height
=
output_shape
[
2
];
int
height
=
output_shape
[
2
];
int
width
=
output_shape
[
3
];
int
width
=
output_shape
[
3
];
int
input_batch
=
input_shape
[
0
];
int
input_batch
=
input_shape
[
0
];
int
input_channels
=
input_shape
[
1
];
int
input_channels
=
input_shape
[
1
];
int
input_height
=
input_shape
[
2
];
int
input_height
=
input_shape
[
2
];
int
input_width
=
input_shape
[
3
];
int
input_width
=
input_shape
[
3
];
int
kernel_h
=
3
;
int
kernel_h
=
3
;
int
kernel_w
=
3
;
int
kernel_w
=
3
;
int
height_count
=
(
height
>>
1
)
<<
1
;
int
height_count
=
(
height
>>
1
)
<<
1
;
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
float
*
output_ptr_base
=
output
+
b
*
channels
*
height
*
width
;
float
*
output_ptr_base
=
output
+
b
*
channels
*
height
*
width
;
for
(
int
oc
=
0
;
oc
<
channels
;
++
oc
)
{
for
(
int
oc
=
0
;
oc
<
channels
;
++
oc
)
{
const
float
*
filter_ptr
=
filter
+
oc
*
input_channels
*
kernel_h
*
kernel_w
;
const
float
*
filter_ptr
=
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
filter
+
oc
*
input_channels
*
kernel_h
*
kernel_w
;
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
float
*
output_ptr
=
output_ptr_base
+
oc
*
height
*
width
;
float
*
output_ptr
=
output_ptr_base
+
oc
*
height
*
width
;
std
::
fill
(
output_ptr
,
output_ptr
+
height
*
width
,
bias
[
oc
]);
std
::
fill
(
output_ptr
,
output_ptr
+
height
*
width
,
bias
[
oc
]);
for
(
int
ic
=
0
;
ic
<
input_channels
;
++
ic
)
{
for
(
int
ic
=
0
;
ic
<
input_channels
;
++
ic
)
{
float32x4_t
filter0
=
vld1q_f32
(
filter_ptr
);
float32x4_t
filter0
=
vld1q_f32
(
filter_ptr
);
float32x4_t
filter3
=
vld1q_f32
(
filter_ptr
+
3
);
float32x4_t
filter3
=
vld1q_f32
(
filter_ptr
+
3
);
float32x4_t
filter6
=
vld1q_f32
(
filter_ptr
+
6
);
float32x4_t
filter6
=
vld1q_f32
(
filter_ptr
+
6
);
const
float
*
row
[
kRegisterSize
]
=
{
const
float
*
row
[
kRegisterSize
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
3
*
input_width
input_ptr
+
3
*
input_width
};
};
float
*
output_ptr1
=
output_ptr
;
float
*
output_ptr1
=
output_ptr
;
float
*
output_ptr2
=
output_ptr
+
width
;
float
*
output_ptr2
=
output_ptr
+
width
;
for
(
int
h
=
0
;
h
<
height_count
;
h
+=
2
)
{
for
(
int
h
=
0
;
h
<
height_count
;
h
+=
2
)
{
int
count
=
width
>>
2
;
int
count
=
width
>>
2
;
int
remain_count
=
width
&
3
;
int
remain_count
=
width
&
3
;
for
(;
count
>
0
;
--
count
)
{
for
(;
count
>
0
;
--
count
)
{
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
sum1
=
vdupq_n_f32
(
.0
f
);
float32x4_t
sum1
=
vdupq_n_f32
(
.0
f
);
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
//0123
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
//4567
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
// 4567
float32x4_t
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
//1234
float32x4_t
row0_ext_1
=
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
//2345
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
//0123
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
//4567
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
// 4567
float32x4_t
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
//1234
float32x4_t
row1_ext_1
=
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
//2345
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
//
0123
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
//
0123
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
//
4567
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
//
4567
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
//
1234
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
//
1234
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
//
2345
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
//
2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
...
@@ -96,10 +99,10 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
...
@@ -96,10 +99,10 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_1
,
filter3
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_1
,
filter3
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_2
,
filter3
,
2
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_2
,
filter3
,
2
);
row1_ext_0
=
vld1q_f32
(
row
[
3
]);
//
0123
row1_ext_0
=
vld1q_f32
(
row
[
3
]);
//
0123
row1_latter
=
vld1q_f32
(
row
[
3
]
+
kRegisterSize
);
//
4567
row1_latter
=
vld1q_f32
(
row
[
3
]
+
kRegisterSize
);
//
4567
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
//
1234
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
//
1234
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
//
2345
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
//
2345
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_0
,
filter6
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_0
,
filter6
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_1
,
filter6
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_1
,
filter6
,
1
);
...
@@ -114,15 +117,15 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
...
@@ -114,15 +117,15 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
output_ptr1
+=
kRegisterSize
;
output_ptr1
+=
kRegisterSize
;
output_ptr2
+=
kRegisterSize
;
output_ptr2
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
kRegisterSize
;
row
[
i
]
+=
kRegisterSize
;
}
}
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
//
0123
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
//
0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
//
0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
//
0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
//
0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
//
0123
float32x4_t
row3
=
vld1q_f32
(
row
[
3
]);
//
0123
float32x4_t
row3
=
vld1q_f32
(
row
[
3
]);
//
0123
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
...
@@ -138,13 +141,13 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
...
@@ -138,13 +141,13 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
++
output_ptr1
;
++
output_ptr1
;
++
output_ptr2
;
++
output_ptr2
;
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
1
;
row
[
i
]
+=
1
;
}
}
}
}
output_ptr1
+=
width
;
output_ptr1
+=
width
;
output_ptr2
+=
width
;
output_ptr2
+=
width
;
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
2
+
input_width
;
row
[
i
]
+=
2
+
input_width
;
}
}
}
}
...
@@ -152,30 +155,34 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
...
@@ -152,30 +155,34 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
if
(
height
!=
height_count
)
{
if
(
height
!=
height_count
)
{
int
count
=
width
>>
2
;
int
count
=
width
>>
2
;
int
remain_count
=
width
&
3
;
int
remain_count
=
width
&
3
;
for
(;
count
>
0
;
--
count
)
{
for
(;
count
>
0
;
--
count
)
{
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
//0123
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
//4567
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
// 4567
float32x4_t
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
//1234
float32x4_t
row0_ext_1
=
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
//2345
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
//0123
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
//4567
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
// 4567
float32x4_t
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
//1234
float32x4_t
row1_ext_1
=
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
//2345
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
//
0123
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
//
0123
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
//
4567
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
//
4567
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
//
1234
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
//
1234
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
//
2345
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
//
2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
...
@@ -185,14 +192,14 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
...
@@ -185,14 +192,14 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
output_row0
=
vaddq_f32
(
output_row0
,
sum0
);
output_row0
=
vaddq_f32
(
output_row0
,
sum0
);
vst1q_f32
(
output_ptr1
,
output_row0
);
vst1q_f32
(
output_ptr1
,
output_row0
);
output_ptr1
+=
kRegisterSize
;
output_ptr1
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row
[
i
]
+=
kRegisterSize
;
row
[
i
]
+=
kRegisterSize
;
}
}
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
//
0123
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
//
0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
//
0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
//
0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
//
0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
//
0123
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
...
@@ -201,7 +208,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
...
@@ -201,7 +208,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
*
output_ptr1
=
vaddvq_f32
(
sum
);
*
output_ptr1
=
vaddvq_f32
(
sum
);
++
output_ptr1
;
++
output_ptr1
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row
[
i
]
+=
1
;
row
[
i
]
+=
1
;
}
}
}
}
...
@@ -213,5 +220,5 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
...
@@ -213,5 +220,5 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/neon/conv_2d_neon_5x5.cc
浏览文件 @
578b382a
...
@@ -10,11 +10,11 @@
...
@@ -10,11 +10,11 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
void
Conv2dNeonK5x5S1
(
const
float
*
input
,
// NCHW
void
Conv2dNeonK5x5S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
const
index_t
batch
=
output_shape
[
0
];
const
index_t
batch
=
output_shape
[
0
];
const
index_t
channels
=
output_shape
[
1
];
const
index_t
channels
=
output_shape
[
1
];
...
@@ -30,17 +30,17 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
...
@@ -30,17 +30,17 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
const
index_t
input_total_pixels_per_channel
=
input_height
*
input_width
;
const
index_t
input_total_pixels_per_channel
=
input_height
*
input_width
;
const
index_t
output_total_pixels_per_channel
=
height
*
width
;
const
index_t
output_total_pixels_per_channel
=
height
*
width
;
const
index_t
input_total_pixels_per_batch
=
input_total_pixels_per_channel
const
index_t
input_total_pixels_per_batch
=
*
input_channels
;
input_total_pixels_per_channel
*
input_channels
;
const
index_t
output_total_pixels_per_batch
=
output_total_pixels_per_channel
const
index_t
output_total_pixels_per_batch
=
*
channels
;
output_total_pixels_per_channel
*
channels
;
const
index_t
patch_size
=
input_channels
*
25
;
const
index_t
patch_size
=
input_channels
*
25
;
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
float
*
output_ptr
=
output
+
n
*
output_total_pixels_per_batch
float
*
output_ptr
=
output
+
n
*
output_total_pixels_per_batch
+
+
c
*
output_total_pixels_per_channel
;
c
*
output_total_pixels_per_channel
;
const
float
*
input_ptr
=
input
+
n
*
input_total_pixels_per_batch
;
const
float
*
input_ptr
=
input
+
n
*
input_total_pixels_per_batch
;
// Fill with bias
// Fill with bias
...
@@ -53,7 +53,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
...
@@ -53,7 +53,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
float
*
outptr2
=
outptr
+
width
;
float
*
outptr2
=
outptr
+
width
;
const
float
*
inptr
=
input_ptr
+
inc
*
input_total_pixels_per_channel
;
const
float
*
inptr
=
input_ptr
+
inc
*
input_total_pixels_per_channel
;
const
float
*
filter_ptr
=
filter
+
c
*
patch_size
+
inc
*
25
;
const
float
*
filter_ptr
=
filter
+
c
*
patch_size
+
inc
*
25
;
const
float
*
r0
=
inptr
;
const
float
*
r0
=
inptr
;
const
float
*
r1
=
inptr
+
input_width
;
const
float
*
r1
=
inptr
+
input_width
;
...
@@ -246,8 +246,8 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
...
@@ -246,8 +246,8 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
sum2
=
r5
[
4
]
*
k4
[
4
];
sum2
=
r5
[
4
]
*
k4
[
4
];
float32x2_t
_ss
=
vadd_f32
(
vget_low_f32
(
_sum
),
vget_high_f32
(
_sum
));
float32x2_t
_ss
=
vadd_f32
(
vget_low_f32
(
_sum
),
vget_high_f32
(
_sum
));
float32x2_t
float32x2_t
_ss2
=
_ss2
=
vadd_f32
(
vget_low_f32
(
_sum2
),
vget_high_f32
(
_sum2
));
vadd_f32
(
vget_low_f32
(
_sum2
),
vget_high_f32
(
_sum2
));
float32x2_t
_ss_ss2
=
vpadd_f32
(
_ss
,
_ss2
);
float32x2_t
_ss_ss2
=
vpadd_f32
(
_ss
,
_ss2
);
sum
+=
vget_lane_f32
(
_ss_ss2
,
0
);
sum
+=
vget_lane_f32
(
_ss_ss2
,
0
);
...
@@ -414,7 +414,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
...
@@ -414,7 +414,7 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
#endif // MACE_KERNELS_NEON_CONV_2D_NEON_5X5_H_
#endif
// MACE_KERNELS_NEON_CONV_2D_NEON_5X5_H_
mace/kernels/neon/max_pooling_neon_2x2.cc
浏览文件 @
578b382a
...
@@ -2,19 +2,17 @@
...
@@ -2,19 +2,17 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include <arm_neon.h>
#include <float.h>
#include <float.h>
#include <limits>
#include <limits>
#include <arm_neon.h>
#include "mace/core/common.h"
#include "mace/core/common.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
void
PoolingMaxNeonK2x2S2x2
(
const
float
*
input
,
void
PoolingMaxNeonK2x2S2x2
(
const
float
*
input
,
const
index_t
*
in_shape
,
const
index_t
*
in_shape
,
float
*
output
,
const
index_t
*
out_shape
,
float
*
output
,
const
index_t
*
out_shape
,
const
int
*
paddings
)
{
const
int
*
paddings
)
{
index_t
batch
=
in_shape
[
0
];
index_t
batch
=
in_shape
[
0
];
index_t
channels
=
in_shape
[
1
];
index_t
channels
=
in_shape
[
1
];
...
@@ -44,7 +42,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
...
@@ -44,7 +42,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
int
w
=
0
;
int
w
=
0
;
int
num_vectors
=
0
;
int
num_vectors
=
0
;
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r1
=
r0
+
in_width
;
r1
=
r0
+
in_width
;
if
(
padding_left
>
0
)
{
if
(
padding_left
>
0
)
{
...
@@ -86,8 +84,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
...
@@ -86,8 +84,7 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
for
(
int
kw
=
0
;
kw
<
2
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
2
;
++
kw
)
{
int
inh
=
h
*
2
-
padding_top
+
kh
;
int
inh
=
h
*
2
-
padding_top
+
kh
;
int
inw
=
w
*
2
-
padding_left
+
kw
;
int
inw
=
w
*
2
-
padding_left
+
kw
;
if
(
inh
>=
0
&&
inh
<
in_height
&&
if
(
inh
>=
0
&&
inh
<
in_height
&&
inw
>=
0
&&
inw
<
in_width
)
{
inw
>=
0
&&
inw
<
in_width
)
{
max
=
std
::
max
(
max
,
input
[
input_offset
+
inh
*
in_width
+
inw
]);
max
=
std
::
max
(
max
,
input
[
input_offset
+
inh
*
in_width
+
inw
]);
}
}
}
}
...
@@ -104,10 +101,8 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
...
@@ -104,10 +101,8 @@ void PoolingMaxNeonK2x2S2x2(const float *input,
}
}
// assume the input has already been padded
// assume the input has already been padded
void
PoolingMaxNeonK2x2S2x2Padded
(
const
float
*
input
,
void
PoolingMaxNeonK2x2S2x2Padded
(
const
float
*
input
,
const
index_t
*
in_shape
,
const
index_t
*
in_shape
,
float
*
output
,
const
index_t
*
out_shape
)
{
float
*
output
,
const
index_t
*
out_shape
)
{
index_t
batch
=
in_shape
[
0
];
index_t
batch
=
in_shape
[
0
];
index_t
channels
=
in_shape
[
1
];
index_t
channels
=
in_shape
[
1
];
index_t
in_height
=
in_shape
[
2
];
index_t
in_height
=
in_shape
[
2
];
...
...
mace/kernels/neon/max_pooling_neon_3x3.cc
浏览文件 @
578b382a
...
@@ -2,19 +2,17 @@
...
@@ -2,19 +2,17 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include <arm_neon.h>
#include <float.h>
#include <float.h>
#include <limits>
#include <limits>
#include <arm_neon.h>
#include "mace/core/common.h"
#include "mace/core/common.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
void
PoolingMaxNeonK3x3S2x2
(
const
float
*
input
,
void
PoolingMaxNeonK3x3S2x2
(
const
float
*
input
,
const
index_t
*
in_shape
,
const
index_t
*
in_shape
,
float
*
output
,
const
index_t
*
out_shape
,
float
*
output
,
const
index_t
*
out_shape
,
const
int
*
paddings
)
{
const
int
*
paddings
)
{
index_t
batch
=
in_shape
[
0
];
index_t
batch
=
in_shape
[
0
];
index_t
channels
=
in_shape
[
1
];
index_t
channels
=
in_shape
[
1
];
...
@@ -44,7 +42,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
...
@@ -44,7 +42,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
int
num_vectors
=
0
;
int
num_vectors
=
0
;
const
float
*
r0
,
*
r1
,
*
r2
;
const
float
*
r0
,
*
r1
,
*
r2
;
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
if
(
!
((
h
==
0
&&
padding_top
>
0
)
||
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
(
h
==
out_height
-
1
&&
padding_bottom
>
0
)))
{
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r0
=
input
+
input_offset
+
(
h
*
2
-
padding_top
)
*
in_width
;
r1
=
r0
+
in_width
;
r1
=
r0
+
in_width
;
r2
=
r1
+
in_width
;
r2
=
r1
+
in_width
;
...
@@ -112,8 +110,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
...
@@ -112,8 +110,7 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
for
(
int
kw
=
0
;
kw
<
3
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
3
;
++
kw
)
{
int
inh
=
h
*
2
-
padding_top
+
kh
;
int
inh
=
h
*
2
-
padding_top
+
kh
;
int
inw
=
w
*
2
-
padding_left
+
kw
;
int
inw
=
w
*
2
-
padding_left
+
kw
;
if
(
inh
>=
0
&&
inh
<
in_height
&&
if
(
inh
>=
0
&&
inh
<
in_height
&&
inw
>=
0
&&
inw
<
in_width
)
{
inw
>=
0
&&
inw
<
in_width
)
{
max
=
std
::
max
(
max
,
input
[
input_offset
+
inh
*
in_width
+
inw
]);
max
=
std
::
max
(
max
,
input
[
input_offset
+
inh
*
in_width
+
inw
]);
}
}
}
}
...
@@ -130,10 +127,8 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
...
@@ -130,10 +127,8 @@ void PoolingMaxNeonK3x3S2x2(const float *input,
}
}
// assume the input has already been padded
// assume the input has already been padded
void
PoolingMaxNeonK3x3S2x2Padded
(
const
float
*
input
,
void
PoolingMaxNeonK3x3S2x2Padded
(
const
float
*
input
,
const
index_t
*
in_shape
,
const
index_t
*
in_shape
,
float
*
output
,
const
index_t
*
out_shape
)
{
float
*
output
,
const
index_t
*
out_shape
)
{
index_t
batch
=
in_shape
[
0
];
index_t
batch
=
in_shape
[
0
];
index_t
channels
=
in_shape
[
1
];
index_t
channels
=
in_shape
[
1
];
index_t
in_height
=
in_shape
[
2
];
index_t
in_height
=
in_shape
[
2
];
...
@@ -218,5 +213,5 @@ void PoolingMaxNeonK3x3S2x2Padded(const float *input,
...
@@ -218,5 +213,5 @@ void PoolingMaxNeonK3x3S2x2Padded(const float *input,
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/neon/pooling_neon.cc
浏览文件 @
578b382a
...
@@ -2,45 +2,36 @@
...
@@ -2,45 +2,36 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include <arm_neon.h>
#include "mace/kernels/pooling.h"
#include "mace/kernels/pooling.h"
#include <arm_neon.h>
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
extern
void
PoolingMaxNeonK2x2S2x2
(
const
float
*
input
,
extern
void
PoolingMaxNeonK2x2S2x2
(
const
float
*
input
,
const
index_t
*
in_shape
,
const
index_t
*
in_shape
,
float
*
output
,
const
index_t
*
out_shape
,
float
*
output
,
const
index_t
*
out_shape
,
const
int
*
paddings
);
const
int
*
paddings
);
extern
void
PoolingMaxNeonK3x3S2x2
(
const
float
*
input
,
extern
void
PoolingMaxNeonK3x3S2x2
(
const
float
*
input
,
const
index_t
*
in_shape
,
const
index_t
*
in_shape
,
float
*
output
,
const
index_t
*
out_shape
,
float
*
output
,
const
index_t
*
out_shape
,
const
int
*
paddings
);
const
int
*
paddings
);
#ifdef __COPY_MAKE_PADDING
#ifdef __COPY_MAKE_PADDING
extern
void
PoolingMaxNeonK2x2S2x2Padded
(
const
float
*
input
,
extern
void
PoolingMaxNeonK2x2S2x2Padded
(
const
float
*
input
,
const
index_t
*
in_shape
,
const
index_t
*
in_shape
,
float
*
output
,
float
*
output
,
const
index_t
*
out_shape
);
const
index_t
*
out_shape
);
extern
void
PoolingMaxNeonK3x3S2x2Padded
(
const
float
*
input
,
extern
void
PoolingMaxNeonK3x3S2x2Padded
(
const
float
*
input
,
const
index_t
*
in_shape
,
float
*
output
,
const
index_t
*
in_shape
,
const
index_t
*
out_shape
);
float
*
output
,
const
index_t
*
out_shape
);
#endif
#endif
template
<
>
template
<
>
void
PoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
PoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
float
*
output
,
const
index_t
*
input_shape
,
float
*
output
,
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
if
(
kernels_
[
0
]
==
2
&&
kernels_
[
1
]
==
2
&&
if
(
kernels_
[
0
]
==
2
&&
kernels_
[
1
]
==
2
&&
strides_
[
0
]
==
2
&&
strides_
[
0
]
==
2
&&
strides_
[
1
]
==
2
&&
strides_
[
1
]
==
2
&&
pooling_type_
==
MAX
)
{
pooling_type_
==
MAX
)
{
#ifdef __COPY_MAKE_PADDING
#ifdef __COPY_MAKE_PADDING
Tensor
padded_input
;
Tensor
padded_input
;
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
,
&
padded_input
);
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
,
&
padded_input
);
...
@@ -50,9 +41,8 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
...
@@ -50,9 +41,8 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
#else
#else
PoolingMaxNeonK2x2S2x2
(
input
,
input_shape
,
output
,
output_shape
,
paddings_
);
PoolingMaxNeonK2x2S2x2
(
input
,
input_shape
,
output
,
output_shape
,
paddings_
);
#endif
#endif
}
else
if
(
kernels_
[
0
]
==
3
&&
kernels_
[
1
]
==
3
&&
}
else
if
(
kernels_
[
0
]
==
3
&&
kernels_
[
1
]
==
3
&&
strides_
[
0
]
==
2
&&
strides_
[
0
]
==
2
&&
strides_
[
1
]
==
2
&&
strides_
[
1
]
==
2
&&
pooling_type_
==
MAX
)
{
pooling_type_
==
MAX
)
{
#ifdef __COPY_MAKE_PADDING
#ifdef __COPY_MAKE_PADDING
Tensor
padded_input
;
Tensor
padded_input
;
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
,
&
padded_input
);
ConstructInputWithPadding
(
input
,
input_shape
,
paddings_
,
&
padded_input
);
...
@@ -65,13 +55,9 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
...
@@ -65,13 +55,9 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
}
else
{
// not implement yet
}
else
{
// not implement yet
PoolingFunctor
<
DeviceType
::
CPU
,
float
>
(
pooling_type_
,
kernels_
,
strides_
,
PoolingFunctor
<
DeviceType
::
CPU
,
float
>
(
pooling_type_
,
kernels_
,
strides_
,
paddings_
,
dilations_
)(
paddings_
,
dilations_
)(
input
,
input
,
input_shape
,
output
,
output_shape
);
input_shape
,
output
,
output_shape
);
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/kernels/neon/relu_neon.cc
浏览文件 @
578b382a
...
@@ -2,17 +2,17 @@
...
@@ -2,17 +2,17 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include <arm_neon.h>
#include "mace/kernels/relu.h"
#include "mace/kernels/relu.h"
#include <arm_neon.h>
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
>
template
<
>
void
ReluFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
void
ReluFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
float
*
output
,
float
*
output
,
index_t
size
)
{
index_t
size
)
{
#pragma omp parallel for num_threads(1) // no significant performance improve
#pragma omp parallel for num_threads(1)
// no significant performance improve
for
(
int64_t
i
=
0
;
i
<
size
;
i
+=
kCostPerGroup
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
i
+=
kCostPerGroup
)
{
int64_t
count
=
std
::
min
(
static_cast
<
int64_t
>
(
kCostPerGroup
),
size
-
i
);
int64_t
count
=
std
::
min
(
static_cast
<
int64_t
>
(
kCostPerGroup
),
size
-
i
);
int
nn
=
count
>>
2
;
int
nn
=
count
>>
2
;
...
@@ -36,6 +36,5 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
...
@@ -36,6 +36,5 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
}
}
};
};
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/kernels/pooling.h
浏览文件 @
578b382a
...
@@ -11,29 +11,24 @@
...
@@ -11,29 +11,24 @@
namespace
mace
{
namespace
mace
{
enum
PoolingType
{
enum
PoolingType
{
AVG
=
1
,
// avg_pool
AVG
=
1
,
// avg_pool
MAX
=
2
,
// max_pool
MAX
=
2
,
// max_pool
};
};
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
PoolingFunctor
{
class
PoolingFunctor
{
public:
public:
PoolingFunctor
(
const
PoolingType
pooling_type
,
PoolingFunctor
(
const
PoolingType
pooling_type
,
const
int
*
kernels
,
const
int
*
kernels
,
const
int
*
strides
,
const
int
*
paddings
,
const
int
*
dilations
)
const
int
*
strides
,
const
int
*
paddings
,
const
int
*
dilations
)
:
pooling_type_
(
pooling_type
),
:
pooling_type_
(
pooling_type
),
kernels_
(
kernels
),
kernels_
(
kernels
),
strides_
(
strides
),
strides_
(
strides
),
paddings_
(
paddings
),
paddings_
(
paddings
),
dilations_
(
dilations
)
{}
dilations_
(
dilations
)
{}
void
operator
()(
const
T
*
input
,
void
operator
()(
const
T
*
input
,
const
index_t
*
input_shape
,
T
*
output
,
const
index_t
*
input_shape
,
T
*
output
,
const
index_t
*
output_shape
)
{
const
index_t
*
output_shape
)
{
index_t
batch
=
output_shape
[
0
];
index_t
batch
=
output_shape
[
0
];
index_t
channels
=
output_shape
[
1
];
index_t
channels
=
output_shape
[
1
];
...
@@ -60,32 +55,31 @@ class PoolingFunctor {
...
@@ -60,32 +55,31 @@ class PoolingFunctor {
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
int
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
index_t
out_offset
=
n
*
channels
*
height
*
width
+
index_t
out_offset
=
n
*
channels
*
height
*
width
+
c
*
height
*
width
;
c
*
height
*
width
;
index_t
in_offset
=
n
*
input_channels
*
input_height
*
input_width
+
index_t
in_offset
=
n
*
input_channels
*
input_height
*
input_width
+
c
*
input_height
*
input_width
;
c
*
input_height
*
input_width
;
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
T
sum_or_max
=
0
;
T
sum_or_max
=
0
;
switch
(
pooling_type_
)
{
switch
(
pooling_type_
)
{
case
AVG
:
break
;
case
AVG
:
case
MAX
:
sum_or_max
=
std
::
numeric_limits
<
T
>::
lowest
();
break
;
case
MAX
:
sum_or_max
=
std
::
numeric_limits
<
T
>::
lowest
();
break
;
break
;
default:
default:
MACE_CHECK
(
false
,
MACE_CHECK
(
false
,
"Unsupported pooling type: "
,
pooling_type_
);
"Unsupported pooling type: "
,
pooling_type_
);
}
}
for
(
int
kh
=
0
;
kh
<
kernel_h
;
++
kh
)
{
for
(
int
kh
=
0
;
kh
<
kernel_h
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
kernel_w
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
kernel_w
;
++
kw
)
{
int
inh
=
padded_h_start
+
h
*
stride_h
+
dilation_h
*
kh
;
int
inh
=
padded_h_start
+
h
*
stride_h
+
dilation_h
*
kh
;
int
inw
=
padded_w_start
+
w
*
stride_w
+
dilation_w
*
kw
;
int
inw
=
padded_w_start
+
w
*
stride_w
+
dilation_w
*
kw
;
if
(
inh
>=
0
&&
inh
<
input_height
&&
if
(
inh
>=
0
&&
inh
<
input_height
&&
inw
>=
0
&&
inw
>=
0
&&
inw
<
input_width
)
{
inw
<
input_width
)
{
index_t
input_offset
=
in_offset
+
index_t
input_offset
=
in_offset
+
inh
*
input_width
+
inw
;
inh
*
input_width
+
inw
;
switch
(
pooling_type_
)
{
switch
(
pooling_type_
)
{
case
AVG
:
sum_or_max
+=
input
[
input_offset
];
case
AVG
:
sum_or_max
+=
input
[
input_offset
];
break
;
break
;
case
MAX
:
case
MAX
:
sum_or_max
=
std
::
max
(
sum_or_max
,
input
[
input_offset
]);
sum_or_max
=
std
::
max
(
sum_or_max
,
input
[
input_offset
]);
...
@@ -98,14 +92,14 @@ class PoolingFunctor {
...
@@ -98,14 +92,14 @@ class PoolingFunctor {
}
}
}
}
switch
(
pooling_type_
)
{
switch
(
pooling_type_
)
{
case
AVG
:
output
[
out_offset
]
=
sum_or_max
/
(
kernel_h
*
kernel_w
);
case
AVG
:
output
[
out_offset
]
=
sum_or_max
/
(
kernel_h
*
kernel_w
);
break
;
break
;
case
MAX
:
output
[
out_offset
]
=
sum_or_max
;
case
MAX
:
output
[
out_offset
]
=
sum_or_max
;
break
;
break
;
default:
default:
MACE_CHECK
(
false
,
MACE_CHECK
(
false
,
"Unsupported pooling type: "
,
pooling_type_
);
"Unsupported pooling type: "
,
pooling_type_
);
}
}
out_offset
+=
1
;
out_offset
+=
1
;
}
}
...
@@ -122,14 +116,12 @@ class PoolingFunctor {
...
@@ -122,14 +116,12 @@ class PoolingFunctor {
const
int
*
dilations_
;
const
int
*
dilations_
;
};
};
template
<
>
template
<
>
void
PoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
void
PoolingFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
float
*
input
,
const
index_t
*
input_shape
,
float
*
output
,
const
index_t
*
input_shape
,
float
*
output
,
const
index_t
*
output_shape
);
const
index_t
*
output_shape
);
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
#endif
//
MACE_KERNELS_POOLING_H
#endif
//
MACE_KERNELS_POOLING_H
mace/kernels/relu.h
浏览文件 @
578b382a
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
ReluFunctor
{
struct
ReluFunctor
{
void
operator
()(
const
T
*
input
,
T
*
output
,
index_t
size
)
{
void
operator
()(
const
T
*
input
,
T
*
output
,
index_t
size
)
{
for
(
index_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
size
;
++
i
)
{
...
@@ -24,7 +24,7 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
...
@@ -24,7 +24,7 @@ void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
float
*
output
,
float
*
output
,
index_t
size
);
index_t
size
);
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
#endif // MACE_KERNELS_RELU_H_
#endif // MACE_KERNELS_RELU_H_
\ No newline at end of file
\ No newline at end of file
mace/kernels/resize_bilinear.h
浏览文件 @
578b382a
...
@@ -22,8 +22,8 @@ struct CachedInterpolation {
...
@@ -22,8 +22,8 @@ struct CachedInterpolation {
inline
float
CalculateResizeScale
(
index_t
in_size
,
index_t
out_size
,
inline
float
CalculateResizeScale
(
index_t
in_size
,
index_t
out_size
,
bool
align_corners
)
{
bool
align_corners
)
{
return
(
align_corners
&&
out_size
>
1
)
return
(
align_corners
&&
out_size
>
1
)
?
(
in_size
-
1
)
/
static_cast
<
float
>
(
out_size
-
1
)
?
(
in_size
-
1
)
/
static_cast
<
float
>
(
out_size
-
1
)
:
in_size
/
static_cast
<
float
>
(
out_size
);
:
in_size
/
static_cast
<
float
>
(
out_size
);
}
}
inline
void
ComputeInterpolationWeights
(
const
index_t
out_size
,
inline
void
ComputeInterpolationWeights
(
const
index_t
out_size
,
...
@@ -41,21 +41,20 @@ inline void ComputeInterpolationWeights(const index_t out_size,
...
@@ -41,21 +41,20 @@ inline void ComputeInterpolationWeights(const index_t out_size,
}
}
inline
float
ComputeLerp
(
const
float
top_left
,
const
float
top_right
,
inline
float
ComputeLerp
(
const
float
top_left
,
const
float
top_right
,
const
float
bottom_left
,
const
float
bottom_right
,
const
float
bottom_left
,
const
float
bottom_right
,
const
float
x_lerp
,
const
float
y_lerp
)
{
const
float
x_lerp
,
const
float
y_lerp
)
{
const
float
top
=
top_left
+
(
top_right
-
top_left
)
*
x_lerp
;
const
float
top
=
top_left
+
(
top_right
-
top_left
)
*
x_lerp
;
const
float
bottom
=
bottom_left
+
(
bottom_right
-
bottom_left
)
*
x_lerp
;
const
float
bottom
=
bottom_left
+
(
bottom_right
-
bottom_left
)
*
x_lerp
;
return
top
+
(
bottom
-
top
)
*
y_lerp
;
return
top
+
(
bottom
-
top
)
*
y_lerp
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
ResizeImage
(
const
T
*
images
,
void
ResizeImage
(
const
T
*
images
,
const
index_t
batch_size
,
const
index_t
batch_size
,
const
index_t
in_height
,
const
index_t
in_height
,
const
index_t
in_width
,
const
index_t
in_width
,
const
index_t
out_height
,
const
index_t
out_height
,
const
index_t
out_width
,
const
index_t
out_width
,
const
index_t
channels
,
const
index_t
channels
,
const
std
::
vector
<
CachedInterpolation
>
&
xs_vec
,
const
std
::
vector
<
CachedInterpolation
>
&
xs_vec
,
const
std
::
vector
<
CachedInterpolation
>
&
ys
,
const
std
::
vector
<
CachedInterpolation
>
&
ys
,
float
*
output
)
{
float
*
output
)
{
const
index_t
in_channel_size
=
in_height
*
in_width
;
const
index_t
in_channel_size
=
in_height
*
in_width
;
const
index_t
in_batch_num_values
=
channels
*
in_channel_size
;
const
index_t
in_batch_num_values
=
channels
*
in_channel_size
;
const
index_t
out_channel_size
=
out_height
*
out_width
;
const
index_t
out_channel_size
=
out_height
*
out_width
;
...
@@ -65,10 +64,10 @@ void ResizeImage(const T *images,
...
@@ -65,10 +64,10 @@ void ResizeImage(const T *images,
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
index_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
for
(
index_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
const
T
*
input_ptr
=
images
+
in_batch_num_values
*
b
const
T
*
input_ptr
=
+
in_channel_size
*
c
;
images
+
in_batch_num_values
*
b
+
in_channel_size
*
c
;
float
*
output_ptr
=
output
+
out_batch_num_values
*
b
float
*
output_ptr
=
+
out_channel_size
*
c
;
output
+
out_batch_num_values
*
b
+
out_channel_size
*
c
;
for
(
index_t
y
=
0
;
y
<
out_height
;
++
y
)
{
for
(
index_t
y
=
0
;
y
<
out_height
;
++
y
)
{
const
T
*
ys_input_lower_ptr
=
input_ptr
+
ys
[
y
].
lower
*
in_width
;
const
T
*
ys_input_lower_ptr
=
input_ptr
+
ys
[
y
].
lower
*
in_width
;
const
T
*
ys_input_upper_ptr
=
input_ptr
+
ys
[
y
].
upper
*
in_width
;
const
T
*
ys_input_upper_ptr
=
input_ptr
+
ys
[
y
].
upper
*
in_width
;
...
@@ -83,9 +82,8 @@ void ResizeImage(const T *images,
...
@@ -83,9 +82,8 @@ void ResizeImage(const T *images,
const
float
bottom_left
=
ys_input_upper_ptr
[
xs_lower
];
const
float
bottom_left
=
ys_input_upper_ptr
[
xs_lower
];
const
float
bottom_right
=
ys_input_upper_ptr
[
xs_upper
];
const
float
bottom_right
=
ys_input_upper_ptr
[
xs_upper
];
output_ptr
[
x
]
=
output_ptr
[
x
]
=
ComputeLerp
(
top_left
,
top_right
,
bottom_left
,
ComputeLerp
(
top_left
,
top_right
,
bottom_left
,
bottom_right
,
bottom_right
,
xs_lerp
,
ys_lerp
);
xs_lerp
,
ys_lerp
);
}
}
output_ptr
+=
out_width
;
output_ptr
+=
out_width
;
}
}
...
@@ -94,16 +92,15 @@ void ResizeImage(const T *images,
...
@@ -94,16 +92,15 @@ void ResizeImage(const T *images,
}
}
}
}
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
ResizeBilinearFunctor
{
struct
ResizeBilinearFunctor
{
bool
align_corners_
;
bool
align_corners_
;
ResizeBilinearFunctor
(
bool
align_corners
)
ResizeBilinearFunctor
(
bool
align_corners
)
:
align_corners_
(
align_corners
)
{}
:
align_corners_
(
align_corners
)
{}
void
operator
()(
const
T
*
input
,
T
*
output
,
void
operator
()(
const
T
*
input
,
T
*
output
,
index_t
n
,
index_t
channels
,
index_t
n
,
index_t
channels
,
index_t
in
_height
,
index_t
in_height
,
index_t
in_width
,
index_t
out
_height
,
index_t
in_width
,
index_t
out_height
,
index_t
out_width
)
{
index_t
out_width
)
{
if
(
out_height
==
in_height
&&
out_width
==
in_width
)
{
if
(
out_height
==
in_height
&&
out_width
==
in_width
)
{
std
::
copy
(
input
,
input
+
channels
*
in_height
*
in_width
,
output
);
std
::
copy
(
input
,
input
+
channels
*
in_height
*
in_width
,
output
);
return
;
return
;
...
@@ -111,8 +108,8 @@ struct ResizeBilinearFunctor {
...
@@ -111,8 +108,8 @@ struct ResizeBilinearFunctor {
float
height_scale
=
float
height_scale
=
CalculateResizeScale
(
in_height
,
out_height
,
align_corners_
);
CalculateResizeScale
(
in_height
,
out_height
,
align_corners_
);
float
float
width_scale
=
width_scale
=
CalculateResizeScale
(
in_width
,
out_width
,
align_corners_
);
CalculateResizeScale
(
in_width
,
out_width
,
align_corners_
);
std
::
vector
<
CachedInterpolation
>
ys
(
out_height
+
1
);
std
::
vector
<
CachedInterpolation
>
ys
(
out_height
+
1
);
std
::
vector
<
CachedInterpolation
>
xs
(
out_width
+
1
);
std
::
vector
<
CachedInterpolation
>
xs
(
out_width
+
1
);
...
@@ -121,12 +118,12 @@ struct ResizeBilinearFunctor {
...
@@ -121,12 +118,12 @@ struct ResizeBilinearFunctor {
ComputeInterpolationWeights
(
out_height
,
in_height
,
height_scale
,
ys
.
data
());
ComputeInterpolationWeights
(
out_height
,
in_height
,
height_scale
,
ys
.
data
());
ComputeInterpolationWeights
(
out_width
,
in_width
,
width_scale
,
xs
.
data
());
ComputeInterpolationWeights
(
out_width
,
in_width
,
width_scale
,
xs
.
data
());
ResizeImage
(
input
,
n
,
in_height
,
in_width
,
out_height
,
out_width
,
ResizeImage
(
input
,
n
,
in_height
,
in_width
,
out_height
,
out_width
,
channels
,
channels
,
xs
,
ys
,
output
);
xs
,
ys
,
output
);
}
}
};
};
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
#endif // MACE_KERNELS_RESIZE_BILINEAR_H_
#endif
// MACE_KERNELS_RESIZE_BILINEAR_H_
mace/ops/addn.cc
浏览文件 @
578b382a
...
@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(AddN, AddNOp<DeviceType::CPU, float>);
...
@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(AddN, AddNOp<DeviceType::CPU, float>);
#if __ARM_NEON
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
AddN
,
AddNOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
AddN
,
AddNOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
#endif
// __ARM_NEON
}
// namespace mace
}
// namespace mace
mace/ops/addn.h
浏览文件 @
578b382a
...
@@ -10,10 +10,10 @@
...
@@ -10,10 +10,10 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
AddNOp
:
public
Operator
<
D
,
T
>
{
class
AddNOp
:
public
Operator
<
D
,
T
>
{
public:
public:
AddNOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
AddNOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{}
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{}
bool
Run
()
override
{
bool
Run
()
override
{
...
@@ -36,6 +36,6 @@ class AddNOp : public Operator<D, T> {
...
@@ -36,6 +36,6 @@ class AddNOp : public Operator<D, T> {
kernels
::
AddNFunctor
<
D
,
T
>
functor_
;
kernels
::
AddNFunctor
<
D
,
T
>
functor_
;
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_ADDN_H_
#endif
// MACE_OPS_ADDN_H_
mace/ops/addn_benchmark.cc
浏览文件 @
578b382a
...
@@ -10,7 +10,6 @@
...
@@ -10,7 +10,6 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
static
void
AddNBenchmark
(
int
iters
,
int
n
,
int
size
)
{
static
void
AddNBenchmark
(
int
iters
,
int
n
,
int
size
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
...
@@ -18,8 +17,7 @@ static void AddNBenchmark(int iters, int n, int size) {
...
@@ -18,8 +17,7 @@ static void AddNBenchmark(int iters, int n, int size) {
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
op_def_builder
.
Input
(
internal
::
MakeString
(
"Input"
,
i
).
c_str
());
op_def_builder
.
Input
(
internal
::
MakeString
(
"Input"
,
i
).
c_str
());
}
}
op_def_builder
.
Output
(
"Output"
)
op_def_builder
.
Output
(
"Output"
).
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add input data
// Add input data
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
@@ -32,27 +30,26 @@ static void AddNBenchmark(int iters, int n, int size) {
...
@@ -32,27 +30,26 @@ static void AddNBenchmark(int iters, int n, int size) {
}
}
mace
::
testing
::
StartTiming
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
while
(
iters
--
)
{
net
.
RunOp
(
D
);
net
.
RunOp
(
D
);
}
}
}
}
#define BM_ADDN_MACRO(N, SIZE, TYPE, DEVICE) \
#define BM_ADDN_MACRO(N, SIZE, TYPE, DEVICE) \
static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE( \
static void BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE(int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \
const int64_t tot = static_cast<int64_t>(iters) * N * SIZE; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \
AddNBenchmark<DEVICE, TYPE>(iters, N, SIZE); \
} \
} \
BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE)
BENCHMARK(BM_ADDN_##N##_##SIZE##_##TYPE##_##DEVICE)
#define BM_ADDN(N, SIZE, TYPE)
\
#define BM_ADDN(N, SIZE, TYPE) \
BM_ADDN_MACRO(N, SIZE, TYPE, CPU);
\
BM_ADDN_MACRO(N, SIZE, TYPE, CPU); \
BM_ADDN_MACRO(N, SIZE, TYPE, NEON);
BM_ADDN_MACRO(N, SIZE, TYPE, NEON);
BM_ADDN
(
10
,
1000
,
float
);
BM_ADDN
(
10
,
1000
,
float
);
BM_ADDN
(
10
,
10000
,
float
);
BM_ADDN
(
10
,
10000
,
float
);
BM_ADDN
(
100
,
1000
,
float
);
BM_ADDN
(
100
,
1000
,
float
);
BM_ADDN
(
100
,
10000
,
float
);
BM_ADDN
(
100
,
10000
,
float
);
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/ops/addn_test.cc
浏览文件 @
578b382a
...
@@ -36,4 +36,4 @@ TEST_F(AddnOpTest, AddnOp) {
...
@@ -36,4 +36,4 @@ TEST_F(AddnOpTest, AddnOp) {
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
}
}
}
// namespace mace
}
// namespace mace
mace/ops/batch_norm.cc
浏览文件 @
578b382a
...
@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp<DeviceType::CPU, float>);
...
@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp<DeviceType::CPU, float>);
#if __ARM_NEON
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
BatchNorm
,
BatchNormOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
BatchNorm
,
BatchNormOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
#endif
// __ARM_NEON
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/ops/batch_norm.h
浏览文件 @
578b382a
...
@@ -10,50 +10,55 @@
...
@@ -10,50 +10,55 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
BatchNormOp
:
public
Operator
<
D
,
T
>
{
class
BatchNormOp
:
public
Operator
<
D
,
T
>
{
public:
public:
BatchNormOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
BatchNormOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
functor_
(
OperatorBase
::
GetSingleArgument
<
float
>
(
"variance_epsilon"
,
1e-4
)){}
functor_
(
OperatorBase
::
GetSingleArgument
<
float
>
(
"variance_epsilon"
,
1e-4
))
{}
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
0
);
const
Tensor
*
scale
=
this
->
Input
(
1
);
const
Tensor
*
offset
=
this
->
Input
(
2
);
const
Tensor
*
mean
=
this
->
Input
(
3
);
const
Tensor
*
var
=
this
->
Input
(
4
);
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional. "
,
input
->
dim_size
());
MACE_CHECK
(
scale
->
dim_size
()
==
1
,
"scale must be 1-dimensional. "
,
scale
->
dim_size
());
MACE_CHECK
(
offset
->
dim_size
()
==
1
,
"offset must be 1-dimensional. "
,
offset
->
dim_size
());
MACE_CHECK
(
mean
->
dim_size
()
==
1
,
"mean must be 1-dimensional. "
,
mean
->
dim_size
());
MACE_CHECK
(
var
->
dim_size
()
==
1
,
"var must be 1-dimensional. "
,
var
->
dim_size
());
Tensor
*
output
=
this
->
Output
(
0
);
output
->
ResizeLike
(
input
);
const
index_t
n
=
input
->
dim
(
0
);
const
index_t
channel
=
input
->
dim
(
1
);
const
index_t
sample_size
=
input
->
dim
(
2
)
*
input
->
dim
(
3
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
scale_ptr
=
scale
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
->
data
<
T
>
();
const
T
*
mean_ptr
=
mean
->
data
<
T
>
();
const
T
*
var_ptr
=
var
->
data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
functor_
(
input_ptr
,
scale_ptr
,
offset_ptr
,
mean_ptr
,
var_ptr
,
n
,
channel
,
sample_size
,
output_ptr
);
return
true
;
}
private:
kernels
::
BatchNormFunctor
<
D
,
T
>
functor_
;
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
0
);
const
Tensor
*
scale
=
this
->
Input
(
1
);
const
Tensor
*
offset
=
this
->
Input
(
2
);
const
Tensor
*
mean
=
this
->
Input
(
3
);
const
Tensor
*
var
=
this
->
Input
(
4
);
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional. "
,
input
->
dim_size
());
MACE_CHECK
(
scale
->
dim_size
()
==
1
,
"scale must be 1-dimensional. "
,
scale
->
dim_size
());
MACE_CHECK
(
offset
->
dim_size
()
==
1
,
"offset must be 1-dimensional. "
,
offset
->
dim_size
());
MACE_CHECK
(
mean
->
dim_size
()
==
1
,
"mean must be 1-dimensional. "
,
mean
->
dim_size
());
MACE_CHECK
(
var
->
dim_size
()
==
1
,
"var must be 1-dimensional. "
,
var
->
dim_size
());
Tensor
*
output
=
this
->
Output
(
0
);
output
->
ResizeLike
(
input
);
const
index_t
n
=
input
->
dim
(
0
);
const
index_t
channel
=
input
->
dim
(
1
);
const
index_t
sample_size
=
input
->
dim
(
2
)
*
input
->
dim
(
3
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
scale_ptr
=
scale
->
data
<
T
>
();
const
T
*
offset_ptr
=
offset
->
data
<
T
>
();
const
T
*
mean_ptr
=
mean
->
data
<
T
>
();
const
T
*
var_ptr
=
var
->
data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
functor_
(
input_ptr
,
scale_ptr
,
offset_ptr
,
mean_ptr
,
var_ptr
,
n
,
channel
,
sample_size
,
output_ptr
);
return
true
;
}
private:
kernels
::
BatchNormFunctor
<
D
,
T
>
functor_
;
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_BATCH_NORM_H_
#endif
// MACE_BATCH_NORM_H_
mace/ops/batch_norm_benchmark.cc
浏览文件 @
578b382a
...
@@ -8,19 +8,19 @@
...
@@ -8,19 +8,19 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
static
void
BatchNorm
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
)
{
static
void
BatchNorm
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
OpDefBuilder
(
"BatchNorm"
,
"BatchNormBM"
)
OpDefBuilder
(
"BatchNorm"
,
"BatchNormBM"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Scale"
)
.
Input
(
"Scale"
)
.
Input
(
"Offset"
)
.
Input
(
"Offset"
)
.
Input
(
"Mean"
)
.
Input
(
"Mean"
)
.
Input
(
"Var"
)
.
Input
(
"Var"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add input data
// Add input data
net
.
AddRandomInput
<
T
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
net
.
AddRandomInput
<
T
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
...
@@ -35,23 +35,23 @@ static void BatchNorm(int iters, int batch, int channels, int height, int width)
...
@@ -35,23 +35,23 @@ static void BatchNorm(int iters, int batch, int channels, int height, int width)
}
}
mace
::
testing
::
StartTiming
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
while
(
iters
--
)
{
net
.
RunOp
(
D
);
net
.
RunOp
(
D
);
}
}
}
}
#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE)
\
#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(
\
static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) {
\
int iters) {
\
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W;
\
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot);
\
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot
* (sizeof(TYPE)));
\
mace::testing::BytesProcessed(tot
*(sizeof(TYPE)));
\
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W);
\
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \
}
\
} \
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BATCH_NORM(N, C, H, W, TYPE)
\
#define BM_BATCH_NORM(N, C, H, W, TYPE) \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU);
\
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON);
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON);
BM_BATCH_NORM
(
1
,
1
,
512
,
512
,
float
);
BM_BATCH_NORM
(
1
,
1
,
512
,
512
,
float
);
...
@@ -65,4 +65,4 @@ BM_BATCH_NORM(1, 128, 256, 256, float);
...
@@ -65,4 +65,4 @@ BM_BATCH_NORM(1, 128, 256, 256, float);
BM_BATCH_NORM
(
1
,
128
,
512
,
512
,
float
);
BM_BATCH_NORM
(
1
,
128
,
512
,
512
,
float
);
BM_BATCH_NORM
(
32
,
1
,
256
,
256
,
float
);
BM_BATCH_NORM
(
32
,
1
,
256
,
256
,
float
);
BM_BATCH_NORM
(
32
,
3
,
256
,
256
,
float
);
BM_BATCH_NORM
(
32
,
3
,
256
,
256
,
float
);
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/ops/batch_norm_test.cc
浏览文件 @
578b382a
...
@@ -13,17 +13,17 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
...
@@ -13,17 +13,17 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Scale"
)
.
Input
(
"Scale"
)
.
Input
(
"Offset"
)
.
Input
(
"Offset"
)
.
Input
(
"Mean"
)
.
Input
(
"Mean"
)
.
Input
(
"Var"
)
.
Input
(
"Var"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
6
,
2
},
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
6
,
2
},
{
5
,
5
,
7
,
7
,
9
,
9
,
11
,
11
,
13
,
13
,
15
,
15
});
{
5
,
5
,
7
,
7
,
9
,
9
,
11
,
11
,
13
,
13
,
15
,
15
});
net
.
AddInputFromArray
<
float
>
(
"Scale"
,
{
1
},
{
4.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Scale"
,
{
1
},
{
4.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Offset"
,
{
1
},
{
2.0
});
net
.
AddInputFromArray
<
float
>
(
"Offset"
,
{
1
},
{
2.0
});
net
.
AddInputFromArray
<
float
>
(
"Mean"
,
{
1
},
{
10
});
net
.
AddInputFromArray
<
float
>
(
"Mean"
,
{
1
},
{
10
});
...
@@ -33,8 +33,8 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
...
@@ -33,8 +33,8 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
6
,
2
},
auto
expected
=
{
-
3.86
,
-
3.86
,
-
1.51
,
-
1.51
,
0.83
,
0.83
,
CreateTensor
<
float
>
({
1
,
1
,
6
,
2
},
{
-
3.86
,
-
3.86
,
-
1.51
,
-
1.51
,
0.83
,
0.83
,
3.17
,
3.17
,
5.51
,
5.51
,
7.86
,
7.86
});
3.17
,
3.17
,
5.51
,
5.51
,
7.86
,
7.86
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
...
@@ -51,13 +51,13 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
...
@@ -51,13 +51,13 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
OpDefBuilder
(
"BatchNorm"
,
"BatchNormTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Scale"
)
.
Input
(
"Scale"
)
.
Input
(
"Offset"
)
.
Input
(
"Offset"
)
.
Input
(
"Mean"
)
.
Input
(
"Mean"
)
.
Input
(
"Var"
)
.
Input
(
"Var"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add input data
// Add input data
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
...
@@ -77,5 +77,4 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
...
@@ -77,5 +77,4 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
}
}
}
}
mace/ops/conv_2d.cc
浏览文件 @
578b382a
...
@@ -11,6 +11,6 @@ REGISTER_CPU_OPERATOR(Conv2d, Conv2dOp<DeviceType::CPU, float>);
...
@@ -11,6 +11,6 @@ REGISTER_CPU_OPERATOR(Conv2d, Conv2dOp<DeviceType::CPU, float>);
#if __ARM_NEON
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
Conv2d
,
Conv2dOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
Conv2d
,
Conv2dOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
#endif
// __ARM_NEON
}
// namespace mace
}
// namespace mace
mace/ops/conv_2d.h
浏览文件 @
578b382a
...
@@ -13,11 +13,11 @@
...
@@ -13,11 +13,11 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
class
Conv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
class
Conv2dOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
public:
public:
Conv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
Conv2dOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
)
{};
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
)
{};
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
...
@@ -27,21 +27,16 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
...
@@ -27,21 +27,16 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
index_t
>
output_shape
(
4
);
std
::
vector
<
int
>
paddings
(
2
);
std
::
vector
<
int
>
paddings
(
2
);
kernels
::
CalcPaddingAndOutputSize
(
input
->
shape
().
data
(),
kernels
::
CalcPaddingAndOutputSize
(
filter
->
shape
().
data
(),
input
->
shape
().
data
(),
filter
->
shape
().
data
(),
this
->
dilations_
.
data
(),
this
->
dilations_
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
output_shape
.
data
(),
this
->
strides_
.
data
(),
paddings
.
data
());
this
->
padding_
,
output_shape
.
data
(),
paddings
.
data
());
output
->
Resize
(
output_shape
);
output
->
Resize
(
output_shape
);
auto
conv2d
=
kernels
::
Conv2dFunctor
<
D
,
T
>
(
this
->
strides_
.
data
(),
auto
conv2d
=
kernels
::
Conv2dFunctor
<
D
,
T
>
(
paddings
.
data
(),
this
->
strides_
.
data
(),
paddings
.
data
(),
this
->
dilations_
.
data
());
this
->
dilations_
.
data
());
conv2d
(
input
->
data
<
T
>
(),
input
->
shape
().
data
(),
filter
->
data
<
T
>
(),
conv2d
(
input
->
data
<
T
>
(),
input
->
shape
().
data
(),
filter
->
shape
().
data
(),
bias
->
data
<
T
>
(),
output
->
mutable_data
<
T
>
(),
filter
->
data
<
T
>
(),
filter
->
shape
().
data
(),
bias
->
data
<
T
>
(),
output
->
mutable_data
<
T
>
(),
output
->
shape
().
data
());
output
->
shape
().
data
());
return
true
;
return
true
;
...
@@ -52,6 +47,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
...
@@ -52,6 +47,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
OP_OUTPUT_TAGS
(
OUTPUT
);
OP_OUTPUT_TAGS
(
OUTPUT
);
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_CONV_2D_H_
#endif
// MACE_OPS_CONV_2D_H_
mace/ops/conv_2d_benchmark.cc
浏览文件 @
578b382a
...
@@ -13,17 +13,17 @@ namespace mace {
...
@@ -13,17 +13,17 @@ namespace mace {
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
static
void
Conv2d
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
,
static
void
Conv2d
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
,
int
kernel_h
,
int
kernel_w
,
int
stride
,
int
kernel_h
,
int
kernel_w
,
int
stride
,
Padding
padding
,
Padding
padding
,
int
output_channels
)
{
int
output_channels
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
.
Input
(
"Bias"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"strides"
,
{
stride
,
stride
});
net
.
AddIntsArg
(
"strides"
,
{
stride
,
stride
});
...
@@ -32,7 +32,8 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
...
@@ -32,7 +32,8 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
// Add input data
// Add input data
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
net
.
AddRandomInput
<
float
>
(
"Filter"
,
{
output_channels
,
channels
,
kernel_h
,
kernel_w
});
net
.
AddRandomInput
<
float
>
(
"Filter"
,
{
output_channels
,
channels
,
kernel_h
,
kernel_w
});
net
.
AddRandomInput
<
float
>
(
"Bias"
,
{
output_channels
});
net
.
AddRandomInput
<
float
>
(
"Bias"
,
{
output_channels
});
// Warm-up
// Warm-up
...
@@ -41,27 +42,30 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
...
@@ -41,27 +42,30 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
}
}
mace
::
testing
::
StartTiming
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
while
(
iters
--
)
{
net
.
RunOp
(
D
);
net
.
RunOp
(
D
);
}
}
}
}
#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
static void BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \
static void \
int iters) { \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
int iters) { \
mace::testing::ItemsProcessed(tot); \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
mace::testing::ItemsProcessed(tot); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
} \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, \
BENCHMARK(BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE)
OC); \
} \
BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE)
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE)
\
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU);
\
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
1
,
1
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
1
,
1
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
1
,
1
,
1
,
VALID
,
128
,
float
);
// Test bad alignments
BM_CONV_2D
(
1
,
64
,
33
,
31
,
1
,
1
,
1
,
VALID
,
128
,
float
);
// Test bad alignments
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
1
,
SAME
,
128
,
float
);
...
@@ -71,4 +75,4 @@ BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
...
@@ -71,4 +75,4 @@ BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
5
,
5
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
5
,
5
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
31
,
5
,
5
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
31
,
5
,
5
,
1
,
SAME
,
128
,
float
);
}
// namespace mace
}
// namespace mace
mace/ops/conv_2d_test.cc
浏览文件 @
578b382a
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/core/operator.h"
#include "mace/ops/conv_2d.h"
#include "mace/ops/conv_2d.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/ops_test_util.h"
using
namespace
mace
;
using
namespace
mace
;
...
@@ -14,11 +14,11 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
...
@@ -14,11 +14,11 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
.
Input
(
"Bias"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"strides"
,
{
1
,
1
});
net
.
AddIntsArg
(
"strides"
,
{
1
,
1
});
...
@@ -26,17 +26,13 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
...
@@ -26,17 +26,13 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
2
,
3
,
3
},
net
.
AddInputFromArray
<
float
>
(
{
1
,
1
,
1
,
"Input"
,
{
1
,
2
,
3
,
3
},
1
,
1
,
1
,
{
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
});
1
,
1
,
1
,
net
.
AddInputFromArray
<
float
>
(
1
,
1
,
1
,
"Filter"
,
{
1
,
2
,
3
,
3
},
1
,
1
,
1
,
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1
,
1
,
1
});
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Filter"
,
{
1
,
2
,
3
,
3
},
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
1
},
{
0.1
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
1
},
{
0.1
f
});
// Run
// Run
...
@@ -52,11 +48,11 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
...
@@ -52,11 +48,11 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
.
Input
(
"Bias"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"strides"
,
{
1
,
1
});
net
.
AddIntsArg
(
"strides"
,
{
1
,
1
});
...
@@ -64,27 +60,22 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
...
@@ -64,27 +60,22 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
2
,
3
,
3
},
net
.
AddInputFromArray
<
float
>
(
{
1
,
1
,
1
,
"Input"
,
{
1
,
2
,
3
,
3
},
1
,
1
,
1
,
{
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
});
1
,
1
,
1
,
net
.
AddInputFromArray
<
float
>
(
1
,
1
,
1
,
"Filter"
,
{
1
,
2
,
3
,
3
},
1
,
1
,
1
,
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1
,
1
,
1
});
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Filter"
,
{
1
,
2
,
3
,
3
},
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
1
},
{
0.1
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
1
},
{
0.1
f
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
3
,
3
},
auto
expected
=
CreateTensor
<
float
>
(
{
8.1
f
,
12.1
f
,
8.1
f
,
{
1
,
1
,
3
,
3
},
12.1
f
,
18.1
f
,
12.1
f
,
{
8.1
f
,
12.1
f
,
8.1
f
,
12.1
f
,
18.1
f
,
12.1
f
,
8.1
f
,
12.1
f
,
8.1
f
});
8.1
f
,
12.1
f
,
8.1
f
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
...
@@ -93,11 +84,11 @@ TEST_F(Conv2dOpTest, Combined) {
...
@@ -93,11 +84,11 @@ TEST_F(Conv2dOpTest, Combined) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
.
Input
(
"Bias"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"strides"
,
{
2
,
2
});
net
.
AddIntsArg
(
"strides"
,
{
2
,
2
});
...
@@ -105,36 +96,24 @@ TEST_F(Conv2dOpTest, Combined) {
...
@@ -105,36 +96,24 @@ TEST_F(Conv2dOpTest, Combined) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
2
,
5
,
5
},
net
.
AddInputFromArray
<
float
>
(
{
1
,
1
,
1
,
1
,
1
,
"Input"
,
{
1
,
2
,
5
,
5
},
{
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
});
1
,
1
,
1
,
1
,
1
,
net
.
AddInputFromArray
<
float
>
(
1
,
1
,
1
,
1
,
1
,
"Filter"
,
{
2
,
2
,
3
,
3
},
1
,
1
,
1
,
1
,
1
,
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1
,
1
,
1
,
1
,
1
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
1
,
1
,
1
,
1
,
1
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
});
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
});
net
.
AddInputFromArray
<
float
>
(
"Filter"
,
{
2
,
2
,
3
,
3
},
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
,
0.5
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
2
},
{
0.1
f
,
0.2
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
2
},
{
0.1
f
,
0.2
f
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
2
,
3
,
3
},
auto
expected
=
CreateTensor
<
float
>
(
{
8.1
f
,
12.1
f
,
8.1
f
,
{
1
,
2
,
3
,
3
},
{
8.1
f
,
12.1
f
,
8.1
f
,
12.1
f
,
18.1
f
,
12.1
f
,
8.1
f
,
12.1
f
,
8.1
f
,
12.1
f
,
18.1
f
,
12.1
f
,
4.2
f
,
6.2
f
,
4.2
f
,
6.2
f
,
9.2
f
,
6.2
f
,
4.2
f
,
6.2
f
,
4.2
f
});
8.1
f
,
12.1
f
,
8.1
f
,
4.2
f
,
6.2
f
,
4.2
f
,
6.2
f
,
9.2
f
,
6.2
f
,
4.2
f
,
6.2
f
,
4.2
f
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
...
@@ -143,11 +122,11 @@ TEST_F(Conv2dOpTest, Conv1x1) {
...
@@ -143,11 +122,11 @@ TEST_F(Conv2dOpTest, Conv1x1) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
.
Input
(
"Bias"
)
.
Input
(
"Bias"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"strides"
,
{
1
,
1
});
net
.
AddIntsArg
(
"strides"
,
{
1
,
1
});
...
@@ -155,38 +134,32 @@ TEST_F(Conv2dOpTest, Conv1x1) {
...
@@ -155,38 +134,32 @@ TEST_F(Conv2dOpTest, Conv1x1) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
5
,
3
,
10
},
net
.
AddInputFromArray
<
float
>
(
{
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
"Input"
,
{
1
,
5
,
3
,
10
},
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
{
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
});
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
net
.
AddInputFromArray
<
float
>
(
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
"Filter"
,
{
2
,
5
,
1
,
1
},
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
2.0
f
,
2.0
f
,
2.0
f
,
2.0
f
,
2.0
f
});
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
});
net
.
AddInputFromArray
<
float
>
(
"Filter"
,
{
2
,
5
,
1
,
1
},
{
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
1.0
f
,
2.0
f
,
2.0
f
,
2.0
f
,
2.0
f
,
2.0
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
2
},
{
0.1
f
,
0.2
f
});
net
.
AddInputFromArray
<
float
>
(
"Bias"
,
{
2
},
{
0.1
f
,
0.2
f
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
2
,
3
,
10
},
auto
expected
=
CreateTensor
<
float
>
(
{
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
{
1
,
2
,
3
,
10
},
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
{
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
5.1
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
});
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
,
10.2
f
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
...
@@ -194,8 +167,7 @@ TEST_F(Conv2dOpTest, Conv1x1) {
...
@@ -194,8 +167,7 @@ TEST_F(Conv2dOpTest, Conv1x1) {
// TODO we need more tests
// TODO we need more tests
TEST_F
(
Conv2dOpTest
,
ConvNxNS12
)
{
TEST_F
(
Conv2dOpTest
,
ConvNxNS12
)
{
testing
::
internal
::
LogToStderr
();
testing
::
internal
::
LogToStderr
();
auto
func
=
[
&
](
int
kernel_h
,
int
kernel_w
,
auto
func
=
[
&
](
int
kernel_h
,
int
kernel_w
,
int
stride_h
,
int
stride_w
,
int
stride_h
,
int
stride_w
,
Padding
type
)
{
Padding
type
)
{
srand
(
time
(
NULL
));
srand
(
time
(
NULL
));
...
@@ -206,7 +178,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
...
@@ -206,7 +178,7 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
index_t
width
=
7
+
rand
()
%
100
;
index_t
width
=
7
+
rand
()
%
100
;
index_t
output_channels
=
1
+
rand
()
%
50
;
index_t
output_channels
=
1
+
rand
()
%
50
;
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
OpDefBuilder
(
"Conv2d"
,
"Conv2dTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Input
(
"Filter"
)
.
Input
(
"Filter"
)
...
@@ -221,8 +193,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
...
@@ -221,8 +193,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// Add input data
// Add input data
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
input_channels
,
height
,
width
});
net
.
AddRandomInput
<
float
>
(
"Input"
,
{
batch
,
input_channels
,
height
,
width
});
net
.
AddRandomInput
<
float
>
(
"Filter"
,
{
output_channels
,
input_channels
,
net
.
AddRandomInput
<
float
>
(
kernel_h
,
kernel_w
});
"Filter"
,
{
output_channels
,
input_channels
,
kernel_h
,
kernel_w
});
net
.
AddRandomInput
<
float
>
(
"Bias"
,
{
output_channels
});
net
.
AddRandomInput
<
float
>
(
"Bias"
,
{
output_channels
});
// run cpu
// run cpu
net
.
RunOp
();
net
.
RunOp
();
...
...
mace/ops/conv_pool_2d_base.h
浏览文件 @
578b382a
...
@@ -10,16 +10,15 @@
...
@@ -10,16 +10,15 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
ConvPool2dOpBase
:
public
Operator
<
D
,
T
>
{
class
ConvPool2dOpBase
:
public
Operator
<
D
,
T
>
{
public:
public:
ConvPool2dOpBase
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
ConvPool2dOpBase
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
op_def
,
ws
),
:
Operator
<
D
,
T
>
(
op_def
,
ws
),
strides_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"strides"
)),
strides_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"strides"
)),
padding_
(
static_cast
<
Padding
>
(
padding_
(
static_cast
<
Padding
>
(
OperatorBase
::
GetSingleArgument
<
int
>
(
OperatorBase
::
GetSingleArgument
<
int
>
(
"padding"
,
"padding"
,
static_cast
<
int
>
(
SAME
)))),
static_cast
<
int
>
(
SAME
)))),
dilations_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"dilations"
))
{}
dilations_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"dilations"
))
{}
protected:
protected:
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
strides_
;
...
@@ -27,6 +26,6 @@ class ConvPool2dOpBase : public Operator<D, T> {
...
@@ -27,6 +26,6 @@ class ConvPool2dOpBase : public Operator<D, T> {
std
::
vector
<
int
>
dilations_
;
std
::
vector
<
int
>
dilations_
;
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_CONV_POOL_2D_BASE_H_
#endif
// MACE_OPS_CONV_POOL_2D_BASE_H_
mace/ops/ops_test_util.h
浏览文件 @
578b382a
...
@@ -43,31 +43,33 @@ class OpsTestNet {
...
@@ -43,31 +43,33 @@ class OpsTestNet {
public:
public:
OpsTestNet
()
{}
OpsTestNet
()
{}
template
<
typename
T
>
template
<
typename
T
>
void
AddInputFromArray
(
const
char
*
name
,
void
AddInputFromArray
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
T
>
&
data
)
{
const
std
::
vector
<
T
>
&
data
)
{
Tensor
*
input
=
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
Tensor
*
input
=
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
input
->
Resize
(
shape
);
input
->
Resize
(
shape
);
T
*
input_data
=
input
->
mutable_data
<
T
>
();
T
*
input_data
=
input
->
mutable_data
<
T
>
();
MACE_CHECK
(
input
->
size
()
==
data
.
size
());
MACE_CHECK
(
input
->
size
()
==
data
.
size
());
memcpy
(
input_data
,
data
.
data
(),
data
.
size
()
*
sizeof
(
T
));
memcpy
(
input_data
,
data
.
data
(),
data
.
size
()
*
sizeof
(
T
));
}
}
template
<
typename
T
>
template
<
typename
T
>
void
AddRepeatedInput
(
const
char
*
name
,
void
AddRepeatedInput
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
index_t
>
&
shape
,
const
T
data
)
{
const
T
data
)
{
Tensor
*
input
=
Tensor
*
input
=
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
input
->
Resize
(
shape
);
input
->
Resize
(
shape
);
T
*
input_data
=
input
->
mutable_data
<
T
>
();
T
*
input_data
=
input
->
mutable_data
<
T
>
();
MACE_CHECK
(
input
->
size
()
==
data
.
size
());
MACE_CHECK
(
input
->
size
()
==
data
.
size
());
std
::
fill
(
input_data
,
input_data
+
input
->
size
(),
data
);
std
::
fill
(
input_data
,
input_data
+
input
->
size
(),
data
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
AddRandomInput
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
bool
positive
=
false
)
{
void
AddRandomInput
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
Tensor
*
input
=
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
bool
positive
=
false
)
{
Tensor
*
input
=
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
input
->
Resize
(
shape
);
input
->
Resize
(
shape
);
float
*
input_data
=
input
->
mutable_data
<
T
>
();
float
*
input_data
=
input
->
mutable_data
<
T
>
();
...
@@ -76,12 +78,16 @@ class OpsTestNet {
...
@@ -76,12 +78,16 @@ class OpsTestNet {
std
::
normal_distribution
<
T
>
nd
(
0
,
1
);
std
::
normal_distribution
<
T
>
nd
(
0
,
1
);
std
::
generate
(
input_data
,
input_data
+
input
->
size
(),
std
::
generate
(
input_data
,
input_data
+
input
->
size
(),
[
&
gen
,
&
nd
,
positive
]
{
return
positive
?
std
::
abs
(
nd
(
gen
))
:
nd
(
gen
);
});
[
&
gen
,
&
nd
,
positive
]
{
return
positive
?
std
::
abs
(
nd
(
gen
))
:
nd
(
gen
);
});
}
}
template
<
typename
T
>
template
<
typename
T
>
void
AddFixedInput
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
T
value
)
{
void
AddFixedInput
(
const
char
*
name
,
const
std
::
vector
<
index_t
>
&
shape
,
Tensor
*
input
=
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
T
value
)
{
Tensor
*
input
=
ws_
.
CreateTensor
(
name
,
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
());
input
->
Resize
(
shape
);
input
->
Resize
(
shape
);
float
*
input_data
=
input
->
mutable_data
<
T
>
();
float
*
input_data
=
input
->
mutable_data
<
T
>
();
...
@@ -122,7 +128,8 @@ class OpsTestNet {
...
@@ -122,7 +128,8 @@ class OpsTestNet {
}
}
}
}
void
AddStringsArg
(
const
char
*
name
,
const
std
::
vector
<
const
char
*>
&
values
)
{
void
AddStringsArg
(
const
char
*
name
,
const
std
::
vector
<
const
char
*>
&
values
)
{
auto
arg
=
op_def_
.
add_arg
();
auto
arg
=
op_def_
.
add_arg
();
arg
->
set_name
(
name
);
arg
->
set_name
(
name
);
for
(
auto
value
:
values
)
{
for
(
auto
value
:
values
)
{
...
@@ -145,9 +152,7 @@ class OpsTestNet {
...
@@ -145,9 +152,7 @@ class OpsTestNet {
return
net_
->
Run
();
return
net_
->
Run
();
}
}
bool
RunOp
()
{
bool
RunOp
()
{
return
RunOp
(
DeviceType
::
CPU
);
}
return
RunOp
(
DeviceType
::
CPU
);
}
Tensor
*
GetOutput
(
const
char
*
output_name
)
{
Tensor
*
GetOutput
(
const
char
*
output_name
)
{
return
ws_
.
GetTensor
(
output_name
);
return
ws_
.
GetTensor
(
output_name
);
...
@@ -177,8 +182,9 @@ class OpsTestBase : public ::testing::Test {
...
@@ -177,8 +182,9 @@ class OpsTestBase : public ::testing::Test {
OpsTestNet
test_net_
;
OpsTestNet
test_net_
;
};
};
template
<
typename
T
>
template
<
typename
T
>
unique_ptr
<
Tensor
>
CreateTensor
(
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
T
>
&
data
)
{
unique_ptr
<
Tensor
>
CreateTensor
(
const
std
::
vector
<
index_t
>
&
shape
,
const
std
::
vector
<
T
>
&
data
)
{
unique_ptr
<
Tensor
>
res
(
new
Tensor
(
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
()));
unique_ptr
<
Tensor
>
res
(
new
Tensor
(
cpu_allocator
(),
DataTypeToEnum
<
T
>::
v
()));
res
->
Resize
(
shape
);
res
->
Resize
(
shape
);
T
*
input_data
=
res
->
mutable_data
<
T
>
();
T
*
input_data
=
res
->
mutable_data
<
T
>
();
...
@@ -209,40 +215,38 @@ inline std::string ShapeToString(const Tensor &x) {
...
@@ -209,40 +215,38 @@ inline std::string ShapeToString(const Tensor &x) {
return
std
::
string
(
stream
.
str
());
return
std
::
string
(
stream
.
str
());
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
is_floating_point_type
{
struct
is_floating_point_type
{
static
const
bool
value
=
std
::
is_same
<
T
,
float
>::
value
||
static
const
bool
value
=
std
::
is_same
<
T
,
double
>::
value
;
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
;
};
};
template
<
typename
T
>
template
<
typename
T
>
inline
void
ExpectEqual
(
const
T
&
a
,
const
T
&
b
)
{
inline
void
ExpectEqual
(
const
T
&
a
,
const
T
&
b
)
{
EXPECT_EQ
(
a
,
b
);
EXPECT_EQ
(
a
,
b
);
}
}
template
<
>
template
<
>
inline
void
ExpectEqual
<
float
>
(
const
float
&
a
,
const
float
&
b
)
{
inline
void
ExpectEqual
<
float
>
(
const
float
&
a
,
const
float
&
b
)
{
EXPECT_FLOAT_EQ
(
a
,
b
);
EXPECT_FLOAT_EQ
(
a
,
b
);
}
}
template
<
>
template
<
>
inline
void
ExpectEqual
<
double
>
(
const
double
&
a
,
const
double
&
b
)
{
inline
void
ExpectEqual
<
double
>
(
const
double
&
a
,
const
double
&
b
)
{
EXPECT_DOUBLE_EQ
(
a
,
b
);
EXPECT_DOUBLE_EQ
(
a
,
b
);
}
}
inline
void
AssertSameTypeDims
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
inline
void
AssertSameTypeDims
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
ASSERT_EQ
(
x
.
dtype
(),
y
.
dtype
());
ASSERT_EQ
(
x
.
dtype
(),
y
.
dtype
());
ASSERT_TRUE
(
IsSameSize
(
x
,
y
))
ASSERT_TRUE
(
IsSameSize
(
x
,
y
))
<<
"x.shape ["
<<
ShapeToString
(
x
)
<<
"] vs "
<<
"x.shape ["
<<
ShapeToString
(
x
)
<<
"] vs "
<<
"y.shape [ "
<<
ShapeToString
(
y
)
<<
"]"
;
<<
"y.shape [ "
<<
ShapeToString
(
y
)
<<
"]"
;
}
}
template
<
typename
T
,
bool
is_fp
=
is_floating_point_type
<
T
>
::
value
>
template
<
typename
T
,
bool
is_fp
=
is_floating_point_type
<
T
>
::
value
>
struct
Expector
;
struct
Expector
;
// Partial specialization for float and double.
// Partial specialization for float and double.
template
<
typename
T
>
template
<
typename
T
>
struct
Expector
<
T
,
true
>
{
struct
Expector
<
T
,
true
>
{
static
void
Equal
(
const
T
&
a
,
const
T
&
b
)
{
ExpectEqual
(
a
,
b
);
}
static
void
Equal
(
const
T
&
a
,
const
T
&
b
)
{
ExpectEqual
(
a
,
b
);
}
...
@@ -262,18 +266,19 @@ struct Expector<T, true> {
...
@@ -262,18 +266,19 @@ struct Expector<T, true> {
auto
a
=
x
.
data
<
T
>
();
auto
a
=
x
.
data
<
T
>
();
auto
b
=
y
.
data
<
T
>
();
auto
b
=
y
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
x
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
x
.
size
();
++
i
)
{
EXPECT_NEAR
(
a
[
i
],
b
[
i
],
abs_err
)
EXPECT_NEAR
(
a
[
i
],
b
[
i
],
abs_err
)
<<
"a = "
<<
a
<<
" b = "
<<
b
<<
"a = "
<<
a
<<
" b = "
<<
b
<<
" index = "
<<
i
;
<<
" index = "
<<
i
;
}
}
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
void
ExpectTensorNear
(
const
Tensor
&
x
,
const
Tensor
&
y
,
const
double
abs_err
)
{
void
ExpectTensorNear
(
const
Tensor
&
x
,
const
Tensor
&
y
,
const
double
abs_err
)
{
static_assert
(
is_floating_point_type
<
T
>::
value
,
"T is not a floating point type"
);
static_assert
(
is_floating_point_type
<
T
>::
value
,
"T is not a floating point type"
);
Expector
<
T
>::
Near
(
x
,
y
,
abs_err
);
Expector
<
T
>::
Near
(
x
,
y
,
abs_err
);
}
}
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_TEST_UTIL_H_
#endif
// MACE_OPS_TEST_UTIL_H_
mace/ops/pooling.cc
浏览文件 @
578b382a
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/ops/pooling.h"
#include "mace/ops/pooling.h"
namespace
mace
{
namespace
mace
{
...
@@ -11,6 +10,6 @@ REGISTER_CPU_OPERATOR(Pooling, PoolingOp<DeviceType::CPU, float>);
...
@@ -11,6 +10,6 @@ REGISTER_CPU_OPERATOR(Pooling, PoolingOp<DeviceType::CPU, float>);
#if __ARM_NEON
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
Pooling
,
PoolingOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
Pooling
,
PoolingOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
#endif
// __ARM_NEON
}
// namespace mace
}
// namespace mace
mace/ops/pooling.h
浏览文件 @
578b382a
...
@@ -11,17 +11,17 @@
...
@@ -11,17 +11,17 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
PoolingOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
class
PoolingOp
:
public
ConvPool2dOpBase
<
D
,
T
>
{
public:
public:
PoolingOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
PoolingOp
(
const
OperatorDef
&
op_def
,
Workspace
*
ws
)
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
),
:
ConvPool2dOpBase
<
D
,
T
>
(
op_def
,
ws
),
kernels_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"kernels"
)),
kernels_
(
OperatorBase
::
GetRepeatedArgument
<
int
>
(
"kernels"
)),
pooling_type_
(
static_cast
<
PoolingType
>
(
pooling_type_
(
OperatorBase
::
GetSingleArgument
<
int
>
(
static_cast
<
PoolingType
>
(
OperatorBase
::
GetSingleArgument
<
int
>
(
"pooling_type"
,
static_cast
<
int
>
(
AVG
))))
{};
"pooling_type"
,
static_cast
<
int
>
(
AVG
))))
{};
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
const
Tensor
*
input
=
this
->
Input
(
INPUT
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
std
::
vector
<
index_t
>
in_shape
=
input
->
shape
();
std
::
vector
<
index_t
>
in_shape
=
input
->
shape
();
...
@@ -33,28 +33,21 @@ public:
...
@@ -33,28 +33,21 @@ public:
filter_shape
[
1
]
=
in_shape
[
0
];
filter_shape
[
1
]
=
in_shape
[
0
];
filter_shape
[
2
]
=
kernels_
[
0
];
filter_shape
[
2
]
=
kernels_
[
0
];
filter_shape
[
3
]
=
kernels_
[
1
];
filter_shape
[
3
]
=
kernels_
[
1
];
kernels
::
CalcPaddingAndOutputSize
(
in_shape
.
data
(),
kernels
::
CalcPaddingAndOutputSize
(
in_shape
.
data
(),
filter_shape
.
data
(),
filter_shape
.
data
(),
this
->
dilations_
.
data
(),
this
->
dilations_
.
data
(),
this
->
strides_
.
data
(),
this
->
strides_
.
data
(),
this
->
padding_
,
this
->
padding_
,
output_shape
.
data
(),
paddings
.
data
());
output_shape
.
data
(),
paddings
.
data
());
output
->
Resize
(
output_shape
);
output
->
Resize
(
output_shape
);
auto
pooling_func
=
kernels
::
PoolingFunctor
<
D
,
T
>
(
pooling_type_
,
auto
pooling_func
=
kernels
::
PoolingFunctor
<
D
,
T
>
(
kernels_
.
data
(),
pooling_type_
,
kernels_
.
data
(),
this
->
strides_
.
data
(),
paddings
.
data
(),
this
->
strides_
.
data
(),
this
->
dilations_
.
data
());
paddings
.
data
(),
pooling_func
(
input
->
data
<
float
>
(),
in_shape
.
data
(),
this
->
dilations_
.
data
());
output
->
mutable_data
<
float
>
(),
output
->
shape
().
data
());
pooling_func
(
input
->
data
<
float
>
(),
in_shape
.
data
(),
output
->
mutable_data
<
float
>
(),
output
->
shape
().
data
());
return
true
;
return
true
;
};
};
protected:
protected:
std
::
vector
<
int
>
kernels_
;
std
::
vector
<
int
>
kernels_
;
PoolingType
pooling_type_
;
PoolingType
pooling_type_
;
...
@@ -62,6 +55,6 @@ protected:
...
@@ -62,6 +55,6 @@ protected:
OP_OUTPUT_TAGS
(
OUTPUT
);
OP_OUTPUT_TAGS
(
OUTPUT
);
};
};
}
// namespace mace
}
// namespace mace
#endif
//
MACE_OPS_POOLING_H_
#endif
//
MACE_OPS_POOLING_H_
mace/ops/pooling_benchmark.cc
浏览文件 @
578b382a
...
@@ -2,20 +2,19 @@
...
@@ -2,20 +2,19 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/core/testing/test_benchmark.h"
#include "mace/core/operator.h"
#include "mace/kernels/pooling.h"
#include "mace/kernels/pooling.h"
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/ops_test_util.h"
using
namespace
mace
;
using
namespace
mace
;
using
namespace
mace
::
kernels
;
using
namespace
mace
::
kernels
;
template
<
DeviceType
D
>
template
<
DeviceType
D
>
static
void
Pooling
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
static
void
Pooling
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
,
int
width
,
int
kernel
,
int
stride
,
Padding
padding
,
int
kernel
,
int
stride
,
Padding
padding
,
PoolingType
pooling_type
)
{
PoolingType
pooling_type
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
...
@@ -45,18 +44,21 @@ static void Pooling(int iters, int batch, int channels, int height,
...
@@ -45,18 +44,21 @@ static void Pooling(int iters, int batch, int channels, int height,
}
}
}
}
#define BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \
#define BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \
static void BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE( \
static void \
int iters) { \
BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE( \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
int iters) { \
mace::testing::ItemsProcessed(tot); \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot * (sizeof(float)));\
mace::testing::ItemsProcessed(tot); \
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, PoolingType::PO); \
mace::testing::BytesProcessed(tot*(sizeof(float))); \
} \
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, \
BENCHMARK(BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE)
PoolingType::PO); \
} \
BENCHMARK( \
BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE)
#define BM_POOLING(N, C, H, W, K, S, PA, PO)
\
#define BM_POOLING(N, C, H, W, K, S, PA, PO) \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU);
\
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON);
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON);
BM_POOLING
(
1
,
3
,
129
,
129
,
2
,
2
,
SAME
,
MAX
);
BM_POOLING
(
1
,
3
,
129
,
129
,
2
,
2
,
SAME
,
MAX
);
...
...
mace/ops/pooling_test.cc
浏览文件 @
578b382a
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "mace/core/operator.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/kernels/pooling.h"
#include "mace/kernels/pooling.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/ops_test_util.h"
using
namespace
mace
;
using
namespace
mace
;
...
@@ -17,9 +17,9 @@ TEST_F(PoolingOpTest, MAX_VALID) {
...
@@ -17,9 +17,9 @@ TEST_F(PoolingOpTest, MAX_VALID) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
...
@@ -29,34 +29,28 @@ TEST_F(PoolingOpTest, MAX_VALID) {
...
@@ -29,34 +29,28 @@ TEST_F(PoolingOpTest, MAX_VALID) {
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
2
,
4
,
4
},
net
.
AddInputFromArray
<
float
>
(
{
0
,
1
,
2
,
3
,
"Input"
,
{
1
,
2
,
4
,
4
},
4
,
5
,
6
,
7
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
8
,
9
,
10
,
11
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
});
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
2
,
2
,
2
},
auto
expected
=
{
5
,
7
,
13
,
15
,
21
,
23
,
29
,
31
});
CreateTensor
<
float
>
({
1
,
2
,
2
,
2
},
{
5
,
7
,
13
,
15
,
21
,
23
,
29
,
31
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
TEST_F
(
PoolingOpTest
,
AVG_VALID
)
{
TEST_F
(
PoolingOpTest
,
AVG_VALID
)
{
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
...
@@ -66,22 +60,17 @@ TEST_F(PoolingOpTest, AVG_VALID) {
...
@@ -66,22 +60,17 @@ TEST_F(PoolingOpTest, AVG_VALID) {
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
AVG
);
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
AVG
);
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
2
,
4
,
4
},
net
.
AddInputFromArray
<
float
>
(
{
0
,
1
,
2
,
3
,
"Input"
,
{
1
,
2
,
4
,
4
},
4
,
5
,
6
,
7
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
8
,
9
,
10
,
11
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
});
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
(
{
1
,
2
,
2
,
2
},
auto
expected
=
CreateTensor
<
float
>
(
{
2.5
,
4.5
,
10.5
,
12.5
,
18.5
,
20.5
,
26.5
,
28.5
});
{
1
,
2
,
2
,
2
},
{
2.5
,
4.5
,
10.5
,
12.5
,
18.5
,
20.5
,
26.5
,
28.5
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
...
@@ -90,9 +79,9 @@ TEST_F(PoolingOpTest, MAX_SAME) {
...
@@ -90,9 +79,9 @@ TEST_F(PoolingOpTest, MAX_SAME) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
...
@@ -103,16 +92,13 @@ TEST_F(PoolingOpTest, MAX_SAME) {
...
@@ -103,16 +92,13 @@ TEST_F(PoolingOpTest, MAX_SAME) {
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
3
,
3
},
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
3
,
3
},
{
0
,
1
,
2
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
});
3
,
4
,
5
,
6
,
7
,
8
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
2
},
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
2
},
{
4
,
5
,
7
,
8
});
{
4
,
5
,
7
,
8
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
...
@@ -121,9 +107,9 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
...
@@ -121,9 +107,9 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
net
.
AddIntsArg
(
"kernels"
,
{
2
,
2
});
...
@@ -133,18 +119,15 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
...
@@ -133,18 +119,15 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
4
,
4
},
net
.
AddInputFromArray
<
float
>
(
{
0
,
1
,
2
,
3
,
"Input"
,
{
1
,
1
,
4
,
4
},
4
,
5
,
6
,
7
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
});
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
});
// Run
// Run
net
.
RunOp
();
net
.
RunOp
();
// Check
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
2
},
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
2
},
{
10
,
11
,
14
,
15
});
{
10
,
11
,
14
,
15
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
...
@@ -153,9 +136,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
...
@@ -153,9 +136,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
...
@@ -165,18 +148,14 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
...
@@ -165,18 +148,14 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
4
,
5
},
net
.
AddInputFromArray
<
float
>
(
{
0
,
1
,
2
,
3
,
4
,
"Input"
,
{
1
,
1
,
4
,
5
},
5
,
6
,
7
,
8
,
9
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
});
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
});
// Run
// Run
net
.
RunOp
(
DeviceType
::
NEON
);
net
.
RunOp
(
DeviceType
::
NEON
);
// Check
// Check
Tensor
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
Tensor
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
6
,
8
,
9
,
16
,
18
,
19
});
{
6
,
8
,
9
,
16
,
18
,
19
});
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
...
@@ -185,9 +164,9 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
...
@@ -185,9 +164,9 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
// Construct graph
// Construct graph
auto
&
net
=
test_net
();
auto
&
net
=
test_net
();
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
OpDefBuilder
(
"Pooling"
,
"PoolingTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
.
Output
(
"Output"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
operator_def
());
.
Finalize
(
net
.
operator_def
());
// Add args
// Add args
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
net
.
AddIntArg
(
"pooling_type"
,
PoolingType
::
MAX
);
...
@@ -197,18 +176,14 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
...
@@ -197,18 +176,14 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
net
.
AddIntsArg
(
"dilations"
,
{
1
,
1
});
// Add input data
// Add input data
net
.
AddInputFromArray
<
float
>
(
"Input"
,
{
1
,
1
,
4
,
5
},
net
.
AddInputFromArray
<
float
>
(
{
0
,
1
,
2
,
3
,
4
,
"Input"
,
{
1
,
1
,
4
,
5
},
5
,
6
,
7
,
8
,
9
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
});
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
});
// Run
// Run
net
.
RunOp
(
DeviceType
::
NEON
);
net
.
RunOp
(
DeviceType
::
NEON
);
// Check
// Check
Tensor
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
Tensor
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
11
,
13
,
14
,
16
,
18
,
19
});
{
11
,
13
,
14
,
16
,
18
,
19
});
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
}
mace/ops/relu.cc
浏览文件 @
578b382a
...
@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>);
...
@@ -10,6 +10,6 @@ REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>);
#if __ARM_NEON
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
Relu
,
ReluOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
Relu
,
ReluOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
#endif
// __ARM_NEON
}
// namespace mace
}
// namespace mace
mace/ops/relu.h
浏览文件 @
578b382a
...
@@ -10,10 +10,10 @@
...
@@ -10,10 +10,10 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
ReluOp
:
public
Operator
<
D
,
T
>
{
class
ReluOp
:
public
Operator
<
D
,
T
>
{
public:
public:
ReluOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
ReluOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{}
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{}
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input_tensor
=
this
->
inputs_
[
0
];
const
Tensor
*
input_tensor
=
this
->
inputs_
[
0
];
...
@@ -31,6 +31,6 @@ class ReluOp : public Operator<D, T> {
...
@@ -31,6 +31,6 @@ class ReluOp : public Operator<D, T> {
kernels
::
ReluFunctor
<
D
,
T
>
functor_
;
kernels
::
ReluFunctor
<
D
,
T
>
functor_
;
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_OPS_RELU_H_
#endif
// MACE_OPS_RELU_H_
mace/ops/relu_benchmark.cc
浏览文件 @
578b382a
...
@@ -10,7 +10,6 @@
...
@@ -10,7 +10,6 @@
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
static
void
ReluBenchmark
(
int
iters
,
int
size
)
{
static
void
ReluBenchmark
(
int
iters
,
int
size
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
OpsTestNet
net
;
...
@@ -28,26 +27,25 @@ static void ReluBenchmark(int iters, int size) {
...
@@ -28,26 +27,25 @@ static void ReluBenchmark(int iters, int size) {
}
}
mace
::
testing
::
StartTiming
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
while
(
iters
--
)
{
net
.
RunOp
(
D
);
net
.
RunOp
(
D
);
}
}
}
}
#define BM_RELU_MACRO(SIZE, TYPE, DEVICE) \
#define BM_RELU_MACRO(SIZE, TYPE, DEVICE) \
static void BM_RELU_##SIZE##_##TYPE##_##DEVICE( \
static void BM_RELU_##SIZE##_##TYPE##_##DEVICE(int iters) { \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * SIZE; \
const int64_t tot = static_cast<int64_t>(iters) * SIZE; \
mace::testing::ItemsProcessed(tot); \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \
ReluBenchmark<DEVICE, TYPE>(iters, SIZE); \
} \
} \
BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE)
BENCHMARK(BM_RELU_##SIZE##_##TYPE##_##DEVICE)
#define BM_RELU(SIZE, TYPE)
\
#define BM_RELU(SIZE, TYPE) \
BM_RELU_MACRO(SIZE, TYPE, CPU);
\
BM_RELU_MACRO(SIZE, TYPE, CPU); \
BM_RELU_MACRO(SIZE, TYPE, NEON);
BM_RELU_MACRO(SIZE, TYPE, NEON);
BM_RELU
(
1000
,
float
);
BM_RELU
(
1000
,
float
);
BM_RELU
(
100000
,
float
);
BM_RELU
(
100000
,
float
);
BM_RELU
(
10000000
,
float
);
BM_RELU
(
10000000
,
float
);
}
// namespace mace
}
// namespace mace
\ No newline at end of file
\ No newline at end of file
mace/ops/relu_test.cc
浏览文件 @
578b382a
...
@@ -32,4 +32,4 @@ TEST_F(ReluOpTest, ReluOp) {
...
@@ -32,4 +32,4 @@ TEST_F(ReluOpTest, ReluOp) {
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.01
);
}
}
}
// namespace mace
}
// namespace mace
mace/ops/resize_bilinear.cc
浏览文件 @
578b382a
...
@@ -9,7 +9,8 @@ namespace mace {
...
@@ -9,7 +9,8 @@ namespace mace {
REGISTER_CPU_OPERATOR
(
ResizeBilinear
,
ResizeBilinearOp
<
DeviceType
::
CPU
,
float
>
);
REGISTER_CPU_OPERATOR
(
ResizeBilinear
,
ResizeBilinearOp
<
DeviceType
::
CPU
,
float
>
);
#if __ARM_NEON
#if __ARM_NEON
REGISTER_NEON_OPERATOR
(
ResizeBilinear
,
ResizeBilinearOp
<
DeviceType
::
NEON
,
float
>
);
REGISTER_NEON_OPERATOR
(
ResizeBilinear
,
#endif // __ARM_NEON
ResizeBilinearOp
<
DeviceType
::
NEON
,
float
>
);
#endif // __ARM_NEON
}
// namespace mace
}
// namespace mace
mace/ops/resize_bilinear.h
浏览文件 @
578b382a
...
@@ -5,18 +5,18 @@
...
@@ -5,18 +5,18 @@
#ifndef MACE_RESIZE_BILINEAR_H
#ifndef MACE_RESIZE_BILINEAR_H
#define MACE_RESIZE_BILINEAR_H
#define MACE_RESIZE_BILINEAR_H
#include "mace/core/operator.h"
#include "mace/core/operator.h"
#include "mace/kernels/resize_bilinear.h"
#include "mace/kernels/resize_bilinear.h"
namespace
mace
{
namespace
mace
{
template
<
DeviceType
D
,
class
T
>
template
<
DeviceType
D
,
class
T
>
class
ResizeBilinearOp
:
public
Operator
<
D
,
T
>
{
class
ResizeBilinearOp
:
public
Operator
<
D
,
T
>
{
public:
public:
ResizeBilinearOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
ResizeBilinearOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
functor_
(
OperatorBase
::
GetSingleArgument
<
bool
>
(
"align_corners"
,
false
))
{}
functor_
(
OperatorBase
::
GetSingleArgument
<
bool
>
(
"align_corners"
,
false
))
{}
bool
Run
()
override
{
bool
Run
()
override
{
const
Tensor
*
input
=
this
->
Input
(
0
);
const
Tensor
*
input
=
this
->
Input
(
0
);
...
@@ -24,8 +24,8 @@ class ResizeBilinearOp : public Operator<D, T> {
...
@@ -24,8 +24,8 @@ class ResizeBilinearOp : public Operator<D, T> {
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional."
,
MACE_CHECK
(
input
->
dim_size
()
==
4
,
"input must be 4-dimensional."
,
input
->
dim_size
());
input
->
dim_size
());
MACE_CHECK
(
resize_dims
->
dim_size
()
==
1
,
"resize dim must be 2-dimensional."
,
MACE_CHECK
(
resize_dims
->
dim_size
()
==
1
,
resize_dims
->
dim_size
());
"resize dim must be 2-dimensional."
,
resize_dims
->
dim_size
());
Tensor
*
output
=
this
->
Output
(
0
);
Tensor
*
output
=
this
->
Output
(
0
);
...
@@ -35,7 +35,7 @@ class ResizeBilinearOp : public Operator<D, T> {
...
@@ -35,7 +35,7 @@ class ResizeBilinearOp : public Operator<D, T> {
index_t
in_width
=
input
->
dim
(
3
);
index_t
in_width
=
input
->
dim
(
3
);
index_t
out_height
=
resize_dims
->
data
<
index_t
>
()[
0
];
index_t
out_height
=
resize_dims
->
data
<
index_t
>
()[
0
];
index_t
out_width
=
resize_dims
->
data
<
index_t
>
()[
1
];
index_t
out_width
=
resize_dims
->
data
<
index_t
>
()[
1
];
vector
<
index_t
>
out_shape
{
n
,
channels
,
out_height
,
out_width
};
vector
<
index_t
>
out_shape
{
n
,
channels
,
out_height
,
out_width
};
output
->
Resize
(
out_shape
);
output
->
Resize
(
out_shape
);
const
T
*
input_ptr
=
input
->
data
<
T
>
();
const
T
*
input_ptr
=
input
->
data
<
T
>
();
...
@@ -45,10 +45,11 @@ class ResizeBilinearOp : public Operator<D, T> {
...
@@ -45,10 +45,11 @@ class ResizeBilinearOp : public Operator<D, T> {
out_height
,
out_width
);
out_height
,
out_width
);
return
true
;
return
true
;
}
}
private:
private:
kernels
::
ResizeBilinearFunctor
<
D
,
T
>
functor_
;
kernels
::
ResizeBilinearFunctor
<
D
,
T
>
functor_
;
};
};
}
// namespace mace
}
// namespace mace
#endif // MACE_RESIZE_BILINEAR_H
#endif
// MACE_RESIZE_BILINEAR_H
mace/ops/resize_bilinear_test.cc
浏览文件 @
578b382a
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//
#include "mace/ops/resize_bilinear.h"
#include "mace/core/operator.h"
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/resize_bilinear.h"
using
namespace
mace
;
using
namespace
mace
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录