Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
3f803f84
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3f803f84
编写于
2月 13, 2018
作者:
刘
刘琦
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'bn' into 'master'
Update BatchNorm CPU kernel See merge request !238
上级
ac6fc00c
292f80aa
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
43 addition
and
22 deletion
+43
-22
mace/kernels/BUILD
mace/kernels/BUILD
+1
-3
mace/kernels/batch_norm.h
mace/kernels/batch_norm.h
+41
-10
mace/ops/batch_norm.cc
mace/ops/batch_norm.cc
+1
-9
未找到文件。
mace/kernels/BUILD
浏览文件 @
3f803f84
...
...
@@ -14,9 +14,7 @@ cc_library(
srcs
=
glob
([
"*.cc"
,
"opencl/*.cc"
,
])
+
if_neon_enabled
(
glob
([
"neon/batch_norm_neon.cc"
,
])),
]),
hdrs
=
glob
([
"*.h"
,
"opencl/*.h"
,
...
...
mace/kernels/batch_norm.h
浏览文件 @
3f803f84
...
...
@@ -5,11 +5,15 @@
#ifndef MACE_KERNELS_BATCH_NORM_H_
#define MACE_KERNELS_BATCH_NORM_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/future.h"
#include "mace/core/public/mace.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace
mace
{
namespace
kernels
{
...
...
@@ -86,17 +90,44 @@ struct BatchNormFunctor : BatchNormFunctorBase {
}
}
const
T
*
scale_data
=
folded_constant_
?
scale_ptr
:
new_scale
.
data
();
const
T
*
offset_data
=
folded_constant_
?
offset_ptr
:
new_offset
.
data
();
#pragma omp parallel for collapse(4)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
index_t
pos
=
(((
n
*
height
)
+
h
)
*
width
+
w
)
*
channels
+
c
;
const
T
*
scale_data
=
folded_constant_
?
scale_ptr
:
new_scale
.
data
();
const
T
*
offset_data
=
folded_constant_
?
offset_ptr
:
new_offset
.
data
();
const
int
elements
=
batch
*
height
*
width
;
constexpr
int
c_tile_size
=
4
;
const
int
c_tiles
=
channels
/
c_tile_size
;
const
index_t
remains_start
=
c_tiles
*
c_tile_size
;
if
(
c_tiles
>
0
)
{
#pragma omp parallel for collapse(2)
for
(
index_t
i
=
0
;
i
<
elements
;
++
i
)
{
for
(
int
cb
=
0
;
cb
<
c_tiles
;
++
cb
)
{
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert
(
c_tile_size
==
4
,
"channels tile size must be 4"
);
int
c
=
cb
*
c_tile_size
;
int
pos
=
i
*
channels
+
c
;
float32x4_t
scales
=
vld1q_f32
(
scale_data
+
c
);
float32x4_t
offsets
=
vld1q_f32
(
offset_data
+
c
);
float32x4_t
in
=
vld1q_f32
(
input_ptr
+
pos
);
float32x4_t
out
=
vfmaq_f32
(
offsets
,
scales
,
in
);
vst1q_f32
(
output_ptr
+
pos
,
out
);
#else
for
(
int
ci
=
0
;
ci
<
c_tile_size
;
++
ci
)
{
int
c
=
cb
*
c_tile_size
+
ci
;
index_t
pos
=
i
*
channels
+
c
;
output_ptr
[
pos
]
=
scale_data
[
c
]
*
input_ptr
[
pos
]
+
offset_data
[
c
];
}
#endif
}
}
}
if
(
remains_start
<
channels
)
{
#pragma omp parallel for collapse(2)
for
(
index_t
i
=
0
;
i
<
elements
;
++
i
)
{
for
(
index_t
c
=
remains_start
;
c
<
channels
;
++
c
)
{
index_t
pos
=
i
*
channels
+
c
;
output_ptr
[
pos
]
=
scale_data
[
c
]
*
input_ptr
[
pos
]
+
offset_data
[
c
];
}
}
}
...
...
mace/ops/batch_norm.cc
浏览文件 @
3f803f84
...
...
@@ -13,14 +13,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
.
Build
(),
BatchNormOp
<
DeviceType
::
CPU
,
float
>
);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"BatchNorm"
)
.
Device
(
DeviceType
::
NEON
)
.
TypeConstraint
<
float
>
(
"T"
)
.
Build
(),
BatchNormOp
<
DeviceType
::
NEON
,
float
>
);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR
(
op_registry
,
OpKeyBuilder
(
"BatchNorm"
)
.
Device
(
DeviceType
::
OPENCL
)
.
TypeConstraint
<
float
>
(
"T"
)
...
...
@@ -34,4 +26,4 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
BatchNormOp
<
DeviceType
::
OPENCL
,
half
>
);
}
}
//
namespace mace
}
// namespace mace
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录