Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c552d1ac
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c552d1ac
编写于
3月 16, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add forward case
上级
4e23ac69
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
740 addition
and
407 deletion
+740
-407
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+0
-28
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+1
-168
paddle/fluid/operators/activation_op.kps
paddle/fluid/operators/activation_op.kps
+0
-146
paddle/fluid/operators/math/selected_rows_functor.cc
paddle/fluid/operators/math/selected_rows_functor.cc
+137
-40
paddle/fluid/operators/math/selected_rows_functor.cu
paddle/fluid/operators/math/selected_rows_functor.cu
+178
-18
paddle/phi/kernels/CMakeLists.txt
paddle/phi/kernels/CMakeLists.txt
+1
-1
paddle/phi/kernels/activation_grad_kernel.h
paddle/phi/kernels/activation_grad_kernel.h
+1
-0
paddle/phi/kernels/activation_kernel.h
paddle/phi/kernels/activation_kernel.h
+14
-0
paddle/phi/kernels/cpu/activation_grad_kernel.cc
paddle/phi/kernels/cpu/activation_grad_kernel.cc
+10
-0
paddle/phi/kernels/cpu/activation_kernel.cc
paddle/phi/kernels/cpu/activation_kernel.cc
+37
-0
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+294
-0
paddle/phi/kernels/gpu/activation_grad_kernel.cu
paddle/phi/kernels/gpu/activation_grad_kernel.cu
+10
-0
paddle/phi/kernels/gpu/activation_kernel.cu
paddle/phi/kernels/gpu/activation_kernel.cu
+35
-0
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
+6
-6
paddle/phi/kernels/impl/activation_impl.h
paddle/phi/kernels/impl/activation_impl.h
+16
-0
未找到文件。
paddle/fluid/operators/activation_op.cc
浏览文件 @
c552d1ac
...
@@ -1650,9 +1650,6 @@ REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker,
...
@@ -1650,9 +1650,6 @@ REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker,
ops
::
LogitGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
LogitGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
LogitGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
LogitGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
logit_grad
,
ops
::
LogitGradOp
);
REGISTER_OPERATOR
(
logit_grad
,
ops
::
LogitGradOp
);
REGISTER_OP_CPU_KERNEL
(
logit
,
ops
::
LogitKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LogitKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
logit_grad
,
ops
::
LogitGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
logit_grad
,
ops
::
LogitGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LogitGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
LogitGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
...
@@ -1830,24 +1827,6 @@ REGISTER_OPERATOR(
...
@@ -1830,24 +1827,6 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR
(
exp_grad
,
ops
::
ActivationOpGrad
,
REGISTER_OPERATOR
(
exp_grad
,
ops
::
ActivationOpGrad
,
ops
::
ActivationGradOpInplaceInferer
);
ops
::
ActivationGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
exp
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpFunctor
<
float
>>
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpFunctor
<
double
>>
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpFunctor
<
int
>>
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpFunctor
<
int64_t
>>
);
REGISTER_OP_CPU_KERNEL
(
exp_grad
,
ops
::
ActivationGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpGradFunctor
<
float
>>
,
ops
::
ActivationGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpGradFunctor
<
double
>>
,
ops
::
ActivationGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpGradFunctor
<
int
>>
,
ops
::
ActivationGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ExpGradFunctor
<
int64_t
>>
);
/* ========================================================================== */
/* ========================================================================== */
/* ========================== expm1 register ============================ */
/* ========================== expm1 register ============================ */
...
@@ -1862,13 +1841,6 @@ REGISTER_OPERATOR(
...
@@ -1862,13 +1841,6 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR
(
expm1_grad
,
ops
::
ActivationOpGrad
,
REGISTER_OPERATOR
(
expm1_grad
,
ops
::
ActivationOpGrad
,
ops
::
ActivationGradOpInplaceInferer
);
ops
::
ActivationGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
expm1
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
Expm1Functor
<
float
>>
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
Expm1Functor
<
double
>>
,
ops
::
ActivationKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
Expm1Functor
<
plat
::
float16
>>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
expm1_grad
,
ops
::
ActivationGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
expm1_grad
,
ops
::
ActivationGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
Expm1GradFunctor
<
float
>>
,
ops
::
Expm1GradFunctor
<
float
>>
,
...
...
paddle/fluid/operators/activation_op.h
浏览文件 @
c552d1ac
...
@@ -273,6 +273,7 @@ USE_PHI_FUNCTOR(Asinh)
...
@@ -273,6 +273,7 @@ USE_PHI_FUNCTOR(Asinh)
USE_PHI_FUNCTOR
(
Acosh
)
USE_PHI_FUNCTOR
(
Acosh
)
USE_PHI_FUNCTOR
(
Atanh
)
USE_PHI_FUNCTOR
(
Atanh
)
USE_PHI_FUNCTOR
(
Tanh
)
USE_PHI_FUNCTOR
(
Tanh
)
USE_PHI_FUNCTOR
(
Exp
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
Tanh
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
Tanh
)
USE_PHI_TRIPLE_GRAD_FUNCTOR
(
Tanh
)
USE_PHI_TRIPLE_GRAD_FUNCTOR
(
Tanh
)
USE_PHI_FUNCTOR
(
BRelu
)
USE_PHI_FUNCTOR
(
BRelu
)
...
@@ -455,37 +456,6 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
...
@@ -455,37 +456,6 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// exp(x) = e^x
template
<
typename
T
>
struct
ExpFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
exp
();
}
};
template
<
typename
T
>
struct
ExpGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
// expm1(x) = e^x - 1
template
<
typename
T
>
struct
Expm1Functor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
expm1
();
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
Expm1GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
Expm1GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
...
@@ -605,15 +575,6 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
...
@@ -605,15 +575,6 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// sqrt(x) = x^(1/2)
template
<
typename
T
>
struct
SqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
sqrt
();
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
SqrtGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
SqrtGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
...
@@ -627,15 +588,6 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
...
@@ -627,15 +588,6 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
// rsqrt(x) = x^(-1/2)
template
<
typename
T
>
struct
RsqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
rsqrt
();
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
RsqrtGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
RsqrtGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
...
@@ -689,15 +641,6 @@ struct RoundFunctor : public BaseActivationFunctor<T> {
...
@@ -689,15 +641,6 @@ struct RoundFunctor : public BaseActivationFunctor<T> {
}
}
};
};
// reciprocal(x) = 1 / x
template
<
typename
T
>
struct
ReciprocalFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
static_cast
<
T
>
(
1
)
/
x
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
ReciprocalGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
ReciprocalGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
...
@@ -793,15 +736,6 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
...
@@ -793,15 +736,6 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// square(x) = x^2
template
<
typename
T
>
struct
SquareFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
square
();
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
SquareGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
SquareGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
...
@@ -894,27 +828,6 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
...
@@ -894,27 +828,6 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// 1, threshold = 20 by default), otherwise x
template
<
typename
T
>
struct
SoftplusFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
},
{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
{
auto
x_beta
=
static_cast
<
T
>
(
beta
)
*
x
;
out
.
device
(
d
)
=
(
x_beta
>
static_cast
<
T
>
(
threshold
))
.
select
(
x
,
(
static_cast
<
T
>
(
1
)
+
x_beta
.
exp
()).
log
()
/
static_cast
<
T
>
(
beta
));
}
};
// For numerical stability, using the following formula instead of
// For numerical stability, using the following formula instead of
// d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
...
@@ -939,24 +852,6 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
...
@@ -939,24 +852,6 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
template
<
typename
T
>
struct
MishFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
{
auto
sp
=
(
x
>
static_cast
<
T
>
(
threshold
))
.
select
(
x
,
(
static_cast
<
T
>
(
1
)
+
x
.
exp
()).
log
());
out
.
device
(
d
)
=
x
*
sp
.
tanh
();
}
};
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
// sp = softplus(x)
template
<
typename
T
>
template
<
typename
T
>
...
@@ -979,15 +874,6 @@ struct MishGradFunctor : public BaseActivationFunctor<T> {
...
@@ -979,15 +874,6 @@ struct MishGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// softsign(x) = x / (1 + |x|)
template
<
typename
T
>
struct
SoftsignFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
{
out
.
device
(
d
)
=
x
/
(
static_cast
<
T
>
(
1
)
+
x
.
abs
());
}
};
// d(softsign(x))/dx = 1 / (1 + |x|)^2
// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
// Taken from https://en.wikipedia.org/wiki/Activation_function
template
<
typename
T
>
template
<
typename
T
>
...
@@ -1198,24 +1084,6 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
...
@@ -1198,24 +1084,6 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
LogitFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
P
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
P
p
,
float
eps
)
const
{
// logit(x) = ln(x/(1-x))
auto
tmp_x
=
(
x
.
cwiseMin
(
static_cast
<
T
>
(
1.0
-
eps
))).
cwiseMax
(
static_cast
<
T
>
(
eps
));
if
(
!
eps
)
{
out
.
device
(
d
)
=
(
x
<
static_cast
<
T
>
(
0.0
)
||
x
>
static_cast
<
T
>
(
1.0
))
.
select
(
p
.
constant
(
static_cast
<
T
>
(
NAN
)),
(
tmp_x
/
(
static_cast
<
T
>
(
1
)
-
tmp_x
)).
log
());
}
else
{
out
.
device
(
d
)
=
(
tmp_x
/
(
static_cast
<
T
>
(
1
)
-
tmp_x
)).
log
();
}
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
LogitGradFunctor
{
struct
LogitGradFunctor
{
template
<
typename
Device
,
typename
X
,
typename
dOut
,
typename
dX
,
typename
P
>
template
<
typename
Device
,
typename
X
,
typename
dOut
,
typename
dX
,
typename
P
>
...
@@ -1228,21 +1096,6 @@ struct LogitGradFunctor {
...
@@ -1228,21 +1096,6 @@ struct LogitGradFunctor {
}
}
};
};
template
<
typename
T
>
struct
STanhFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
scale_a
;
float
scale_b
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"scale_a"
,
&
scale_a
},
{
"scale_b"
,
&
scale_b
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
static_cast
<
T
>
(
scale_b
)
*
(
static_cast
<
T
>
(
scale_a
)
*
x
).
tanh
();
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
STanhGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
STanhGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
scale_a
;
float
scale_a
;
...
@@ -2075,26 +1928,6 @@ class PowGradKernel
...
@@ -2075,26 +1928,6 @@ class PowGradKernel
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
class
LogitKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
eps
=
context
.
Attr
<
float
>
(
"eps"
);
out
->
mutable_data
<
T
>
(
in
->
place
());
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
eigen_p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
LogitFunctor
<
T
>
functor
;
functor
(
place
,
eigen_in
,
eigen_out
,
eigen_p
,
eps
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
LogitGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LogitGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
...
paddle/fluid/operators/activation_op.kps
浏览文件 @
c552d1ac
...
@@ -192,14 +192,6 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
...
@@ -192,14 +192,6 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// reciprocal(x) = 1 / x
__device__ __forceinline__ T operator()(const T x) const { return one / x; }
};
template <typename T>
template <typename T>
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
// dx = -dout * out^2
// dx = -dout * out^2
...
@@ -212,40 +204,6 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
...
@@ -212,40 +204,6 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// exp(x) = exp(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(exp(x));
}
};
template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// expm1(x) = expm1(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(expm1(x));
}
};
template <typename T>
template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
// dx = dout * out
...
@@ -279,12 +237,6 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
...
@@ -279,12 +237,6 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
// square(x) = x * x
__device__ __forceinline__ T operator()(const T x) const { return x * x; }
};
template <typename T>
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
T two = static_cast<T>(2.0f);
T two = static_cast<T>(2.0f);
...
@@ -297,17 +249,6 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
...
@@ -297,17 +249,6 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// sqrt(x) = sqrt(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(sqrt(x));
}
};
template <typename T>
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
T one_half = static_cast<T>(0.5f);
T one_half = static_cast<T>(0.5f);
...
@@ -322,17 +263,6 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
...
@@ -322,17 +263,6 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// rsqrt(x) = rsqrt(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(rsqrt(x));
}
};
template <typename T>
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f);
T minus_one_half = static_cast<T>(-0.5f);
...
@@ -466,25 +396,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
...
@@ -466,25 +396,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
// stanh(x) = b * tanh(a * x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b);
return static_cast<T>(b * tanh(a * x));
}
};
template <typename T>
template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename details::MPTypeTrait<T>::Type;
...
@@ -510,27 +421,6 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
...
@@ -510,27 +421,6 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
// softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
}
};
template <typename T>
template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename details::MPTypeTrait<T>::Type;
...
@@ -556,16 +446,6 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
...
@@ -556,16 +446,6 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// softsign(x) = x / (1 + abs(x))
__device__ __forceinline__ T operator()(const T x) const {
return x / (one + abs(x));
}
};
template <typename T>
template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
T one = static_cast<T>(1.0f);
...
@@ -762,27 +642,6 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
...
@@ -762,27 +642,6 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
return static_cast<T>(x * tanh(sp));
}
};
template <typename T>
template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename details::MPTypeTrait<T>::Type;
...
@@ -1292,11 +1151,6 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -1292,11 +1151,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================== logit register ============================ */
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
logit, ops::LogitKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogitKernel<paddle::platform::CUDADeviceContext, double>,
ops::LogitKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL(
logit_grad,
logit_grad,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
...
...
paddle/fluid/operators/math/selected_rows_functor.cc
浏览文件 @
c552d1ac
...
@@ -279,6 +279,46 @@ struct SelectedRowsAddToTensor<platform::CPUDeviceContext, T> {
...
@@ -279,6 +279,46 @@ struct SelectedRowsAddToTensor<platform::CPUDeviceContext, T> {
}
}
};
};
template
<
typename
T
>
struct
SelectedRowsAddToTensor
<
phi
::
CPUContext
,
T
>
{
void
operator
()(
const
phi
::
CPUContext
&
context
,
const
phi
::
SelectedRows
&
input1
,
framework
::
Tensor
*
input2
)
{
if
(
UNLIKELY
(
input1
.
rows
().
size
()
==
0
))
{
LOG
(
WARNING
)
<<
"input selected rows is empty!"
;
return
;
}
auto
in1_height
=
input1
.
height
();
auto
in2_dims
=
input2
->
dims
();
PADDLE_ENFORCE_EQ
(
in1_height
,
in2_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"The two inputs height must be equal."
"But recieved first input height = "
"[%d], second input height = [%d]"
,
in1_height
,
in2_dims
[
0
]));
auto
&
in1_value
=
input1
.
value
();
auto
&
in1_rows
=
input1
.
rows
();
int64_t
in1_row_numel
=
in1_value
.
numel
()
/
in1_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
input2
->
numel
()
/
in1_height
,
platform
::
errors
::
InvalidArgument
(
"The two inputs width must be equal."
"But recieved first input width = [%d], second input width = [%d]"
,
in1_row_numel
,
input2
->
numel
()
/
in1_height
));
auto
*
in1_data
=
in1_value
.
data
<
T
>
();
auto
*
input2_data
=
input2
->
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
in1_rows
.
size
();
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
in1_row_numel
;
j
++
)
{
input2_data
[
in1_rows
[
i
]
*
in1_row_numel
+
j
]
+=
in1_data
[
i
*
in1_row_numel
+
j
];
}
}
}
};
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
int
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
int
>;
...
@@ -286,6 +326,11 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
...
@@ -286,6 +326,11 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
template
struct
SelectedRowsAddToTensor
<
platform
::
CPUDeviceContext
,
platform
::
bfloat16
>;
platform
::
bfloat16
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
CPUContext
,
float
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
CPUContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
CPUContext
,
int
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
CPUContext
,
int64_t
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
CPUContext
,
platform
::
bfloat16
>;
// This is a separated namespace for manipulate SelectedRows typed
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
//
...
@@ -294,30 +339,30 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
...
@@ -294,30 +339,30 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
// add or mul.
// add or mul.
namespace
scatter
{
namespace
scatter
{
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
typename
std
::
enable_if
<!
std
::
is_integral
<
T
>::
value
>::
type
elementwise_add_to
(
typename
std
::
enable_if
<!
std
::
is_integral
<
T
>::
value
>::
type
elementwise_add_to
(
phi
::
funcs
::
BlasT
<
platform
::
CPUDeviceContext
,
T
>*
blas
,
size_t
data_le
n
,
phi
::
funcs
::
BlasT
<
DeviceContext
,
T
>*
blas
,
size_t
data_len
,
const
T
*
i
n
,
const
T
*
in
,
T
*
out
)
{
T
*
out
)
{
blas
->
AXPY
(
data_len
,
T
(
1.
f
),
in
,
out
);
blas
->
AXPY
(
data_len
,
T
(
1.
f
),
in
,
out
);
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
elementwise_add_to
(
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
elementwise_add_to
(
phi
::
funcs
::
BlasT
<
platform
::
CPUDeviceContext
,
T
>*
blas
,
size_t
data_le
n
,
phi
::
funcs
::
BlasT
<
DeviceContext
,
T
>*
blas
,
size_t
data_len
,
const
T
*
i
n
,
const
T
*
in
,
T
*
out
)
{
T
*
out
)
{
for
(
size_t
i
=
0
;
i
<
data_len
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
data_len
;
i
++
)
{
out
[
i
]
+=
in
[
i
];
out
[
i
]
+=
in
[
i
];
}
}
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
>::
type
typename
std
::
enable_if
<
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
>::
type
add_sparse_inputs
(
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
add_sparse_inputs
(
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
const
std
::
unordered_map
<
int64_t
,
size_t
>&
rows_to_id
,
const
std
::
unordered_map
<
int64_t
,
size_t
>&
rows_to_id
,
int64_t
input_width
,
int64_t
input_width
,
const
DeviceContext
&
context
,
const
platform
::
CPUDeviceContext
&
context
,
T
*
out_data
)
{
T
*
out_data
)
{
#ifndef PADDLE_WITH_MKLDNN
#ifndef PADDLE_WITH_MKLDNN
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CPU
DeviceContext
,
T
>
(
context
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
#endif
#endif
for
(
auto
*
input
:
inputs
)
{
for
(
auto
*
input
:
inputs
)
{
if
(
input
->
rows
().
size
()
==
0
)
{
if
(
input
->
rows
().
size
()
==
0
)
{
...
@@ -336,22 +381,22 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
...
@@ -336,22 +381,22 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
#else
#else
for
(
size_t
i
=
0
;
i
<
input_rows
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
input_rows
.
size
();
i
++
)
{
size_t
out_i
=
rows_to_id
.
at
(
input_rows
[
i
]);
size_t
out_i
=
rows_to_id
.
at
(
input_rows
[
i
]);
elementwise_add_to
<
T
>
(
&
blas
,
static_cast
<
size_t
>
(
input_width
),
elementwise_add_to
<
T
,
DeviceContext
>
(
&
input_data
[
i
*
input_width
],
&
blas
,
static_cast
<
size_t
>
(
input_width
),
&
input_data
[
i
*
input_width
],
&
out_data
[
out_i
*
input_width
]);
&
out_data
[
out_i
*
input_width
]);
}
}
#endif
#endif
}
}
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
DeviceContext
>
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
>::
type
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
>::
type
add_sparse_inputs
(
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
add_sparse_inputs
(
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
const
std
::
unordered_map
<
int64_t
,
size_t
>&
rows_to_id
,
const
std
::
unordered_map
<
int64_t
,
size_t
>&
rows_to_id
,
int64_t
input_width
,
int64_t
input_width
,
const
DeviceContext
&
context
,
const
platform
::
CPUDeviceContext
&
context
,
T
*
out_data
)
{
T
*
out_data
)
{
VLOG
(
4
)
<<
"[CPU] add_sparse_inputs <"
<<
typeid
(
T
).
name
();
VLOG
(
4
)
<<
"[CPU] add_sparse_inputs <"
<<
typeid
(
T
).
name
();
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CPU
DeviceContext
,
T
>
(
context
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
for
(
auto
*
input
:
inputs
)
{
for
(
auto
*
input
:
inputs
)
{
if
(
input
->
rows
().
size
()
==
0
)
{
if
(
input
->
rows
().
size
()
==
0
)
{
continue
;
continue
;
...
@@ -361,16 +406,16 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
...
@@ -361,16 +406,16 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
for
(
size_t
i
=
0
;
i
<
input_rows
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
input_rows
.
size
();
i
++
)
{
size_t
out_i
=
rows_to_id
.
at
(
input_rows
[
i
]);
size_t
out_i
=
rows_to_id
.
at
(
input_rows
[
i
]);
elementwise_add_to
<
T
>
(
&
blas
,
static_cast
<
size_t
>
(
input_width
),
elementwise_add_to
<
T
,
DeviceContext
>
(
&
input_data
[
i
*
input_width
],
&
blas
,
static_cast
<
size_t
>
(
input_width
),
&
input_data
[
i
*
input_width
],
&
out_data
[
out_i
*
input_width
]);
&
out_data
[
out_i
*
input_width
]);
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
T
>
{
struct
MergeAdd
Impl
{
phi
::
SelectedRows
operator
()(
const
platform
::
CPU
DeviceContext
&
context
,
phi
::
SelectedRows
operator
()(
const
DeviceContext
&
context
,
const
phi
::
SelectedRows
&
input
,
const
phi
::
SelectedRows
&
input
,
const
bool
sorted_result
=
false
)
{
const
bool
sorted_result
=
false
)
{
phi
::
SelectedRows
out
;
phi
::
SelectedRows
out
;
...
@@ -378,15 +423,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
...
@@ -378,15 +423,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
return
out
;
return
out
;
}
}
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
phi
::
SelectedRows
&
input
,
const
phi
::
SelectedRows
&
input
,
phi
::
SelectedRows
*
output
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
const
bool
sorted_result
=
false
)
{
std
::
vector
<
const
phi
::
SelectedRows
*>
inputs
;
std
::
vector
<
const
phi
::
SelectedRows
*>
inputs
;
inputs
.
push_back
(
&
input
);
inputs
.
push_back
(
&
input
);
(
*
this
)(
context
,
inputs
,
output
,
sorted_result
);
(
*
this
)(
context
,
inputs
,
output
,
sorted_result
);
}
}
void
operator
()(
const
platform
::
CPU
DeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
if
(
inputs
.
size
()
==
0
)
{
if
(
inputs
.
size
()
==
0
)
{
...
@@ -461,7 +505,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
...
@@ -461,7 +505,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out
.
set_rows
(
merge_rows
);
out
.
set_rows
(
merge_rows
);
phi
::
funcs
::
SetConstant
<
platform
::
CPU
DeviceContext
,
T
>
constant_functor
;
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
constant_functor
;
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0.
f
));
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0.
f
));
std
::
unordered_map
<
int64_t
,
size_t
>
rows_to_id
;
std
::
unordered_map
<
int64_t
,
size_t
>
rows_to_id
;
...
@@ -469,11 +513,75 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
...
@@ -469,11 +513,75 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
rows_to_id
[
merge_rows
[
i
]]
=
i
;
rows_to_id
[
merge_rows
[
i
]]
=
i
;
}
}
add_sparse_inputs
<
T
>
(
inputs
,
rows_to_id
,
input_width
,
context
,
out_data
);
add_sparse_inputs
<
T
,
DeviceContext
>
(
inputs
,
rows_to_id
,
input_width
,
context
,
out_data
);
}
}
}
}
};
};
template
<
typename
T
>
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
T
>
{
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi
::
SelectedRows
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
phi
::
SelectedRows
&
input
,
const
bool
sorted_result
)
{
return
MergeAddImpl
<
platform
::
CPUDeviceContext
,
T
>
()(
context
,
input
,
sorted_result
);
}
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
phi
::
SelectedRows
&
input
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
platform
::
CPUDeviceContext
,
T
>
()(
context
,
input
,
output
,
sorted_result
);
}
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
platform
::
CPUDeviceContext
,
T
>
()(
context
,
inputs
,
output
,
sorted_result
);
}
};
template
<
typename
T
>
struct
MergeAdd
<
phi
::
CPUContext
,
T
>
{
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi
::
SelectedRows
operator
()(
const
phi
::
CPUContext
&
context
,
const
phi
::
SelectedRows
&
input
,
const
bool
sorted_result
)
{
return
MergeAddImpl
<
phi
::
CPUContext
,
T
>
()(
context
,
input
,
sorted_result
);
}
void
operator
()(
const
phi
::
CPUContext
&
context
,
const
phi
::
SelectedRows
&
input
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
phi
::
CPUContext
,
T
>
()(
context
,
input
,
output
,
sorted_result
);
}
void
operator
()(
const
phi
::
CPUContext
&
context
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
phi
::
CPUContext
,
T
>
()(
context
,
inputs
,
output
,
sorted_result
);
}
};
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(dtype) \
template struct MergeAddImpl<platform::CPUDeviceContext, dtype>; \
template struct MergeAddImpl<phi::CPUContext, dtype>; \
template struct MergeAdd<platform::CPUDeviceContext, dtype>; \
template struct MergeAdd<phi::CPUContext, dtype>;
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU
(
float
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU
(
double
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU
(
int
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU
(
int64_t
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU
(
platform
::
bfloat16
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU
(
platform
::
complex
<
float
>
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU
(
platform
::
complex
<
double
>
)
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
template
<
typename
T
>
template
<
typename
T
>
struct
MergeAdd
<
platform
::
XPUDeviceContext
,
T
>
{
struct
MergeAdd
<
platform
::
XPUDeviceContext
,
T
>
{
...
@@ -714,17 +822,6 @@ struct MergeAverage<platform::CPUDeviceContext, T> {
...
@@ -714,17 +822,6 @@ struct MergeAverage<platform::CPUDeviceContext, T> {
}
}
};
};
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
int
>;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>
>
;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>
>
;
template
struct
MergeAdd
<
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>;
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
template
struct
MergeAdd
<
platform
::
XPUDeviceContext
,
float
>;
template
struct
MergeAdd
<
platform
::
XPUDeviceContext
,
float
>;
#endif
#endif
...
...
paddle/fluid/operators/math/selected_rows_functor.cu
浏览文件 @
c552d1ac
...
@@ -174,12 +174,77 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
...
@@ -174,12 +174,77 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
}
}
};
};
template
<
typename
T
>
struct
SelectedRowsAddTensor
<
phi
::
GPUContext
,
T
>
{
void
operator
()(
const
phi
::
GPUContext
&
context
,
const
phi
::
SelectedRows
&
input1
,
const
framework
::
Tensor
&
input2
,
framework
::
Tensor
*
output
)
{
auto
in1_height
=
input1
.
height
();
auto
in2_dims
=
input2
.
dims
();
auto
out_dims
=
output
->
dims
();
PADDLE_ENFORCE_EQ
(
in1_height
,
in2_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"The two inputs height must be equal."
"But recieved first input height = [%d], first input height = [%d]"
,
in1_height
,
in2_dims
[
0
]));
PADDLE_ENFORCE_EQ
(
in1_height
,
out_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"The input and output height must be equal."
"But recieved input height = [%d], output height = [%d]"
,
in1_height
,
out_dims
[
0
]));
auto
&
in1_value
=
input1
.
value
();
auto
&
in1_rows
=
input1
.
rows
();
int64_t
in1_row_numel
=
in1_value
.
numel
()
/
in1_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
input2
.
numel
()
/
in1_height
,
platform
::
errors
::
InvalidArgument
(
"The two inputs width must be equal."
"But recieved first input width = [%d], second input width = [%d]"
,
in1_row_numel
,
input2
.
numel
()
/
in1_height
));
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
output
->
numel
()
/
in1_height
,
platform
::
errors
::
InvalidArgument
(
"The input and output width must be equal."
"But recieved input width = [%d], output width = [%d]"
,
in1_row_numel
,
output
->
numel
()
/
in1_height
));
auto
*
in1_data
=
in1_value
.
data
<
T
>
();
auto
*
in2_data
=
input2
.
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
phi
::
GPUContext
,
T
>
functor
;
functor
(
context
,
output
,
static_cast
<
T
>
(
0
));
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
in1_rows
.
size
(),
1
);
paddle
::
framework
::
MixVector
<
int64_t
>
mixv_in1_rows
(
&
in1_rows
);
SelectedRowsAddTensorKernel
<
T
,
block_size
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in1_data
,
mixv_in1_rows
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
in1_row_numel
);
auto
out_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
in2_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
input2
);
out_eigen
.
device
(
*
context
.
eigen_device
())
=
out_eigen
+
in2_eigen
;
}
};
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAdd
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
struct
SelectedRowsAdd
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
platform
::
float16
>;
template
struct
SelectedRowsAddTensor
<
phi
::
GPUContext
,
float
>;
template
struct
SelectedRowsAddTensor
<
phi
::
GPUContext
,
double
>;
template
struct
SelectedRowsAdd
<
phi
::
GPUContext
,
platform
::
float16
>;
template
struct
SelectedRowsAddTensor
<
phi
::
GPUContext
,
platform
::
float16
>;
template
<
typename
T
>
template
<
typename
T
>
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
T
>
{
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
...
@@ -285,12 +350,54 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
...
@@ -285,12 +350,54 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
}
}
};
};
template
<
typename
T
>
struct
SelectedRowsAddToTensor
<
phi
::
GPUContext
,
T
>
{
void
operator
()(
const
phi
::
GPUContext
&
context
,
const
phi
::
SelectedRows
&
input1
,
framework
::
Tensor
*
input2
)
{
auto
in1_height
=
input1
.
height
();
auto
in2_dims
=
input2
->
dims
();
PADDLE_ENFORCE_EQ
(
in1_height
,
in2_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"The two inputs height must be equal."
"But recieved first input height = "
"[%d], second input height = [%d]"
,
in1_height
,
in2_dims
[
0
]));
auto
&
in1_value
=
input1
.
value
();
auto
&
in1_rows
=
input1
.
rows
();
int64_t
in1_row_numel
=
in1_value
.
numel
()
/
in1_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in1_row_numel
,
input2
->
numel
()
/
in1_height
,
platform
::
errors
::
InvalidArgument
(
"The two inputs width must be equal."
"But recieved first input width = [%d], second input width = [%d]"
,
in1_row_numel
,
input2
->
numel
()
/
in1_height
));
auto
*
in1_data
=
in1_value
.
data
<
T
>
();
auto
*
in2_data
=
input2
->
data
<
T
>
();
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
in1_rows
.
size
(),
1
);
paddle
::
framework
::
MixVector
<
int64_t
>
mixv_in1_rows
(
&
in1_rows
);
SelectedRowsAddToTensorKernel
<
T
,
block_size
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in1_data
,
mixv_in1_rows
.
CUDAData
(
context
.
GetPlace
()),
in2_data
,
in1_row_numel
);
}
};
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
platform
::
float16
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
GPUContext
,
float
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
GPUContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
GPUContext
,
int
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
GPUContext
,
int64_t
>;
template
struct
SelectedRowsAddToTensor
<
phi
::
GPUContext
,
platform
::
float16
>;
namespace
scatter
{
namespace
scatter
{
...
@@ -319,9 +426,9 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
...
@@ -319,9 +426,9 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
}
}
}
}
template
<
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
T
>
{
struct
MergeAdd
Impl
{
phi
::
SelectedRows
operator
()(
const
platform
::
CUDA
DeviceContext
&
context
,
phi
::
SelectedRows
operator
()(
const
DeviceContext
&
context
,
const
phi
::
SelectedRows
&
input
,
const
phi
::
SelectedRows
&
input
,
const
bool
sorted_result
=
false
)
{
const
bool
sorted_result
=
false
)
{
phi
::
SelectedRows
out
;
phi
::
SelectedRows
out
;
...
@@ -329,9 +436,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
...
@@ -329,9 +436,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
return
out
;
return
out
;
}
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
phi
::
SelectedRows
&
input
,
const
phi
::
SelectedRows
&
input
,
phi
::
SelectedRows
*
output
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
const
bool
sorted_result
=
false
)
{
framework
::
Vector
<
int64_t
>
input_rows
(
input
.
rows
());
framework
::
Vector
<
int64_t
>
input_rows
(
input
.
rows
());
if
(
input_rows
.
size
()
==
0
)
{
if
(
input_rows
.
size
()
==
0
)
{
return
;
return
;
...
@@ -350,7 +456,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
...
@@ -350,7 +456,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
phi
::
make_ddim
({
static_cast
<
int64_t
>
(
merge_rows
.
size
()),
input_width
}),
phi
::
make_ddim
({
static_cast
<
int64_t
>
(
merge_rows
.
size
()),
input_width
}),
context
.
GetPlace
());
context
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
platform
::
CUDA
DeviceContext
,
T
>
constant_functor
;
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
constant_functor
;
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0
));
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0
));
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
...
@@ -369,7 +475,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
...
@@ -369,7 +475,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
mix_vector_out
.
CopyToCPU
();
mix_vector_out
.
CopyToCPU
();
}
}
void
operator
()(
const
platform
::
CUDA
DeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
=
false
)
{
if
(
inputs
.
size
()
==
0
)
{
if
(
inputs
.
size
()
==
0
)
{
...
@@ -414,7 +520,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
...
@@ -414,7 +520,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
phi
::
make_ddim
({
static_cast
<
int64_t
>
(
merge_rows
.
size
()),
input_width
}),
phi
::
make_ddim
({
static_cast
<
int64_t
>
(
merge_rows
.
size
()),
input_width
}),
context
.
GetPlace
());
context
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
platform
::
CUDA
DeviceContext
,
T
>
constant_functor
;
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
constant_functor
;
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0
));
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0
));
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
...
@@ -441,15 +547,69 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
...
@@ -441,15 +547,69 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
}
}
};
};
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
double
>;
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
T
>
{
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
int
>;
// unary functor, merge by adding duplicated rows in
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
int64_t
>;
// the input SelectedRows object.
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
phi
::
SelectedRows
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
platform
::
bfloat16
>;
const
phi
::
SelectedRows
&
input
,
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
platform
::
complex
<
float
>
>
;
const
bool
sorted_result
)
{
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
return
MergeAddImpl
<
platform
::
CUDADeviceContext
,
T
>
()(
context
,
input
,
platform
::
complex
<
double
>
>
;
sorted_result
);
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
phi
::
SelectedRows
&
input
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
platform
::
CUDADeviceContext
,
T
>
()(
context
,
input
,
output
,
sorted_result
);
}
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
platform
::
CUDADeviceContext
,
T
>
()(
context
,
inputs
,
output
,
sorted_result
);
}
};
template
<
typename
T
>
struct
MergeAdd
<
phi
::
GPUContext
,
T
>
{
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
phi
::
SelectedRows
operator
()(
const
phi
::
GPUContext
&
context
,
const
phi
::
SelectedRows
&
input
,
const
bool
sorted_result
)
{
return
MergeAddImpl
<
phi
::
GPUContext
,
T
>
()(
context
,
input
,
sorted_result
);
}
void
operator
()(
const
phi
::
GPUContext
&
context
,
const
phi
::
SelectedRows
&
input
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
phi
::
GPUContext
,
T
>
()(
context
,
input
,
output
,
sorted_result
);
}
void
operator
()(
const
phi
::
GPUContext
&
context
,
const
std
::
vector
<
const
phi
::
SelectedRows
*>&
inputs
,
phi
::
SelectedRows
*
output
,
const
bool
sorted_result
)
{
MergeAddImpl
<
phi
::
GPUContext
,
T
>
()(
context
,
inputs
,
output
,
sorted_result
);
}
};
#define TEMPLATE_SPECIALIZED_FOR_MERGEADD(dtype) \
template struct MergeAddImpl<platform::CUDADeviceContext, dtype>; \
template struct MergeAddImpl<phi::GPUContext, dtype>; \
template struct MergeAdd<platform::CUDADeviceContext, dtype>; \
template struct MergeAdd<phi::GPUContext, dtype>;
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
float
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
double
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
int
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
int64_t
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
platform
::
float16
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
platform
::
bfloat16
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
platform
::
complex
<
float
>
)
TEMPLATE_SPECIALIZED_FOR_MERGEADD
(
platform
::
complex
<
double
>
)
template
<
typename
T
,
int
block_size
>
template
<
typename
T
,
int
block_size
>
__global__
void
UpdateToTensorKernel
(
const
T
*
selected_rows
,
__global__
void
UpdateToTensorKernel
(
const
T
*
selected_rows
,
...
...
paddle/phi/kernels/CMakeLists.txt
浏览文件 @
c552d1ac
...
@@ -11,7 +11,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "")
...
@@ -11,7 +11,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "")
# [ 1. Common kernel compilation dependencies ]
# [ 1. Common kernel compilation dependencies ]
set
(
COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel
)
set
(
COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
eigen_function blas math_function im2col vol2col concat_and_split_functor
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
eigen_function blas math_function im2col vol2col concat_and_split_functor
selected_rows_functor
)
# remove this dep after removing fluid deps on tensor creation
# remove this dep after removing fluid deps on tensor creation
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
phi_api_utils
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
phi_api_utils
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
infermeta
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
infermeta
)
...
...
paddle/phi/kernels/activation_grad_kernel.h
浏览文件 @
c552d1ac
...
@@ -100,5 +100,6 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh);
...
@@ -100,5 +100,6 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut
(
Exp
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/activation_kernel.h
浏览文件 @
c552d1ac
...
@@ -37,6 +37,8 @@ DECLARE_ACTIVATION_KERNEL(Acosh)
...
@@ -37,6 +37,8 @@ DECLARE_ACTIVATION_KERNEL(Acosh)
DECLARE_ACTIVATION_KERNEL
(
Atanh
)
DECLARE_ACTIVATION_KERNEL
(
Atanh
)
DECLARE_ACTIVATION_KERNEL
(
Relu
)
DECLARE_ACTIVATION_KERNEL
(
Relu
)
DECLARE_ACTIVATION_KERNEL
(
Tanh
)
DECLARE_ACTIVATION_KERNEL
(
Tanh
)
DECLARE_ACTIVATION_KERNEL
(
Exp
)
DECLARE_ACTIVATION_KERNEL
(
Expm1
)
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
BReluKernel
(
const
Context
&
dev_ctx
,
void
BReluKernel
(
const
Context
&
dev_ctx
,
...
@@ -57,4 +59,16 @@ void ThresholdedReluKernel(const Context& dev_ctx,
...
@@ -57,4 +59,16 @@ void ThresholdedReluKernel(const Context& dev_ctx,
float
threshold
,
float
threshold
,
DenseTensor
*
out
);
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
LogitKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
eps
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
MishKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
threshold
,
DenseTensor
*
out
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/cpu/activation_grad_kernel.cc
浏览文件 @
c552d1ac
...
@@ -104,6 +104,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor);
...
@@ -104,6 +104,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
,
funcs
::
ReluGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Relu
,
funcs
::
ReluGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
,
funcs
::
TanhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Tanh
,
funcs
::
TanhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Exp
,
funcs
::
ExpGradFunctor
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
LeakyRelu
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
LeakyRelu
,
funcs
::
LeakyReluGradFunctor
,
funcs
::
LeakyReluGradFunctor
,
...
@@ -159,3 +160,12 @@ PD_REGISTER_KERNEL(tanh_triple_grad,
...
@@ -159,3 +160,12 @@ PD_REGISTER_KERNEL(tanh_triple_grad,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
exp_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
ExpGradKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/cpu/activation_kernel.cc
浏览文件 @
c552d1ac
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
namespace
phi
{
namespace
phi
{
...
@@ -67,11 +68,27 @@ DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor<T>)
...
@@ -67,11 +68,27 @@ DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL
(
Atanh
,
funcs
::
AtanhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Atanh
,
funcs
::
AtanhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Relu
,
funcs
::
ReluCPUFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Relu
,
funcs
::
ReluCPUFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Tanh
,
funcs
::
TanhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Tanh
,
funcs
::
TanhFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Exp
,
funcs
::
ExpFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Expm1
,
funcs
::
Expm1Functor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Reciprocal
,
funcs
::
ReciprocalFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Square
,
funcs
::
SquareFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Sqrt
,
funcs
::
SqrtFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Rsqrt
,
funcs
::
RsqrtFunctor
<
T
>
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Softsign
,
funcs
::
SoftsignFunctor
<
T
>
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
funcs
::
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
funcs
::
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
funcs
::
ThresholdedReluFunctor
,
funcs
::
ThresholdedReluFunctor
,
threshold
)
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Mish
,
funcs
::
MishFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
funcs
::
BReluFunctor
,
t_min
,
t_max
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
funcs
::
BReluFunctor
,
t_min
,
t_max
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
STanh
,
funcs
::
STanhFunctor
,
scale_a
,
scale_b
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
Softplus
,
funcs
::
SoftplusFunctor
,
beta
,
threshold
)
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
relu
,
CPU
,
ALL_LAYOUT
,
phi
::
ReluKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
relu
,
CPU
,
ALL_LAYOUT
,
phi
::
ReluKernel
,
float
,
double
)
{}
...
@@ -94,3 +111,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh)
...
@@ -94,3 +111,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedRelu
)
PD_REGISTER_ACTIVATION_KERNEL
(
mish
,
Mish
)
PD_REGISTER_ACTIVATION_KERNEL
(
stanh
,
STanh
)
PD_REGISTER_ACTIVATION_KERNEL
(
reciprocal
,
Reciprocal
)
PD_REGISTER_ACTIVATION_KERNEL
(
sqrt
,
Sqrt
)
PD_REGISTER_ACTIVATION_KERNEL
(
rsqrt
,
Rsqrt
)
PD_REGISTER_ACTIVATION_KERNEL
(
softplus
,
Softplus
)
PD_REGISTER_ACTIVATION_KERNEL
(
softsign
,
Softsign
)
PD_REGISTER_KERNEL
(
exp
,
CPU
,
ALL_LAYOUT
,
phi
::
ExpKernel
,
float
,
double
,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
expm1
,
CPU
,
ALL_LAYOUT
,
phi
::
Expm1Kernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
logit
,
CPU
,
ALL_LAYOUT
,
phi
::
LogitKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
square
,
CPU
,
ALL_LAYOUT
,
phi
::
SquareKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
c552d1ac
...
@@ -100,6 +100,15 @@ struct SinFunctor : public BaseActivationFunctor<T> {
...
@@ -100,6 +100,15 @@ struct SinFunctor : public BaseActivationFunctor<T> {
}
}
};
};
// reciprocal(x) = 1 / x
template
<
typename
T
>
struct
ReciprocalFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
static_cast
<
T
>
(
1
)
/
x
;
}
};
// cosine'(x) = -sin(x)
// cosine'(x) = -sin(x)
template
<
typename
T
>
template
<
typename
T
>
struct
CosGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CosGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -124,6 +133,57 @@ struct CosFunctor : public BaseActivationFunctor<T> {
...
@@ -124,6 +133,57 @@ struct CosFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template
<
typename
T
>
struct
LogitFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
P
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
P
p
,
float
eps
)
const
{
// logit(x) = ln(x/(1-x))
auto
tmp_x
=
(
x
.
cwiseMin
(
static_cast
<
T
>
(
1.0
-
eps
))).
cwiseMax
(
static_cast
<
T
>
(
eps
));
if
(
!
eps
)
{
out
.
device
(
d
)
=
(
x
<
static_cast
<
T
>
(
0.0
)
||
x
>
static_cast
<
T
>
(
1.0
))
.
select
(
p
.
constant
(
static_cast
<
T
>
(
NAN
)),
(
tmp_x
/
(
static_cast
<
T
>
(
1
)
-
tmp_x
)).
log
());
}
else
{
out
.
device
(
d
)
=
(
tmp_x
/
(
static_cast
<
T
>
(
1
)
-
tmp_x
)).
log
();
}
}
};
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
template
<
typename
T
>
struct
MishFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
{
auto
sp
=
(
x
>
static_cast
<
T
>
(
threshold
))
.
select
(
x
,
(
static_cast
<
T
>
(
1
)
+
x
.
exp
()).
log
());
out
.
device
(
d
)
=
x
*
sp
.
tanh
();
}
};
template
<
typename
T
>
struct
STanhFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
scale_a
;
float
scale_b
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"scale_a"
,
&
scale_a
},
{
"scale_b"
,
&
scale_b
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
static_cast
<
T
>
(
scale_b
)
*
(
static_cast
<
T
>
(
scale_a
)
*
x
).
tanh
();
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
Tangent
{
struct
Tangent
{
HOSTDEVICE
T
operator
()(
const
T
&
val
)
const
{
return
tan
(
val
);
}
HOSTDEVICE
T
operator
()(
const
T
&
val
)
const
{
return
tan
(
val
);
}
...
@@ -151,6 +211,55 @@ struct TanGradFunctor : public BaseActivationFunctor<T> {
...
@@ -151,6 +211,55 @@ struct TanGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
};
// square(x) = x^2
template
<
typename
T
>
struct
SquareFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
square
();
}
};
// sqrt(x) = x^(1/2)
template
<
typename
T
>
struct
SqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
sqrt
();
}
};
// rsqrt(x) = x^(-1/2)
template
<
typename
T
>
struct
RsqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
rsqrt
();
}
};
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// 1, threshold = 20 by default), otherwise x
template
<
typename
T
>
struct
SoftplusFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
},
{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
{
auto
x_beta
=
static_cast
<
T
>
(
beta
)
*
x
;
out
.
device
(
d
)
=
(
x_beta
>
static_cast
<
T
>
(
threshold
))
.
select
(
x
,
(
static_cast
<
T
>
(
1
)
+
x_beta
.
exp
()).
log
()
/
static_cast
<
T
>
(
beta
));
}
};
// Tangent(x) = tan(x)
// Tangent(x) = tan(x)
template
<
typename
T
>
template
<
typename
T
>
struct
TanFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
TanFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -452,6 +561,41 @@ struct AtanhGradFunctor : public BaseActivationFunctor<T> {
...
@@ -452,6 +561,41 @@ struct AtanhGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
};
// exp functor
// exp(x) = e^x
template
<
typename
T
>
struct
ExpFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
exp
();
}
};
template
<
typename
T
>
struct
ExpGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
// expm1(x) = e^x - 1
template
<
typename
T
>
struct
Expm1Functor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
expm1
();
}
};
// relu(x) = max(x, 0)
// relu(x) = max(x, 0)
template
<
typename
T
>
template
<
typename
T
>
struct
ReluCPUFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
ReluCPUFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -672,6 +816,15 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
...
@@ -672,6 +816,15 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// softsign(x) = x / (1 + |x|)
template
<
typename
T
>
struct
SoftsignFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
{
out
.
device
(
d
)
=
x
/
(
static_cast
<
T
>
(
1
)
+
x
.
abs
());
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
LeakyReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
LeakyReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
float
alpha
;
...
@@ -827,6 +980,54 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
...
@@ -827,6 +980,54 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
CudaExpFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// exp(x) = exp(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
exp
(
x
));
}
};
template
<
typename
T
>
struct
CudaSquareFunctor
:
public
BaseActivationFunctor
<
T
>
{
// square(x) = x * x
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
x
*
x
;
}
};
template
<
typename
T
>
struct
CudaExpGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
// dx = dout * out
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
out
)
const
{
return
dout
*
out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
template
<
typename
T
>
struct
CudaReciprocalFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one
=
static_cast
<
T
>
(
1.0
f
);
// reciprocal(x) = 1 / x
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
one
/
x
;
}
};
template
<
typename
T
>
struct
CudaExpm1Functor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// expm1(x) = expm1(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
expm1
(
x
));
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
CudaSinFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaSinFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
...
@@ -838,6 +1039,16 @@ struct CudaSinFunctor : public BaseActivationFunctor<T> {
...
@@ -838,6 +1039,16 @@ struct CudaSinFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template
<
typename
T
>
struct
CudaSoftsignFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one
=
static_cast
<
T
>
(
1.0
f
);
// softsign(x) = x / (1 + abs(x))
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
x
/
(
one
+
abs
(
x
));
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
CudaSinGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaSinGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
...
@@ -1049,6 +1260,46 @@ struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
...
@@ -1049,6 +1260,46 @@ struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template
<
typename
T
>
struct
CudaSTanhFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
float
scale_a
;
float
scale_b
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"scale_a"
,
&
scale_a
},
{
"scale_b"
,
&
scale_b
}};
}
// stanh(x) = b * tanh(a * x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
a
=
static_cast
<
MPType
>
(
scale_a
);
MPType
b
=
static_cast
<
MPType
>
(
scale_b
);
return
static_cast
<
T
>
(
b
*
tanh
(
a
*
x
));
}
};
template
<
typename
T
>
struct
CudaSoftplusFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
float
beta
;
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
},
{
"threshold"
,
&
threshold
}};
}
// softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
b
=
static_cast
<
MPType
>
(
beta
);
MPType
t
=
static_cast
<
MPType
>
(
threshold
);
MPType
x_beta
=
x
*
beta
;
return
static_cast
<
T
>
(
x_beta
>
t
?
x
:
log
(
one
+
exp
(
x_beta
))
/
b
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
CudaAtanhGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaAtanhGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
...
@@ -1064,6 +1315,28 @@ struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
...
@@ -1064,6 +1315,28 @@ struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
CudaSqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// sqrt(x) = sqrt(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
sqrt
(
x
));
}
};
template
<
typename
T
>
struct
CudaRsqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// rsqrt(x) = rsqrt(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
rsqrt
(
x
));
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
CudaAtanFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaAtanFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
...
@@ -1131,6 +1404,27 @@ struct CudaBReluFunctor : public BaseActivationFunctor<T> {
...
@@ -1131,6 +1404,27 @@ struct CudaBReluFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template
<
typename
T
>
struct
CudaMishFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
}};
}
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
sp
=
(
x
>
static_cast
<
MPType
>
(
threshold
))
?
x
:
log
(
one
+
exp
(
x
));
return
static_cast
<
T
>
(
x
*
tanh
(
sp
));
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
CudaBReluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaBReluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
T
zero
=
static_cast
<
T
>
(
0.0
f
);
...
...
paddle/phi/kernels/gpu/activation_grad_kernel.cu
浏览文件 @
c552d1ac
...
@@ -155,6 +155,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, CudaCoshGradFunctor);
...
@@ -155,6 +155,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, CudaCoshGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Asinh
,
CudaAsinhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Asinh
,
CudaAsinhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Acosh
,
CudaAcoshGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Acosh
,
CudaAcoshGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
,
CudaAtanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX
(
Atanh
,
CudaAtanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut
(
Exp
,
CudaExpGradFunctor
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
LeakyRelu
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX
(
LeakyRelu
,
CudaLeakyReluGradFunctor
,
CudaLeakyReluGradFunctor
,
...
@@ -234,3 +235,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
...
@@ -234,3 +235,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel
)
LeakyReluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
thresholded_relu_grad
,
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
thresholded_relu_grad
,
ThresholdedReluGradKernel
)
ThresholdedReluGradKernel
)
PD_REGISTER_KERNEL
(
exp_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
ExpGradKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/activation_kernel.cu
浏览文件 @
c552d1ac
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/activation_grad_impl.h"
#include "paddle/phi/kernels/impl/activation_grad_impl.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
...
@@ -88,13 +89,27 @@ DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor<T>)
...
@@ -88,13 +89,27 @@ DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL
(
Atanh
,
funcs
::
CudaAtanhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Atanh
,
funcs
::
CudaAtanhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Relu
,
funcs
::
CudaReluFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Relu
,
funcs
::
CudaReluFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Tanh
,
funcs
::
CudaTanhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Tanh
,
funcs
::
CudaTanhFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Exp
,
funcs
::
CudaExpFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Expm1
,
funcs
::
CudaExpm1Functor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Reciprocal
,
funcs
::
CudaReciprocalFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Square
,
funcs
::
CudaSquareFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Sqrt
,
funcs
::
CudaSqrtFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Rsqrt
,
funcs
::
CudaRsqrtFunctor
<
T
>
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Softsign
,
funcs
::
CudaSoftsignFunctor
<
T
>
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
CudaThresholdedReluFunctor
,
CudaThresholdedReluFunctor
,
threshold
)
threshold
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Mish
,
CudaMishFunctor
,
threshold
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
Stanh
,
CudaSTanhFunctor
,
scale_a
,
scale_b
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
Softplus
,
CudaSoftplusFunctor
,
beta
,
threshold
)
}
// namespace phi
}
// namespace phi
...
@@ -142,3 +157,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
...
@@ -142,3 +157,23 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
brelu
,
BReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
thresholded_relu
,
ThresholdedReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
leaky_relu
,
LeakyReluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
mish
,
MishKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
stanh
,
StanhKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
reciprocal
,
ReciprocalKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
sqrt
,
SqrtKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
rsqrt
,
RsqrtKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
softplus
,
SoftplusKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
softsign
,
SoftsignKernel
)
PD_REGISTER_KERNEL
(
exp
,
GPU
,
ALL_LAYOUT
,
phi
::
ExpKernel
,
float
,
double
,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
expm1
,
GPU
,
ALL_LAYOUT
,
phi
::
Expm1Kernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
logit
,
GPU
,
ALL_LAYOUT
,
phi
::
LogitKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
square
,
GPU
,
ALL_LAYOUT
,
phi
::
SquareKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
浏览文件 @
c552d1ac
...
@@ -40,7 +40,8 @@ void ClipByNormKernel<phi::dtype::float16, phi::GPUContext>(
...
@@ -40,7 +40,8 @@ void ClipByNormKernel<phi::dtype::float16, phi::GPUContext>(
DenseTensor
tmp
;
DenseTensor
tmp
;
tmp
.
Resize
({
1
});
tmp
.
Resize
({
1
});
dev_ctx
.
template
Alloc
<
float
>(
&
tmp
);
dev_ctx
.
template
Alloc
<
float
>(
&
tmp
);
kernels
::
TensorReduceImpl
<
dtype
::
float16
,
phi
::
funcs
::
ReduceKernel
<
dtype
::
float16
,
float
,
float
,
kps
::
AddFunctor
,
kps
::
AddFunctor
,
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>>
(
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>>
(
...
@@ -48,8 +49,7 @@ void ClipByNormKernel<phi::dtype::float16, phi::GPUContext>(
...
@@ -48,8 +49,7 @@ void ClipByNormKernel<phi::dtype::float16, phi::GPUContext>(
x_in
,
x_in
,
&
tmp
,
&
tmp
,
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>
(),
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>
(),
reduce_dims
,
reduce_dims
);
dev_ctx
.
stream
());
auto
tmp_eigen
=
EigenVector
<
float
>::
Flatten
(
tmp
);
auto
tmp_eigen
=
EigenVector
<
float
>::
Flatten
(
tmp
);
auto
x_norm
=
tmp_eigen
.
sqrt
();
auto
x_norm
=
tmp_eigen
.
sqrt
();
...
...
paddle/phi/kernels/impl/activation_impl.h
浏览文件 @
c552d1ac
...
@@ -47,4 +47,20 @@ void ActivationImpl(const Context& dev_ctx,
...
@@ -47,4 +47,20 @@ void ActivationImpl(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
LogitKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
eps
,
DenseTensor
*
out
)
{
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
eigen_out
=
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
eigen_in
=
EigenVector
<
T
>::
Flatten
(
x
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
eigen_p
=
EigenVector
<
T
>::
Flatten
(
*
out
);
funcs
::
LogitFunctor
<
T
>
functor
;
functor
(
place
,
eigen_in
,
eigen_out
,
eigen_p
,
eps
);
}
}
// namespace phi
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录