Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
191c441a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
191c441a
编写于
5月 20, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
5月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move activation kernel (#42880)
上级
d8b69124
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
475 addition
and
511 deletion
+475
-511
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+0
-36
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+5
-345
paddle/fluid/operators/activation_op.kps
paddle/fluid/operators/activation_op.kps
+29
-125
paddle/phi/kernels/activation_grad_kernel.h
paddle/phi/kernels/activation_grad_kernel.h
+34
-0
paddle/phi/kernels/activation_kernel.h
paddle/phi/kernels/activation_kernel.h
+1
-0
paddle/phi/kernels/cpu/activation_grad_kernel.cc
paddle/phi/kernels/cpu/activation_grad_kernel.cc
+17
-0
paddle/phi/kernels/cpu/activation_kernel.cc
paddle/phi/kernels/cpu/activation_kernel.cc
+5
-4
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+243
-0
paddle/phi/kernels/gpu/activation_grad_kernel.cu
paddle/phi/kernels/gpu/activation_grad_kernel.cu
+17
-0
paddle/phi/kernels/gpu/activation_kernel.cu
paddle/phi/kernels/gpu/activation_kernel.cu
+3
-1
paddle/phi/kernels/impl/activation_grad_impl.h
paddle/phi/kernels/impl/activation_grad_impl.h
+83
-0
paddle/phi/ops/compat/activation_sig.cc
paddle/phi/ops/compat/activation_sig.cc
+38
-0
未找到文件。
paddle/fluid/operators/activation_op.cc
浏览文件 @
191c441a
...
@@ -1659,15 +1659,6 @@ REGISTER_OPERATOR(
...
@@ -1659,15 +1659,6 @@ REGISTER_OPERATOR(
ops
::
ActivationOpDoubleGrad
<
ops
::
CELUGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationOpDoubleGrad
<
ops
::
CELUGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationDoubleGradOpInplaceInferer
);
ops
::
ActivationDoubleGradOpInplaceInferer
);
REGISTER_ACTIVATION_CPU_KERNEL
(
celu
,
CELU
,
CELUFunctor
,
CELUGradFunctor
);
REGISTER_OP_CPU_KERNEL
(
celu_grad_grad
,
ops
::
CELUDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
CELUGradGradFunctor
<
float
>>
,
ops
::
CELUDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
CELUGradGradFunctor
<
double
>>
,
ops
::
CELUDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
CELUGradGradFunctor
<
plat
::
float16
>>
);
/* ========================================================================== */
/* ========================================================================== */
/* =========================== sqrt register ============================= */
/* =========================== sqrt register ============================= */
...
@@ -1687,13 +1678,6 @@ REGISTER_OPERATOR(
...
@@ -1687,13 +1678,6 @@ REGISTER_OPERATOR(
ops
::
ActivationOpDoubleGrad
<
ops
::
SqrtGradGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationOpDoubleGrad
<
ops
::
SqrtGradGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationDoubleGradOpInplaceInferer
);
ops
::
ActivationDoubleGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
sqrt_grad_grad
,
ops
::
SqrtDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SqrtGradGradFunctor
<
float
>>
,
ops
::
SqrtDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SqrtGradGradFunctor
<
double
>>
,
ops
::
SqrtDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SqrtGradGradFunctor
<
plat
::
float16
>>
);
/* ========================================================================== */
/* ========================================================================== */
/* =========================== rsqrt register =============================
/* =========================== rsqrt register =============================
...
@@ -1714,14 +1698,6 @@ REGISTER_OPERATOR(
...
@@ -1714,14 +1698,6 @@ REGISTER_OPERATOR(
ops
::
ActivationOpDoubleGrad
<
ops
::
RsqrtGradGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationOpDoubleGrad
<
ops
::
RsqrtGradGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationDoubleGradOpInplaceInferer
);
ops
::
ActivationDoubleGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
rsqrt_grad_grad
,
ops
::
RsqrtDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
RsqrtGradGradFunctor
<
float
>>
,
ops
::
RsqrtDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
RsqrtGradGradFunctor
<
double
>>
,
ops
::
RsqrtDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
RsqrtGradGradFunctor
<
plat
::
float16
>>
);
/* ========================================================================== */
/* ========================================================================== */
/* ========================== square register ============================ */
/* ========================== square register ============================ */
...
@@ -1742,18 +1718,6 @@ REGISTER_OPERATOR(
...
@@ -1742,18 +1718,6 @@ REGISTER_OPERATOR(
ops
::
ActivationOpDoubleGrad
<
ops
::
SquareGradGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationOpDoubleGrad
<
ops
::
SquareGradGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationDoubleGradOpInplaceInferer
);
ops
::
ActivationDoubleGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
square_grad_grad
,
ops
::
SquareDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SquareGradGradFunctor
<
float
>>
,
ops
::
SquareDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SquareGradGradFunctor
<
double
>>
,
ops
::
SquareDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SquareGradGradFunctor
<
plat
::
float16
>>
,
ops
::
SquareDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SquareGradGradFunctor
<
int
>>
,
ops
::
SquareDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
SquareGradGradFunctor
<
int64_t
>>
);
/* ========================================================================== */
/* ========================================================================== */
/* ========================== pow register ============================ */
/* ========================== pow register ============================ */
...
...
paddle/fluid/operators/activation_op.h
浏览文件 @
191c441a
...
@@ -296,9 +296,14 @@ USE_PHI_FUNCTOR(Mish)
...
@@ -296,9 +296,14 @@ USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR
(
STanh
)
USE_PHI_FUNCTOR
(
STanh
)
USE_PHI_FUNCTOR
(
Reciprocal
)
USE_PHI_FUNCTOR
(
Reciprocal
)
USE_PHI_FUNCTOR
(
Square
)
USE_PHI_FUNCTOR
(
Square
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
Square
)
USE_PHI_FUNCTOR
(
Sqrt
)
USE_PHI_FUNCTOR
(
Sqrt
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
Sqrt
)
USE_PHI_FUNCTOR
(
Rsqrt
)
USE_PHI_FUNCTOR
(
Rsqrt
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
Rsqrt
)
USE_PHI_FUNCTOR
(
Softplus
)
USE_PHI_FUNCTOR
(
Softplus
)
USE_PHI_FUNCTOR
(
CELU
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
CELU
)
template
<
typename
T
>
template
<
typename
T
>
using
ELUGradNegativeAlphaFunctor
=
phi
::
funcs
::
ELUGradNegativeAlphaFunctor
<
T
>
;
using
ELUGradNegativeAlphaFunctor
=
phi
::
funcs
::
ELUGradNegativeAlphaFunctor
<
T
>
;
...
@@ -331,68 +336,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
...
@@ -331,68 +336,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
template
<
typename
T
>
template
<
typename
T
>
using
ReluCUDAFunctor
=
phi
::
funcs
::
ReluCUDAFunctor
<
T
>
;
using
ReluCUDAFunctor
=
phi
::
funcs
::
ReluCUDAFunctor
<
T
>
;
template
<
typename
T
>
struct
SqrtGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
framework
::
Tensor
*
Out
,
const
framework
::
Tensor
*
ddX
,
framework
::
Tensor
*
ddOut
,
framework
::
Tensor
*
dOut
,
const
framework
::
Tensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"SqrtGradGrad"
));
auto
out
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
Out
,
"Output"
,
"Out"
,
"SqrtGradGrad"
));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if
(
dOut
)
{
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"SqrtGradGrad"
));
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"SqrtGradGrad"
));
dout
.
device
(
*
d
)
=
dx
*
ddx
*
static_cast
<
T
>
(
-
1
)
/
out
;
}
if
(
ddOut
)
{
auto
ddout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"SqrtGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
0.5
)
/
out
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
template
<
typename
T
>
struct
RsqrtGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
framework
::
Tensor
*
Out
,
const
framework
::
Tensor
*
ddX
,
framework
::
Tensor
*
ddOut
,
framework
::
Tensor
*
dOut
,
const
framework
::
Tensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"RsqrtGradGrad"
));
auto
out
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
Out
,
"Output"
,
"Out"
,
"RsqrtGradGrad"
));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if
(
dOut
)
{
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"RsqrtGradGrad"
));
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"RsqrtGradGrad"
));
dout
.
device
(
*
d
)
=
(
static_cast
<
T
>
(
3.0
)
/
out
)
*
dx
*
ddx
;
}
if
(
ddOut
)
{
auto
ddout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"RsqrtGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
-
0.5
)
*
out
*
out
*
out
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
// relu6(x) = min(max(0, x), 6)
// relu6(x) = min(max(0, x), 6)
template
<
typename
T
>
template
<
typename
T
>
struct
Relu6Functor
:
public
BaseActivationFunctor
<
T
>
{
struct
Relu6Functor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -498,51 +441,6 @@ class ELUGradKernel : public framework::OpKernel<T> {
...
@@ -498,51 +441,6 @@ class ELUGradKernel : public framework::OpKernel<T> {
}
}
};
};
template
<
typename
T
>
struct
CELUFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
x
<
static_cast
<
T
>
(
0
))
.
select
(
static_cast
<
T
>
(
alpha
)
*
((
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
-
static_cast
<
T
>
(
1
)),
x
);
}
};
template
<
typename
T
>
struct
CELUGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp_a_pos
=
static_cast
<
T
>
(
alpha
>
0
);
auto
temp_a_neg
=
static_cast
<
T
>
(
alpha
<=
0
);
auto
temp_x_pos
=
(
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
auto
temp_x_neg
=
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx
.
device
(
d
)
=
dout
*
temp_a_pos
*
temp_x_pos
+
dout
*
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
temp_a_pos
*
temp_x_neg
+
dout
*
temp_a_neg
*
temp_x_pos
+
dout
*
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
temp_a_neg
*
temp_x_neg
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
AbsGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
AbsGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
template
<
typename
Device
>
...
@@ -564,74 +462,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
...
@@ -564,74 +462,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
CELUGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
framework
::
Tensor
*
X
,
const
framework
::
Tensor
*
ddX
,
framework
::
Tensor
*
ddOut
,
const
framework
::
Tensor
*
dOut
,
framework
::
Tensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"CELUGradGrad"
));
auto
x
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"CELUGradGrad"
));
if
(
dX
)
{
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"CELUGradGrad"
));
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"CELUGradGrad"
));
dx
.
device
(
*
d
)
=
ddx
*
dout
/
static_cast
<
T
>
(
alpha
)
*
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
}
if
(
ddOut
)
{
auto
ddout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"CELUGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
((
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
+
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>())
.
template
cast
<
T
>();
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
SquareGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
framework
::
Tensor
*
X
,
const
framework
::
Tensor
*
ddX
,
framework
::
Tensor
*
ddOut
,
const
framework
::
Tensor
*
dOut
,
framework
::
Tensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"SquareGradGrad"
));
auto
x
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"SquareGradGrad"
));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx
if
(
dX
)
{
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"SquareGradGrad"
));
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"SquareGradGrad"
));
dx
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
2
)
*
dout
;
}
if
(
ddOut
)
{
auto
ddout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"SquareGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
2
)
*
x
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from
// DOut(dy) as input(not output), tensor extraction is different from
// others. Impliment extraction kernel separately here.
// others. Impliment extraction kernel separately here.
...
@@ -675,29 +505,6 @@ inline void ExtractDoubleGradTensorWithInputDOut(
...
@@ -675,29 +505,6 @@ inline void ExtractDoubleGradTensorWithInputDOut(
}
}
}
}
template
<
typename
DeviceContext
,
typename
Functor
>
class
SquareDoubleGradKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
Tensor
*
X
,
*
ddX
,
*
dOut
;
X
=
ddX
=
dOut
=
nullptr
;
framework
::
Tensor
*
dX
,
*
ddOut
;
dX
=
ddOut
=
nullptr
;
ExtractDoubleGradTensorWithInputDOut
(
ctx
,
&
X
,
&
ddX
,
&
dX
,
&
dOut
,
&
ddOut
);
if
(
dX
)
dX
->
mutable_data
<
T
>
(
X
->
dims
(),
ctx
.
GetPlace
());
if
(
ddOut
)
ddOut
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
place
=
ctx
.
template
device_context
<
DeviceContext
>();
Functor
functor
;
functor
(
place
,
X
,
ddX
,
ddOut
,
dOut
,
dX
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
SoftsignFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
SoftsignFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
template
<
typename
Device
,
typename
X
,
typename
Out
>
...
@@ -721,153 +528,6 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
...
@@ -721,153 +528,6 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
CELUDoubleGradKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
Tensor
*
X
,
*
ddX
,
*
dOut
;
X
=
ddX
=
dOut
=
nullptr
;
framework
::
Tensor
*
dX
,
*
ddOut
;
dX
=
ddOut
=
nullptr
;
ExtractDoubleGradTensorWithInputDOut
(
ctx
,
&
X
,
&
ddX
,
&
dX
,
&
dOut
,
&
ddOut
);
if
(
dX
)
dX
->
mutable_data
<
T
>
(
X
->
dims
(),
ctx
.
GetPlace
());
if
(
ddOut
)
ddOut
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
place
=
ctx
.
template
device_context
<
DeviceContext
>();
Functor
functor
;
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
ctx
.
Attr
<
float
>
(
attr
.
first
);
}
functor
(
place
,
X
,
ddX
,
ddOut
,
dOut
,
dX
);
}
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
SqrtDoubleGradKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
Tensor
*
Out
,
*
dX
,
*
ddX
;
Out
=
dX
=
ddX
=
nullptr
;
framework
::
Tensor
*
ddOut
,
*
dOut
;
ddOut
=
dOut
=
nullptr
;
// extract ddx(input), ddout(output)
auto
ddx_var
=
ctx
.
InputVar
(
"DDX"
);
auto
ddo_var
=
ctx
.
OutputVar
(
"DDOut"
);
PADDLE_ENFORCE_NOT_NULL
(
ddx_var
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable DDX, variable name = %s"
,
ctx
.
InputName
(
"DDX"
)));
ddX
=
ctx
.
Input
<
framework
::
Tensor
>
(
"DDX"
);
if
(
ddo_var
)
{
ddOut
=
ctx
.
Output
<
framework
::
Tensor
>
(
"DDOut"
);
}
PADDLE_ENFORCE_NOT_NULL
(
ddX
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable DDX, variable name = %s"
,
ctx
.
InputName
(
"DDX"
)));
// extract out(input), dout(output)
auto
out_var
=
ctx
.
InputVar
(
"Out"
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable Out, variable name = %s"
,
ctx
.
InputName
(
"Out"
)));
auto
dout_var
=
ctx
.
OutputVar
(
"DOut"
);
Out
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Out"
);
if
(
dout_var
)
{
dOut
=
ctx
.
Output
<
framework
::
Tensor
>
(
"DOut"
);
}
// extract dx(input)
auto
dx_var
=
ctx
.
InputVar
(
"DX"
);
PADDLE_ENFORCE_NOT_NULL
(
dx_var
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable DX, variable name = %s"
,
ctx
.
InputName
(
"DX"
)));
if
(
dx_var
)
{
dX
=
ctx
.
Input
<
framework
::
Tensor
>
(
"DX"
);
}
if
(
dOut
)
dOut
->
mutable_data
<
T
>
(
Out
->
dims
(),
ctx
.
GetPlace
());
if
(
ddOut
)
ddOut
->
mutable_data
<
T
>
(
Out
->
dims
(),
ctx
.
GetPlace
());
auto
&
place
=
ctx
.
template
device_context
<
DeviceContext
>();
Functor
functor
;
functor
(
place
,
Out
,
ddX
,
ddOut
,
dOut
,
dX
);
}
};
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
template
<
typename
DeviceContext
,
typename
Functor
>
class
RsqrtDoubleGradKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
Tensor
*
Out
,
*
dX
,
*
ddX
;
Out
=
dX
=
ddX
=
nullptr
;
framework
::
Tensor
*
ddOut
,
*
dOut
;
ddOut
=
dOut
=
nullptr
;
// extract ddx(input), ddout(output)
auto
ddx_var
=
ctx
.
InputVar
(
"DDX"
);
auto
ddo_var
=
ctx
.
OutputVar
(
"DDOut"
);
PADDLE_ENFORCE_NOT_NULL
(
ddx_var
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable DDX, variable name = %s"
,
ctx
.
InputName
(
"DDX"
)));
ddX
=
ctx
.
Input
<
framework
::
Tensor
>
(
"DDX"
);
if
(
ddo_var
)
{
ddOut
=
ctx
.
Output
<
framework
::
Tensor
>
(
"DDOut"
);
}
PADDLE_ENFORCE_NOT_NULL
(
ddX
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable DDX, variable name = %s"
,
ctx
.
InputName
(
"DDX"
)));
// extract out(input), dout(output)
auto
out_var
=
ctx
.
InputVar
(
"Out"
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable Out, variable name = %s"
,
ctx
.
InputName
(
"Out"
)));
auto
dout_var
=
ctx
.
OutputVar
(
"DOut"
);
Out
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Out"
);
if
(
dout_var
)
{
dOut
=
ctx
.
Output
<
framework
::
Tensor
>
(
"DOut"
);
}
// extract dx(input)
auto
dx_var
=
ctx
.
InputVar
(
"DX"
);
PADDLE_ENFORCE_NOT_NULL
(
dx_var
,
platform
::
errors
::
NotFound
(
"Cannot get input Variable DX, variable name = %s"
,
ctx
.
InputName
(
"DX"
)));
if
(
dx_var
)
{
dX
=
ctx
.
Input
<
framework
::
Tensor
>
(
"DX"
);
}
if
(
dOut
)
dOut
->
mutable_data
<
T
>
(
Out
->
dims
(),
ctx
.
GetPlace
());
if
(
ddOut
)
ddOut
->
mutable_data
<
T
>
(
Out
->
dims
(),
ctx
.
GetPlace
());
auto
&
place
=
ctx
.
template
device_context
<
DeviceContext
>();
Functor
functor
;
functor
(
place
,
Out
,
ddX
,
ddOut
,
dOut
,
dX
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/operators/activation_op.kps
浏览文件 @
191c441a
...
@@ -126,59 +126,6 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
...
@@ -126,59 +126,6 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};
template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout *
(temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename Functor>
template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
...
@@ -357,79 +304,35 @@ namespace plat = paddle::platform;
...
@@ -357,79 +304,35 @@ namespace plat = paddle::platform;
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::bfloat16>>);
ops::grad_functor<plat::bfloat16>>);
/* ========================================================================== */
/* ======================== celu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
CudaCELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
relu6, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CudaRelu6Functor<float>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CudaRelu6Functor<double>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);
ops::CudaRelu6Functor<int>>,
/* ========================================================================== */
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6Functor<int64_t>>,
/* =========================== sqrt register ============================= */
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6Functor<plat::float16>>,
REGISTER_OP_CUDA_KERNEL(
ops::ActivationCudaKernel<plat::CUDADeviceContext,
sqrt_grad_grad,
ops::CudaRelu6Functor<plat::bfloat16>>);
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::bfloat16>>);
/* ========================================================================== */
/* =========================== rsqrt register =============================
*/
REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL(
rsqrt_grad_grad,
relu6_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6GradFunctor<float>>,
ops::RsqrtGradGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6GradFunctor<double>>,
ops::RsqrtGradGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6GradFunctor<int>>,
ops::RsqrtGradGradFunctor<plat::float16>>);
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
/* ========================================================================== */
ops::CudaRelu6GradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
/* =========================== square register ============================ */
ops::CudaRelu6GradFunctor<plat::float16>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
REGISTER_OP_CUDA_KERNEL(
ops::CudaRelu6GradFunctor<plat::bfloat16>>);
square_grad_grad,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::float16>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::bfloat16>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<int>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
/* ========================================================================== */
/* ========================== exp register ============================ */
/* ========================================================================== */
/* ========================== expm1 register ============================ */
/* ========================================================================== */
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \
__macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);
__macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
...
@@ -452,13 +355,14 @@ REGISTER_OP_KERNEL(
...
@@ -452,13 +355,14 @@ REGISTER_OP_KERNEL(
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
ops::CudaZeroGradFunctor<float>>);
ops::CudaZeroGradFunctor<float>>);
REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace,
REGISTER_OP_KERNEL(
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
celu, KP, plat::XPUPlace,
ops::CudaCELUFunctor<float>>);
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
phi::funcs::CudaCELUFunctor<float>>);
REGISTER_OP_KERNEL(
REGISTER_OP_KERNEL(
celu_grad, KP, plat::XPUPlace,
celu_grad, KP, plat::XPUPlace,
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
op
s::CudaCELUGradFunctor<float>>);
phi::func
s::CudaCELUGradFunctor<float>>);
REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace,
REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace,
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
...
...
paddle/phi/kernels/activation_grad_kernel.h
浏览文件 @
191c441a
...
@@ -150,6 +150,39 @@ void LogDoubleGradKernel(const Context& dev_ctx,
...
@@ -150,6 +150,39 @@ void LogDoubleGradKernel(const Context& dev_ctx,
DenseTensor
*
dx
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
);
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
SqrtDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dx
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dout
,
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
RsqrtDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dx
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dout
,
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
CeluDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
float
alpha
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
SquareDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
HardSwishGradKernel
(
const
Context
&
dev_ctx
,
void
HardSwishGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
@@ -200,6 +233,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda);
...
@@ -200,6 +233,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Swish
,
beta
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Swish
,
beta
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Logit
,
eps
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Logit
,
eps
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Celu
,
alpha
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
t_min
,
t_max
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
t_min
,
t_max
);
...
...
paddle/phi/kernels/activation_kernel.h
浏览文件 @
191c441a
...
@@ -78,6 +78,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
...
@@ -78,6 +78,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
threshold
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
threshold
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Elu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Elu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Swish
,
beta
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Swish
,
beta
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
celu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
t_min
,
t_max
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
t_min
,
t_max
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
STanh
,
scale_a
,
scale_b
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
STanh
,
scale_a
,
scale_b
)
...
...
paddle/phi/kernels/cpu/activation_grad_kernel.cc
浏览文件 @
191c441a
...
@@ -167,6 +167,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta);
...
@@ -167,6 +167,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Mish
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Mish
,
MishGradFunctor
,
MishGradFunctor
,
threshold
);
threshold
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Celu
,
CELUGradFunctor
,
alpha
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
BReluGradFunctor
,
BReluGradFunctor
,
...
@@ -281,6 +282,10 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
...
@@ -281,6 +282,10 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
leaky_relu_double_grad
,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
leaky_relu_double_grad
,
LeakyReluDoubleGradKernel
)
LeakyReluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
elu_double_grad
,
EluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
elu_double_grad
,
EluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
sqrt_double_grad
,
SqrtDoubleGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
rsqrt_double_grad
,
RsqrtDoubleGradKernel
)
PD_REGISTER_KERNEL
(
tanh_triple_grad
,
PD_REGISTER_KERNEL
(
tanh_triple_grad
,
CPU
,
CPU
,
...
@@ -317,6 +322,15 @@ PD_REGISTER_KERNEL(square_grad,
...
@@ -317,6 +322,15 @@ PD_REGISTER_KERNEL(square_grad,
double
,
double
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PD_REGISTER_KERNEL
(
square_double_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
SquareDoubleGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
int
,
int64_t
)
{}
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_grad
,
SigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_grad
,
SigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_double_grad
,
SigmoidDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_double_grad
,
SigmoidDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_triple_grad
,
SigmoidTripleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_triple_grad
,
SigmoidTripleGradKernel
)
...
@@ -332,6 +346,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
...
@@ -332,6 +346,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
round_grad
,
RoundGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
round_grad
,
RoundGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
floor_grad
,
FloorGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
floor_grad
,
FloorGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
ceil_grad
,
CeilGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
ceil_grad
,
CeilGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
celu_grad
,
CeluGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
celu_double_grad
,
CeluDoubleGradKernel
)
PD_REGISTER_KERNEL
(
pow_grad
,
PD_REGISTER_KERNEL
(
pow_grad
,
CPU
,
CPU
,
...
...
paddle/phi/kernels/cpu/activation_kernel.cc
浏览文件 @
191c441a
...
@@ -90,19 +90,19 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
...
@@ -90,19 +90,19 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL
(
Ceil
,
CeilFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Ceil
,
CeilFunctor
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
ThresholdedReluFunctor
,
ThresholdedReluFunctor
,
threshold
)
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Mish
,
MishFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Mish
,
MishFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
BReluFunctor
,
t_min
,
t_max
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
STanh
,
STanhFunctor
,
scale_a
,
scale_b
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
Softplus
,
SoftplusFunctor
,
beta
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
HardShrinkFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
HardShrinkFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
SoftShrinkFunctor
,
lambda
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
SoftShrinkFunctor
,
lambda
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
ELUFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
ELUFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Swish
,
SwishFunctor
,
beta
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Swish
,
SwishFunctor
,
beta
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Celu
,
CELUFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
BReluFunctor
,
t_min
,
t_max
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
STanh
,
STanhFunctor
,
scale_a
,
scale_b
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
Softplus
,
SoftplusFunctor
,
beta
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
HardSigmoidFunctor
,
HardSigmoidFunctor
,
slope
,
slope
,
...
@@ -181,5 +181,6 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
...
@@ -181,5 +181,6 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
round
,
RoundKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
round
,
RoundKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
celu
,
CeluKernel
)
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
pow
,
CPU
,
ALL_LAYOUT
,
phi
::
PowKernel
,
float
,
double
,
int
,
int64_t
)
{}
pow
,
CPU
,
ALL_LAYOUT
,
phi
::
PowKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
191c441a
...
@@ -1832,6 +1832,196 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
...
@@ -1832,6 +1832,196 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template
<
typename
T
>
struct
SqrtGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
DenseTensor
*
Out
,
const
DenseTensor
*
dX
,
const
DenseTensor
*
ddX
,
DenseTensor
*
dOut
,
DenseTensor
*
ddOut
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"SqrtGradGrad"
));
auto
out
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
Out
,
"Output"
,
"Out"
,
"SqrtGradGrad"
));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if
(
dOut
)
{
auto
dx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"SqrtGradGrad"
));
auto
dout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"SqrtGradGrad"
));
dout
.
device
(
*
d
)
=
dx
*
ddx
*
static_cast
<
T
>
(
-
1
)
/
out
;
}
if
(
ddOut
)
{
auto
ddout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"SqrtGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
0.5
)
/
out
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
template
<
typename
T
>
struct
RsqrtGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
DenseTensor
*
Out
,
const
DenseTensor
*
dX
,
const
DenseTensor
*
ddX
,
DenseTensor
*
dOut
,
DenseTensor
*
ddOut
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"RsqrtGradGrad"
));
auto
out
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
Out
,
"Output"
,
"Out"
,
"RsqrtGradGrad"
));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if
(
dOut
)
{
auto
dx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"RsqrtGradGrad"
));
auto
dout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"RsqrtGradGrad"
));
dout
.
device
(
*
d
)
=
(
static_cast
<
T
>
(
3.0
)
/
out
)
*
dx
*
ddx
;
}
if
(
ddOut
)
{
auto
ddout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"RsqrtGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
-
0.5
)
*
out
*
out
*
out
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
template
<
typename
T
>
struct
CELUFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
x
<
static_cast
<
T
>
(
0
))
.
select
(
static_cast
<
T
>
(
alpha
)
*
((
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
-
static_cast
<
T
>
(
1
)),
x
);
}
};
template
<
typename
T
>
struct
CELUGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp_a_pos
=
static_cast
<
T
>
(
alpha
>
0
);
auto
temp_a_neg
=
static_cast
<
T
>
(
alpha
<=
0
);
auto
temp_x_pos
=
(
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
auto
temp_x_neg
=
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx
.
device
(
d
)
=
dout
*
temp_a_pos
*
temp_x_pos
+
dout
*
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
temp_a_pos
*
temp_x_neg
+
dout
*
temp_a_neg
*
temp_x_pos
+
dout
*
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
temp_a_neg
*
temp_x_neg
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CELUGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
DenseTensor
*
X
,
const
DenseTensor
*
dOut
,
const
DenseTensor
*
ddX
,
DenseTensor
*
dX
,
DenseTensor
*
ddOut
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"CELUGradGrad"
));
auto
x
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"CELUGradGrad"
));
if
(
dX
)
{
auto
dx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"CELUGradGrad"
));
auto
dout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"CELUGradGrad"
));
dx
.
device
(
*
d
)
=
ddx
*
dout
/
static_cast
<
T
>
(
alpha
)
*
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
}
if
(
ddOut
)
{
auto
ddout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"CELUGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
((
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
+
(
x
/
static_cast
<
T
>
(
alpha
)).
exp
()
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>())
.
template
cast
<
T
>();
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
SquareGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
DenseTensor
*
X
,
const
DenseTensor
*
dOut
,
const
DenseTensor
*
ddX
,
DenseTensor
*
dX
,
DenseTensor
*
ddOut
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"SquareGradGrad"
));
auto
x
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"SquareGradGrad"
));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx
if
(
dX
)
{
auto
dx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"SquareGradGrad"
));
auto
dout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"SquareGradGrad"
));
dx
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
2
)
*
dout
;
}
if
(
ddOut
)
{
auto
ddout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"SquareGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
2
)
*
x
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template
<
typename
T
>
template
<
typename
T
>
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -3091,6 +3281,59 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
...
@@ -3091,6 +3281,59 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template
<
typename
T
>
struct
CudaCELUFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
CT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
CT
zero
=
static_cast
<
CT
>
(
0.0
f
);
CT
one
=
static_cast
<
CT
>
(
1.0
f
);
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
CT
x
=
static_cast
<
CT
>
(
arg_x
);
CT
temp
=
static_cast
<
CT
>
(
alpha
)
*
(
exp
(
x
/
static_cast
<
CT
>
(
alpha
))
-
one
);
CT
res
=
(
x
>
zero
?
x
:
zero
)
+
(
temp
>
zero
?
zero
:
temp
);
return
static_cast
<
T
>
(
res
);
}
};
template
<
typename
T
>
struct
CudaCELUGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
zero
=
static_cast
<
MPType
>
(
0.0
f
);
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__
__forceinline__
T
operator
()(
const
T
arg_dout
,
const
T
arg_x
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
arg_dout
);
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
a
=
static_cast
<
MPType
>
(
alpha
);
MPType
temp_a_pos
=
static_cast
<
MPType
>
(
alpha
>
0.0
f
);
MPType
temp_a_neg
=
static_cast
<
MPType
>
(
alpha
<=
0.0
f
);
MPType
temp_x_pos
=
static_cast
<
MPType
>
(
x
>
zero
);
MPType
temp_x_neg
=
static_cast
<
MPType
>
(
x
<=
zero
);
return
static_cast
<
T
>
(
dout
*
(
temp_a_pos
*
temp_x_pos
+
temp_a_pos
*
temp_x_neg
*
exp
(
x
/
a
)
+
temp_a_neg
*
temp_x_pos
+
exp
(
x
/
a
)
*
temp_a_neg
*
temp_x_neg
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
#endif
#endif
}
// namespace funcs
}
// namespace funcs
...
...
paddle/phi/kernels/gpu/activation_grad_kernel.cu
浏览文件 @
191c441a
...
@@ -221,6 +221,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
...
@@ -221,6 +221,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Mish
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Mish
,
CudaMishGradFunctor
,
CudaMishGradFunctor
,
threshold
);
threshold
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Celu
,
CudaCELUGradFunctor
,
alpha
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
CudaBReluGradFunctor
,
CudaBReluGradFunctor
,
...
@@ -351,7 +354,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
...
@@ -351,7 +354,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
reciprocal_grad
,
ReciprocalGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
reciprocal_grad
,
ReciprocalGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
softplus_grad
,
SoftplusGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
softplus_grad
,
SoftplusGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sqrt_grad
,
SqrtGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sqrt_grad
,
SqrtGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sqrt_double_grad
,
SqrtDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
rsqrt_grad
,
RsqrtGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
rsqrt_grad
,
RsqrtGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
rsqrt_double_grad
,
RsqrtDoubleGradKernel
)
PD_REGISTER_KERNEL
(
exp_grad
,
PD_REGISTER_KERNEL
(
exp_grad
,
GPU
,
GPU
,
...
@@ -396,6 +401,16 @@ PD_REGISTER_KERNEL(square_grad,
...
@@ -396,6 +401,16 @@ PD_REGISTER_KERNEL(square_grad,
int64_t
,
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
square_double_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
SquareDoubleGradKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_grad
,
SigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_grad
,
SigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_double_grad
,
SigmoidDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_double_grad
,
SigmoidDoubleGradKernel
)
...
@@ -418,6 +433,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
...
@@ -418,6 +433,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
round_grad
,
RoundGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
round_grad
,
RoundGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
floor_grad
,
FloorGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
floor_grad
,
FloorGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
ceil_grad
,
CeilGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
ceil_grad
,
CeilGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
celu_grad
,
CeluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
celu_double_grad
,
CeluDoubleGradKernel
)
PD_REGISTER_KERNEL
(
pow_grad
,
PD_REGISTER_KERNEL
(
pow_grad
,
GPU
,
GPU
,
...
...
paddle/phi/kernels/gpu/activation_kernel.cu
浏览文件 @
191c441a
...
@@ -118,8 +118,8 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
...
@@ -118,8 +118,8 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
CudaSoftShrinkFunctor
,
lambda
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
CudaSoftShrinkFunctor
,
lambda
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
CudaELUFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
CudaELUFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Swish
,
CudaSwishFunctor
,
beta
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Swish
,
CudaSwishFunctor
,
beta
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Mish
,
CudaMishFunctor
,
threshold
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Mish
,
CudaMishFunctor
,
threshold
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Celu
,
CudaCELUFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
Stanh
,
CudaSTanhFunctor
,
scale_a
,
scale_b
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
Stanh
,
CudaSTanhFunctor
,
scale_a
,
scale_b
)
...
@@ -234,6 +234,7 @@ PD_REGISTER_KERNEL(square,
...
@@ -234,6 +234,7 @@ PD_REGISTER_KERNEL(square,
int64_t
,
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_ACTIVATION_KERNEL
(
hard_shrink
,
HardShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
hard_shrink
,
HardShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
soft_shrink
,
SoftShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
soft_shrink
,
SoftShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
tanh_shrink
,
TanhShrinkKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
tanh_shrink
,
TanhShrinkKernel
)
...
@@ -251,6 +252,7 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
...
@@ -251,6 +252,7 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
round
,
RoundKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
round
,
RoundKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
celu
,
CeluKernel
)
PD_REGISTER_KERNEL
(
pow
,
PD_REGISTER_KERNEL
(
pow
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/impl/activation_grad_impl.h
浏览文件 @
191c441a
...
@@ -335,4 +335,87 @@ void PowGradKernel(const Context& dev_ctx,
...
@@ -335,4 +335,87 @@ void PowGradKernel(const Context& dev_ctx,
functor
(
*
place
,
x_flatten
,
nullptr
,
dout_flatten
,
dx_flatten
);
functor
(
*
place
,
x_flatten
,
nullptr
,
dout_flatten
,
dx_flatten
);
}
}
template
<
typename
T
,
typename
Context
>
void
SqrtDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dx
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dout
,
DenseTensor
*
ddout
)
{
if
(
dout
)
{
dout
->
Resize
(
out
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dout
);
}
if
(
ddout
)
{
ddout
->
Resize
(
out
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
ddout
);
}
phi
::
funcs
::
SqrtGradGradFunctor
<
T
>
functor
;
functor
(
dev_ctx
,
&
out
,
&
dx
,
&
ddx
,
dout
,
ddout
);
}
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
template
<
typename
T
,
typename
Context
>
void
RsqrtDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dx
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dout
,
DenseTensor
*
ddout
)
{
if
(
dout
)
{
dout
->
Resize
(
out
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dout
);
}
if
(
ddout
)
{
ddout
->
Resize
(
out
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
ddout
);
}
phi
::
funcs
::
RsqrtGradGradFunctor
<
T
>
functor
;
functor
(
dev_ctx
,
&
out
,
&
dx
,
&
ddx
,
dout
,
ddout
);
}
template
<
typename
T
,
typename
Context
>
void
CeluDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
float
alpha
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
)
{
if
(
dx
)
{
dx
->
Resize
(
x
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dx
);
}
if
(
ddout
)
{
dev_ctx
.
template
Alloc
<
T
>(
ddout
);
}
phi
::
funcs
::
CELUGradGradFunctor
<
T
>
functor
;
auto
attrs
=
functor
.
GetAttrs
();
*
(
attrs
[
0
].
second
)
=
alpha
;
functor
(
dev_ctx
,
&
x
,
&
dout
,
&
ddx
,
dx
,
ddout
);
}
template
<
typename
T
,
typename
Context
>
void
SquareDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
)
{
if
(
dx
)
{
dx
->
Resize
(
x
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dx
);
}
if
(
ddout
)
{
dev_ctx
.
template
Alloc
<
T
>(
ddout
);
}
phi
::
funcs
::
SquareGradGradFunctor
<
T
>
functor
;
functor
(
dev_ctx
,
&
x
,
&
dout
,
&
ddx
,
dx
,
ddout
);
}
}
// namespace phi
}
// namespace phi
paddle/phi/ops/compat/activation_sig.cc
浏览文件 @
191c441a
...
@@ -67,6 +67,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
...
@@ -67,6 +67,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log2
,
"log2"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log2
,
"log2"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log10
,
"log10"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log10
,
"log10"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log1p
,
"log1p"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log1p
,
"log1p"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Celu
,
"celu"
,
"alpha"
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
HardSwish
,
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
HardSwish
,
"hard_swish"
,
"hard_swish"
,
"threshold"
comma
"scale"
comma
"threshold"
comma
"scale"
comma
...
@@ -181,6 +182,30 @@ KernelSignature LogDoubleGradOpArgumentMapping(
...
@@ -181,6 +182,30 @@ KernelSignature LogDoubleGradOpArgumentMapping(
"log_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{},
{
"DX"
,
"DDOut"
});
"log_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{},
{
"DX"
,
"DDOut"
});
}
}
KernelSignature
SqrtDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"sqrt_double_grad"
,
{
"Out"
,
"DX"
,
"DDX"
},
{},
{
"DOut"
,
"DDOut"
});
}
KernelSignature
RsqrtDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"rsqrt_double_grad"
,
{
"Out"
,
"DX"
,
"DDX"
},
{},
{
"DOut"
,
"DDOut"
});
}
KernelSignature
CeluDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"celu_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{
"alpha"
},
{
"DX"
,
"DDOut"
});
}
KernelSignature
SquareDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"square_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{},
{
"DX"
,
"DDOut"
});
}
KernelSignature
PowOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
KernelSignature
PowOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasInput
(
"FactorTensor"
))
{
if
(
ctx
.
HasInput
(
"FactorTensor"
))
{
return
KernelSignature
(
"pow"
,
{
"X"
},
{
"FactorTensor"
},
{
"Out"
});
return
KernelSignature
(
"pow"
,
{
"X"
},
{
"FactorTensor"
},
{
"Out"
});
...
@@ -209,6 +234,10 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad);
...
@@ -209,6 +234,10 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad);
PD_REGISTER_BASE_KERNEL_NAME
(
elu_grad_grad
,
elu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elu_grad_grad
,
elu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
sigmoid_grad_grad
,
sigmoid_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
sigmoid_grad_grad
,
sigmoid_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
log_grad_grad
,
log_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
log_grad_grad
,
log_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
sqrt_grad_grad
,
sqrt_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
rsqrt_grad_grad
,
rsqrt_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
celu_grad_grad
,
celu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
square_grad_grad
,
square_double_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
cos_grad
,
phi
::
CosGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
cos_grad
,
phi
::
CosGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
tan_grad
,
phi
::
TanGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
tan_grad
,
phi
::
TanGradOpArgumentMapping
);
...
@@ -229,7 +258,11 @@ PD_REGISTER_ARG_MAPPING_FN(square_grad, phi::SquareGradOpArgumentMapping);
...
@@ -229,7 +258,11 @@ PD_REGISTER_ARG_MAPPING_FN(square_grad, phi::SquareGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN
(
reciprocal_grad
,
PD_REGISTER_ARG_MAPPING_FN
(
reciprocal_grad
,
phi
::
ReciprocalGradOpArgumentMapping
);
phi
::
ReciprocalGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
sqrt_grad
,
phi
::
SqrtGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
sqrt_grad
,
phi
::
SqrtGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
sqrt_grad_grad
,
phi
::
SqrtDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
rsqrt_grad
,
phi
::
RsqrtGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
rsqrt_grad
,
phi
::
RsqrtGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
rsqrt_grad_grad
,
phi
::
RsqrtDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
mish_grad
,
phi
::
MishGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
mish_grad
,
phi
::
MishGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
stanh_grad
,
phi
::
STanhGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
stanh_grad
,
phi
::
STanhGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
softplus_grad
,
phi
::
SoftplusGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
softplus_grad
,
phi
::
SoftplusGradOpArgumentMapping
);
...
@@ -286,3 +319,8 @@ PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping);
...
@@ -286,3 +319,8 @@ PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN
(
ceil_grad
,
phi
::
CeilGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
ceil_grad
,
phi
::
CeilGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
pow_grad
,
phi
::
PowGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
pow_grad
,
phi
::
PowGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
pow
,
phi
::
PowOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
pow
,
phi
::
PowOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
celu_grad
,
phi
::
CeluGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
celu_grad_grad
,
phi
::
CeluDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
square_grad_grad
,
phi
::
SquareDoubleGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录