Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d020229c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d020229c
编写于
6月 14, 2019
作者:
H
Hong Ming
提交者:
Tensor Tang
6月 14, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable conv_winograd, fix conv_gemmlike bug, and update the unit tests of conv op
test=develop
上级
5f0d7166
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
810 addition
and
0 deletion
+810
-0
paddle/fluid/lite/arm/math/scale.cc
paddle/fluid/lite/arm/math/scale.cc
+105
-0
paddle/fluid/lite/arm/math/scale.h
paddle/fluid/lite/arm/math/scale.h
+8
-0
paddle/fluid/lite/kernels/arm/CMakeLists.txt
paddle/fluid/lite/kernels/arm/CMakeLists.txt
+3
-0
paddle/fluid/lite/kernels/arm/batch_norm_compute.cc
paddle/fluid/lite/kernels/arm/batch_norm_compute.cc
+114
-0
paddle/fluid/lite/kernels/arm/batch_norm_compute.h
paddle/fluid/lite/kernels/arm/batch_norm_compute.h
+42
-0
paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc
paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc
+221
-0
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+3
-0
paddle/fluid/lite/operators/batch_norm_op.cc
paddle/fluid/lite/operators/batch_norm_op.cc
+110
-0
paddle/fluid/lite/operators/batch_norm_op.h
paddle/fluid/lite/operators/batch_norm_op.h
+46
-0
paddle/fluid/lite/operators/batch_norm_op_test.cc
paddle/fluid/lite/operators/batch_norm_op_test.cc
+139
-0
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+19
-0
未找到文件。
paddle/fluid/lite/arm/math/scale.cc
浏览文件 @
d020229c
...
@@ -58,6 +58,111 @@ void scale<float>(const float* din, float* dout, int num, float scale,
...
@@ -58,6 +58,111 @@ void scale<float>(const float* din, float* dout, int num, float scale,
}
}
}
}
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_dim
,
int
scale_dim
,
int
inner_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
)
{
int
cnt
=
inner_dim
>>
4
;
int
remain
=
inner_dim
%
16
;
int
size
=
inner_dim
*
scale_dim
;
for
(
int
n
=
0
;
n
<
outer_dim
;
n
++
)
{
const
float
*
din_ptr_n
=
din
+
n
*
size
;
float
*
dout_ptr_n
=
dout
+
n
*
size
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
scale_dim
;
i
++
)
{
const
float
*
din_ptr
=
din_ptr_n
+
i
*
inner_dim
;
float
*
dout_ptr
=
dout_ptr_n
+
i
*
inner_dim
;
float
scale
=
scale_data
[
i
];
float32x4_t
vscale
=
vdupq_n_f32
(
scale
);
float
bias
=
bias_data
[
i
];
float32x4_t
vbias
=
vdupq_n_f32
(
bias
);
for
(
int
j
=
0
;
j
<
cnt
;
j
++
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias
,
din0
,
vscale
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias
,
din1
,
vscale
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias
,
din2
,
vscale
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias
,
din3
,
vscale
);
din_ptr
+=
16
;
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
dout_ptr
+=
16
;
}
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
*
dout_ptr
=
*
din_ptr
*
scale
+
bias
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_dim
,
int
scale_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
)
{
int
cnt
=
scale_dim
>>
4
;
int
remain
=
scale_dim
%
16
;
for
(
int
n
=
0
;
n
<
outer_dim
;
n
++
)
{
const
float
*
din_ptr_n
=
din
+
n
*
scale_dim
;
float
*
dout_ptr_n
=
dout
+
n
*
scale_dim
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
int
idx
=
i
<<
4
;
const
float
*
din_ptr
=
din_ptr_n
+
idx
;
const
float
*
scale_ptr
=
scale_data
+
idx
;
const
float
*
bias_ptr
=
bias_data
+
idx
;
float
*
dout_ptr
=
dout_ptr_n
+
idx
;
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vscale0
=
vld1q_f32
(
scale_ptr
);
float32x4_t
vbias0
=
vld1q_f32
(
bias_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
vscale1
=
vld1q_f32
(
scale_ptr
+
4
);
float32x4_t
vbias1
=
vld1q_f32
(
bias_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
vscale2
=
vld1q_f32
(
scale_ptr
+
8
);
float32x4_t
vbias2
=
vld1q_f32
(
bias_ptr
+
8
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias0
,
din0
,
vscale0
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias1
,
din1
,
vscale1
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vscale3
=
vld1q_f32
(
scale_ptr
+
12
);
float32x4_t
vbias3
=
vld1q_f32
(
bias_ptr
+
12
);
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias2
,
din2
,
vscale2
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias3
,
din3
,
vscale3
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
}
int
idx
=
cnt
<<
4
;
const
float
*
din_ptr
=
din_ptr_n
+
idx
;
float
*
dout_ptr
=
dout_ptr_n
+
idx
;
const
float
*
scale_ptr
=
scale_data
+
idx
;
const
float
*
bias_ptr
=
bias_data
+
idx
;
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
*
dout_ptr
=
*
din_ptr
*
(
*
scale_ptr
)
+
(
*
bias_ptr
);
dout_ptr
++
;
din_ptr
++
;
scale_ptr
++
;
bias_ptr
++
;
}
}
}
}
// namespace math
}
// namespace math
}
// namespace arm
}
// namespace arm
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/arm/math/scale.h
浏览文件 @
d020229c
...
@@ -22,6 +22,14 @@ namespace math {
...
@@ -22,6 +22,14 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
num
,
float
scale
,
float
bias
);
void
scale
(
const
T
*
din
,
T
*
dout
,
int
num
,
float
scale
,
float
bias
);
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
outer_dim
,
int
scale_dim
,
int
inner_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
);
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
outer_dim
,
int
scale_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
);
}
// namespace math
}
// namespace math
}
// namespace arm
}
// namespace arm
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/kernels/arm/CMakeLists.txt
浏览文件 @
d020229c
...
@@ -10,6 +10,7 @@ cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm
...
@@ -10,6 +10,7 @@ cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm
cc_library
(
scale_compute_arm SRCS scale_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
scale_compute_arm SRCS scale_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
softmax_compute_arm SRCS softmax_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
softmax_compute_arm SRCS softmax_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
conv_compute_arm SRCS conv_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
conv_compute_arm SRCS conv_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
pool_compute_arm SRCS pool_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
pool_compute_arm SRCS pool_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
split_compute_arm SRCS split_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
split_compute_arm SRCS split_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
...
@@ -18,6 +19,7 @@ lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm mat
...
@@ -18,6 +19,7 @@ lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm mat
lite_cc_test
(
test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm
)
lite_cc_test
(
test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm
)
lite_cc_test
(
test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm
)
lite_cc_test
(
test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm
)
lite_cc_test
(
test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm
)
lite_cc_test
(
test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm
)
lite_cc_test
(
test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm
)
lite_cc_test
(
test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm
)
lite_cc_test
(
test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm
)
lite_cc_test
(
test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm
)
lite_cc_test
(
test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm
)
lite_cc_test
(
test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm
)
lite_cc_test
(
test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm
)
...
@@ -30,6 +32,7 @@ set(arm_kernels
...
@@ -30,6 +32,7 @@ set(arm_kernels
scale_compute_arm
scale_compute_arm
softmax_compute_arm
softmax_compute_arm
conv_compute_arm
conv_compute_arm
batch_norm_compute_arm
elementwise_add_compute_arm
elementwise_add_compute_arm
pool_compute_arm
pool_compute_arm
split_compute_arm
split_compute_arm
...
...
paddle/fluid/lite/kernels/arm/batch_norm_compute.cc
0 → 100644
浏览文件 @
d020229c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
arm
{
void
BatchNormCompute
::
PrepareForRun
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
x_dims
=
param
.
x
->
dims
();
bool
global_stats
=
param
.
is_test
||
param
.
use_global_stats
;
if
(
global_stats
)
{
int64_t
channel_size
=
0
;
switch
(
param
.
data_layout
)
{
case
DATALAYOUT
(
kNCHW
):
channel_size
=
x_dims
[
1
];
break
;
// case DATALAYOUT(kNHWC):
// channel_size = x_dims[x_dims.size() - 1];
// break;
default:
LOG
(
FATAL
)
<<
"Unknown storage order: "
<<
DataLayoutToStr
(
param
.
data_layout
);
break
;
}
new_scale
.
Resize
({
channel_size
});
new_bias
.
Resize
({
channel_size
});
auto
*
scale_data
=
param
.
scale
->
mutable_data
<
float
>
();
auto
*
bias_data
=
param
.
bias
->
mutable_data
<
float
>
();
auto
*
mean_data
=
param
.
mean
->
mutable_data
<
float
>
();
auto
*
variance_data
=
param
.
variance
->
mutable_data
<
float
>
();
auto
*
new_scale_data
=
new_scale
.
mutable_data
<
float
>
();
auto
*
new_bias_data
=
new_bias
.
mutable_data
<
float
>
();
for
(
int
c
=
0
;
c
<
channel_size
;
c
++
)
{
float
inv_scale
=
1.
f
/
(
std
::
sqrt
(
variance_data
[
c
]
+
param
.
epsilon
));
new_bias_data
[
c
]
=
bias_data
[
c
]
-
inv_scale
*
scale_data
[
c
]
*
mean_data
[
c
];
new_scale_data
[
c
]
=
inv_scale
*
scale_data
[
c
];
}
}
}
void
BatchNormCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
x_dims
=
param
.
x
->
dims
();
auto
x_data
=
param
.
x
->
mutable_data
<
float
>
();
auto
y_data
=
param
.
y
->
mutable_data
<
float
>
();
bool
global_stats
=
param
.
is_test
||
param
.
use_global_stats
;
if
(
global_stats
)
{
auto
*
new_scale_data
=
new_scale
.
mutable_data
<
float
>
();
auto
*
new_bias_data
=
new_bias
.
mutable_data
<
float
>
();
int64_t
outer_size
=
0
;
int64_t
channel_size
=
0
;
int64_t
inner_size
=
0
;
switch
(
param
.
data_layout
)
{
case
DATALAYOUT
(
kNCHW
):
outer_size
=
x_dims
[
0
];
channel_size
=
x_dims
[
1
];
inner_size
=
x_dims
.
Slice
(
2
,
x_dims
.
size
()).
production
();
lite
::
arm
::
math
::
scale
(
x_data
,
y_data
,
outer_size
,
channel_size
,
inner_size
,
new_scale_data
,
new_bias_data
);
break
;
// case DATALAYOUT(kNHWC):
// outer_size = x_dims.Slice(0, x_dims.size() - 1).production();
// channel_size = x_dims[x_dims.size() - 1];
// lite::arm::math::scale(x_data, y_data, outer_size, channel_size,
// new_scale_data, new_bias_data);
// break;
default:
LOG
(
FATAL
)
<<
"Unknown storage order: "
<<
DataLayoutToStr
(
param
.
data_layout
);
break
;
}
}
else
{
// TODO(hong19860320) calculate mean_out, variance_out, saved_mean and
// saved_variance
}
}
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
batch_norm
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
BatchNormCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Scale"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Mean"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Variance"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"MeanOut"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"VarianceOut"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"SavedMean"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"SavedVariance"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
paddle/fluid/lite/kernels/arm/batch_norm_compute.h
0 → 100644
浏览文件 @
d020229c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
arm
{
class
BatchNormCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
BatchNormParam
;
void
PrepareForRun
()
override
;
void
Run
()
override
;
virtual
~
BatchNormCompute
()
=
default
;
private:
Tensor
new_scale
;
Tensor
new_bias
;
};
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/kernels/arm/batch_norm_compute_test.cc
0 → 100644
浏览文件 @
d020229c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/batch_norm_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
arm
{
template
<
typename
dtype
>
void
batch_norm_compute_ref
(
const
operators
::
BatchNormParam
&
param
)
{
DDim
x_dims
=
param
.
x
->
dims
();
auto
x_data
=
param
.
x
->
mutable_data
<
dtype
>
();
auto
scale_data
=
param
.
scale
->
mutable_data
<
dtype
>
();
auto
bias_data
=
param
.
bias
->
mutable_data
<
dtype
>
();
auto
mean_data
=
param
.
mean
->
mutable_data
<
dtype
>
();
auto
variance_data
=
param
.
variance
->
mutable_data
<
dtype
>
();
auto
y_data
=
param
.
y
->
mutable_data
<
dtype
>
();
float
epsilon
=
param
.
epsilon
;
float
momentum
=
param
.
momentum
;
DataLayoutType
data_layout
=
param
.
data_layout
;
bool
global_stats
=
param
.
is_test
||
param
.
use_global_stats
;
if
(
global_stats
)
{
int64_t
outer_size
=
0
;
int64_t
channel_size
=
0
;
int64_t
inner_size
=
0
;
switch
(
data_layout
)
{
case
DATALAYOUT
(
kNCHW
):
outer_size
=
x_dims
[
0
];
channel_size
=
x_dims
[
1
];
inner_size
=
x_dims
.
Slice
(
2
,
x_dims
.
size
()).
production
();
break
;
// case DATALAYOUT(kNHWC):
// outer_size = x_dims.Slice(0, x_dims.size() - 1).production();
// channel_size = x_dims[x_dims.size() - 1];
// inner_size = 1;
// break;
default:
LOG
(
FATAL
)
<<
"Unknown storage order: "
<<
DataLayoutToStr
(
data_layout
);
break
;
}
auto
x_ptr
=
x_data
;
auto
y_ptr
=
y_data
;
for
(
int
o
=
0
;
o
<
outer_size
;
o
++
)
{
for
(
int
c
=
0
;
c
<
channel_size
;
c
++
)
{
for
(
int
i
=
0
;
i
<
inner_size
;
i
++
)
{
dtype
norm_x
=
(
*
x_ptr
-
mean_data
[
c
])
/
std
::
sqrt
(
variance_data
[
c
]
+
epsilon
);
*
y_ptr
=
norm_x
*
scale_data
[
c
]
+
bias_data
[
c
];
x_ptr
++
;
y_ptr
++
;
}
}
}
}
else
{
// TODO(hong19860320) calculate mean_out, variance_out, saved_mean and
// saved_variance
}
}
TEST
(
batch_norm_arm
,
retrive_op
)
{
auto
batch_norm
=
KernelRegistry
::
Global
().
Create
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
(
"batch_norm"
);
ASSERT_FALSE
(
batch_norm
.
empty
());
ASSERT_TRUE
(
batch_norm
.
front
());
}
TEST
(
batch_norm_arm
,
init
)
{
BatchNormCompute
batch_norm
;
ASSERT_EQ
(
batch_norm
.
precision
(),
PRECISION
(
kFloat
));
ASSERT_EQ
(
batch_norm
.
target
(),
TARGET
(
kARM
));
}
TEST
(
batch_norm_arm
,
compute
)
{
DeviceInfo
::
Init
();
for
(
auto
n
:
{
1
,
2
})
{
for
(
auto
c
:
{
6
,
32
/*, 128*/
})
{
for
(
auto
h
:
{
9
,
18
/*, 56 , 112, 224, 512*/
})
{
for
(
auto
w
:
{
9
,
18
/*, 56, 112, 224, 512*/
})
{
for
(
auto
is_test
:
{
/*false, */
true
})
{
for
(
auto
use_global_stats
:
{
false
,
true
})
{
for
(
auto
epsilon
:
{
1e-4
f
,
1e-5
f
})
{
for
(
auto
momentum
:
{
0.9
f
,
0.99
f
})
{
for
(
auto
data_layout
:
{
DATALAYOUT
(
kNCHW
)
/*, DATALAYOUT(kNHWC)*/
})
{
Tensor
x
;
Tensor
scale
;
Tensor
bias
;
Tensor
mean
;
Tensor
variance
;
Tensor
y
;
Tensor
mean_out
;
Tensor
variance_out
;
Tensor
saved_mean
;
Tensor
saved_variance
;
Tensor
y_ref
;
Tensor
mean_out_ref
;
Tensor
variance_out_ref
;
Tensor
saved_mean_ref
;
Tensor
saved_variance_ref
;
// set the dims of input, output, ref output tensors
std
::
vector
<
int64_t
>
in_out_shape
;
switch
(
data_layout
)
{
case
DATALAYOUT
(
kNCHW
):
in_out_shape
=
{
n
,
c
,
h
,
w
};
break
;
// case DATALAYOUT(kNHWC):
// in_out_shape = {n, h, w, c};
// break;
default:
LOG
(
FATAL
)
<<
"Unknown storage order: "
<<
DataLayoutToStr
(
data_layout
);
break
;
}
x
.
Resize
(
in_out_shape
);
scale
.
Resize
({
c
});
bias
.
Resize
({
c
});
mean
.
Resize
({
c
});
variance
.
Resize
({
c
});
y
.
Resize
(
in_out_shape
);
mean_out
.
Resize
({
c
});
variance_out
.
Resize
({
c
});
saved_mean
.
Resize
({
c
});
saved_variance
.
Resize
({
c
});
y_ref
.
Resize
(
in_out_shape
);
mean_out_ref
.
Resize
({
c
});
variance_out_ref
.
Resize
({
c
});
saved_mean_ref
.
Resize
({
c
});
saved_variance_ref
.
Resize
({
c
});
// initialize the data of input tensors
auto
*
x_data
=
x
.
mutable_data
<
float
>
();
auto
*
scale_data
=
scale
.
mutable_data
<
float
>
();
auto
*
bias_data
=
bias
.
mutable_data
<
float
>
();
auto
*
mean_data
=
mean
.
mutable_data
<
float
>
();
auto
*
variance_data
=
variance
.
mutable_data
<
float
>
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x
.
dims
().
production
();
i
++
)
{
x_data
[
i
]
=
static_cast
<
float
>
(
i
%
64
);
}
for
(
int
i
=
0
;
i
<
scale
.
dims
().
production
();
i
++
)
{
scale_data
[
i
]
=
static_cast
<
float
>
(
i
)
*
0.01
f
+
0.03
f
;
}
for
(
int
i
=
0
;
i
<
bias
.
dims
().
production
();
i
++
)
{
bias_data
[
i
]
=
static_cast
<
float
>
(
i
)
*
0.065
f
+
0.1
f
;
}
for
(
int
i
=
0
;
i
<
mean
.
dims
().
production
();
i
++
)
{
mean_data
[
i
]
=
static_cast
<
float
>
(
i
)
*
0.0565
f
;
}
for
(
int
i
=
0
;
i
<
variance
.
dims
().
production
();
i
++
)
{
variance_data
[
i
]
=
static_cast
<
float
>
(
i
)
*
2.08
f
+
1.5
f
;
}
// prepare kernel params and run
BatchNormCompute
batch_norm
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
ctx
->
As
<
ARMContext
>
();
batch_norm
.
SetContext
(
std
::
move
(
ctx
));
operators
::
BatchNormParam
param
;
param
.
x
=
&
x
;
param
.
scale
=
&
scale
;
param
.
bias
=
&
bias
;
param
.
mean
=
&
mean
;
param
.
variance
=
&
variance
;
param
.
is_test
=
is_test
;
param
.
use_global_stats
=
use_global_stats
;
param
.
epsilon
=
epsilon
;
param
.
momentum
=
momentum
;
param
.
data_layout
=
data_layout
;
param
.
y
=
&
y
;
param
.
mean_out
=
&
mean_out
;
param
.
variance_out
=
&
variance_out
;
param
.
saved_mean
=
&
saved_mean
;
param
.
saved_variance
=
&
saved_variance
;
batch_norm
.
SetParam
(
param
);
batch_norm
.
Launch
();
// invoking ref implementation and compare results
param
.
y
=
&
y_ref
;
param
.
mean_out
=
&
mean_out_ref
;
param
.
variance_out
=
&
variance_out_ref
;
param
.
saved_mean
=
&
saved_mean_ref
;
param
.
saved_variance
=
&
saved_variance_ref
;
batch_norm_compute_ref
<
float
>
(
param
);
auto
*
y_ref_data
=
y_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
y
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
y_data
[
i
],
y_ref_data
[
i
],
1e-5
);
}
}
}
}
}
}
}
}
}
}
}
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
USE_LITE_KERNEL
(
batch_norm
,
kARM
,
kFloat
,
kNCHW
,
def
);
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
d020229c
...
@@ -8,6 +8,7 @@ cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
...
@@ -8,6 +8,7 @@ cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library
(
scale_op_lite SRCS scale_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
scale_op_lite SRCS scale_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
softmax_op_lite SRCS softmax_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
softmax_op_lite SRCS softmax_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
reshape_op_lite SRCS reshape_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
reshape_op_lite SRCS reshape_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
batch_norm_op_lite SRCS batch_norm_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
feed_op_lite SRCS feed_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
feed_op_lite SRCS feed_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
fetch_op_lite SRCS fetch_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
fetch_op_lite SRCS fetch_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS
${
op_DEPS
}
)
...
@@ -30,6 +31,7 @@ set(ops_lite
...
@@ -30,6 +31,7 @@ set(ops_lite
scale_op_lite
scale_op_lite
softmax_op_lite
softmax_op_lite
reshape_op_lite
reshape_op_lite
batch_norm_op_lite
feed_op_lite
feed_op_lite
fetch_op_lite
fetch_op_lite
io_copy_op_lite
io_copy_op_lite
...
@@ -52,4 +54,5 @@ lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc
...
@@ -52,4 +54,5 @@ lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc
lite_cc_test
(
test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite
)
lite_cc_test
(
test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite
)
lite_cc_test
(
test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite
)
lite_cc_test
(
test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite
)
lite_cc_test
(
test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite
)
lite_cc_test
(
test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite
)
lite_cc_test
(
test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite
)
lite_cc_test
(
test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite
)
lite_cc_test
(
test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite
)
paddle/fluid/lite/operators/batch_norm_op.cc
0 → 100644
浏览文件 @
d020229c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
BatchNormOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
bias
);
CHECK_OR_FALSE
(
param_
.
scale
);
CHECK_OR_FALSE
(
param_
.
mean
);
CHECK_OR_FALSE
(
param_
.
variance
);
CHECK_OR_FALSE
(
param_
.
y
);
if
(
!
param_
.
is_test
)
{
CHECK_OR_FALSE
(
param_
.
mean_out
);
CHECK_OR_FALSE
(
param_
.
variance_out
);
CHECK_OR_FALSE
(
param_
.
saved_mean
);
CHECK_OR_FALSE
(
param_
.
saved_variance
);
}
auto
x_dims
=
param_
.
x
->
dims
();
auto
scale_dims
=
param_
.
scale
->
dims
();
auto
bias_dims
=
param_
.
bias
->
dims
();
auto
mean_dims
=
param_
.
mean
->
dims
();
auto
variance_dims
=
param_
.
variance
->
dims
();
CHECK
(
x_dims
.
size
()
>=
2
&&
x_dims
.
size
()
<=
5
)
<<
"Input X must have 2 to 5 dimensions."
;
CHECK_EQ
(
scale_dims
.
size
(),
1UL
)
<<
"Input Scale must have 1 dimensions."
;
CHECK_EQ
(
bias_dims
.
size
(),
1UL
)
<<
"Input Bias must have 1 dimensions."
;
CHECK_EQ
(
mean_dims
.
size
(),
1UL
)
<<
"Input Mean must have 1 dimensions."
;
CHECK_EQ
(
variance_dims
.
size
(),
1UL
)
<<
"Input Variance must have 1 dimensions."
;
return
true
;
}
bool
BatchNormOp
::
InferShape
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
int64_t
channel_size
=
0
;
switch
(
param_
.
data_layout
)
{
case
DATALAYOUT
(
kNCHW
):
channel_size
=
x_dims
[
1
];
break
;
// case DATALAYOUT(kNHWC):
// channel_size = x_dims[x_dims.size() - 1];
// break;
default:
LOG
(
FATAL
)
<<
"Unknown storage order: "
<<
DataLayoutToStr
(
param_
.
data_layout
);
break
;
}
if
(
!
param_
.
is_test
)
{
param_
.
mean_out
->
Resize
({
channel_size
});
param_
.
variance_out
->
Resize
({
channel_size
});
param_
.
saved_mean
->
Resize
({
channel_size
});
param_
.
saved_variance
->
Resize
({
channel_size
});
}
param_
.
y
->
Resize
(
x_dims
);
return
true
;
}
bool
BatchNormOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
param_
.
x
=
scope
->
FindVar
(
op_desc
.
Input
(
"X"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
bias
=
scope
->
FindVar
(
op_desc
.
Input
(
"Bias"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
scale
=
scope
->
FindVar
(
op_desc
.
Input
(
"Scale"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
mean
=
scope
->
FindVar
(
op_desc
.
Input
(
"Mean"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
variance
=
scope
->
FindVar
(
op_desc
.
Input
(
"Variance"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
y
=
scope
->
FindVar
(
op_desc
.
Output
(
"Y"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
is_test
=
op_desc
.
GetAttr
<
bool
>
(
"is_test"
);
param_
.
use_global_stats
=
op_desc
.
GetAttr
<
bool
>
(
"use_global_stats"
);
if
(
!
param_
.
is_test
)
{
param_
.
mean_out
=
scope
->
FindVar
(
op_desc
.
Output
(
"MeanOut"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
variance_out
=
scope
->
FindVar
(
op_desc
.
Output
(
"VarianceOut"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
saved_mean
=
scope
->
FindVar
(
op_desc
.
Output
(
"SavedMean"
).
front
())
->
GetMutable
<
Tensor
>
();
param_
.
saved_variance
=
scope
->
FindVar
(
op_desc
.
Output
(
"SavedVariance"
).
front
())
->
GetMutable
<
Tensor
>
();
}
param_
.
epsilon
=
op_desc
.
GetAttr
<
float
>
(
"epsilon"
);
param_
.
momentum
=
op_desc
.
GetAttr
<
float
>
(
"momentum"
);
std
::
string
data_layout
=
op_desc
.
GetAttr
<
std
::
string
>
(
"data_layout"
);
CHECK_EQ
(
data_layout
,
"NCHW"
)
<<
"TODO(hong19860320): Only support NCHW."
;
// param_.data_layout = StringToDataLayout(data_layout);
return
true
;
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
batch_norm
,
paddle
::
lite
::
operators
::
BatchNormOp
);
paddle/fluid/lite/operators/batch_norm_op.h
0 → 100644
浏览文件 @
d020229c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
BatchNormOp
:
public
OpLite
{
public:
BatchNormOp
()
{}
explicit
BatchNormOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"batch_norm"
;
}
private:
mutable
BatchNormParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/batch_norm_op_test.cc
0 → 100644
浏览文件 @
d020229c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
TEST
(
batch_norm_op_lite
,
test
)
{
// prepare variables
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
auto
*
scale
=
scope
.
Var
(
"scale"
)
->
GetMutable
<
Tensor
>
();
auto
*
bias
=
scope
.
Var
(
"bias"
)
->
GetMutable
<
Tensor
>
();
auto
*
mean
=
scope
.
Var
(
"mean"
)
->
GetMutable
<
Tensor
>
();
auto
*
variance
=
scope
.
Var
(
"variance"
)
->
GetMutable
<
Tensor
>
();
auto
*
y
=
scope
.
Var
(
"y"
)
->
GetMutable
<
Tensor
>
();
x
->
Resize
({
2
,
32
,
10
,
20
});
auto
x_dims
=
x
->
dims
();
const
int64_t
channel_size
=
x_dims
[
1
];
// NCHW
scale
->
Resize
({
channel_size
});
bias
->
Resize
({
channel_size
});
mean
->
Resize
({
channel_size
});
variance
->
Resize
(
DDim
({
channel_size
}));
// prepare op desc
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"batch_norm"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"Scale"
,
{
"scale"
});
desc
.
SetInput
(
"Bias"
,
{
"bias"
});
desc
.
SetInput
(
"Mean"
,
{
"mean"
});
desc
.
SetInput
(
"Variance"
,
{
"variance"
});
desc
.
SetOutput
(
"Y"
,
{
"y"
});
desc
.
SetAttr
(
"is_test"
,
true
);
desc
.
SetAttr
(
"use_global_stats"
,
false
);
desc
.
SetAttr
(
"epsilon"
,
1e-5
f
);
desc
.
SetAttr
(
"momentum"
,
0.9
f
);
desc
.
SetAttr
(
"data_layout"
,
std
::
string
(
"NCHW"
));
BatchNormOp
batch_norm
(
"batch_norm"
);
batch_norm
.
SetValidPlaces
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
batch_norm
.
Attach
(
desc
,
&
scope
);
batch_norm
.
CheckShape
();
batch_norm
.
InferShape
();
// check output dims
auto
y_dims
=
y
->
dims
();
CHECK_EQ
(
y_dims
.
size
(),
x_dims
.
size
());
for
(
size_t
i
=
0
;
i
<
y_dims
.
size
();
i
++
)
{
CHECK_EQ
(
y_dims
[
i
],
x_dims
[
i
]);
}
}
TEST
(
batch_norm_op_lite
,
test_enable_is_test
)
{
// prepare variables
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
auto
*
scale
=
scope
.
Var
(
"scale"
)
->
GetMutable
<
Tensor
>
();
auto
*
bias
=
scope
.
Var
(
"bias"
)
->
GetMutable
<
Tensor
>
();
auto
*
mean
=
scope
.
Var
(
"mean"
)
->
GetMutable
<
Tensor
>
();
auto
*
variance
=
scope
.
Var
(
"variance"
)
->
GetMutable
<
Tensor
>
();
auto
*
y
=
scope
.
Var
(
"y"
)
->
GetMutable
<
Tensor
>
();
auto
*
mean_out
=
scope
.
Var
(
"mean_out"
)
->
GetMutable
<
Tensor
>
();
auto
*
variance_out
=
scope
.
Var
(
"variance_out"
)
->
GetMutable
<
Tensor
>
();
auto
*
saved_mean
=
scope
.
Var
(
"saved_mean"
)
->
GetMutable
<
Tensor
>
();
auto
*
saved_variance
=
scope
.
Var
(
"saved_variance"
)
->
GetMutable
<
Tensor
>
();
x
->
Resize
({
2
,
32
,
10
,
20
});
auto
x_dims
=
x
->
dims
();
const
int64_t
channel_size
=
x_dims
[
1
];
// NCHW
scale
->
Resize
({
channel_size
});
bias
->
Resize
({
channel_size
});
mean
->
Resize
({
channel_size
});
variance
->
Resize
({
channel_size
});
// prepare op desc
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"batch_norm"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"Scale"
,
{
"scale"
});
desc
.
SetInput
(
"Bias"
,
{
"bias"
});
desc
.
SetInput
(
"Mean"
,
{
"mean"
});
desc
.
SetInput
(
"Variance"
,
{
"variance"
});
desc
.
SetOutput
(
"Y"
,
{
"y"
});
desc
.
SetOutput
(
"MeanOut"
,
{
"mean_out"
});
desc
.
SetOutput
(
"VarianceOut"
,
{
"variance_out"
});
desc
.
SetOutput
(
"SavedMean"
,
{
"saved_mean"
});
desc
.
SetOutput
(
"SavedVariance"
,
{
"saved_variance"
});
desc
.
SetAttr
(
"is_test"
,
false
);
desc
.
SetAttr
(
"use_global_stats"
,
false
);
desc
.
SetAttr
(
"epsilon"
,
1e-5
f
);
desc
.
SetAttr
(
"momentum"
,
0.9
f
);
desc
.
SetAttr
(
"data_layout"
,
std
::
string
(
"NCHW"
));
BatchNormOp
batch_norm
(
"batch_norm"
);
batch_norm
.
SetValidPlaces
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
batch_norm
.
Attach
(
desc
,
&
scope
);
batch_norm
.
CheckShape
();
batch_norm
.
InferShape
();
// check output dims
auto
y_dims
=
y
->
dims
();
CHECK_EQ
(
y_dims
.
size
(),
x_dims
.
size
());
for
(
size_t
i
=
0
;
i
<
y_dims
.
size
();
i
++
)
{
CHECK_EQ
(
y_dims
[
i
],
x_dims
[
i
]);
}
auto
mean_out_dims
=
mean_out
->
dims
();
auto
variance_out_dims
=
variance_out
->
dims
();
auto
saved_mean_dims
=
saved_mean
->
dims
();
auto
saved_variance_dims
=
saved_variance
->
dims
();
CHECK_EQ
(
mean_out_dims
.
size
(),
1UL
);
CHECK_EQ
(
variance_out_dims
.
size
(),
1UL
);
CHECK_EQ
(
saved_mean_dims
.
size
(),
1UL
);
CHECK_EQ
(
saved_variance_dims
.
size
(),
1UL
);
CHECK_EQ
(
mean_out_dims
[
0
],
channel_size
);
CHECK_EQ
(
variance_out_dims
[
0
],
channel_size
);
CHECK_EQ
(
saved_mean_dims
[
0
],
channel_size
);
CHECK_EQ
(
saved_variance_dims
[
0
],
channel_size
);
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/op_params.h
浏览文件 @
d020229c
...
@@ -146,6 +146,25 @@ struct ConvParam {
...
@@ -146,6 +146,25 @@ struct ConvParam {
std
::
string
data_format
{
"Anylayout"
};
std
::
string
data_format
{
"Anylayout"
};
};
};
// For BatchNorm op
struct
BatchNormParam
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
bias
{};
lite
::
Tensor
*
scale
{};
lite
::
Tensor
*
mean
{};
lite
::
Tensor
*
variance
{};
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
mean_out
{};
lite
::
Tensor
*
variance_out
{};
lite
::
Tensor
*
saved_mean
{};
lite
::
Tensor
*
saved_variance
{};
bool
is_test
{
true
};
bool
use_global_stats
{
false
};
float
epsilon
;
float
momentum
;
DataLayoutType
data_layout
{
DATALAYOUT
(
kNCHW
)};
};
// For Pooling op
// For Pooling op
struct
PoolParam
{
struct
PoolParam
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录