Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
e8824833
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e8824833
编写于
12月 04, 2017
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update cpu batch norm to adapt locality, modify op to use template dtype
上级
99963c98
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
52 addition
and
30 deletion
+52
-30
mace/kernels/batch_norm.h
mace/kernels/batch_norm.h
+31
-18
mace/kernels/opencl/batch_norm_opencl.cc
mace/kernels/opencl/batch_norm_opencl.cc
+9
-5
mace/kernels/opencl/cl/batch_norm.cl
mace/kernels/opencl/cl/batch_norm.cl
+6
-6
mace/ops/batch_norm.cc
mace/ops/batch_norm.cc
+6
-1
未找到文件。
mace/kernels/batch_norm.h
浏览文件 @
e8824833
...
...
@@ -28,8 +28,11 @@ struct BatchNormFunctor {
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const
index_t
ch_pixel_size
=
input
->
dim
(
0
)
*
input
->
dim
(
1
)
*
input
->
dim
(
2
);
const
index_t
channel
=
input
->
dim
(
3
);
const
index_t
batchs
=
input
->
dim
(
0
);
const
index_t
height
=
input
->
dim
(
1
);
const
index_t
width
=
input
->
dim
(
2
);
const
index_t
height_width
=
height
*
width
;
const
index_t
channels
=
input
->
dim
(
3
);
Tensor
::
MappingGuard
input_mapper
(
input
);
Tensor
::
MappingGuard
scale_mapper
(
scale
);
...
...
@@ -47,15 +50,24 @@ struct BatchNormFunctor {
const
T
*
epsilon_ptr
=
epsilon
->
data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
vector
<
T
>
new_scale
(
channels
);
vector
<
T
>
new_offset
(
channels
);
#pragma omp parallel for
for
(
index_t
c
=
0
;
c
<
channel
;
++
c
)
{
T
new_scale
=
scale_ptr
[
c
]
/
std
::
sqrt
(
var_ptr
[
c
]
+
*
epsilon_ptr
);
T
new_offset
=
offset_ptr
[
c
]
-
mean_ptr
[
c
]
*
new_scale
;
index_t
pos
=
c
;
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
new_scale
[
c
]
=
scale_ptr
[
c
]
/
std
::
sqrt
(
var_ptr
[
c
]
+
*
epsilon_ptr
);
new_offset
[
c
]
=
offset_ptr
[
c
]
-
mean_ptr
[
c
]
*
new_scale
[
c
];
}
index_t
pos
=
0
;
for
(
index_t
i
=
0
;
i
<
ch_pixel_size
;
++
i
)
{
output_ptr
[
pos
]
=
new_scale
*
input_ptr
[
pos
]
+
new_offset
;
pos
+=
channel
;
#pragma omp parallel for
for
(
index_t
n
=
0
;
n
<
batchs
;
++
n
)
{
for
(
index_t
hb
=
0
;
hb
<
height_width
;
++
hb
)
{
for
(
index_t
c
=
0
;
c
<
channels
;
++
c
)
{
output_ptr
[
pos
]
=
new_scale
[
c
]
*
input_ptr
[
pos
]
+
new_offset
[
c
];
++
pos
;
}
}
}
}
...
...
@@ -71,15 +83,16 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const
Tensor
*
epsilon
,
Tensor
*
output
);
template
<
>
void
BatchNormFunctor
<
DeviceType
::
OPENCL
,
float
>::
operator
()(
const
Tensor
*
input
,
const
Tensor
*
scale
,
const
Tensor
*
offset
,
const
Tensor
*
mean
,
const
Tensor
*
var
,
const
Tensor
*
epsilon
,
Tensor
*
output
);
template
<
typename
T
>
struct
BatchNormFunctor
<
DeviceType
::
OPENCL
,
T
>
{
void
operator
()(
const
Tensor
*
input
,
const
Tensor
*
scale
,
const
Tensor
*
offset
,
const
Tensor
*
mean
,
const
Tensor
*
var
,
const
Tensor
*
epsilon
,
Tensor
*
output
);
};
}
// namepsace kernels
}
// namespace mace
...
...
mace/kernels/opencl/batch_norm_opencl.cc
浏览文件 @
e8824833
...
...
@@ -11,8 +11,8 @@
namespace
mace
{
namespace
kernels
{
template
<
>
void
BatchNormFunctor
<
DeviceType
::
OPENCL
,
float
>::
operator
()(
template
<
typename
T
>
void
BatchNormFunctor
<
DeviceType
::
OPENCL
,
T
>::
operator
()(
const
Tensor
*
input
,
const
Tensor
*
scale
,
const
Tensor
*
offset
,
...
...
@@ -27,7 +27,6 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const
index_t
channels
=
input
->
dim
(
3
);
const
index_t
channel_blocks
=
RoundUpDiv4
(
channels
);
const
index_t
width_blocks
=
RoundUpDiv4
(
width
);
const
uint32_t
gws
[
3
]
=
{
static_cast
<
uint32_t
>
(
channel_blocks
),
static_cast
<
uint32_t
>
(
width
),
...
...
@@ -35,8 +34,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
auto
runtime
=
OpenCLRuntime
::
Get
();
std
::
set
<
std
::
string
>
built_options
;
built_options
.
emplace
(
"-DDATA_TYPE="
+
DtToUpstreamCLDt
(
input
->
dtype
()));
built_options
.
emplace
(
"-DCMD_DATA_TYPE="
+
DtToUpstreamCLCMDDt
(
input
->
dtype
()));
auto
dt
=
DataTypeToEnum
<
T
>::
value
;
built_options
.
emplace
(
"-DDATA_TYPE="
+
DtToUpstreamCLDt
(
dt
));
built_options
.
emplace
(
"-DCMD_DATA_TYPE="
+
DtToUpstreamCLCMDDt
(
dt
));
auto
bm_kernel
=
runtime
->
BuildKernel
(
"batch_norm"
,
"batch_norm"
,
built_options
);
const
uint32_t
kwg_size
=
runtime
->
GetKernelMaxWorkGroupSize
(
bm_kernel
);
...
...
@@ -83,5 +83,9 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
func
);
}
template
struct
BatchNormFunctor
<
DeviceType
::
OPENCL
,
float
>;
template
struct
BatchNormFunctor
<
DeviceType
::
OPENCL
,
half
>;
}
// namespace kernels
}
// namespace mace
mace/kernels/opencl/cl/batch_norm.cl
浏览文件 @
e8824833
#
include
<common.h>
//
Supported
data
types:
half/float
__kernel
void
batch_norm
(
__read_only
image2d_t
input,
__read_only
image2d_t
scale,
__read_only
image2d_t
offset,
__read_only
image2d_t
mean,
__read_only
image2d_t
var,
global
const
DATA_TYPE
*epsilon,
__write_only
image2d_t
output
)
{
__read_only
image2d_t
scale,
__read_only
image2d_t
offset,
__read_only
image2d_t
mean,
__read_only
image2d_t
var,
__
global
const
DATA_TYPE
*epsilon,
__write_only
image2d_t
output
)
{
const
int
ch_blk
=
get_global_id
(
0
)
;
const
int
w_blk
=
get_global_id
(
1
)
;
const
int
hb_blk
=
get_global_id
(
2
)
;
...
...
mace/ops/batch_norm.cc
浏览文件 @
e8824833
...
...
@@ -23,4 +23,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("BatchNorm")
.
Build
(),
BatchNormOp
<
DeviceType
::
OPENCL
,
float
>
);
}
// namespace mace
\ No newline at end of file
REGISTER_OPENCL_OPERATOR
(
OpKeyBuilder
(
"BatchNorm"
)
.
TypeConstraint
<
half
>
(
"T"
)
.
Build
(),
BatchNormOp
<
DeviceType
::
OPENCL
,
half
>
);
}
// namespace mace
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录