Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
57f54d3b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
57f54d3b
编写于
3月 16, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
3月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move activation kernel (#40565)
上级
603f8425
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
919 addition
and
740 deletion
+919
-740
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+7
-16
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+27
-231
paddle/fluid/operators/activation_op.kps
paddle/fluid/operators/activation_op.kps
+9
-251
paddle/phi/kernels/activation_grad_kernel.h
paddle/phi/kernels/activation_grad_kernel.h
+50
-20
paddle/phi/kernels/activation_kernel.h
paddle/phi/kernels/activation_kernel.h
+24
-18
paddle/phi/kernels/cpu/activation_grad_kernel.cc
paddle/phi/kernels/cpu/activation_grad_kernel.cc
+114
-69
paddle/phi/kernels/cpu/activation_kernel.cc
paddle/phi/kernels/cpu/activation_kernel.cc
+76
-61
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+435
-0
paddle/phi/kernels/gpu/activation_grad_kernel.cu
paddle/phi/kernels/gpu/activation_grad_kernel.cu
+58
-22
paddle/phi/kernels/gpu/activation_kernel.cu
paddle/phi/kernels/gpu/activation_kernel.cu
+32
-19
paddle/phi/kernels/impl/activation_grad_impl.h
paddle/phi/kernels/impl/activation_grad_impl.h
+20
-0
paddle/phi/ops/compat/activation_sig.cc
paddle/phi/ops/compat/activation_sig.cc
+67
-33
未找到文件。
paddle/fluid/operators/activation_op.cc
浏览文件 @
57f54d3b
...
@@ -1485,6 +1485,13 @@ REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
...
@@ -1485,6 +1485,13 @@ REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
REGISTER_ACTIVATION_OP
(
brelu
,
BRelu
,
BReluFunctor
,
BReluGradFunctor
);
REGISTER_ACTIVATION_OP
(
brelu
,
BRelu
,
BReluFunctor
,
BReluGradFunctor
);
REGISTER_ACTIVATION_OP
(
thresholded_relu
,
ThresholdedRelu
,
REGISTER_ACTIVATION_OP
(
thresholded_relu
,
ThresholdedRelu
,
ThresholdedReluFunctor
,
ThresholdedReluGradFunctor
);
ThresholdedReluFunctor
,
ThresholdedReluGradFunctor
);
REGISTER_ACTIVATION_OP
(
hard_shrink
,
HardShrink
,
HardShrinkFunctor
,
HardShrinkGradFunctor
);
REGISTER_ACTIVATION_OP
(
softshrink
,
SoftShrink
,
SoftShrinkFunctor
,
SoftShrinkGradFunctor
);
REGISTER_ACTIVATION_OP
(
tanh_shrink
,
TanhShrink
,
TanhShrinkFunctor
,
TanhShrinkGradFunctor
);
REGISTER_ACTIVATION_OP
(
silu
,
Silu
,
SiluFunctor
,
SiluGradFunctor
);
/* ========================== sigmoid register =============================
/* ========================== sigmoid register =============================
*/
*/
...
@@ -1626,22 +1633,6 @@ REGISTER_OPERATOR(
...
@@ -1626,22 +1633,6 @@ REGISTER_OPERATOR(
ops
::
ActivationOpDoubleGrad
<
ops
::
ELUGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationOpDoubleGrad
<
ops
::
ELUGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationDoubleGradOpInplaceInferer
);
ops
::
ActivationDoubleGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
elu
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ELUFunctor
<
float
>>
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ELUFunctor
<
double
>>
);
REGISTER_OP_CPU_KERNEL
(
elu_grad
,
ops
::
ELUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ELUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
elu_grad_grad
,
ops
::
ELUDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
ELUGradGradFunctor
<
float
>>
,
ops
::
ELUDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
ELUGradGradFunctor
<
double
>>
,
ops
::
ELUDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
ELUGradGradFunctor
<
plat
::
float16
>>
);
/* ========================================================================== */
/* ========================================================================== */
/* ======================== logit register ============================
/* ======================== logit register ============================
...
...
paddle/fluid/operators/activation_op.h
浏览文件 @
57f54d3b
...
@@ -279,6 +279,15 @@ USE_PHI_FUNCTOR(BRelu)
...
@@ -279,6 +279,15 @@ USE_PHI_FUNCTOR(BRelu)
USE_PHI_FUNCTOR
(
ThresholdedRelu
)
USE_PHI_FUNCTOR
(
ThresholdedRelu
)
USE_PHI_FUNCTOR
(
LeakyRelu
)
USE_PHI_FUNCTOR
(
LeakyRelu
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
LeakyRelu
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
LeakyRelu
)
USE_PHI_FUNCTOR
(
HardShrink
)
USE_PHI_FUNCTOR
(
SoftShrink
)
USE_PHI_FUNCTOR
(
TanhShrink
)
USE_PHI_FUNCTOR
(
Silu
)
USE_PHI_FUNCTOR
(
ELU
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
ELU
)
template
<
typename
T
>
using
ELUGradNegativeAlphaFunctor
=
phi
::
funcs
::
ELUGradNegativeAlphaFunctor
<
T
>
;
template
<
typename
T
>
template
<
typename
T
>
struct
SigmoidGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
SigmoidGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -392,31 +401,6 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
...
@@ -392,31 +401,6 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
// silu(x) = x / (1 + exp(-x))
template
<
typename
T
>
struct
SiluFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
temp
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
(
-
x
).
exp
());
out
.
device
(
d
)
=
x
*
temp
;
}
};
// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x}))
template
<
typename
T
>
struct
SiluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
static_cast
<
T
>
(
1
)
+
(
-
x
).
exp
();
// 1+e^(-x)
auto
temp2
=
x
*
(
-
x
).
exp
();
// x*e^(-x)
dx
.
device
(
d
)
=
dout
*
((
static_cast
<
T
>
(
1
)
/
temp1
)
*
(
static_cast
<
T
>
(
1
)
+
(
temp2
/
temp1
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
...
@@ -512,99 +496,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
...
@@ -512,99 +496,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
template
<
typename
T
>
template
<
typename
T
>
using
ReluCUDAFunctor
=
phi
::
funcs
::
ReluCUDAFunctor
<
T
>
;
using
ReluCUDAFunctor
=
phi
::
funcs
::
ReluCUDAFunctor
<
T
>
;
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template
<
typename
T
>
struct
TanhShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
-
x
.
tanh
();
}
};
template
<
typename
T
>
struct
TanhShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
(
x
.
tanh
()
*
x
.
tanh
());
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template
<
typename
T
>
struct
HardShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
temp1
=
x
<
static_cast
<
T
>
(
threshold
*
-
1.
f
);
auto
temp2
=
x
>
static_cast
<
T
>
(
threshold
);
out
.
device
(
d
)
=
x
*
(
temp1
||
temp2
).
template
cast
<
T
>();
}
};
template
<
typename
T
>
struct
HardShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
x
<
static_cast
<
T
>
(
threshold
*
-
1.
f
);
auto
temp2
=
x
>
static_cast
<
T
>
(
threshold
);
dx
.
device
(
d
)
=
dout
*
(
temp1
||
temp2
).
template
cast
<
T
>();
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
// otherwise
template
<
typename
T
>
struct
SoftShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
lambdaT
=
static_cast
<
T
>
(
lambda
);
auto
temp1
=
(
x
>
lambdaT
).
template
cast
<
T
>();
auto
temp2
=
(
x
<
-
lambdaT
).
template
cast
<
T
>();
out
.
device
(
d
)
=
temp1
*
(
x
-
lambdaT
)
+
temp2
*
(
x
+
lambdaT
);
}
};
template
<
typename
T
>
struct
SoftShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
lambdaT
=
static_cast
<
T
>
(
lambda
);
auto
temp1
=
(
x
>
lambdaT
).
template
cast
<
T
>();
auto
temp2
=
(
x
<
-
lambdaT
).
template
cast
<
T
>();
dx
.
device
(
d
)
=
dout
*
(
temp1
+
temp2
).
template
cast
<
T
>();
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// sqrt(x) = x^(1/2)
// sqrt(x) = x^(1/2)
template
<
typename
T
>
template
<
typename
T
>
struct
SqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
SqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -1036,59 +927,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
...
@@ -1036,59 +927,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template
<
typename
T
>
struct
ELUFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
x
<
static_cast
<
T
>
(
0
))
.
select
(
static_cast
<
T
>
(
alpha
)
*
(
x
.
exp
()
-
static_cast
<
T
>
(
1
)),
x
);
}
};
template
<
typename
T
>
struct
ELUGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
dx
.
device
(
d
)
=
(
out
>
static_cast
<
T
>
(
0
))
.
select
(
dout
,
dout
*
(
out
+
static_cast
<
T
>
(
alpha
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
ELUGradNegativeAlphaFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
dx
.
device
(
d
)
=
(
x
>
static_cast
<
T
>
(
0
))
.
select
(
dout
,
dout
*
static_cast
<
T
>
(
alpha
)
*
x
.
exp
());
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ELUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ELUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -1354,44 +1192,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
...
@@ -1354,44 +1192,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
ELUGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
framework
::
Tensor
*
X
,
const
framework
::
Tensor
*
ddX
,
framework
::
Tensor
*
ddOut
,
const
framework
::
Tensor
*
dOut
,
framework
::
Tensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"ELUGradGrad"
));
auto
x
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"ELUGradGrad"
));
if
(
dX
)
{
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"ELUGradGrad"
));
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"ELUGradGrad"
));
dx
.
device
(
*
d
)
=
ddx
*
dout
*
static_cast
<
T
>
(
alpha
)
*
x
.
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
}
if
(
ddOut
)
{
auto
ddout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"ELUGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
((
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
+
static_cast
<
T
>
(
alpha
)
*
x
.
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>())
.
template
cast
<
T
>();
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
CELUGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CELUGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
float
alpha
;
...
@@ -2152,9 +1952,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
...
@@ -2152,9 +1952,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
}
// namespace paddle
}
// namespace paddle
#define FOR_EACH_ACTIVATION_OP(__macro) \
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(silu, Silu, SiluFunctor, SiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
...
@@ -2167,8 +1965,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
...
@@ -2167,8 +1965,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
HardSigmoidGradFunctor); \
HardSigmoidGradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
...
...
paddle/fluid/operators/activation_op.kps
浏览文件 @
57f54d3b
...
@@ -44,35 +44,6 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
...
@@ -44,35 +44,6 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// silu(x) = x / (1 + exp(-x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x / (one + exp(-x)));
}
};
template <typename T>
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType temp = one / (one + exp(-x));
return static_cast<T>(dout * (temp * (one + x * (one - temp))));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename details::MPTypeTrait<T>::Type;
...
@@ -110,43 +81,6 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
...
@@ -110,43 +81,6 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// softshrink(x) = x - lambda, if x > lambda;
// x + lambda, if x < -lambda;
// 0, otherwise.
__device__ __forceinline__ T operator()(const T x) const {
T l = static_cast<T>(lambda);
T temp1 = static_cast<T>(x > l);
T temp2 = static_cast<T>(x < -l);
return temp1 * (x - l) + temp2 * (x + l);
}
};
template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// dx = dout, if x > lambda or x < -lambda else 0
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T l = static_cast<T>(lambda);
return (x >= -l && x <= l) ? zero : dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename details::MPTypeTrait<T>::Type;
...
@@ -615,66 +549,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
...
@@ -615,66 +549,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// tanhshrink(x) = x - tanh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x - tanh(x));
}
};
template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * tanh(x)^2
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * tanh(x) * tanh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
__device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : x;
}
};
template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (x > -threshold && x < threshold) ? 0 : dout
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T zero = static_cast<T>(0.0f);
...
@@ -863,110 +737,6 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
...
@@ -863,110 +737,6 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// elu(x) = x, if x > 0
// elu(x) = alpha * (e^x - 1), if x <= 0
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x) - one);
CT res = x > zero ? x : temp;
return static_cast<T>(res);
}
};
template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
__device__ __forceinline__ T operator()(T arg_dout, T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType a = static_cast<MPType>(alpha);
MPType out_pos = static_cast<MPType>(out > zero);
MPType out_neg = static_cast<MPType>(out <= zero);
return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
__device__ __forceinline__ T operator()(const T arg_dout, const T arg_out,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType x_pos = static_cast<MPType>(x > zero);
MPType x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename T>
class ELUGradCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = ctx.Input<framework::Tensor>("Out");
auto* x = ctx.Input<framework::Tensor>("X");
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const float alpha = ctx.Attr<float>("alpha");
auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<const framework::Tensor*> ins = {d_out, out};
std::vector<framework::Tensor*> outs = {d_x};
if (alpha > 0) {
CudaELUGradFunctor<T> functor;
functor.alpha = alpha;
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
} else {
CudaELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
ins.push_back(x);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
}
};
template <typename T>
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
using CT = typename details::MPTypeTrait<T>::Type;
...
@@ -1099,6 +869,15 @@ USE_PHI_FUNCTOR(CudaTanh)
...
@@ -1099,6 +869,15 @@ USE_PHI_FUNCTOR(CudaTanh)
USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu)
USE_PHI_FUNCTOR(CudaHardShrink)
USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)
template <typename T>
using CudaELUGradNegativeAlphaFunctor =
phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
} // namespace operators
} // namespace operators
} // namespace paddle
} // namespace paddle
...
@@ -1158,26 +937,6 @@ namespace plat = paddle::platform;
...
@@ -1158,26 +937,6 @@ namespace plat = paddle::platform;
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::bfloat16>>);
ops::grad_functor<plat::bfloat16>>);
/* ======================== elu register ============================ */
REGISTER_OP_CUDA_KERNEL(
elu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaELUFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
elu_grad, ops::ELUGradCudaKernel<plat::CUDADeviceContext, float>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, double>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<float>>,
ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<double>>,
ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================================================================== */
/* ======================== celu register ============================ */
/* ======================== celu register ============================ */
...
@@ -1359,7 +1118,6 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -1359,7 +1118,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* ========================================================================== */
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \
__macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \
CudaLogSigmoidGradFunctor); \
CudaLogSigmoidGradFunctor); \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
...
...
paddle/phi/kernels/activation_grad_kernel.h
浏览文件 @
57f54d3b
...
@@ -26,6 +26,23 @@ namespace phi {
...
@@ -26,6 +26,23 @@ namespace phi {
const DenseTensor& dout, \
const DenseTensor& dout, \
DenseTensor* dx);
DenseTensor* dx);
#define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(name, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx);
#define DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(name, attr1, attr2) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
float attr1, \
float attr2, \
DenseTensor* dx);
#define DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(name) \
#define DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(name) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -33,6 +50,14 @@ namespace phi {
...
@@ -33,6 +50,14 @@ namespace phi {
const DenseTensor& dout, \
const DenseTensor& dout, \
DenseTensor* dx);
DenseTensor* dx);
#define DECLARE_ACTIVATION_GRAD_KERNEL_WITH_ONE_ATTRS_DepOut(name, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ReluDoubleGradKernel
(
const
Context
&
dev_ctx
,
void
ReluDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
out
,
...
@@ -59,34 +84,29 @@ void TanhTripleGradKernel(const Context& dev_ctx,
...
@@ -59,34 +84,29 @@ void TanhTripleGradKernel(const Context& dev_ctx,
DenseTensor
*
d_ddx
);
DenseTensor
*
d_ddx
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
BRelu
GradKernel
(
const
Context
&
dev_ctx
,
void
LeakyReluDouble
GradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
float
t_min
,
float
alpha
,
float
t_max
,
DenseTensor
*
ddout
);
DenseTensor
*
dx
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
LeakyRe
luGradKernel
(
const
Context
&
dev_ctx
,
void
E
luGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
dout
,
float
alpha
,
float
alpha
,
DenseTensor
*
dx
);
DenseTensor
*
dx
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
LeakyRe
luDoubleGradKernel
(
const
Context
&
dev_ctx
,
void
E
luDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
const
DenseTensor
&
ddx
,
float
alpha
,
float
alpha
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
);
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
ThresholdedReluGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
float
threshold
,
DenseTensor
*
dx
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Cos
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Cos
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Tan
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Tan
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Acos
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Acos
);
...
@@ -98,7 +118,17 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh);
...
@@ -98,7 +118,17 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Asinh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Asinh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Acosh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Acosh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
TanhShrink
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Silu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
LeakyRelu
,
alpha
)
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
ThresholdedRelu
,
threshold
)
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
SoftShrink
,
lambda
)
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
HardShrink
,
threshold
)
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX
(
BRelu
,
t_min
,
t_max
)
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/activation_kernel.h
浏览文件 @
57f54d3b
...
@@ -24,6 +24,21 @@ namespace phi {
...
@@ -24,6 +24,21 @@ namespace phi {
void name##Kernel( \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
#define DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(name, attr) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr, \
DenseTensor* out);
#define DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(name, attr1, attr2) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr1, \
float attr2, \
DenseTensor* out);
DECLARE_ACTIVATION_KERNEL
(
Cos
)
DECLARE_ACTIVATION_KERNEL
(
Cos
)
DECLARE_ACTIVATION_KERNEL
(
Tan
)
DECLARE_ACTIVATION_KERNEL
(
Tan
)
DECLARE_ACTIVATION_KERNEL
(
Acos
)
DECLARE_ACTIVATION_KERNEL
(
Acos
)
...
@@ -37,24 +52,15 @@ DECLARE_ACTIVATION_KERNEL(Acosh)
...
@@ -37,24 +52,15 @@ DECLARE_ACTIVATION_KERNEL(Acosh)
DECLARE_ACTIVATION_KERNEL
(
Atanh
)
DECLARE_ACTIVATION_KERNEL
(
Atanh
)
DECLARE_ACTIVATION_KERNEL
(
Relu
)
DECLARE_ACTIVATION_KERNEL
(
Relu
)
DECLARE_ACTIVATION_KERNEL
(
Tanh
)
DECLARE_ACTIVATION_KERNEL
(
Tanh
)
DECLARE_ACTIVATION_KERNEL
(
TanhShrink
)
DECLARE_ACTIVATION_KERNEL
(
Silu
)
template
<
typename
T
,
typename
Context
>
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
alpha
)
void
BReluKernel
(
const
Context
&
dev_ctx
,
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
threshold
)
const
DenseTensor
&
x
,
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
lambda
)
float
t_min
,
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
threshold
)
float
t_max
,
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Elu
,
alpha
)
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
t_min
,
t_max
)
void
LeakyReluKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
alpha
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
ThresholdedReluKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
threshold
,
DenseTensor
*
out
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/cpu/activation_grad_kernel.cc
浏览文件 @
57f54d3b
...
@@ -21,18 +21,18 @@ limitations under the License. */
...
@@ -21,18 +21,18 @@ limitations under the License. */
namespace
phi
{
namespace
phi
{
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_D
ep
X(name, functor_class) \
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_D
EP
X(name, functor_class) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& x, \
const DenseTensor& dout, \
const DenseTensor& dout, \
DenseTensor* dx) { \
DenseTensor* dx) { \
func
tor_class<T> functor;
\
func
s::functor_class<T> functor;
\
ActivationGradImpl<T, Context, func
tor_class<T>>(
\
ActivationGradImpl<T, Context, func
s::functor_class<T>>(
\
dev_ctx, &x, nullptr, &dout, dx, functor); \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
}
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
epX(
\
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
EPX(
\
name, functor_class, attr) \
name, functor_class, attr) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -40,14 +40,14 @@ namespace phi {
...
@@ -40,14 +40,14 @@ namespace phi {
const DenseTensor& dout, \
const DenseTensor& dout, \
float attr, \
float attr, \
DenseTensor* dx) { \
DenseTensor* dx) { \
func
tor_class<T> functor;
\
func
s::functor_class<T> functor;
\
auto attrs = functor.GetAttrs(); \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \
*(attrs[0].second) = attr; \
ActivationGradImpl<T, Context, func
tor_class<T>>(
\
ActivationGradImpl<T, Context, func
s::functor_class<T>>(
\
dev_ctx, &x, nullptr, &dout, dx, functor); \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
}
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_D
epX(
\
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_D
EPX(
\
name, functor_class, attr1, attr2) \
name, functor_class, attr1, attr2) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -56,26 +56,26 @@ namespace phi {
...
@@ -56,26 +56,26 @@ namespace phi {
float attr1, \
float attr1, \
float attr2, \
float attr2, \
DenseTensor* dx) { \
DenseTensor* dx) { \
func
tor_class<T> functor;
\
func
s::functor_class<T> functor;
\
auto attrs = functor.GetAttrs(); \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \
*(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \
*(attrs[1].second) = attr2; \
ActivationGradImpl<T, Context, func
tor_class<T>>(
\
ActivationGradImpl<T, Context, func
s::functor_class<T>>(
\
dev_ctx, &x, nullptr, &dout, dx, functor); \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
}
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_D
epOut
(name, functor_class) \
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_D
EPOUT
(name, functor_class) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& out, \
const DenseTensor& dout, \
const DenseTensor& dout, \
DenseTensor* dx) { \
DenseTensor* dx) { \
func
tor_class<T> functor;
\
func
s::functor_class<T> functor;
\
ActivationGradImpl<T, Context, func
tor_class<T>>(
\
ActivationGradImpl<T, Context, func
s::functor_class<T>>(
\
dev_ctx, nullptr, &out, &dout, dx, functor); \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
}
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
epOut(
\
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
EPOUT(
\
name, functor_class, attr) \
name, functor_class, attr) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -83,39 +83,78 @@ namespace phi {
...
@@ -83,39 +83,78 @@ namespace phi {
const DenseTensor& dout, \
const DenseTensor& dout, \
float attr, \
float attr, \
DenseTensor* dx) { \
DenseTensor* dx) { \
func
tor_class<T> functor;
\
func
s::functor_class<T> functor;
\
auto attrs = functor.GetAttrs(); \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \
*(attrs[0].second) = attr; \
ActivationGradImpl<T, Context, func
tor_class<T>>(
\
ActivationGradImpl<T, Context, func
s::functor_class<T>>(
\
dev_ctx, nullptr, &out, &dout, dx, functor); \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
}
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Cos
,
funcs
::
CosGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
,
CosGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Tan
,
funcs
::
TanGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
,
TanGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Acos
,
funcs
::
AcosGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
,
AcosGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Sin
,
funcs
::
SinGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Sin
,
SinGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Asin
,
funcs
::
AsinGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Asin
,
AsinGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Atan
,
funcs
::
AtanGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Atan
,
AtanGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Sinh
,
funcs
::
SinhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Sinh
,
SinhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Cosh
,
funcs
::
CoshGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cosh
,
CoshGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Asinh
,
funcs
::
AsinhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Asinh
,
AsinhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Acosh
,
funcs
::
AcoshGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acosh
,
AcoshGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
,
funcs
::
AtanhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Atanh
,
AtanhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
TanhShrink
,
TanhShrinkGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
,
funcs
::
ReluGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Silu
,
SiluGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
,
funcs
::
TanhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Relu
,
ReluGradFunctor
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
LeakyRelu
,
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
,
TanhGradFunctor
);
funcs
::
LeakyReluGradFunctor
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
LeakyRelu
,
LeakyReluGradFunctor
,
alpha
);
alpha
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
ThresholdedRelu
,
ThresholdedRelu
,
funcs
::
ThresholdedReluGradFunctor
,
threshold
);
ThresholdedReluGradFunctor
,
threshold
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX
(
BRelu
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
SoftShrink
,
funcs
::
BReluGradFunctor
,
SoftShrinkGradFunctor
,
lambda
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
HardShrinkGradFunctor
,
threshold
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
BReluGradFunctor
,
t_min
,
t_min
,
t_max
);
t_max
);
template
<
typename
T
,
typename
Context
>
void
EluGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
float
alpha
,
DenseTensor
*
dx
)
{
dev_ctx
.
template
Alloc
<
T
>(
dx
);
auto
x_flatten
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
&
x
,
"Input"
,
"X"
,
"elu_grad"
));
auto
out_flatten
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
&
out
,
"Input"
,
"Out"
,
"elu_grad"
));
auto
dout_flatten
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
&
dout
,
"Input"
,
"dOut"
,
"elu_grad"
));
auto
dx_flatten
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dx
,
"Output"
,
"dX"
,
"elu_grad"
));
auto
*
place
=
dev_ctx
.
eigen_device
();
if
(
alpha
>
0
)
{
funcs
::
ELUGradFunctor
<
T
>
functor
;
functor
.
alpha
=
alpha
;
functor
(
*
place
,
x_flatten
,
out_flatten
,
dout_flatten
,
dx_flatten
);
}
else
{
funcs
::
ELUGradNegativeAlphaFunctor
<
T
>
functor
;
functor
.
alpha
=
alpha
;
functor
(
*
place
,
x_flatten
,
out_flatten
,
dout_flatten
,
dx_flatten
);
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
...
@@ -144,6 +183,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel)
...
@@ -144,6 +183,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
leaky_relu_grad
,
LeakyReluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
leaky_relu_grad
,
LeakyReluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
thresholded_relu_grad
,
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
thresholded_relu_grad
,
ThresholdedReluGradKernel
)
ThresholdedReluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
soft_shrink_grad
,
SoftShrinkGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
hard_shrink_grad
,
HardShrinkGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
tanh_shrink_grad
,
TanhShrinkGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
elu_grad
,
EluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
silu_grad
,
SiluGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
relu_double_grad
,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
relu_double_grad
,
ReluDoubleGradKernel
)
ReluDoubleGradKernel
)
...
@@ -151,6 +195,7 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
...
@@ -151,6 +195,7 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
TanhDoubleGradKernel
)
TanhDoubleGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
leaky_relu_double_grad
,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
leaky_relu_double_grad
,
LeakyReluDoubleGradKernel
)
LeakyReluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
elu_double_grad
,
EluDoubleGradKernel
)
PD_REGISTER_KERNEL
(
tanh_triple_grad
,
PD_REGISTER_KERNEL
(
tanh_triple_grad
,
CPU
,
CPU
,
...
...
paddle/phi/kernels/cpu/activation_kernel.cc
浏览文件 @
57f54d3b
...
@@ -23,8 +23,9 @@ namespace phi {
...
@@ -23,8 +23,9 @@ namespace phi {
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##Kernel( \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class functor; \
funcs::functor_class<T> functor; \
ActivationImpl<T, Context, functor_class>(dev_ctx, x, out, functor); \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
}
#define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
#define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
...
@@ -33,10 +34,11 @@ namespace phi {
...
@@ -33,10 +34,11 @@ namespace phi {
const DenseTensor& x, \
const DenseTensor& x, \
float attr, \
float attr, \
DenseTensor* out) { \
DenseTensor* out) { \
func
tor_class<T> functor;
\
func
s::functor_class<T> functor;
\
auto attrs = functor.GetAttrs(); \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \
*(attrs[0].second) = attr; \
ActivationImpl<T, Context, functor_class<T>>(dev_ctx, x, out, functor); \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
}
#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \
#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \
...
@@ -47,50 +49,63 @@ namespace phi {
...
@@ -47,50 +49,63 @@ namespace phi {
float attr1, \
float attr1, \
float attr2, \
float attr2, \
DenseTensor* out) { \
DenseTensor* out) { \
func
tor_class<T> functor;
\
func
s::functor_class<T> functor;
\
auto attrs = functor.GetAttrs(); \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \
*(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \
*(attrs[1].second) = attr2; \
ActivationImpl<T, Context, functor_class<T>>(dev_ctx, x, out, functor); \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
}
DEFINE_CPU_ACTIVATION_KERNEL
(
Sin
,
funcs
::
SinFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Sin
,
SinFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Cos
,
funcs
::
CosFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Cos
,
CosFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Tan
,
funcs
::
TanFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Tan
,
TanFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Asin
,
funcs
::
AsinFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Asin
,
AsinFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Atan
,
funcs
::
AtanFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Atan
,
AtanFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Acos
,
funcs
::
AcosFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Acos
,
AcosFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Sinh
,
funcs
::
SinhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Sinh
,
SinhFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Cosh
,
funcs
::
CoshFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Cosh
,
CoshFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Asinh
,
funcs
::
AsinhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Asinh
,
AsinhFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Acosh
,
funcs
::
AcoshFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Acosh
,
AcoshFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Atanh
,
funcs
::
AtanhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Atanh
,
AtanhFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Relu
,
funcs
::
ReluCPUFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Relu
,
ReluCPUFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Tanh
,
funcs
::
TanhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Tanh
,
TanhFunctor
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
funcs
::
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACTIVATION_KERNEL
(
TanhShrink
,
TanhShrinkFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Silu
,
SiluFunctor
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
funcs
::
ThresholdedReluFunctor
,
ThresholdedReluFunctor
,
threshold
)
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
funcs
::
BReluFunctor
,
t_min
,
t_max
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
HardShrinkFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
SoftShrinkFunctor
,
lambda
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
ELUFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
BReluFunctor
,
t_min
,
t_max
)
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
relu
,
CPU
,
ALL_LAYOUT
,
phi
::
ReluKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
relu
,
CPU
,
ALL_LAYOUT
,
phi
::
ReluKernel
,
float
,
double
)
{}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func
##Kernel
, float, double) {}
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}
PD_REGISTER_ACTIVATION_KERNEL
(
sin
,
Sin
)
PD_REGISTER_ACTIVATION_KERNEL
(
sin
,
SinKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
cos
,
Cos
)
PD_REGISTER_ACTIVATION_KERNEL
(
cos
,
CosKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
tan
,
Tan
)
PD_REGISTER_ACTIVATION_KERNEL
(
tan
,
TanKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
acos
,
Acos
)
PD_REGISTER_ACTIVATION_KERNEL
(
acos
,
AcosKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
asin
,
Asin
)
PD_REGISTER_ACTIVATION_KERNEL
(
asin
,
AsinKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
atan
,
Atan
)
PD_REGISTER_ACTIVATION_KERNEL
(
atan
,
AtanKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
sinh
,
Sinh
)
PD_REGISTER_ACTIVATION_KERNEL
(
sinh
,
SinhKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
cosh
,
Cosh
)
PD_REGISTER_ACTIVATION_KERNEL
(
cosh
,
CoshKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
asinh
,
Asinh
)
PD_REGISTER_ACTIVATION_KERNEL
(
asinh
,
AsinhKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
acosh
,
Acosh
)
PD_REGISTER_ACTIVATION_KERNEL
(
acosh
,
AcoshKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
atanh
,
Atanh
)
PD_REGISTER_ACTIVATION_KERNEL
(
atanh
,
AtanhKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
tanh
,
Tanh
)
PD_REGISTER_ACTIVATION_KERNEL
(
tanh
,
TanhKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
hard_shrink
,
HardShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
soft_shrink
,
SoftShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
tanh_shrink
,
TanhShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
elu
,
EluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
silu
,
SiluKernel
)
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
57f54d3b
...
@@ -29,11 +29,13 @@
...
@@ -29,11 +29,13 @@
#include <type_traits>
#include <type_traits>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#ifdef PADDLE_WITH_XPU_KP
#ifdef PADDLE_WITH_XPU_KP
#define __forceinline__ __inline__
#define __forceinline__ __inline__
...
@@ -780,6 +782,236 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
...
@@ -780,6 +782,236 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template
<
typename
T
>
struct
TanhShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
-
x
.
tanh
();
}
};
template
<
typename
T
>
struct
TanhShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
(
x
.
tanh
()
*
x
.
tanh
());
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template
<
typename
T
>
struct
HardShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
temp1
=
x
<
static_cast
<
T
>
(
threshold
*
-
1.
f
);
auto
temp2
=
x
>
static_cast
<
T
>
(
threshold
);
out
.
device
(
d
)
=
x
*
(
temp1
||
temp2
).
template
cast
<
T
>();
}
};
template
<
typename
T
>
struct
HardShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
x
<
static_cast
<
T
>
(
threshold
*
-
1.
f
);
auto
temp2
=
x
>
static_cast
<
T
>
(
threshold
);
dx
.
device
(
d
)
=
dout
*
(
temp1
||
temp2
).
template
cast
<
T
>();
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
// otherwise
template
<
typename
T
>
struct
SoftShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
lambdaT
=
static_cast
<
T
>
(
lambda
);
auto
temp1
=
(
x
>
lambdaT
).
template
cast
<
T
>();
auto
temp2
=
(
x
<
-
lambdaT
).
template
cast
<
T
>();
out
.
device
(
d
)
=
temp1
*
(
x
-
lambdaT
)
+
temp2
*
(
x
+
lambdaT
);
}
};
template
<
typename
T
>
struct
SoftShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
lambdaT
=
static_cast
<
T
>
(
lambda
);
auto
temp1
=
(
x
>
lambdaT
).
template
cast
<
T
>();
auto
temp2
=
(
x
<
-
lambdaT
).
template
cast
<
T
>();
dx
.
device
(
d
)
=
dout
*
(
temp1
+
temp2
).
template
cast
<
T
>();
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
ELUFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
x
<
static_cast
<
T
>
(
0
))
.
select
(
static_cast
<
T
>
(
alpha
)
*
(
x
.
exp
()
-
static_cast
<
T
>
(
1
)),
x
);
}
};
template
<
typename
T
>
struct
ELUGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
dx
.
device
(
d
)
=
(
out
>
static_cast
<
T
>
(
0
))
.
select
(
dout
,
dout
*
(
out
+
static_cast
<
T
>
(
alpha
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
ELUGradNegativeAlphaFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
dx
.
device
(
d
)
=
(
x
>
static_cast
<
T
>
(
0
))
.
select
(
dout
,
dout
*
static_cast
<
T
>
(
alpha
)
*
x
.
exp
());
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
ELUGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
DenseTensor
*
X
,
const
DenseTensor
*
ddX
,
DenseTensor
*
ddOut
,
const
DenseTensor
*
dOut
,
DenseTensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"ELUGradGrad"
));
auto
x
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"ELUGradGrad"
));
if
(
dX
)
{
auto
dx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"ELUGradGrad"
));
auto
dout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"ELUGradGrad"
));
dx
.
device
(
*
d
)
=
ddx
*
dout
*
static_cast
<
T
>
(
alpha
)
*
x
.
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
}
if
(
ddOut
)
{
auto
ddout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"ELUGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
((
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
+
static_cast
<
T
>
(
alpha
)
*
x
.
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>())
.
template
cast
<
T
>();
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// silu(x) = x / (1 + exp(-x))
template
<
typename
T
>
struct
SiluFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
temp
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
(
-
x
).
exp
());
out
.
device
(
d
)
=
x
*
temp
;
}
};
// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x}))
template
<
typename
T
>
struct
SiluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
static_cast
<
T
>
(
1
)
+
(
-
x
).
exp
();
// 1+e^(-x)
auto
temp2
=
x
*
(
-
x
).
exp
();
// x*e^(-x)
dx
.
device
(
d
)
=
dout
*
((
static_cast
<
T
>
(
1
)
/
temp1
)
*
(
static_cast
<
T
>
(
1
)
+
(
temp2
/
temp1
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template
<
typename
T
>
template
<
typename
T
>
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -1218,6 +1450,209 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
...
@@ -1218,6 +1450,209 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
CudaSoftShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
// softshrink(x) = x - lambda, if x > lambda;
// x + lambda, if x < -lambda;
// 0, otherwise.
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
T
l
=
static_cast
<
T
>
(
lambda
);
T
temp1
=
static_cast
<
T
>
(
x
>
l
);
T
temp2
=
static_cast
<
T
>
(
x
<
-
l
);
return
temp1
*
(
x
-
l
)
+
temp2
*
(
x
+
l
);
}
};
template
<
typename
T
>
struct
CudaSoftShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
// dx = dout, if x > lambda or x < -lambda else 0
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
x
)
const
{
T
l
=
static_cast
<
T
>
(
lambda
);
return
(
x
>=
-
l
&&
x
<=
l
)
?
zero
:
dout
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaTanhShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// tanhshrink(x) = x - tanh(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
x
-
tanh
(
x
));
}
};
template
<
typename
T
>
struct
CudaTanhShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// dx = dout * tanh(x)^2
__device__
__forceinline__
T
operator
()(
const
T
arg_dout
,
const
T
arg_x
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
arg_dout
);
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
dout
*
tanh
(
x
)
*
tanh
(
x
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaHardShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
T
t
=
static_cast
<
T
>
(
threshold
);
return
(
x
>
-
t
&&
x
<
t
)
?
zero
:
x
;
}
};
template
<
typename
T
>
struct
CudaHardShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
// dx = (x > -threshold && x < threshold) ? 0 : dout
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
x
)
const
{
T
t
=
static_cast
<
T
>
(
threshold
);
return
(
x
>
-
t
&&
x
<
t
)
?
zero
:
dout
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaELUFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
CT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
CT
zero
=
static_cast
<
CT
>
(
0.0
f
);
CT
one
=
static_cast
<
CT
>
(
1.0
f
);
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
// elu(x) = x, if x > 0
// elu(x) = alpha * (e^x - 1), if x <= 0
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
CT
x
=
static_cast
<
CT
>
(
arg_x
);
CT
temp
=
static_cast
<
CT
>
(
alpha
)
*
(
exp
(
x
)
-
one
);
CT
res
=
x
>
zero
?
x
:
temp
;
return
static_cast
<
T
>
(
res
);
}
};
template
<
typename
T
>
struct
CudaELUGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
zero
=
static_cast
<
MPType
>
(
0.0
f
);
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
__device__
__forceinline__
T
operator
()(
T
arg_dout
,
T
arg_out
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
arg_dout
);
MPType
out
=
static_cast
<
MPType
>
(
arg_out
);
MPType
a
=
static_cast
<
MPType
>
(
alpha
);
MPType
out_pos
=
static_cast
<
MPType
>
(
out
>
zero
);
MPType
out_neg
=
static_cast
<
MPType
>
(
out
<=
zero
);
return
static_cast
<
T
>
(
dout
*
(
out_pos
+
out_neg
*
(
out
+
a
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
template
<
typename
T
>
struct
CudaELUGradNegativeAlphaFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
zero
=
static_cast
<
MPType
>
(
0.0
f
);
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
__device__
__forceinline__
T
operator
()(
const
T
arg_dout
,
const
T
arg_out
,
const
T
arg_x
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
arg_dout
);
MPType
out
=
static_cast
<
MPType
>
(
arg_out
);
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
a
=
static_cast
<
MPType
>
(
alpha
);
MPType
x_pos
=
static_cast
<
MPType
>
(
x
>
zero
);
MPType
x_neg
=
static_cast
<
MPType
>
(
x
<=
zero
);
return
static_cast
<
T
>
(
dout
*
(
x_pos
+
x_neg
*
(
out
+
a
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaSiluFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// silu(x) = x / (1 + exp(-x))
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
x
/
(
one
+
exp
(
-
x
)));
}
};
template
<
typename
T
>
struct
CudaSiluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
__device__
__forceinline__
T
operator
()(
const
T
arg_dout
,
const
T
arg_x
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
arg_dout
);
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
temp
=
one
/
(
one
+
exp
(
-
x
));
return
static_cast
<
T
>
(
dout
*
(
temp
*
(
one
+
x
*
(
one
-
temp
))));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
#endif
#endif
}
// namespace funcs
}
// namespace funcs
...
...
paddle/phi/kernels/gpu/activation_grad_kernel.cu
浏览文件 @
57f54d3b
...
@@ -73,7 +73,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
...
@@ -73,7 +73,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
}
}
}
}
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_D
ep
X(name, functor_class) \
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_D
EP
X(name, functor_class) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& x, \
...
@@ -84,7 +84,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
...
@@ -84,7 +84,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, &x, nullptr, &dout, dx, functor); \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
}
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
ep
X( \
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
EP
X( \
name, functor_class, attr) \
name, functor_class, attr) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -99,7 +99,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
...
@@ -99,7 +99,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, &x, nullptr, &dout, dx, functor); \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
}
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_D
ep
X( \
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_D
EP
X( \
name, functor_class, attr1, attr2) \
name, functor_class, attr1, attr2) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -116,7 +116,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
...
@@ -116,7 +116,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, &x, nullptr, &dout, dx, functor); \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
}
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_D
epOut
(name, functor_class) \
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_D
EPOUT
(name, functor_class) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& out, \
...
@@ -127,7 +127,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
...
@@ -127,7 +127,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, nullptr, &out, &dout, dx, functor); \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
}
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
epOut
( \
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
EPOUT
( \
name, functor_class, attr) \
name, functor_class, attr) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -142,32 +142,62 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
...
@@ -142,32 +142,62 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, nullptr, &out, &dout, dx, functor); \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
}
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
,
CudaReluGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Relu
,
CudaReluGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
,
CudaTanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
,
CudaTanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Cos
,
CudaCosGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
,
CudaCosGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Tan
,
CudaTanGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
,
CudaTanGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Acos
,
CudaAcosGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
,
CudaAcosGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Sin
,
CudaSinGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Sin
,
CudaSinGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Asin
,
CudaAsinGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Asin
,
CudaAsinGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Atan
,
CudaAtanGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Atan
,
CudaAtanGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Sinh
,
CudaSinhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Sinh
,
CudaSinhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Cosh
,
CudaCoshGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cosh
,
CudaCoshGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Asinh
,
CudaAsinhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Asinh
,
CudaAsinhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Acosh
,
CudaAcoshGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acosh
,
CudaAcoshGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
,
CudaAtanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Atanh
,
CudaAtanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
TanhShrink
,
CudaTanhShrinkGradFunctor
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
LeakyRelu
,
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Silu
,
CudaSiluGradFunctor
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
LeakyRelu
,
CudaLeakyReluGradFunctor
,
CudaLeakyReluGradFunctor
,
alpha
);
alpha
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
ep
X
(
ThresholdedRelu
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_D
EP
X
(
ThresholdedRelu
,
CudaThresholdedReluGradFunctor
,
CudaThresholdedReluGradFunctor
,
threshold
);
threshold
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
SoftShrink
,
CudaSoftShrinkGradFunctor
,
lambda
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
CudaHardShrinkGradFunctor
,
threshold
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_D
ep
X
(
BRelu
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_D
EP
X
(
BRelu
,
CudaBReluGradFunctor
,
CudaBReluGradFunctor
,
t_min
,
t_min
,
t_max
);
t_max
);
template
<
typename
T
,
typename
Context
>
void
EluGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
float
alpha
,
DenseTensor
*
dx
)
{
dev_ctx
.
template
Alloc
<
T
>(
dx
);
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
dout
,
&
out
};
std
::
vector
<
DenseTensor
*>
outs
=
{
dx
};
if
(
alpha
>
0
)
{
funcs
::
CudaELUGradFunctor
<
T
>
functor
;
functor
.
alpha
=
alpha
;
funcs
::
ElementwiseKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
else
{
funcs
::
CudaELUGradNegativeAlphaFunctor
<
T
>
functor
;
functor
.
alpha
=
alpha
;
ins
.
push_back
(
&
x
);
funcs
::
ElementwiseKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
}
}
// namespace phi
}
// namespace phi
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
...
@@ -234,3 +264,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
...
@@ -234,3 +264,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel
)
LeakyReluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
thresholded_relu_grad
,
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
thresholded_relu_grad
,
ThresholdedReluGradKernel
)
ThresholdedReluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
soft_shrink_grad
,
SoftShrinkGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
hard_shrink_grad
,
HardShrinkGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
tanh_shrink_grad
,
TanhShrinkGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
silu_grad
,
SiluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
elu_grad
,
EluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
elu_double_grad
,
EluDoubleGradKernel
)
paddle/phi/kernels/gpu/activation_kernel.cu
浏览文件 @
57f54d3b
...
@@ -42,8 +42,9 @@ void ActivationGPUImpl(const Context& dev_ctx,
...
@@ -42,8 +42,9 @@ void ActivationGPUImpl(const Context& dev_ctx,
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##Kernel( \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class functor; \
funcs::functor_class<T> functor; \
ActivationGPUImpl<T, Context, functor_class>(dev_ctx, x, out, functor); \
ActivationGPUImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
}
#define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
#define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
...
@@ -75,24 +76,31 @@ void ActivationGPUImpl(const Context& dev_ctx,
...
@@ -75,24 +76,31 @@ void ActivationGPUImpl(const Context& dev_ctx,
dev_ctx, x, out, functor); \
dev_ctx, x, out, functor); \
}
}
DEFINE_GPU_ACTIVATION_KERNEL
(
Cos
,
funcs
::
CudaCosFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Cos
,
CudaCosFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Tan
,
funcs
::
CudaTanFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Tan
,
CudaTanFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Acos
,
funcs
::
CudaAcosFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Acos
,
CudaAcosFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Sin
,
funcs
::
CudaSinFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Sin
,
CudaSinFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Asin
,
funcs
::
CudaAsinFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Asin
,
CudaAsinFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Atan
,
funcs
::
CudaAtanFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Atan
,
CudaAtanFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Sinh
,
funcs
::
CudaSinhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Sinh
,
CudaSinhFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Cosh
,
funcs
::
CudaCoshFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Cosh
,
CudaCoshFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Asinh
,
funcs
::
CudaAsinhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Asinh
,
CudaAsinhFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Acosh
,
funcs
::
CudaAcoshFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Acosh
,
CudaAcoshFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Atanh
,
funcs
::
CudaAtanhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Atanh
,
CudaAtanhFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Relu
,
funcs
::
CudaReluFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Relu
,
CudaReluFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Tanh
,
funcs
::
CudaTanhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Tanh
,
CudaTanhFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
TanhShrink
,
CudaTanhShrinkFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Silu
,
CudaSiluFunctor
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
CudaThresholdedReluFunctor
,
CudaThresholdedReluFunctor
,
threshold
)
threshold
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
CudaHardShrinkFunctor
,
threshold
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
CudaSoftShrinkFunctor
,
lambda
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
CudaELUFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
...
@@ -142,3 +150,8 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
...
@@ -142,3 +150,8 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
hard_shrink
,
HardShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
soft_shrink
,
SoftShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
tanh_shrink
,
TanhShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
elu
,
EluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
silu
,
SiluKernel
)
paddle/phi/kernels/impl/activation_grad_impl.h
浏览文件 @
57f54d3b
...
@@ -202,4 +202,24 @@ void TanhTripleGradKernel(const Context& dev_ctx,
...
@@ -202,4 +202,24 @@ void TanhTripleGradKernel(const Context& dev_ctx,
d_ddx
);
// output
d_ddx
);
// output
}
}
template
<
typename
T
,
typename
Context
>
void
EluDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
float
alpha
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
)
{
if
(
dx
)
{
dx
->
Resize
(
x
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dx
);
}
if
(
ddout
)
{
dev_ctx
.
template
Alloc
<
T
>(
ddout
);
}
funcs
::
ELUGradGradFunctor
<
T
>
functor
;
functor
.
alpha
=
alpha
;
functor
(
dev_ctx
,
&
x
,
&
ddx
,
ddout
,
&
dout
,
dx
);
}
}
// namespace phi
}
// namespace phi
paddle/phi/ops/compat/activation_sig.cc
浏览文件 @
57f54d3b
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
namespace
phi
{
namespace
phi
{
#define D
efineActGradDepXOpArgMap
(func_name, op_name, attrs) \
#define D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \
return KernelSignature(op_name "_grad", \
...
@@ -25,7 +25,7 @@ namespace phi {
...
@@ -25,7 +25,7 @@ namespace phi {
{GradVarName("X")}); \
{GradVarName("X")}); \
}
}
#define D
efineActGradDepOutOpArgMap
(func_name, op_name, attrs) \
#define D
EFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \
return KernelSignature(op_name "_grad", \
...
@@ -36,25 +36,29 @@ namespace phi {
...
@@ -36,25 +36,29 @@ namespace phi {
#define comma ,
#define comma ,
D
efineActGradDepXOpArgMap
(
Cos
,
"cos"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Cos
,
"cos"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Tan
,
"tan"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Tan
,
"tan"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Acos
,
"acos"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Acos
,
"acos"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Sin
,
"sin"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Sin
,
"sin"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Asin
,
"asin"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Asin
,
"asin"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Atan
,
"atan"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Atan
,
"atan"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Sinh
,
"sinh"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Sinh
,
"sinh"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Cosh
,
"cosh"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Cosh
,
"cosh"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Asinh
,
"asinh"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Asinh
,
"asinh"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Acosh
,
"acosh"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Acosh
,
"acosh"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
Atanh
,
"atanh"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Atanh
,
"atanh"
,
);
// NOLINT
D
efineActGradDepXOpArgMap
(
BRelu
,
"brelu"
,
"t_min"
comma
"t_max"
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
BRelu
,
"brelu"
,
"t_min"
comma
"t_max"
);
D
efineActGradDepXOpArgMap
(
LeakyRelu
,
"leaky_relu"
,
"alpha"
);
// NOLINT
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
LeakyRelu
,
"leaky_relu"
,
"alpha"
);
D
efineActGradDepXOpArgMap
(
ThresholdedRelu
,
D
EFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
ThresholdedRelu
,
"thresholded_relu"
,
"thresholded_relu"
,
"threshold"
);
// NOLINT
"threshold"
);
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
SoftShrink
,
"soft_shrink"
,
"lambda"
);
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
HardShrink
,
"hard_shrink"
,
"threshold"
);
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
TanhShrink
,
"tanh_shrink"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Silu
,
"silu"
,
);
// NOLINT
D
efineActGradDepOutOpArgMap
(
Relu
,
"relu"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Relu
,
"relu"
,
);
// NOLINT
D
efineActGradDepOutOpArgMap
(
Tanh
,
"tanh"
,
);
// NOLINT
D
EFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Tanh
,
"tanh"
,
);
// NOLINT
KernelSignature
ReluDoubleGradOpArgumentMapping
(
KernelSignature
ReluDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
const
ArgumentMappingContext
&
ctx
)
{
...
@@ -85,11 +89,31 @@ KernelSignature LeakyReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
...
@@ -85,11 +89,31 @@ KernelSignature LeakyReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
return
KernelSignature
(
"leaky_relu"
,
{
"X"
},
{
"alpha"
},
{
"Out"
});
return
KernelSignature
(
"leaky_relu"
,
{
"X"
},
{
"alpha"
},
{
"Out"
});
}
}
KernelSignature
EluOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"elu"
,
{
"X"
},
{
"alpha"
},
{
"Out"
});
}
KernelSignature
EluGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"elu_grad"
,
{
"X"
,
"Out"
,
GradVarName
(
"Out"
)},
{
"alpha"
},
{
GradVarName
(
"X"
)});
}
KernelSignature
EluDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"elu_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{
"alpha"
},
{
"DX"
,
"DDOut"
});
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
relu_grad_grad
,
relu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
relu_grad_grad
,
relu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
tanh_grad_grad
,
tanh_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
tanh_grad_grad
,
tanh_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
leaky_relu_grad_grad
,
leaky_relu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
leaky_relu_grad_grad
,
leaky_relu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
softshrink
,
soft_shrink
);
PD_REGISTER_BASE_KERNEL_NAME
(
softshrink_grad
,
soft_shrink_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elu_grad_grad
,
elu_double_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
cos_grad
,
phi
::
CosGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
cos_grad
,
phi
::
CosGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
tan_grad
,
phi
::
TanGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
tan_grad
,
phi
::
TanGradOpArgumentMapping
);
...
@@ -118,3 +142,13 @@ PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad_grad,
...
@@ -118,3 +142,13 @@ PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad_grad,
phi
::
LeakyReluDoubleGradOpArgumentMapping
);
phi
::
LeakyReluDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
thresholded_relu_grad
,
PD_REGISTER_ARG_MAPPING_FN
(
thresholded_relu_grad
,
phi
::
ThresholdedReluGradOpArgumentMapping
);
phi
::
ThresholdedReluGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
softshrink_grad
,
phi
::
SoftShrinkGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
hard_shrink_grad
,
phi
::
HardShrinkGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
tanh_shrink_grad
,
phi
::
TanhShrinkGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elu
,
phi
::
EluOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elu_grad
,
phi
::
EluGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elu_grad_grad
,
phi
::
EluDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
silu_grad
,
phi
::
SiluGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录