Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
11a2a2a1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
11a2a2a1
编写于
6月 15, 2019
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'hongming/arm-fix' of
http://10.87.145.36/inference/paddlelite
into xzl/incubate/lite
上级
ec133f19
c07b215d
变更
20
显示空白变更内容
内联
并排
Showing
20 changed file
with
273 addition
and
85 deletion
+273
-85
.gitlab-ci.yml
.gitlab-ci.yml
+14
-4
CMakeLists.txt
CMakeLists.txt
+1
-0
paddle/fluid/lite/api/cxx_api_bin.cc
paddle/fluid/lite/api/cxx_api_bin.cc
+27
-13
paddle/fluid/lite/arm/math/elementwise.cc
paddle/fluid/lite/arm/math/elementwise.cc
+71
-8
paddle/fluid/lite/arm/math/elementwise.h
paddle/fluid/lite/arm/math/elementwise.h
+4
-0
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+2
-0
paddle/fluid/lite/core/naive_test_model.py
paddle/fluid/lite/core/naive_test_model.py
+6
-6
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+1
-1
paddle/fluid/lite/kernels/arm/conv_compute.cc
paddle/fluid/lite/kernels/arm/conv_compute.cc
+4
-4
paddle/fluid/lite/kernels/arm/conv_compute_test.cc
paddle/fluid/lite/kernels/arm/conv_compute_test.cc
+4
-3
paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc
paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc
+25
-2
paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc
...le/fluid/lite/kernels/arm/elementwise_add_compute_test.cc
+82
-25
paddle/fluid/lite/kernels/arm/pool_compute.cc
paddle/fluid/lite/kernels/arm/pool_compute.cc
+1
-1
paddle/fluid/lite/kernels/arm/pool_compute_test.cc
paddle/fluid/lite/kernels/arm/pool_compute_test.cc
+1
-1
paddle/fluid/lite/kernels/arm/relu_compute.h
paddle/fluid/lite/kernels/arm/relu_compute.h
+2
-0
paddle/fluid/lite/operators/batch_norm_op_test.cc
paddle/fluid/lite/operators/batch_norm_op_test.cc
+2
-2
paddle/fluid/lite/operators/conv_op.h
paddle/fluid/lite/operators/conv_op.h
+9
-6
paddle/fluid/lite/operators/pool_op.h
paddle/fluid/lite/operators/pool_op.h
+13
-5
paddle/fluid/lite/operators/pool_op_test.cc
paddle/fluid/lite/operators/pool_op_test.cc
+3
-3
paddle/fluid/lite/operators/split_op.cc
paddle/fluid/lite/operators/split_op.cc
+1
-1
未找到文件。
.gitlab-ci.yml
浏览文件 @
11a2a2a1
...
...
@@ -9,6 +9,8 @@ stages:
-
build_mobile
check:prebuilt:
tags
:
-
lite
stage
:
ci
script
:
#- pip3 install pre-commit
...
...
@@ -24,17 +26,21 @@ check:prebuilt:
-
/root/.cache
build:server:
tags
:
-
lite
image
:
$SERVER_LITE_DOCKER_IMAGE
stage
:
build_server
cache
:
key
:
server_thirdparty
paths
:
-
build/third_party
-
/root/.ccache
script
:
#- export http_proxy=http://172.19.57.45:3128
#- export https_proxy=http://172.19.57.45:3128
-
export http_proxy=http://agent.baidu.com:8118
-
export https_proxy=http://agent.baidu.com:8118
-
apt install ccache
-
export http_proxy=http://172.19.57.45:3128
-
export https_proxy=http://172.19.57.45:3128
#- export http_proxy=http://agent.baidu.com:8118
#- export https_proxy=http://agent.baidu.com:8118
-
mkdir -p build
-
cd build
-
../paddle/fluid/lite/tools/build.sh cmake_x86
...
...
@@ -49,6 +55,8 @@ build:server:
-
check:prebuilt
build:mobile:
tags
:
-
lite
stage
:
build_mobile
image
:
$MOBILE_LITE_DOCKER_IMAGE
cache
:
...
...
@@ -56,7 +64,9 @@ build:mobile:
paths
:
-
$MOBILE_LITE_CACHE0
-
$MOBILE_LITE_CACHE1
-
/root/.ccache
script
:
-
apt install ccache
-
export http_proxy=http://172.19.57.45:3128
-
export https_proxy=http://172.19.57.45:3128
-
./paddle/fluid/lite/tools/build.sh build_test_arm
...
...
CMakeLists.txt
浏览文件 @
11a2a2a1
...
...
@@ -166,6 +166,7 @@ if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
#include(external/zlib) # download, build, install gtest
include
(
external/protobuf
)
# download, build, install protobuf
include
(
external/eigen
)
# download eigen3
include
(
ccache
)
# set ccache for compilation
include
(
generic
)
# simplify cmake module
include
(
configure
)
# add paddle env configuration
...
...
paddle/fluid/lite/api/cxx_api_bin.cc
浏览文件 @
11a2a2a1
...
...
@@ -13,17 +13,25 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include <chrono>
#include "paddle/fluid/lite/core/mir/passes.h"
#endif
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
void
Run
(
const
char
*
model_dir
)
{
using
Time
=
decltype
(
std
::
chrono
::
high_resolution_clock
::
now
());
Time
time
()
{
return
std
::
chrono
::
high_resolution_clock
::
now
();
}
double
time_diff
(
Time
t1
,
Time
t2
)
{
typedef
std
::
chrono
::
microseconds
ms
;
auto
diff
=
t2
-
t1
;
ms
counter
=
std
::
chrono
::
duration_cast
<
ms
>
(
diff
);
return
counter
.
count
()
/
1000.0
;
}
void
Run
(
const
char
*
model_dir
,
int
repeat
)
{
#ifdef LITE_WITH_ARM
DeviceInfo
::
Init
();
#endif
lite
::
ExecutorLite
predictor
;
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)}});
...
...
@@ -32,13 +40,19 @@ void Run(const char* model_dir) {
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
input_tensor
->
Resize
(
DDim
(
std
::
vector
<
DDim
::
value_type
>
({
3
,
224
,
224
})));
input_tensor
->
Resize
(
DDim
(
std
::
vector
<
DDim
::
value_type
>
({
1
,
3
,
224
,
224
})));
auto
*
data
=
input_tensor
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
3
*
224
*
224
;
i
++
)
{
data
[
i
]
=
i
;
for
(
int
i
=
0
;
i
<
input_tensor
->
dims
().
production
()
;
i
++
)
{
data
[
i
]
=
1
;
}
predictor
.
Run
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
predictor
.
Run
();
auto
time1
=
time
();
for
(
int
i
=
0
;
i
<
repeat
;
i
++
)
predictor
.
Run
();
auto
time2
=
time
();
std
::
cout
<<
" predict cost: "
<<
time_diff
(
time1
,
time2
)
/
repeat
<<
"ms"
<<
std
::
endl
;
auto
*
out
=
predictor
.
GetOutput
(
0
);
LOG
(
INFO
)
<<
out
<<
" memory size "
<<
out
->
data_size
();
...
...
@@ -53,7 +67,7 @@ void Run(const char* model_dir) {
int
main
(
int
argc
,
char
**
argv
)
{
CHECK_EQ
(
argc
,
2
)
<<
"usage: ./cmd <model_dir>"
;
paddle
::
lite
::
Run
(
argv
[
1
]);
paddle
::
lite
::
Run
(
argv
[
1
]
,
1
);
return
0
;
}
...
...
@@ -66,7 +80,7 @@ USE_LITE_OP(fetch);
USE_LITE_OP
(
io_copy
);
USE_LITE_OP
(
conv2d
);
//
USE_LITE_OP(batch_norm);
USE_LITE_OP
(
batch_norm
);
USE_LITE_OP
(
relu
);
USE_LITE_OP
(
depthwise_conv2d
);
USE_LITE_OP
(
pool2d
);
...
...
@@ -85,7 +99,7 @@ USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL
(
batch_norm
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
relu
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
depthwise_conv2d
,
kARM
,
kFloat
,
kNCHW
,
def
);
//
USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL
(
pool2d
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_add
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
softmax
,
kARM
,
kFloat
,
kNCHW
,
def
);
...
...
paddle/fluid/lite/arm/math/elementwise.cc
浏览文件 @
11a2a2a1
...
...
@@ -41,15 +41,15 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
float32x4_t
diny2
=
vld1q_f32
(
diny_ptr
+
8
);
float32x4_t
diny3
=
vld1q_f32
(
diny_ptr
+
12
);
float32x4_t
vsum
0
=
vaddq_f32
(
dinx0
,
diny0
);
float32x4_t
vsum
1
=
vaddq_f32
(
dinx1
,
diny1
);
float32x4_t
vsum
2
=
vaddq_f32
(
dinx2
,
diny2
);
float32x4_t
vsum
3
=
vaddq_f32
(
dinx3
,
diny3
);
dinx
0
=
vaddq_f32
(
dinx0
,
diny0
);
dinx
1
=
vaddq_f32
(
dinx1
,
diny1
);
dinx
2
=
vaddq_f32
(
dinx2
,
diny2
);
dinx
3
=
vaddq_f32
(
dinx3
,
diny3
);
vst1q_f32
(
dout_ptr
,
vsum
0
);
vst1q_f32
(
dout_ptr
+
4
,
vsum
1
);
vst1q_f32
(
dout_ptr
+
8
,
vsum
2
);
vst1q_f32
(
dout_ptr
+
12
,
vsum
3
);
vst1q_f32
(
dout_ptr
,
dinx
0
);
vst1q_f32
(
dout_ptr
+
4
,
dinx
1
);
vst1q_f32
(
dout_ptr
+
8
,
dinx
2
);
vst1q_f32
(
dout_ptr
+
12
,
dinx
3
);
}
if
(
remain
>
0
)
{
const
float
*
dinx_ptr
=
dinx
+
(
cnt
<<
4
);
...
...
@@ -64,6 +64,69 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
}
}
template
<
>
void
elementwise_add_axis
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
#pragma omp parallel for collapse(2)
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
float
*
din_ptr
=
dinx
+
offset
;
const
float
diny_data
=
diny
[
j
];
float
*
dout_ptr
=
dout
+
offset
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
rb
=
vdupq_n_f32
(
diny_data
);
for
(
int
k
=
0
;
k
<
cnt
;
++
k
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
din2
=
vaddq_f32
(
din2
,
rb
);
din3
=
vaddq_f32
(
din3
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
din_ptr
+=
16
;
dout_ptr
+=
16
;
}
if
(
remain
>=
8
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
din_ptr
+=
8
;
dout_ptr
+=
8
;
remain
-=
8
;
}
if
(
remain
>=
4
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
din0
=
vaddq_f32
(
din0
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
din_ptr
+=
4
;
dout_ptr
+=
4
;
remain
-=
4
;
}
if
(
remain
>
0
)
{
for
(
int
p
=
0
;
p
<
remain
;
p
++
)
{
*
dout_ptr
=
*
din_ptr
+
diny_data
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
paddle/fluid/lite/arm/math/elementwise.h
浏览文件 @
11a2a2a1
...
...
@@ -22,6 +22,10 @@ namespace math {
template
<
typename
T
>
void
elementwise_add
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
elementwise_add_axis
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
11a2a2a1
...
...
@@ -21,6 +21,7 @@ namespace mir {} // namespace mir
}
// namespace lite
}
// namespace paddle
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
);
...
...
@@ -30,6 +31,7 @@ USE_MIR_PASS(type_target_transform_pass);
USE_MIR_PASS
(
generate_program_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
#endif
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
paddle/fluid/lite/core/naive_test_model.py
浏览文件 @
11a2a2a1
...
...
@@ -18,10 +18,10 @@ import numpy as np
import
paddle.fluid
as
fluid
from
paddle.fluid.backward
import
append_backward
a
=
fluid
.
layers
.
data
(
name
=
"a"
,
shape
=
[
100
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
10
0
],
dtype
=
'float32'
)
a
=
fluid
.
layers
.
data
(
name
=
"a"
,
shape
=
[
2
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
10
],
dtype
=
'float32'
)
a1
=
fluid
.
layers
.
fc
(
input
=
a
,
size
=
500
,
act
=
None
,
bias_attr
=
False
)
a1
=
fluid
.
layers
.
fc
(
input
=
a
,
size
=
3
,
act
=
None
,
bias_attr
=
False
)
cost
=
fluid
.
layers
.
square_error_cost
(
a1
,
label
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
...
...
@@ -36,7 +36,7 @@ exe.run(fluid.default_startup_program())
with
open
(
'startup_program.pb'
,
'wb'
)
as
f
:
f
.
write
(
fluid
.
default_startup_program
().
desc
.
serialize_to_string
())
data_1
=
np
.
array
(
numpy
.
random
.
random
([
100
,
100
]),
dtype
=
'float32'
)
#
data_1 = np.array(numpy.random.random([100, 100]), dtype='float32')
#fluid.default_main_program().desc.
...
...
@@ -50,7 +50,7 @@ with open('main_program.pb', 'wb') as f:
#outs = exe.run(program=prog, feed={'a':data_1, }, fetch_list=[cost])
sys
.
exit
(
0
)
#
sys.exit(0)
fluid
.
io
.
save_inference_model
(
"./model2"
,
[
a
.
name
],
[
a1
],
exe
)
print
(
numpy
.
array
(
outs
))
#
print(numpy.array(outs))
paddle/fluid/lite/core/optimizer.h
浏览文件 @
11a2a2a1
...
...
@@ -51,8 +51,8 @@ class Optimizer {
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_act_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
"static_kernel_pick_pass"
,
//
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"type_target_transform_pass"
,
//
...
...
paddle/fluid/lite/kernels/arm/conv_compute.cc
浏览文件 @
11a2a2a1
...
...
@@ -100,15 +100,15 @@ void ConvCompute::Run() {
REGISTER_LITE_KERNEL
(
conv2d
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ConvCompute
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
//
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out
put
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
depthwise_conv2d
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ConvCompute
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
//
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out
put
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
paddle/fluid/lite/kernels/arm/conv_compute_test.cc
浏览文件 @
11a2a2a1
...
...
@@ -45,7 +45,7 @@ void conv_compute_ref(const operators::ConvParam& param) {
bias_data
=
param
.
bias
->
mutable_data
<
float
>
();
}
bool
flag_bias
=
bias_data
!=
nullptr
;
bool
flag_relu
=
false
;
// TODO(hong19860320) param.relu
bool
flag_relu
=
param
.
fuse_relu
;
int
num
=
input_dims
[
0
];
int
chout
=
output_dims
[
1
];
...
...
@@ -183,7 +183,8 @@ TEST(conv_arm, compute) {
auto
*
filter_data
=
filter
.
mutable_data
<
float
>
();
auto
*
output_data
=
output
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
input
.
dims
().
production
();
i
++
)
{
input_data
[
i
]
=
static_cast
<
float
>
(
i
%
128
);
float
sign
=
i
%
3
==
0
?
-
1.0
f
:
1.0
f
;
input_data
[
i
]
=
sign
*
static_cast
<
float
>
(
i
%
128
);
}
for
(
int
i
=
0
;
i
<
filter
.
dims
().
production
();
i
++
)
{
filter_data
[
i
]
=
...
...
@@ -208,7 +209,7 @@ TEST(conv_arm, compute) {
}
param
.
bias
=
&
bias
;
}
// TODO(hong19860320) param.
relu = flag_relu;
param
.
fuse_
relu
=
flag_relu
;
param
.
paddings
=
std
::
vector
<
int
>
({
padding
,
padding
});
param
.
strides
=
std
::
vector
<
int
>
({
stride
,
stride
});
param
.
dilations
=
...
...
paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc
浏览文件 @
11a2a2a1
...
...
@@ -25,8 +25,31 @@ void ElementwiseAddCompute::Run() {
const
float
*
x_data
=
param
.
X
->
data
<
float
>
();
const
float
*
y_data
=
param
.
Y
->
data
<
float
>
();
float
*
out_data
=
param
.
Out
->
mutable_data
<
float
>
();
int
n
=
param
.
X
->
dims
().
production
();
lite
::
arm
::
math
::
elementwise_add
(
x_data
,
y_data
,
out_data
,
n
);
int
axis
=
param
.
axis
;
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
if
(
axis
<
0
)
{
axis
=
x_dims
.
size
()
-
y_dims
.
size
();
}
if
(
x_dims
.
size
()
==
y_dims
.
size
())
{
lite
::
arm
::
math
::
elementwise_add
(
x_data
,
y_data
,
out_data
,
x_dims
.
production
());
}
else
{
int
batch
=
1
;
int
channels
=
1
;
int
num
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
batch
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
channels
*=
y_dims
[
i
];
}
for
(
int
i
=
y_dims
.
size
()
+
axis
;
i
<
x_dims
.
size
();
++
i
)
{
num
*=
x_dims
[
i
];
}
lite
::
arm
::
math
::
elementwise_add_axis
(
x_data
,
y_data
,
out_data
,
batch
,
channels
,
num
);
}
}
}
// namespace arm
...
...
paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc
浏览文件 @
11a2a2a1
...
...
@@ -41,40 +41,97 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
const
dtype
*
x_data
=
param
.
X
->
data
<
const
dtype
>
();
const
dtype
*
y_data
=
param
.
Y
->
data
<
const
dtype
>
();
dtype
*
out_data
=
param
.
Out
->
mutable_data
<
dtype
>
();
DDim
dim
=
param
.
X
->
dims
();
ASSERT_EQ
(
dim
.
data
(),
param
.
Out
->
dims
().
data
());
for
(
int
i
=
0
;
i
<
dim
.
production
();
i
++
)
{
out_data
[
i
]
=
x_data
[
i
]
+
y_data
[
i
];
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
int
axis
=
param
.
axis
;
if
(
axis
<
0
)
{
axis
=
x_dims
.
size
()
-
y_dims
.
size
();
}
int
batch
=
1
;
int
channels
=
1
;
int
num
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
batch
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
channels
*=
y_dims
[
i
];
}
for
(
int
i
=
y_dims
.
size
()
+
axis
;
i
<
x_dims
.
size
();
++
i
)
{
num
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
dtype
*
din_ptr
=
x_data
+
offset
;
const
dtype
diny_data
=
y_data
[
j
];
dtype
*
dout_ptr
=
out_data
+
offset
;
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
*
dout_ptr
=
*
din_ptr
+
diny_data
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
TEST
(
elementwise_add
,
compute
)
{
ElementwiseAddCompute
elementwise_add
;
operators
::
ElementwiseParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
h
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
w
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
2
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
c
,
h
}),
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
if
(
axis_t
+
y_dim
.
size
()
>
4
)
continue
;
bool
flag
=
false
;
for
(
int
i
=
0
;
i
<
y_dim
.
size
();
i
++
)
{
if
(
x_dim
[
i
+
axis_t
]
!=
y_dim
[
i
])
flag
=
true
;
}
if
(
flag
)
continue
;
lite
::
Tensor
x
,
y
,
out
,
out_ref
;
x
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
2
,
3
,
4
,
5
})));
y
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
2
,
3
,
4
,
5
})));
out
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
2
,
3
,
4
,
5
})));
out_ref
.
Resize
(
DDim
(
std
::
vector
<
int64_t
>
({
2
,
3
,
4
,
5
})));
x
.
Resize
(
x_dim
);
y
.
Resize
(
y_dim
);
output
.
Resize
(
x_dim
);
output_ref
.
Resize
(
x_dim
);
auto
*
x_data
=
x
.
mutable_data
<
float
>
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
();
auto
*
out_data
=
out
.
mutable_data
<
float
>
();
auto
*
out_ref_data
=
out_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x
.
dims
().
production
();
i
++
)
{
x_data
[
i
]
=
y_data
[
i
]
=
i
;
auto
*
output_data
=
output
.
mutable_data
<
float
>
();
auto
*
output_ref_data
=
output_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dim
.
production
();
i
++
)
{
x_data
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
y_dim
.
production
();
i
++
)
{
y_data
[
i
]
=
i
;
}
param
.
X
=
&
x
;
param
.
Y
=
&
y
;
param
.
Out
=
&
out
;
param
.
axis
=
axis
;
param
.
Out
=
&
output
;
elementwise_add
.
SetParam
(
param
);
elementwise_add
.
Run
();
param
.
Out
=
&
out_ref
;
param
.
Out
=
&
output_ref
;
elementwise_add_compute_ref
<
float
>
(
param
);
for
(
int
i
=
0
;
i
<
out
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
out_data
[
i
],
out_ref_data
[
i
],
1e-5
);
for
(
int
i
=
0
;
i
<
output
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
output_data
[
i
],
output_ref_data
[
i
],
1e-5
);
}
}
}
}
}
}
}
}
...
...
paddle/fluid/lite/kernels/arm/pool_compute.cc
浏览文件 @
11a2a2a1
...
...
@@ -163,7 +163,7 @@ PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); }
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
pool
,
kARM
,
kFloat
,
kNCHW
,
REGISTER_LITE_KERNEL
(
pool
2d
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
PoolCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
...
...
paddle/fluid/lite/kernels/arm/pool_compute_test.cc
浏览文件 @
11a2a2a1
...
...
@@ -272,4 +272,4 @@ TEST(pool, retrive_op) {
}
// namespace lite
}
// namespace paddle
USE_LITE_KERNEL
(
pool
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
pool
2d
,
kARM
,
kFloat
,
kNCHW
,
def
);
paddle/fluid/lite/kernels/arm/relu_compute.h
浏览文件 @
11a2a2a1
...
...
@@ -45,4 +45,6 @@ class ReluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
relu
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ReluCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
paddle/fluid/lite/operators/batch_norm_op_test.cc
浏览文件 @
11a2a2a1
...
...
@@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) {
desc
.
SetInput
(
"Mean"
,
{
"mean"
});
desc
.
SetInput
(
"Variance"
,
{
"variance"
});
desc
.
SetOutput
(
"Y"
,
{
"y"
});
desc
.
SetAttr
(
"is_test"
,
true
);
desc
.
SetAttr
(
"is_test"
,
static_cast
<
int
>
(
1
)
);
desc
.
SetAttr
(
"use_global_stats"
,
false
);
desc
.
SetAttr
(
"epsilon"
,
1e-5
f
);
desc
.
SetAttr
(
"momentum"
,
0.9
f
);
...
...
@@ -101,7 +101,7 @@ TEST(batch_norm_op_lite, test_enable_is_test) {
desc
.
SetOutput
(
"VarianceOut"
,
{
"variance_out"
});
desc
.
SetOutput
(
"SavedMean"
,
{
"saved_mean"
});
desc
.
SetOutput
(
"SavedVariance"
,
{
"saved_variance"
});
desc
.
SetAttr
(
"is_test"
,
false
);
desc
.
SetAttr
(
"is_test"
,
static_cast
<
int
>
(
0
)
);
desc
.
SetAttr
(
"use_global_stats"
,
false
);
desc
.
SetAttr
(
"epsilon"
,
1e-5
f
);
desc
.
SetAttr
(
"momentum"
,
0.9
f
);
...
...
paddle/fluid/lite/operators/conv_op.h
浏览文件 @
11a2a2a1
...
...
@@ -56,23 +56,26 @@ class ConvOpLite : public OpLite {
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"Bias"
)
!=
input_arg_names
.
end
())
{
auto
bias_arguments
=
op_desc
.
Input
(
"Bias"
);
if
(
bias_arguments
.
size
()
!=
0
)
{
if
(
bias_arguments
.
size
()
>
0
)
{
auto
bias_var
=
scope
->
FindVar
(
bias_arguments
.
front
());
if
(
bias_var
!=
nullptr
)
{
param_
.
bias
=
bias_var
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
bias
=
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
}
}
}
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"ResidualData"
)
!=
input_arg_names
.
end
())
{
auto
res_
argument
=
op_desc
.
Input
(
"ResidualData"
);
if
(
res_
argument
.
size
()
!=
0
)
{
auto
residual_data_var
=
scope
->
FindVar
(
res_
argument
.
front
());
auto
res_
data_arguments
=
op_desc
.
Input
(
"ResidualData"
);
if
(
res_
data_arguments
.
size
()
>
0
)
{
auto
residual_data_var
=
scope
->
FindVar
(
res_
data_arguments
.
front
());
if
(
residual_data_var
!=
nullptr
)
{
param_
.
residualData
=
residual_data_var
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
residualData
=
const_cast
<
lite
::
Tensor
*>
(
&
(
residual_data_var
->
Get
<
lite
::
Tensor
>
()));
}
}
}
param_
.
fuse_relu
=
op_desc
.
GetAttr
<
bool
>
(
"fuse_relu"
);
return
true
;
}
...
...
paddle/fluid/lite/operators/pool_op.h
浏览文件 @
11a2a2a1
...
...
@@ -53,17 +53,25 @@ class PoolOpLite : public OpLite {
param_
.
strides
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
);
param_
.
paddings
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
);
if
(
op_desc
.
HasAttr
(
"exclusive"
))
{
param_
.
exclusive
=
op_desc
.
GetAttr
<
bool
>
(
"exclusive"
);
}
if
(
op_desc
.
HasAttr
(
"adaptive"
))
{
param_
.
adaptive
=
op_desc
.
GetAttr
<
bool
>
(
"adaptive"
);
}
if
(
op_desc
.
HasAttr
(
"ceil_mode"
))
{
param_
.
ceil_mode
=
op_desc
.
GetAttr
<
bool
>
(
"ceil_mode"
);
}
if
(
op_desc
.
HasAttr
(
"use_quantizer"
))
{
param_
.
use_quantizer
=
op_desc
.
GetAttr
<
bool
>
(
"use_quantizer"
);
}
// param_.data_format = op_desc.GetAttr<bool>("data_format");
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"pool"
;
}
std
::
string
DebugString
()
const
override
{
return
"pool
2d
"
;
}
private:
mutable
PoolParam
param_
;
...
...
paddle/fluid/lite/operators/pool_op_test.cc
浏览文件 @
11a2a2a1
...
...
@@ -38,7 +38,7 @@ TEST(pool_op_lite, test) {
// prepare op desc
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"pool"
);
desc
.
SetType
(
"pool
2d
"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetOutput
(
"Out"
,
{
"output"
});
...
...
@@ -69,7 +69,7 @@ TEST(pool_op_lite, test) {
bool
use_quantizer
{
false
};
desc
.
SetAttr
(
"use_quantizer"
,
use_quantizer
);
PoolOpLite
pool
(
"pool"
);
PoolOpLite
pool
(
"pool
2d
"
);
pool
.
SetValidPlaces
({
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)}});
pool
.
Attach
(
desc
,
&
scope
);
auto
kernels
=
pool
.
CreateKernels
({
Place
{
TARGET
(
kARM
),
PRECISION
(
kFloat
)}});
...
...
@@ -86,5 +86,5 @@ TEST(pool_op_lite, test) {
}
// namespace paddle
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL
(
pool
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
pool
2d
,
kARM
,
kFloat
,
kNCHW
,
def
);
#endif
paddle/fluid/lite/operators/split_op.cc
浏览文件 @
11a2a2a1
...
...
@@ -37,7 +37,7 @@ bool SplitOp::InferShape() const {
const
auto
&
sections
=
param_
.
sections
;
const
int
outs_number
=
outs
.
size
();
std
::
vector
<
lite
::
DDim
Hvy
>
outs_dims
;
std
::
vector
<
lite
::
DDim
>
outs_dims
;
outs_dims
.
reserve
(
outs_number
);
if
(
num
>
0
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录