Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
6d7e40a9
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d7e40a9
编写于
3月 23, 2020
作者:
X
xiaogang
提交者:
GitHub
3月 23, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat: add elementwise_grad op (#3246)
上级
5045d394
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1084 addition
and
3 deletion
+1084
-3
lite/backends/arm/math/elementwise.cc
lite/backends/arm/math/elementwise.cc
+144
-0
lite/backends/arm/math/elementwise.h
lite/backends/arm/math/elementwise.h
+14
-0
lite/kernels/arm/CMakeLists.txt
lite/kernels/arm/CMakeLists.txt
+1
-0
lite/kernels/arm/elementwise_grad_compute.cc
lite/kernels/arm/elementwise_grad_compute.cc
+199
-0
lite/kernels/arm/elementwise_grad_compute.h
lite/kernels/arm/elementwise_grad_compute.h
+68
-0
lite/operators/CMakeLists.txt
lite/operators/CMakeLists.txt
+1
-0
lite/operators/elementwise_grad_ops.cc
lite/operators/elementwise_grad_ops.cc
+67
-0
lite/operators/elementwise_grad_ops.h
lite/operators/elementwise_grad_ops.h
+44
-0
lite/operators/op_params.h
lite/operators/op_params.h
+4
-3
lite/tests/kernels/CMakeLists.txt
lite/tests/kernels/CMakeLists.txt
+1
-0
lite/tests/kernels/elementwise_grad_compute_test.cc
lite/tests/kernels/elementwise_grad_compute_test.cc
+541
-0
未找到文件。
lite/backends/arm/math/elementwise.cc
浏览文件 @
6d7e40a9
...
@@ -266,6 +266,72 @@ void elementwise_add_relu_broadcast<float>(const float* dinx,
...
@@ -266,6 +266,72 @@ void elementwise_add_relu_broadcast<float>(const float* dinx,
}
}
}
}
template
<
>
void
elementwise_add_grad
<
float
>
(
const
float
*
dout_grad
,
float
*
x_grad
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
&
0x0f
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
const
float
*
out_data
=
dout_grad
+
16
*
i
;
float
*
x_data
=
x_grad
+
16
*
i
;
float32x4_t
din0
=
vld1q_f32
(
out_data
);
float32x4_t
din1
=
vld1q_f32
(
out_data
+
4
);
float32x4_t
din2
=
vld1q_f32
(
out_data
+
8
);
float32x4_t
din3
=
vld1q_f32
(
out_data
+
12
);
vst1q_f32
(
x_data
,
din0
);
vst1q_f32
(
x_data
+
4
,
din1
);
vst1q_f32
(
x_data
+
8
,
din2
);
vst1q_f32
(
x_data
+
12
,
din3
);
}
if
(
remain
>
0
)
{
const
float
*
out_data
=
dout_grad
+
16
*
cnt
;
float
*
x_data
=
x_grad
+
16
*
cnt
;
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
x_data
[
i
]
=
out_data
[
i
];
}
}
}
// we assume that y_data numel less than x_data, otherwise, call this function
// by change x_grad and y_grad position
template
<
>
void
elementwise_add_grad_broadcast
<
float
>
(
const
float
*
dout_grad
,
float
*
x_grad
,
float
*
y_grad
,
int
pre
,
int
n
,
int
post
)
{
if
(
x_grad
)
{
elementwise_add_grad
(
dout_grad
,
x_grad
,
pre
*
n
*
post
);
}
if
(
y_grad
)
{
memset
(
y_grad
,
0
,
n
*
sizeof
(
float
));
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
float
sum
=
0
;
int
cnt
=
post
>>
2
;
int
remain
=
post
&
0x03
;
const
float
*
out_data
=
dout_grad
+
(
i
*
n
+
j
)
*
post
;
float32x4_t
sum_v
=
vdupq_n_f32
(
0
);
for
(
int
ci
=
0
;
ci
<
cnt
;
++
ci
)
{
float32x4_t
din
=
vld1q_f32
(
out_data
+
4
*
ci
);
sum_v
=
vaddq_f32
(
sum_v
,
din
);
}
out_data
+=
4
*
cnt
;
for
(
int
ci
=
0
;
ci
<
remain
;
++
ci
)
{
sum
+=
out_data
[
ci
];
}
float32x2_t
high
=
vget_high_f32
(
sum_v
);
float32x2_t
low
=
vget_low_f32
(
sum_v
);
sum
+=
vget_lane_f32
(
high
,
0
)
+
vget_lane_f32
(
high
,
1
)
+
vget_lane_f32
(
low
,
0
)
+
vget_lane_f32
(
low
,
1
);
y_grad
[
j
]
+=
sum
;
}
}
}
}
template
<
>
template
<
>
void
elementwise_sub
<
float
>
(
const
float
*
dinx
,
void
elementwise_sub
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
const
float
*
diny
,
...
@@ -510,6 +576,84 @@ void elementwise_sub_relu_broadcast<float>(const float* dinx,
...
@@ -510,6 +576,84 @@ void elementwise_sub_relu_broadcast<float>(const float* dinx,
}
}
}
}
}
}
// we assume the formula is x-y
template
<
>
void
elementwise_sub_grad
<
float
>
(
const
float
*
dout_grad
,
float
*
x_grad
,
float
*
y_grad
,
int
num
)
{
if
(
x_grad
)
{
elementwise_add_grad
(
dout_grad
,
x_grad
,
num
);
}
if
(
y_grad
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
&
0x0f
;
float32x4_t
minus
=
vdupq_n_f32
(
-
1
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
const
float
*
out_data
=
dout_grad
+
16
*
i
;
float
*
y_data
=
y_grad
+
16
*
i
;
float32x4_t
din0
=
vld1q_f32
(
out_data
);
float32x4_t
din1
=
vld1q_f32
(
out_data
+
4
);
float32x4_t
din2
=
vld1q_f32
(
out_data
+
8
);
float32x4_t
din3
=
vld1q_f32
(
out_data
+
12
);
din0
=
vmulq_f32
(
din0
,
minus
);
din1
=
vmulq_f32
(
din1
,
minus
);
din2
=
vmulq_f32
(
din2
,
minus
);
din3
=
vmulq_f32
(
din3
,
minus
);
vst1q_f32
(
y_data
,
din0
);
vst1q_f32
(
y_data
+
4
,
din1
);
vst1q_f32
(
y_data
+
8
,
din2
);
vst1q_f32
(
y_data
+
12
,
din3
);
}
if
(
remain
>
0
)
{
const
float
*
out_data
=
dout_grad
+
16
*
cnt
;
float
*
y_data
=
y_grad
+
16
*
cnt
;
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
y_data
[
i
]
=
-
out_data
[
i
];
}
}
}
}
// we assume that y_data numel less than x_data, otherwise, call this function
// by change x_grad and y_grad position
template
<
>
void
elementwise_sub_grad_broadcast
<
float
>
(
const
float
*
dout_grad
,
float
*
x_grad
,
float
*
y_grad
,
int
pre
,
int
n
,
int
post
)
{
if
(
x_grad
)
{
elementwise_add_grad
(
dout_grad
,
x_grad
,
pre
*
n
*
post
);
}
if
(
y_grad
)
{
memset
(
y_grad
,
0
,
n
*
sizeof
(
float
));
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
float
sum
=
0
;
int
cnt
=
post
<<
2
;
int
remain
=
post
&
0x03
;
const
float
*
out_data
=
dout_grad
+
(
i
*
n
+
j
)
*
post
;
float32x4_t
sum_v
=
vdupq_n_f32
(
0
);
for
(
int
ci
=
0
;
ci
<
cnt
;
++
ci
)
{
float32x4_t
din
=
vld1q_f32
(
out_data
+
4
*
ci
);
sum_v
=
vaddq_f32
(
sum_v
,
din
);
}
out_data
+=
4
*
cnt
;
for
(
int
ci
=
0
;
ci
<
remain
;
++
ci
)
{
sum
-=
out_data
[
ci
];
}
float32x2_t
high
=
vget_high_f32
(
sum_v
);
float32x2_t
low
=
vget_low_f32
(
sum_v
);
sum
-=
vget_lane_f32
(
high
,
0
)
+
vget_lane_f32
(
high
,
1
)
+
vget_lane_f32
(
low
,
0
)
+
vget_lane_f32
(
low
,
1
);
y_grad
[
j
]
+=
sum
;
}
}
}
}
template
<
>
template
<
>
void
elementwise_mul
<
float
>
(
const
float
*
dinx
,
void
elementwise_mul
<
float
>
(
const
float
*
dinx
,
...
...
lite/backends/arm/math/elementwise.h
浏览文件 @
6d7e40a9
...
@@ -183,6 +183,13 @@ template <typename T>
...
@@ -183,6 +183,13 @@ template <typename T>
void
elementwise_add_relu_broadcast
(
void
elementwise_add_relu_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
template
<
typename
T
>
void
elementwise_add_grad
(
const
T
*
dout
,
T
*
dinx
,
int
num
);
template
<
typename
T
>
void
elementwise_add_grad_broadcast
(
const
T
*
dout_grad
,
T
*
x_grad
,
T
*
y_grad
,
int
pre
,
int
n
,
int
post
);
template
<
typename
T
>
template
<
typename
T
>
void
elementwise_sub
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
void
elementwise_sub
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
...
@@ -197,6 +204,13 @@ template <typename T>
...
@@ -197,6 +204,13 @@ template <typename T>
void
elementwise_sub_relu_broadcast
(
void
elementwise_sub_relu_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
template
<
typename
T
>
void
elementwise_sub_grad
(
const
T
*
dout
,
T
*
dinx
,
T
*
diny
,
int
num
);
template
<
typename
T
>
void
elementwise_sub_grad_broadcast
(
const
T
*
dout_grad
,
T
*
x_grad
,
T
*
y_grad
,
int
pre
,
int
n
,
int
post
);
template
<
typename
T
>
template
<
typename
T
>
void
elementwise_mul
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
void
elementwise_mul
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
...
...
lite/kernels/arm/CMakeLists.txt
浏览文件 @
6d7e40a9
...
@@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de
...
@@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de
if
(
LITE_WITH_TRAIN
)
if
(
LITE_WITH_TRAIN
)
add_kernel
(
mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
add_kernel
(
sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
endif
()
endif
()
...
...
lite/kernels/arm/elementwise_grad_compute.cc
0 → 100644
浏览文件 @
6d7e40a9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/elementwise_grad_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
arm
{
inline
DDim
trim_trailing_singular_dims
(
const
DDim
&
dims
)
{
// Remove trailing dimensions of size 1 for y
auto
actual_dims_size
=
dims
.
size
();
for
(;
actual_dims_size
!=
0
;
--
actual_dims_size
)
{
if
(
dims
[
actual_dims_size
-
1
]
!=
1
)
break
;
}
std
::
vector
<
int64_t
>
trim_dims
;
trim_dims
.
resize
(
actual_dims_size
);
for
(
int
i
=
0
;
i
<
actual_dims_size
;
++
i
)
{
trim_dims
[
i
]
=
dims
[
i
];
}
if
(
trim_dims
.
size
()
==
0
)
{
return
DDim
();
}
return
DDim
(
trim_dims
);
}
inline
bool
is_broadcast
(
const
DDim
&
x_dims
,
const
DDim
&
y_dims
,
int
axis
,
int
*
pre
,
int
*
n
,
int
*
post
)
{
if
(
axis
<
0
)
{
axis
=
x_dims
.
size
()
-
y_dims
.
size
();
}
DDim
y_dim_trim
=
trim_trailing_singular_dims
(
y_dims
);
axis
=
(
y_dim_trim
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
if
(
x_dims
.
size
()
==
y_dim_trim
.
size
())
{
return
false
;
}
*
pre
=
1
;
*
n
=
1
;
*
post
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dim_trim
.
size
();
++
i
)
{
CHECK_EQ
(
x_dims
[
i
+
axis
],
y_dim_trim
[
i
])
<<
"Broadcast dimension mismatch."
;
(
*
n
)
*=
y_dim_trim
[
i
];
}
for
(
int
i
=
axis
+
y_dim_trim
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
return
true
;
}
void
ElementwiseAddGradCompute
::
Run
()
{
auto
&
param
=
Param
<
operators
::
ElementwiseGradParam
>
();
const
float
*
x_data
=
param
.
X
->
data
<
float
>
();
const
float
*
y_data
=
param
.
Y
->
data
<
float
>
();
const
float
*
out_grad_data
=
param
.
OutGrad
->
data
<
float
>
();
float
*
x_grad_data
=
param
.
XGrad
->
mutable_data
<
float
>
();
float
*
y_grad_data
=
param
.
YGrad
->
mutable_data
<
float
>
();
int
axis
=
param
.
axis
;
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
int
pre
,
n
,
post
;
if
(
x_dims
.
size
()
<
y_dims
.
size
()
&&
is_broadcast
(
y_dims
,
x_dims
,
axis
,
&
pre
,
&
n
,
&
post
))
{
lite
::
arm
::
math
::
elementwise_add_grad_broadcast
(
out_grad_data
,
y_grad_data
,
x_grad_data
,
pre
,
n
,
post
);
}
else
if
(
is_broadcast
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
))
{
lite
::
arm
::
math
::
elementwise_add_grad_broadcast
(
out_grad_data
,
x_grad_data
,
y_grad_data
,
pre
,
n
,
post
);
}
else
{
lite
::
arm
::
math
::
elementwise_add_grad
(
out_grad_data
,
x_grad_data
,
x_dims
.
production
());
lite
::
arm
::
math
::
elementwise_add_grad
(
out_grad_data
,
y_grad_data
,
y_dims
.
production
());
}
}
void
ElementwiseSubGradCompute
::
Run
()
{
auto
&
param
=
Param
<
operators
::
ElementwiseGradParam
>
();
const
float
*
x_data
=
param
.
X
->
data
<
float
>
();
const
float
*
y_data
=
param
.
Y
->
data
<
float
>
();
const
float
*
out_data
=
param
.
OutGrad
->
data
<
float
>
();
float
*
x_grad_data
=
param
.
XGrad
->
mutable_data
<
float
>
();
float
*
y_grad_data
=
param
.
YGrad
->
mutable_data
<
float
>
();
int
axis
=
param
.
axis
;
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
int
pre
,
n
,
post
;
if
(
x_dims
.
size
()
<
y_dims
.
size
())
{
LOG
(
FATAL
)
<<
"elewise div grad don't support x_dims size < y_dims size"
;
}
if
(
is_broadcast
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
))
{
lite
::
arm
::
math
::
elementwise_sub_grad_broadcast
(
out_data
,
x_grad_data
,
y_grad_data
,
pre
,
n
,
post
);
}
else
{
lite
::
arm
::
math
::
elementwise_sub_grad
(
out_data
,
x_grad_data
,
y_grad_data
,
x_dims
.
production
());
}
}
template
<
typename
T
,
PrecisionType
PType
>
void
ElementwiseMulGradCompute
<
T
,
PType
>::
Run
()
{
LOG
(
FATAL
)
<<
"elementwise mul_grad not implement yet"
;
}
void
ElementwiseMaxGradCompute
::
Run
()
{
LOG
(
FATAL
)
<<
"elementwise max_grad not implement yet"
;
}
void
ElementwiseDivGradCompute
::
Run
()
{
LOG
(
FATAL
)
<<
"elementwise div_grad not implement yet"
;
}
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
using
elementwise_mul_grad_float
=
paddle
::
lite
::
kernels
::
arm
::
ElementwiseMulGradCompute
<
float
,
PRECISION
(
kFloat
)
>
;
REGISTER_LITE_KERNEL
(
elementwise_add_grad
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ElementwiseAddGradCompute
,
def
)
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Out@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"X@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Y@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
elementwise_sub_grad
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ElementwiseSubGradCompute
,
def
)
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Out@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"X@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Y@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
elementwise_div_grad
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ElementwiseDivGradCompute
,
def
)
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Out@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"X@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Y@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
elementwise_mul_grad
,
kARM
,
kFloat
,
kNCHW
,
elementwise_mul_grad_float
,
def
)
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Out@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"X@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Y@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
elementwise_max_grad
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ElementwiseMaxGradCompute
,
def
)
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Out@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"X@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Y@Grad"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
lite/kernels/arm/elementwise_grad_compute.h
0 → 100644
浏览文件 @
6d7e40a9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
arm
{
class
ElementwiseAddGradCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
;
virtual
~
ElementwiseAddGradCompute
()
=
default
;
};
class
ElementwiseSubGradCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
;
virtual
~
ElementwiseSubGradCompute
()
=
default
;
};
template
<
typename
T
,
PrecisionType
PType
>
class
ElementwiseMulGradCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PType
>
{
public:
void
Run
()
override
;
virtual
~
ElementwiseMulGradCompute
()
=
default
;
};
class
ElementwiseMaxGradCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
;
virtual
~
ElementwiseMaxGradCompute
()
=
default
;
};
class
ElementwiseDivGradCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
;
virtual
~
ElementwiseDivGradCompute
()
=
default
;
};
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/operators/CMakeLists.txt
浏览文件 @
6d7e40a9
...
@@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
...
@@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
if
(
LITE_WITH_TRAIN
)
if
(
LITE_WITH_TRAIN
)
add_operator
(
mean_grad_op extra SRCS mean_grad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
mean_grad_op extra SRCS mean_grad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
activation_grad_ops basic SRCS activation_grad_ops.cc DEPS
${
op_DEPS
}
)
add_operator
(
activation_grad_ops basic SRCS activation_grad_ops.cc DEPS
${
op_DEPS
}
)
add_operator
(
elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS
${
op_DEPS
}
)
add_operator
(
mul_grad_op basic SRCS mul_grad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
mul_grad_op basic SRCS mul_grad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sgd_op extra SRCS sgd_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sgd_op extra SRCS sgd_op.cc DEPS
${
op_DEPS
}
)
endif
()
endif
()
...
...
lite/operators/elementwise_grad_ops.cc
0 → 100644
浏览文件 @
6d7e40a9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/elementwise_grad_ops.h"
#include <algorithm>
#include <cmath>
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
ElementwiseGradOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
XGrad
);
CHECK_OR_FALSE
(
param_
.
YGrad
);
CHECK_OR_FALSE
(
param_
.
OutGrad
);
return
true
;
}
bool
ElementwiseGradOp
::
InferShape
()
const
{
auto
x_dim
=
param_
.
X
->
dims
();
auto
y_dim
=
param_
.
Y
->
dims
();
param_
.
XGrad
->
Resize
(
x_dim
);
param_
.
YGrad
->
Resize
(
y_dim
);
return
true
;
}
bool
ElementwiseGradOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_name
=
opdesc
.
Input
(
"Out@Grad"
).
front
();
auto
x_grad_name
=
opdesc
.
Output
(
"X@Grad"
).
front
();
auto
y_grad_name
=
opdesc
.
Output
(
"Y@Grad"
).
front
();
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
XGrad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
x_grad_name
);
param_
.
YGrad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
y_grad_name
);
param_
.
OutGrad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
elementwise_grad_sub
,
paddle
::
lite
::
operators
::
ElementwiseGradOp
);
REGISTER_LITE_OP
(
elementwise_grad_add
,
paddle
::
lite
::
operators
::
ElementwiseGradOp
);
REGISTER_LITE_OP
(
elementwise_grad_mul
,
paddle
::
lite
::
operators
::
ElementwiseGradOp
);
REGISTER_LITE_OP
(
elementwise_grad_max
,
paddle
::
lite
::
operators
::
ElementwiseGradOp
);
lite/operators/elementwise_grad_ops.h
0 → 100644
浏览文件 @
6d7e40a9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
ElementwiseGradOp
:
public
OpLite
{
public:
explicit
ElementwiseGradOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_grad_op"
;
}
private:
mutable
operators
::
ElementwiseGradParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
lite/operators/op_params.h
浏览文件 @
6d7e40a9
...
@@ -387,10 +387,11 @@ struct ElementwiseParam {
...
@@ -387,10 +387,11 @@ struct ElementwiseParam {
};
};
struct
ElementwiseGradParam
{
struct
ElementwiseGradParam
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Out
_g
rad
{};
const
lite
::
Tensor
*
Out
G
rad
{};
lite
::
Tensor
*
X
_g
rad
{};
lite
::
Tensor
*
X
G
rad
{};
lite
::
Tensor
*
Y
_g
rad
{};
lite
::
Tensor
*
Y
G
rad
{};
int
axis
{
-
1
};
// for broadcasting.
int
axis
{
-
1
};
// for broadcasting.
};
};
...
...
lite/tests/kernels/CMakeLists.txt
浏览文件 @
6d7e40a9
...
@@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA)
...
@@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA)
if
(
LITE_WITH_TRAIN
)
if
(
LITE_WITH_TRAIN
)
lite_cc_test
(
test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
lite_cc_test
(
test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework
${
xpu_kernels
}
${
npu_kernels
}
${
x86_kernels
}
${
cuda_kernels
}
${
arm_kernels
}
${
lite_ops
}
${
host_kernels
}
)
endif
()
endif
()
...
...
lite/tests/kernels/elementwise_grad_compute_test.cc
0 → 100644
浏览文件 @
6d7e40a9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/elementwise_grad_compute.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/elementwise_compute.h"
#include "lite/tests/utils/fill_data.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
arm
{
using
param_t
=
operators
::
ElementwiseParam
;
using
grad_param_t
=
operators
::
ElementwiseGradParam
;
using
kernel_add_t
=
ElementwiseAddCompute
;
using
grad_kernel_add_t
=
ElementwiseAddGradCompute
;
using
kernel_sub_t
=
ElementwiseSubCompute
;
using
grad_kernel_sub_t
=
ElementwiseSubGradCompute
;
void
elementwise_common
(
grad_param_t
&
param
,
// NOLINT
std
::
vector
<
float
>&
out_grad
,
// NOLINT
std
::
vector
<
float
>&
x_grad
,
// NOLINT
std
::
vector
<
float
>&
y_grad
,
// NOLINT
std
::
string
flag
)
{
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
if
(
x_dims
==
y_dims
)
{
for
(
int
i
=
0
;
i
<
x_dims
.
production
();
++
i
)
{
if
(
flag
==
"add"
)
{
x_grad
[
i
]
=
out_grad
[
i
];
y_grad
[
i
]
=
out_grad
[
i
];
}
if
(
flag
==
"sub"
)
{
x_grad
[
i
]
=
out_grad
[
i
];
y_grad
[
i
]
=
-
out_grad
[
i
];
}
}
}
else
{
LOG
(
FATAL
)
<<
"unsupport dims"
;
}
}
class
ElementwiseAddGradTester
{
public:
explicit
ElementwiseAddGradTester
(
const
DDim
&
x_dims
,
const
DDim
&
y_dims
,
int
axis
)
:
x_dims_
(
x_dims
),
y_dims_
(
y_dims
),
axis_
(
axis
)
{}
void
prepare_kernel
()
{
std
::
unique_ptr
<
KernelContext
>
ctx1
(
new
KernelContext
);
ctx1
->
As
<
ARMContext
>
();
kernel_
.
SetContext
(
std
::
move
(
ctx1
));
std
::
unique_ptr
<
KernelContext
>
ctx3
(
new
KernelContext
);
ctx3
->
As
<
ARMContext
>
();
grad_kernel_
.
SetContext
(
std
::
move
(
ctx3
));
}
void
run_forward
(
param_t
*
param
,
kernel_add_t
*
kernel
,
const
std
::
vector
<
float
>&
x_vec
,
const
std
::
vector
<
float
>&
y_vec
,
float
*
out_vec
)
{
Tensor
x
;
Tensor
y
;
Tensor
output
;
x
.
Resize
(
x_dims_
);
y
.
Resize
(
y_dims_
);
output
.
Resize
(
DDim
(
out_dims_
));
auto
*
x_data
=
x
.
mutable_data
<
float
>
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
x_data
[
i
]
=
x_vec
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
y_data
[
i
]
=
y_vec
[
i
];
}
param
->
X
=
&
x
;
param
->
Y
=
&
y
;
param
->
Out
=
&
output
;
param
->
axis
=
axis_
;
kernel
->
SetParam
(
*
param
);
kernel
->
Launch
();
auto
*
output_data
=
output
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
out_vec
[
i
]
=
output_data
[
i
];
}
}
void
run_backward
(
grad_param_t
*
param
,
grad_kernel_add_t
*
kernel
,
const
std
::
vector
<
float
>&
x_vec
,
const
std
::
vector
<
float
>&
y_vec
,
const
std
::
vector
<
float
>&
out_grad_vec
,
float
*
x_grad_vec
,
float
*
y_grad_vec
)
{
Tensor
x
;
Tensor
x_grad
;
Tensor
y
;
Tensor
y_grad
;
Tensor
out_grad
;
x
.
Resize
(
x_dims_
);
x_grad
.
Resize
(
x_dims_
);
y
.
Resize
(
y_dims_
);
y_grad
.
Resize
(
y_dims_
);
out_grad
.
Resize
(
out_dims_
);
auto
*
x_data
=
x
.
mutable_data
<
float
>
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
();
auto
*
out_grad_data
=
out_grad
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
x_data
[
i
]
=
x_vec
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
y_data
[
i
]
=
y_vec
[
i
];
}
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
out_grad_data
[
i
]
=
out_grad_vec
[
i
];
}
param
->
X
=
&
x
;
param
->
XGrad
=
&
x_grad
;
param
->
Y
=
&
y
;
param
->
YGrad
=
&
y_grad
;
param
->
OutGrad
=
&
out_grad
;
param
->
axis
=
axis_
;
kernel
->
SetParam
(
*
param
);
kernel
->
Launch
();
auto
*
x_grad_data
=
x_grad
.
mutable_data
<
float
>
();
auto
*
y_grad_data
=
y_grad
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
x_grad_vec
[
i
]
=
x_grad_data
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
y_grad_vec
[
i
]
=
y_grad_data
[
i
];
}
}
void
check_grad
(
float
delta2
,
float
max_grad_delta2
)
{
std
::
vector
<
int64_t
>
out_shape
;
// infer shape
auto
x_dim
=
x_dims_
;
auto
y_dim
=
y_dims_
;
if
(
x_dim
==
y_dim
)
{
out_dims_
=
x_dim
;
}
else
{
int
max_dim
=
(
x_dim
.
size
()
>
y_dim
.
size
()
?
x_dim
.
size
()
:
y_dim
.
size
());
int
axis
=
param_
.
axis
;
axis
=
(
axis
==
-
1
?
std
::
abs
(
static_cast
<
int
>
(
x_dim
.
size
()
-
y_dim
.
size
()))
:
axis
);
std
::
vector
<
int64_t
>
x_dims_array
(
max_dim
);
std
::
vector
<
int64_t
>
y_dims_array
(
max_dim
);
std
::
vector
<
int64_t
>
out_dims_array
(
max_dim
);
if
(
x_dim
.
size
()
>
y_dim
.
size
())
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
y_dims_array
[
i
]
=
1
;
}
if
(
axis
+
y_dim
.
size
()
<
max_dim
)
{
for
(
int
i
=
axis
+
y_dim
.
size
();
i
<
max_dim
;
++
i
)
{
y_dims_array
[
i
]
=
1
;
}
}
x_dims_array
=
x_dim
.
Vectorize
();
for
(
int
i
=
0
;
i
<
y_dim
.
size
();
++
i
)
{
y_dims_array
[
i
+
axis
]
=
y_dim
[
i
];
}
}
else
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
x_dims_array
[
i
]
=
1
;
}
if
(
axis
+
x_dim
.
size
()
<
max_dim
)
{
for
(
int
i
=
axis
+
x_dim
.
size
();
i
<
max_dim
;
++
i
)
{
x_dims_array
[
i
]
=
1
;
}
}
y_dims_array
=
y_dim
.
Vectorize
();
for
(
int
i
=
0
;
i
<
x_dim
.
size
();
++
i
)
{
x_dims_array
[
i
+
axis
]
=
x_dim
[
i
];
}
}
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
if
(
x_dims_array
[
i
]
==
-
1
||
y_dims_array
[
i
]
==
-
1
)
{
out_dims_array
[
i
]
=
-
1
;
}
else
{
out_dims_array
[
i
]
=
std
::
max
(
x_dims_array
[
i
],
y_dims_array
[
i
]);
}
}
out_dims_
=
DDim
(
out_dims_array
);
}
// infer end
// forward
std
::
vector
<
float
>
x
(
x_dims_
.
production
());
std
::
vector
<
float
>
y
(
y_dims_
.
production
());
std
::
vector
<
float
>
out
(
out_dims_
.
production
());
fill_data_rand
(
x
.
data
(),
-
1.
f
,
1.
f
,
x_dims_
.
production
());
fill_data_rand
(
y
.
data
(),
-
1.
f
,
1.
f
,
y_dims_
.
production
());
this
->
run_forward
(
&
param_
,
&
kernel_
,
x
,
y
,
out
.
data
());
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
LOG
(
INFO
)
<<
"x_"
<<
i
<<
": "
<<
x
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
LOG
(
INFO
)
<<
"y_"
<<
i
<<
": "
<<
y
[
i
];
}
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
LOG
(
INFO
)
<<
"out_"
<<
i
<<
": "
<<
out
[
i
];
}
// backward
std
::
vector
<
float
>
out_grad
(
out_dims_
.
production
());
std
::
vector
<
float
>
x_grad
(
x_dims_
.
production
());
std
::
vector
<
float
>
y_grad
(
y_dims_
.
production
());
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
out_grad
[
i
]
=
1.0
;
}
this
->
run_backward
(
&
grad_param_
,
&
grad_kernel_
,
x
,
y
,
out_grad
,
x_grad
.
data
(),
y_grad
.
data
());
for
(
int
i
=
0
;
i
<
x_grad
.
size
();
i
++
)
{
LOG
(
INFO
)
<<
"x_grad_"
<<
i
<<
": "
<<
x_grad
[
i
];
}
for
(
int
i
=
0
;
i
<
y_grad
.
size
();
i
++
)
{
LOG
(
INFO
)
<<
"y_grad_"
<<
i
<<
": "
<<
y_grad
[
i
];
}
// get numeric gradient
std
::
vector
<
float
>
x_delta
(
x_dims_
.
production
());
std
::
vector
<
float
>
y_delta
(
y_dims_
.
production
());
std
::
vector
<
float
>
out_delta
(
out_dims_
.
production
());
Tensor
tensor_x
;
Tensor
tensor_y
;
tensor_x
.
Resize
(
x_dims_
);
tensor_y
.
Resize
(
y_dims_
);
grad_param_
.
X
=
&
tensor_x
;
grad_param_
.
Y
=
&
tensor_y
;
elementwise_common
(
grad_param_
,
out_grad
,
x_delta
,
y_delta
,
"add"
);
float
max_grad_delta
=
0.0005
;
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
EXPECT_NEAR
(
x_grad
[
i
],
x_delta
[
i
],
max_grad_delta
);
EXPECT_NEAR
(
y_grad
[
i
],
y_delta
[
i
],
max_grad_delta
);
}
}
private:
DDim
x_dims_
;
DDim
y_dims_
;
DDim
out_dims_
;
int
axis_
;
kernel_add_t
kernel_
;
grad_kernel_add_t
grad_kernel_
;
param_t
param_
;
grad_param_t
grad_param_
;
};
class
ElementwiseSubGradTester
{
public:
explicit
ElementwiseSubGradTester
(
const
DDim
&
x_dims
,
const
DDim
&
y_dims
,
int
axis
)
:
x_dims_
(
x_dims
),
y_dims_
(
y_dims
),
axis_
(
axis
)
{}
void
prepare_kernel
()
{
std
::
unique_ptr
<
KernelContext
>
ctx1
(
new
KernelContext
);
ctx1
->
As
<
ARMContext
>
();
kernel_
.
SetContext
(
std
::
move
(
ctx1
));
std
::
unique_ptr
<
KernelContext
>
ctx3
(
new
KernelContext
);
ctx3
->
As
<
ARMContext
>
();
grad_kernel_
.
SetContext
(
std
::
move
(
ctx3
));
}
void
run_forward
(
param_t
*
param
,
kernel_sub_t
*
kernel
,
const
std
::
vector
<
float
>&
x_vec
,
const
std
::
vector
<
float
>&
y_vec
,
float
*
out_vec
)
{
Tensor
x
;
Tensor
y
;
Tensor
output
;
x
.
Resize
(
x_dims_
);
y
.
Resize
(
y_dims_
);
output
.
Resize
(
DDim
(
out_dims_
));
auto
*
x_data
=
x
.
mutable_data
<
float
>
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
x_data
[
i
]
=
x_vec
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
y_data
[
i
]
=
y_vec
[
i
];
}
param
->
X
=
&
x
;
param
->
Y
=
&
y
;
param
->
Out
=
&
output
;
param
->
axis
=
axis_
;
kernel
->
SetParam
(
*
param
);
kernel
->
Launch
();
auto
*
output_data
=
output
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
out_vec
[
i
]
=
output_data
[
i
];
}
}
void
run_backward
(
grad_param_t
*
param
,
grad_kernel_sub_t
*
kernel
,
const
std
::
vector
<
float
>&
x_vec
,
const
std
::
vector
<
float
>&
y_vec
,
const
std
::
vector
<
float
>&
out_grad_vec
,
float
*
x_grad_vec
,
float
*
y_grad_vec
)
{
Tensor
x
;
Tensor
x_grad
;
Tensor
y
;
Tensor
y_grad
;
Tensor
out_grad
;
x
.
Resize
(
x_dims_
);
x_grad
.
Resize
(
x_dims_
);
y
.
Resize
(
y_dims_
);
y_grad
.
Resize
(
y_dims_
);
out_grad
.
Resize
(
out_dims_
);
auto
*
x_data
=
x
.
mutable_data
<
float
>
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
();
auto
*
out_grad_data
=
out_grad
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
x_data
[
i
]
=
x_vec
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
y_data
[
i
]
=
y_vec
[
i
];
}
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
out_grad_data
[
i
]
=
out_grad_vec
[
i
];
}
param
->
X
=
&
x
;
param
->
XGrad
=
&
x_grad
;
param
->
Y
=
&
y
;
param
->
YGrad
=
&
y_grad
;
param
->
OutGrad
=
&
out_grad
;
param
->
axis
=
axis_
;
kernel
->
SetParam
(
*
param
);
kernel
->
Launch
();
auto
*
x_grad_data
=
x_grad
.
mutable_data
<
float
>
();
auto
*
y_grad_data
=
y_grad
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
x_grad_vec
[
i
]
=
x_grad_data
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
y_grad_vec
[
i
]
=
y_grad_data
[
i
];
}
}
void
check_grad
(
float
delta2
,
float
max_grad_delta2
)
{
std
::
vector
<
int64_t
>
out_shape
;
// infer shape
auto
x_dim
=
x_dims_
;
auto
y_dim
=
y_dims_
;
if
(
x_dim
==
y_dim
)
{
out_dims_
=
x_dim
;
}
else
{
int
max_dim
=
(
x_dim
.
size
()
>
y_dim
.
size
()
?
x_dim
.
size
()
:
y_dim
.
size
());
int
axis
=
param_
.
axis
;
axis
=
(
axis
==
-
1
?
std
::
abs
(
static_cast
<
int
>
(
x_dim
.
size
()
-
y_dim
.
size
()))
:
axis
);
std
::
vector
<
int64_t
>
x_dims_array
(
max_dim
);
std
::
vector
<
int64_t
>
y_dims_array
(
max_dim
);
std
::
vector
<
int64_t
>
out_dims_array
(
max_dim
);
if
(
x_dim
.
size
()
>
y_dim
.
size
())
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
y_dims_array
[
i
]
=
1
;
}
if
(
axis
+
y_dim
.
size
()
<
max_dim
)
{
for
(
int
i
=
axis
+
y_dim
.
size
();
i
<
max_dim
;
++
i
)
{
y_dims_array
[
i
]
=
1
;
}
}
x_dims_array
=
x_dim
.
Vectorize
();
for
(
int
i
=
0
;
i
<
y_dim
.
size
();
++
i
)
{
y_dims_array
[
i
+
axis
]
=
y_dim
[
i
];
}
}
else
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
x_dims_array
[
i
]
=
1
;
}
if
(
axis
+
x_dim
.
size
()
<
max_dim
)
{
for
(
int
i
=
axis
+
x_dim
.
size
();
i
<
max_dim
;
++
i
)
{
x_dims_array
[
i
]
=
1
;
}
}
y_dims_array
=
y_dim
.
Vectorize
();
for
(
int
i
=
0
;
i
<
x_dim
.
size
();
++
i
)
{
x_dims_array
[
i
+
axis
]
=
x_dim
[
i
];
}
}
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
if
(
x_dims_array
[
i
]
==
-
1
||
y_dims_array
[
i
]
==
-
1
)
{
out_dims_array
[
i
]
=
-
1
;
}
else
{
out_dims_array
[
i
]
=
std
::
max
(
x_dims_array
[
i
],
y_dims_array
[
i
]);
}
}
out_dims_
=
DDim
(
out_dims_array
);
}
// infer end
// forward
std
::
vector
<
float
>
x
(
x_dims_
.
production
());
std
::
vector
<
float
>
y
(
y_dims_
.
production
());
std
::
vector
<
float
>
out
(
out_dims_
.
production
());
fill_data_rand
(
x
.
data
(),
-
1.
f
,
1.
f
,
x_dims_
.
production
());
fill_data_rand
(
y
.
data
(),
-
1.
f
,
1.
f
,
y_dims_
.
production
());
this
->
run_forward
(
&
param_
,
&
kernel_
,
x
,
y
,
out
.
data
());
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
LOG
(
INFO
)
<<
"x_"
<<
i
<<
": "
<<
x
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims_
.
production
();
i
++
)
{
LOG
(
INFO
)
<<
"y_"
<<
i
<<
": "
<<
y
[
i
];
}
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
LOG
(
INFO
)
<<
"out_"
<<
i
<<
": "
<<
out
[
i
];
}
// backward
std
::
vector
<
float
>
out_grad
(
out_dims_
.
production
());
std
::
vector
<
float
>
x_grad
(
x_dims_
.
production
());
std
::
vector
<
float
>
y_grad
(
y_dims_
.
production
());
for
(
int
i
=
0
;
i
<
out_dims_
.
production
();
i
++
)
{
out_grad
[
i
]
=
1.0
;
}
this
->
run_backward
(
&
grad_param_
,
&
grad_kernel_
,
x
,
y
,
out_grad
,
x_grad
.
data
(),
y_grad
.
data
());
for
(
int
i
=
0
;
i
<
x_grad
.
size
();
i
++
)
{
LOG
(
INFO
)
<<
"x_grad_"
<<
i
<<
": "
<<
x_grad
[
i
];
}
for
(
int
i
=
0
;
i
<
y_grad
.
size
();
i
++
)
{
LOG
(
INFO
)
<<
"y_grad_"
<<
i
<<
": "
<<
y_grad
[
i
];
}
// get numeric gradient
std
::
vector
<
float
>
x_delta
(
x_dims_
.
production
());
std
::
vector
<
float
>
y_delta
(
y_dims_
.
production
());
std
::
vector
<
float
>
out_delta
(
out_dims_
.
production
());
Tensor
tensor_x
;
Tensor
tensor_y
;
tensor_x
.
Resize
(
x_dims_
);
tensor_y
.
Resize
(
y_dims_
);
grad_param_
.
X
=
&
tensor_x
;
grad_param_
.
Y
=
&
tensor_y
;
elementwise_common
(
grad_param_
,
out_grad
,
x_delta
,
y_delta
,
"sub"
);
float
max_grad_delta
=
0.0005
;
for
(
int
i
=
0
;
i
<
x_dims_
.
production
();
i
++
)
{
EXPECT_NEAR
(
x_grad
[
i
],
x_delta
[
i
],
max_grad_delta
);
EXPECT_NEAR
(
y_grad
[
i
],
y_delta
[
i
],
max_grad_delta
);
}
}
private:
DDim
x_dims_
;
DDim
y_dims_
;
DDim
out_dims_
;
int
axis_
;
kernel_sub_t
kernel_
;
grad_kernel_sub_t
grad_kernel_
;
param_t
param_
;
grad_param_t
grad_param_
;
};
void
TestNormalCase
(
const
std
::
vector
<
int64_t
>&
x_dims
,
const
std
::
vector
<
int64_t
>&
y_dims
,
int
axis
)
{
std
::
unique_ptr
<
ElementwiseAddGradTester
>
tester_add
(
new
ElementwiseAddGradTester
(
DDim
(
x_dims
),
DDim
(
y_dims
),
axis
));
std
::
unique_ptr
<
ElementwiseSubGradTester
>
tester_sub
(
new
ElementwiseSubGradTester
(
DDim
(
x_dims
),
DDim
(
y_dims
),
axis
));
tester_add
->
prepare_kernel
();
tester_sub
->
prepare_kernel
();
float
delta
=
0.001
;
float
max_grad_delta
=
0.005
;
tester_add
->
check_grad
(
delta
,
max_grad_delta
);
tester_sub
->
check_grad
(
delta
,
max_grad_delta
);
}
TEST
(
mul_grad_arm
,
compute
)
{
LOG
(
INFO
)
<<
"Test Elementwise grad"
;
DeviceInfo
::
Init
();
TestNormalCase
({
3
,
2
},
{
3
,
2
},
0
);
TestNormalCase
({
3
,
5
},
{
3
,
5
},
1
);
TestNormalCase
({
3
,
4
,
3
},
{
3
,
4
,
3
},
0
);
TestNormalCase
({
9
,
2
,
5
},
{
9
,
2
,
5
},
1
);
}
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
USE_LITE_KERNEL
(
elementwise_add_grad
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_add
,
kARM
,
kFloat
,
kNCHW
,
def
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录