Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
ae3ebea5
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae3ebea5
编写于
4月 13, 2020
作者:
C
cc
提交者:
GitHub
4月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix gather and concat, add abs op, test=develop (#3395)
上级
4a7284f9
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
171 addition
and
107 deletion
+171
-107
lite/api/benchmark.cc
lite/api/benchmark.cc
+20
-15
lite/api/cxx_api.cc
lite/api/cxx_api.cc
+4
-0
lite/backends/arm/math/activation.cc
lite/backends/arm/math/activation.cc
+9
-0
lite/backends/arm/math/activation.h
lite/backends/arm/math/activation.h
+3
-0
lite/backends/arm/math/concat.cc
lite/backends/arm/math/concat.cc
+0
-43
lite/backends/arm/math/concat.h
lite/backends/arm/math/concat.h
+32
-2
lite/kernels/arm/activation_compute.cc
lite/kernels/arm/activation_compute.cc
+15
-0
lite/kernels/arm/activation_compute.h
lite/kernels/arm/activation_compute.h
+9
-0
lite/kernels/arm/concat_compute.cc
lite/kernels/arm/concat_compute.cc
+41
-32
lite/kernels/arm/concat_compute.h
lite/kernels/arm/concat_compute.h
+1
-1
lite/kernels/arm/concat_compute_test.cc
lite/kernels/arm/concat_compute_test.cc
+3
-4
lite/kernels/arm/gather_compute.cc
lite/kernels/arm/gather_compute.cc
+34
-10
未找到文件。
lite/api/benchmark.cc
浏览文件 @
ae3ebea5
...
@@ -187,6 +187,10 @@ void Run(const std::vector<int64_t>& input_shape,
...
@@ -187,6 +187,10 @@ void Run(const std::vector<int64_t>& input_shape,
}
}
LOG
(
INFO
)
<<
"max_value:"
<<
max_value
;
LOG
(
INFO
)
<<
"max_value:"
<<
max_value
;
LOG
(
INFO
)
<<
"max_index:"
<<
max_index
;
LOG
(
INFO
)
<<
"max_index:"
<<
max_index
;
LOG
(
INFO
)
<<
"output data[0:10]:"
;
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
LOG
(
INFO
)
<<
out_data
[
i
];
}
}
}
}
}
#endif
#endif
...
@@ -198,32 +202,33 @@ void print_usage() {
...
@@ -198,32 +202,33 @@ void print_usage() {
std
::
string
help_info
=
std
::
string
help_info
=
"Usage:
\n
"
"Usage:
\n
"
"./benchmark_bin
\n
"
"./benchmark_bin
\n
"
" --optimized_model_path (the path of the model that is optimized
\n
"
" --optimized_model_path (The path of the model that is optimized
\n
"
" by opt.) type: string
\n
"
" by opt. If the model is optimized, please set the param.)
\n
"
" --model_dir (the path of the model that is not optimized by opt,
\n
"
" type: string
\n
"
" --model_dir (The path of the model that is not optimized by opt,
\n
"
" the model and param files is under model_dir.) type: string
\n
"
" the model and param files is under model_dir.) type: string
\n
"
" --model_filename (
t
he filename of model file. When the model is
\n
"
" --model_filename (
T
he filename of model file. When the model is
\n
"
" combined formate, please set model_file. Otherwise, it is not
\n
"
" combined formate, please set model_file. Otherwise, it is not
\n
"
" necessary to set it.) type: string
\n
"
" necessary to set it.) type: string
\n
"
" --param_filename (
t
he filename of param file, set param_file when
\n
"
" --param_filename (
T
he filename of param file, set param_file when
\n
"
" the model is combined formate. Otherwise, it is not necessary
\n
"
" the model is combined formate. Otherwise, it is not necessary
\n
"
" to set it.) type: string
\n
"
" to set it.) type: string
\n
"
" --input_shape (
s
et input shapes according to the model, separated by
\n
"
" --input_shape (
T
et input shapes according to the model, separated by
\n
"
" colon and comma, such as 1,3,244,244) type: string
\n
"
" colon and comma, such as 1,3,244,244) type: string
\n
"
" default: 1,3,224,224
\n
"
" default: 1,3,224,224
\n
"
" --input_img_path (
t
he path of input image, if not set
\n
"
" --input_img_path (
T
he path of input image, if not set
\n
"
" input_img_path, the input will be 1.0.) type: string
\n
"
" input_img_path, the input will be 1.0.) type: string
\n
"
" --power_mode (
a
rm power mode: 0 for big cluster, 1 for little
\n
"
" --power_mode (
A
rm power mode: 0 for big cluster, 1 for little
\n
"
" cluster, 2 for all cores, 3 for no bind) type: int32 default: 3
\n
"
" cluster, 2 for all cores, 3 for no bind) type: int32 default: 3
\n
"
" --repeats (
r
epeats times) type: int32 default: 1
\n
"
" --repeats (
R
epeats times) type: int32 default: 1
\n
"
" --result_filename (
s
ave the inference time to the file.) type:
\n
"
" --result_filename (
S
ave the inference time to the file.) type:
\n
"
" string default: result.txt
\n
"
" string default: result.txt
\n
"
" --threads (
t
hreads num) type: int32 default: 1
\n
"
" --threads (
T
hreads num) type: int32 default: 1
\n
"
" --warmup (
w
armup times) type: int32 default: 0
\n
"
" --warmup (
W
armup times) type: int32 default: 0
\n
"
"Note that:
\n
"
"Note that:
\n
"
" If load the optimized model, set optimized_model_path
, or set
\n
"
" If load the optimized model, set optimized_model_path
. Otherwise,
\n
"
"
model_dir, model_filename and param_filename according to the
\n
"
"
set model_dir, model_filename and param_filename according to
\n
"
" model.
\n
"
;
"
the
model.
\n
"
;
LOG
(
INFO
)
<<
help_info
;
LOG
(
INFO
)
<<
help_info
;
}
}
...
...
lite/api/cxx_api.cc
浏览文件 @
ae3ebea5
...
@@ -295,6 +295,10 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
...
@@ -295,6 +295,10 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places
.
emplace_back
(
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
));
inner_places
.
emplace_back
(
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
));
inner_places
.
emplace_back
(
inner_places
.
emplace_back
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
inner_places
.
emplace_back
(
TARGET
(
kHost
),
PRECISION
(
kInt32
),
DATALAYOUT
(
kNCHW
));
inner_places
.
emplace_back
(
TARGET
(
kHost
),
PRECISION
(
kInt64
),
DATALAYOUT
(
kNCHW
));
// Analysis whether the modle is quantized.
// Analysis whether the modle is quantized.
// For quantized model, add place(arm, int8) to inner_places
// For quantized model, add place(arm, int8) to inner_places
...
...
lite/backends/arm/math/activation.cc
浏览文件 @
ae3ebea5
...
@@ -744,6 +744,15 @@ void act_reciprocal<float>(const float* din,
...
@@ -744,6 +744,15 @@ void act_reciprocal<float>(const float* din,
}
}
}
}
template
<
>
void
act_abs
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
int
threads
)
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
dout
[
0
]
=
(
din
[
0
]
>
0
?
din
[
0
]
:
-
din
[
0
]);
din
++
;
dout
++
;
}
}
#ifdef LITE_WITH_TRAIN
#ifdef LITE_WITH_TRAIN
template
<
>
template
<
>
void
act_square_grad
(
const
float
*
din
,
void
act_square_grad
(
const
float
*
din
,
...
...
lite/backends/arm/math/activation.h
浏览文件 @
ae3ebea5
...
@@ -83,6 +83,9 @@ void act_hard_swish(const T* din,
...
@@ -83,6 +83,9 @@ void act_hard_swish(const T* din,
template
<
typename
T
>
template
<
typename
T
>
void
act_reciprocal
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
void
act_reciprocal
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
template
<
typename
T
>
void
act_abs
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
#ifdef LITE_WITH_TRAIN
#ifdef LITE_WITH_TRAIN
template
<
typename
T
>
template
<
typename
T
>
void
act_square_grad
(
void
act_square_grad
(
...
...
lite/backends/arm/math/concat.cc
浏览文件 @
ae3ebea5
...
@@ -16,46 +16,3 @@
...
@@ -16,46 +16,3 @@
#include <algorithm>
#include <algorithm>
#include <limits>
#include <limits>
#include <memory>
#include <memory>
#include "lite/backends/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
concat_func
(
const
std
::
vector
<
lite
::
Tensor
*>
&
input
,
const
int
axis
,
lite
::
Tensor
*
output
)
{
int64_t
concat_input_size
=
1
;
int64_t
num_cancats
=
1
;
auto
dim_0
=
input
[
0
]
->
dims
();
size_t
num
=
input
.
size
();
for
(
int
i
=
axis
+
1
;
i
<
dim_0
.
size
();
i
++
)
{
concat_input_size
*=
dim_0
[
i
];
}
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
num_cancats
*=
dim_0
[
i
];
}
float
*
dst_ptr
=
output
->
mutable_data
<
float
>
();
const
int
out_concat_axis
=
output
->
dims
()[
axis
];
int64_t
offset_concat_axis
=
0
;
int64_t
out_sum
=
out_concat_axis
*
concat_input_size
;
for
(
int
n
=
0
;
n
<
num
;
n
++
)
{
auto
dims
=
input
[
n
]
->
dims
();
const
float
*
src_ptr
=
input
[
n
]
->
data
<
float
>
();
int64_t
in_concat_axis
=
dims
[
axis
];
float
*
dout_ptr
=
dst_ptr
+
offset_concat_axis
*
concat_input_size
;
int64_t
in_sum
=
in_concat_axis
*
concat_input_size
;
for
(
int
i
=
0
;
i
<
num_cancats
;
i
++
)
{
std
::
memcpy
(
dout_ptr
,
src_ptr
,
sizeof
(
float
)
*
in_sum
);
dout_ptr
+=
out_sum
;
src_ptr
+=
in_sum
;
}
offset_concat_axis
+=
in_concat_axis
;
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
lite/backends/arm/math/concat.h
浏览文件 @
ae3ebea5
...
@@ -25,9 +25,39 @@ namespace lite {
...
@@ -25,9 +25,39 @@ namespace lite {
namespace
arm
{
namespace
arm
{
namespace
math
{
namespace
math
{
void
concat_func
(
const
std
::
vector
<
lite
::
Tensor
*>
&
input
,
template
<
typename
T
>
void
concat_func
(
const
std
::
vector
<
lite
::
Tensor
*>&
input
,
const
int
axis
,
const
int
axis
,
lite
::
Tensor
*
output
);
lite
::
Tensor
*
output
)
{
size_t
num
=
input
.
size
();
auto
dim_0
=
input
[
0
]
->
dims
();
int64_t
concat_input_size
=
1
;
int64_t
num_cancats
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
dim_0
.
size
();
i
++
)
{
concat_input_size
*=
dim_0
[
i
];
}
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
num_cancats
*=
dim_0
[
i
];
}
auto
*
dst_ptr
=
output
->
mutable_data
<
T
>
();
const
int
out_concat_axis
=
output
->
dims
()[
axis
];
int64_t
offset_concat_axis
=
0
;
int64_t
out_sum
=
out_concat_axis
*
concat_input_size
;
for
(
int
n
=
0
;
n
<
num
;
n
++
)
{
auto
dims
=
input
[
n
]
->
dims
();
auto
*
src_ptr
=
input
[
n
]
->
data
<
T
>
();
int64_t
in_concat_axis
=
dims
[
axis
];
auto
*
dout_ptr
=
dst_ptr
+
offset_concat_axis
*
concat_input_size
;
int64_t
in_sum
=
in_concat_axis
*
concat_input_size
;
for
(
int
i
=
0
;
i
<
num_cancats
;
i
++
)
{
std
::
memcpy
(
dout_ptr
,
src_ptr
,
sizeof
(
T
)
*
in_sum
);
dout_ptr
+=
out_sum
;
src_ptr
+=
in_sum
;
}
offset_concat_axis
+=
in_concat_axis
;
}
}
}
// namespace math
}
// namespace math
}
// namespace arm
}
// namespace arm
...
...
lite/kernels/arm/activation_compute.cc
浏览文件 @
ae3ebea5
...
@@ -207,6 +207,16 @@ void ReciprocalCompute::Run() {
...
@@ -207,6 +207,16 @@ void ReciprocalCompute::Run() {
x_data
,
output_data
,
x_dims
.
production
(),
ctx
.
threads
());
x_data
,
output_data
,
x_dims
.
production
(),
ctx
.
threads
());
}
}
void
AbsCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
ARMContext
>();
auto
x_dims
=
param
.
X
->
dims
();
auto
x_data
=
param
.
X
->
data
<
float
>
();
auto
output_data
=
param
.
Out
->
mutable_data
<
float
>
();
lite
::
arm
::
math
::
act_abs
<
float
>
(
x_data
,
output_data
,
x_dims
.
production
(),
ctx
.
threads
());
}
}
// namespace arm
}
// namespace arm
}
// namespace kernels
}
// namespace kernels
}
// namespace lite
}
// namespace lite
...
@@ -321,3 +331,8 @@ REGISTER_LITE_KERNEL(reciprocal,
...
@@ -321,3 +331,8 @@ REGISTER_LITE_KERNEL(reciprocal,
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
abs
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
AbsCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
lite/kernels/arm/activation_compute.h
浏览文件 @
ae3ebea5
...
@@ -166,6 +166,15 @@ class ReciprocalCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
...
@@ -166,6 +166,15 @@ class ReciprocalCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual
~
ReciprocalCompute
()
=
default
;
virtual
~
ReciprocalCompute
()
=
default
;
};
};
class
AbsCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
ActivationParam
;
void
Run
()
override
;
virtual
~
AbsCompute
()
=
default
;
};
}
// namespace arm
}
// namespace arm
}
// namespace kernels
}
// namespace kernels
}
// namespace lite
}
// namespace lite
...
...
lite/kernels/arm/concat_compute.cc
浏览文件 @
ae3ebea5
...
@@ -34,40 +34,21 @@ std::vector<size_t> stride_numel(const DDim& ddim) {
...
@@ -34,40 +34,21 @@ std::vector<size_t> stride_numel(const DDim& ddim) {
return
strides
;
return
strides
;
}
}
void
ConcatCompute
::
Run
()
{
template
<
typename
T
>
auto
&
param
=
Param
<
operators
::
ConcatParam
>
();
void
ConcatFunc
(
const
std
::
vector
<
lite
::
Tensor
*>
inputs
,
std
::
vector
<
lite
::
Tensor
*>
inputs
=
param
.
x
;
int
axis
,
auto
*
out
=
param
.
output
;
lite
::
Tensor
*
out
)
{
int
axis
=
param
.
axis
;
// Sometimes direct copies will be faster, this maybe need deeply analysis.
auto
*
axis_tensor
=
param
.
axis_tensor
;
if
(
axis_tensor
!=
nullptr
)
{
auto
*
axis_tensor_data
=
axis_tensor
->
data
<
int
>
();
axis
=
axis_tensor_data
[
0
];
}
out
->
mutable_data
<
float
>
();
/// Sometimes direct copies will be faster, this maybe need deeply analysis.
if
(
axis
==
0
&&
inputs
.
size
()
<
10
)
{
if
(
axis
==
0
&&
inputs
.
size
()
<
10
)
{
size_t
output_offset
=
0
;
size_t
output_offset
=
0
;
for
(
auto
*
in
:
inputs
)
{
for
(
auto
*
in
:
inputs
)
{
auto
in_stride
=
stride_numel
(
in
->
dims
());
auto
in_stride
=
stride_numel
(
in
->
dims
());
auto
out_stride
=
stride_numel
(
out
->
dims
());
auto
out_stride
=
stride_numel
(
out
->
dims
());
void
*
dst
=
out
->
mutable_data
<
float
>
()
+
output_offset
;
void
*
dst
=
out
->
mutable_data
<
T
>
()
+
output_offset
;
const
void
*
src
=
in
->
data
<
float
>
();
const
void
*
src
=
in
->
data
<
T
>
();
#if 0
LOG(INFO) << "out_stride.size():" << out_stride.size();
LOG(INFO) << "out_stride[0]" << out_stride[0];
for (int i=0; i < out_stride.size(); ++i) {
LOG(INFO) << "out_stride[" << i << "]:" << out_stride[i];
}
LOG(INFO) << "in_stride.size():" << in_stride.size();
for (int i=0; i < in_stride.size(); ++i) {
LOG(INFO) << "in_stride[" << i << "]:" << in_stride[i];
}
#endif
// src and dst tensor should have the same dims size.
// src and dst tensor should have the same dims size.
CHECK
(
in_stride
.
size
()
==
out_stride
.
size
());
CHECK
(
in_stride
.
size
()
==
out_stride
.
size
());
std
::
memcpy
(
dst
,
src
,
sizeof
(
float
)
*
in_stride
[
0
]);
std
::
memcpy
(
dst
,
src
,
sizeof
(
T
)
*
in_stride
[
0
]);
output_offset
+=
in_stride
[
0
];
output_offset
+=
in_stride
[
0
];
}
}
}
else
{
}
else
{
...
@@ -75,9 +56,37 @@ void ConcatCompute::Run() {
...
@@ -75,9 +56,37 @@ void ConcatCompute::Run() {
for
(
int
j
=
0
;
j
<
inputs
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
inputs
.
size
();
++
j
)
{
inputs_concat
[
j
]
=
inputs
[
j
];
inputs_concat
[
j
]
=
inputs
[
j
];
}
}
lite
::
arm
::
math
::
concat_func
(
inputs_concat
,
axis
,
out
);
lite
::
arm
::
math
::
concat_func
<
T
>
(
inputs_concat
,
axis
,
out
);
}
}
void
ConcatCompute
::
Run
()
{
auto
&
param
=
Param
<
operators
::
ConcatParam
>
();
std
::
vector
<
lite
::
Tensor
*>
inputs
=
param
.
x
;
CHECK_GE
(
inputs
.
size
(),
1
);
auto
*
out
=
param
.
output
;
int
axis
=
param
.
axis
;
auto
*
axis_tensor
=
param
.
axis_tensor
;
if
(
axis_tensor
!=
nullptr
)
{
auto
*
axis_tensor_data
=
axis_tensor
->
data
<
int
>
();
axis
=
axis_tensor_data
[
0
];
}
switch
(
inputs
.
front
()
->
precision
())
{
case
PRECISION
(
kFloat
):
ConcatFunc
<
float
>
(
inputs
,
axis
,
out
);
break
;
case
PRECISION
(
kInt32
):
ConcatFunc
<
int32_t
>
(
inputs
,
axis
,
out
);
break
;
case
PRECISION
(
kInt64
):
ConcatFunc
<
int64_t
>
(
inputs
,
axis
,
out
);
break
;
default:
LOG
(
FATAL
)
<<
"Concat does not implement for the "
<<
"input type:"
<<
static_cast
<
int
>
(
inputs
.
front
()
->
precision
());
}
}
return
;
}
}
}
// namespace arm
}
// namespace arm
...
@@ -86,9 +95,9 @@ void ConcatCompute::Run() {
...
@@ -86,9 +95,9 @@ void ConcatCompute::Run() {
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
REGISTER_LITE_KERNEL
(
concat
,
kARM
,
k
Float
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ConcatCompute
,
def
)
concat
,
kARM
,
k
Any
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ConcatCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
)
,
PRECISION
(
kAny
)
)})
.
BindInput
(
"AxisTensor"
,
.
BindInput
(
"AxisTensor"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt32
))})
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt32
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
)
,
PRECISION
(
kAny
)
)})
.
Finalize
();
.
Finalize
();
lite/kernels/arm/concat_compute.h
浏览文件 @
ae3ebea5
...
@@ -22,7 +22,7 @@ namespace lite {
...
@@ -22,7 +22,7 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
arm
{
namespace
arm
{
class
ConcatCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
k
Float
)
>
{
class
ConcatCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
k
Any
)
>
{
public:
public:
using
param_t
=
operators
::
ConcatParam
;
using
param_t
=
operators
::
ConcatParam
;
...
...
lite/kernels/arm/concat_compute_test.cc
浏览文件 @
ae3ebea5
...
@@ -95,7 +95,7 @@ void concat_compute_ref(const operators::ConcatParam& param) {
...
@@ -95,7 +95,7 @@ void concat_compute_ref(const operators::ConcatParam& param) {
TEST
(
concat_arm
,
init
)
{
TEST
(
concat_arm
,
init
)
{
ConcatCompute
concat
;
ConcatCompute
concat
;
ASSERT_EQ
(
concat
.
precision
(),
PRECISION
(
k
Float
));
ASSERT_EQ
(
concat
.
precision
(),
PRECISION
(
k
Any
));
ASSERT_EQ
(
concat
.
target
(),
TARGET
(
kARM
));
ASSERT_EQ
(
concat
.
target
(),
TARGET
(
kARM
));
}
}
...
@@ -222,8 +222,7 @@ TEST(concat_arm, compute_input_multi) {
...
@@ -222,8 +222,7 @@ TEST(concat_arm, compute_input_multi) {
TEST
(
concat
,
retrive_op
)
{
TEST
(
concat
,
retrive_op
)
{
auto
concat
=
auto
concat
=
KernelRegistry
::
Global
().
Create
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
(
KernelRegistry
::
Global
().
Create
<
TARGET
(
kARM
),
PRECISION
(
kAny
)
>
(
"concat"
);
"concat"
);
ASSERT_FALSE
(
concat
.
empty
());
ASSERT_FALSE
(
concat
.
empty
());
ASSERT_TRUE
(
concat
.
front
());
ASSERT_TRUE
(
concat
.
front
());
}
}
...
@@ -233,4 +232,4 @@ TEST(concat, retrive_op) {
...
@@ -233,4 +232,4 @@ TEST(concat, retrive_op) {
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
USE_LITE_KERNEL
(
concat
,
kARM
,
k
Float
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
concat
,
kARM
,
k
Any
,
kNCHW
,
def
);
lite/kernels/arm/gather_compute.cc
浏览文件 @
ae3ebea5
...
@@ -20,24 +20,48 @@ namespace lite {
...
@@ -20,24 +20,48 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
arm
{
namespace
arm
{
void
GatherCompute
::
Run
()
{
template
<
typename
T
>
auto
&
param
=
this
->
Param
<
operators
::
GatherParam
>
();
void
GatherFunc
(
const
operators
::
GatherParam
&
param
)
{
auto
*
p_output
=
param
.
Out
->
mutable_data
<
float
>
();
auto
index_size
=
param
.
Index
->
dims
()[
0
];
auto
src_dims
=
param
.
X
->
dims
();
auto
src_dims
=
param
.
X
->
dims
();
const
float
*
p_src
=
param
.
X
->
data
<
float
>
();
auto
index_size
=
param
.
Index
->
dims
()[
0
];
auto
*
p_src
=
param
.
X
->
data
<
T
>
();
const
int
*
p_index
=
param
.
Index
->
data
<
int
>
();
const
int
*
p_index
=
param
.
Index
->
data
<
int
>
();
auto
*
p_output
=
param
.
Out
->
mutable_data
<
T
>
();
int
slice_size
=
1
;
int
slice_size
=
1
;
for
(
in
t
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
{
for
(
size_
t
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
{
slice_size
*=
src_dims
[
i
];
slice_size
*=
src_dims
[
i
];
}
}
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
{
int
index_
=
p_index
[
i
];
int
index_
=
p_index
[
i
];
memcpy
(
p_output
+
i
*
slice_size
,
memcpy
(
p_output
+
i
*
slice_size
,
p_src
+
index_
*
slice_size
,
p_src
+
index_
*
slice_size
,
slice_size
*
sizeof
(
float
));
slice_size
*
sizeof
(
T
));
}
}
void
GatherCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
operators
::
GatherParam
>
();
switch
(
param
.
X
->
precision
())
{
case
PRECISION
(
kFloat
):
GatherFunc
<
float
>
(
param
);
break
;
case
PRECISION
(
kInt8
):
GatherFunc
<
int8_t
>
(
param
);
break
;
case
PRECISION
(
kInt16
):
GatherFunc
<
int16_t
>
(
param
);
break
;
case
PRECISION
(
kInt32
):
GatherFunc
<
int32_t
>
(
param
);
break
;
case
PRECISION
(
kInt64
):
GatherFunc
<
int64_t
>
(
param
);
break
;
default:
LOG
(
FATAL
)
<<
"Gather does not implement for the "
<<
"input type:"
<<
static_cast
<
int
>
(
param
.
X
->
precision
());
}
}
}
}
...
@@ -48,8 +72,8 @@ void GatherCompute::Run() {
...
@@ -48,8 +72,8 @@ void GatherCompute::Run() {
REGISTER_LITE_KERNEL
(
REGISTER_LITE_KERNEL
(
gather
,
kARM
,
kAny
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
GatherCompute
,
def
)
gather
,
kARM
,
kAny
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
GatherCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
)
,
PRECISION
(
kAny
)
)})
.
BindInput
(
"Index"
,
.
BindInput
(
"Index"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt32
))})
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt32
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
)
,
PRECISION
(
kAny
)
)})
.
Finalize
();
.
Finalize
();
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录