Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
be5918e0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
be5918e0
编写于
3月 25, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
3月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move activation (#40913)
上级
c33b4f95
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
625 addition
and
434 deletion
+625
-434
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+6
-12
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+16
-257
paddle/fluid/operators/activation_op.kps
paddle/fluid/operators/activation_op.kps
+15
-165
paddle/phi/kernels/activation_grad_kernel.h
paddle/phi/kernels/activation_grad_kernel.h
+27
-0
paddle/phi/kernels/activation_kernel.h
paddle/phi/kernels/activation_kernel.h
+20
-0
paddle/phi/kernels/cpu/activation_grad_kernel.cc
paddle/phi/kernels/cpu/activation_grad_kernel.cc
+45
-0
paddle/phi/kernels/cpu/activation_kernel.cc
paddle/phi/kernels/cpu/activation_kernel.cc
+27
-0
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+300
-0
paddle/phi/kernels/gpu/activation_grad_kernel.cu
paddle/phi/kernels/gpu/activation_grad_kernel.cu
+48
-0
paddle/phi/kernels/gpu/activation_kernel.cu
paddle/phi/kernels/gpu/activation_kernel.cu
+34
-0
paddle/phi/kernels/impl/activation_grad_impl.h
paddle/phi/kernels/impl/activation_grad_impl.h
+24
-0
paddle/phi/kernels/impl/activation_impl.h
paddle/phi/kernels/impl/activation_impl.h
+19
-0
paddle/phi/ops/compat/activation_sig.cc
paddle/phi/ops/compat/activation_sig.cc
+44
-0
未找到文件。
paddle/fluid/operators/activation_op.cc
浏览文件 @
be5918e0
...
@@ -1499,6 +1499,12 @@ REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor,
...
@@ -1499,6 +1499,12 @@ REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor,
REGISTER_ACTIVATION_OP
(
log2
,
Log2
,
Log2Functor
,
Log2GradFunctor
);
REGISTER_ACTIVATION_OP
(
log2
,
Log2
,
Log2Functor
,
Log2GradFunctor
);
REGISTER_ACTIVATION_OP
(
log10
,
Log10
,
Log10Functor
,
Log10GradFunctor
);
REGISTER_ACTIVATION_OP
(
log10
,
Log10
,
Log10Functor
,
Log10GradFunctor
);
REGISTER_ACTIVATION_OP
(
log1p
,
Log1p
,
Log1pFunctor
,
Log1pGradFunctor
);
REGISTER_ACTIVATION_OP
(
log1p
,
Log1p
,
Log1pFunctor
,
Log1pGradFunctor
);
REGISTER_ACTIVATION_OP
(
hard_swish
,
HardSwish
,
HardSwishFunctor
,
HardSwishGradFunctor
);
REGISTER_ACTIVATION_OP
(
swish
,
Swish
,
SwishFunctor
,
SwishGradFunctor
);
REGISTER_ACTIVATION_OP
(
round
,
Round
,
RoundFunctor
,
ZeroGradFunctor
);
REGISTER_ACTIVATION_OP
(
floor
,
Floor
,
FloorFunctor
,
ZeroGradFunctor
);
REGISTER_ACTIVATION_OP
(
ceil
,
Ceil
,
CeilFunctor
,
ZeroGradFunctor
);
/* ========================== sigmoid register =============================
/* ========================== sigmoid register =============================
*/
*/
...
@@ -1778,18 +1784,6 @@ REGISTER_OPERATOR(
...
@@ -1778,18 +1784,6 @@ REGISTER_OPERATOR(
ops
::
ActFwdInplaceInferer
,
void
>::
type
);
ops
::
ActFwdInplaceInferer
,
void
>::
type
);
REGISTER_OPERATOR
(
pow_grad
,
ops
::
PowOpGrad
,
REGISTER_OPERATOR
(
pow_grad
,
ops
::
PowOpGrad
,
ops
::
ActivationGradOpInplaceInferer
);
ops
::
ActivationGradOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
pow
,
ops
::
PowKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowFunctor
<
float
>>
,
ops
::
PowKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowFunctor
<
double
>>
,
ops
::
PowKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowFunctor
<
int
>>
,
ops
::
PowKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowFunctor
<
int64_t
>>
);
REGISTER_OP_CPU_KERNEL
(
pow_grad
,
ops
::
PowGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowGradFunctor
<
float
>>
,
ops
::
PowGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowGradFunctor
<
double
>>
,
ops
::
PowGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowGradFunctor
<
int
>>
,
ops
::
PowGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
PowGradFunctor
<
int64_t
>>
);
/* ========================================================================== */
/* ========================================================================== */
/* ========================== exp register ============================ */
/* ========================== exp register ============================ */
...
...
paddle/fluid/operators/activation_op.h
浏览文件 @
be5918e0
...
@@ -286,10 +286,25 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Log)
...
@@ -286,10 +286,25 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Log)
USE_PHI_FUNCTOR
(
Log2
)
USE_PHI_FUNCTOR
(
Log2
)
USE_PHI_FUNCTOR
(
Log10
)
USE_PHI_FUNCTOR
(
Log10
)
USE_PHI_FUNCTOR
(
Log1p
)
USE_PHI_FUNCTOR
(
Log1p
)
USE_PHI_FUNCTOR
(
Swish
)
USE_PHI_FUNCTOR
(
HardSwish
)
USE_PHI_FUNCTOR
(
Pow
)
template
<
typename
T
>
template
<
typename
T
>
using
ELUGradNegativeAlphaFunctor
=
phi
::
funcs
::
ELUGradNegativeAlphaFunctor
<
T
>
;
using
ELUGradNegativeAlphaFunctor
=
phi
::
funcs
::
ELUGradNegativeAlphaFunctor
<
T
>
;
template
<
typename
T
>
using
RoundFunctor
=
phi
::
funcs
::
RoundFunctor
<
T
>
;
template
<
typename
T
>
using
FloorFunctor
=
phi
::
funcs
::
FloorFunctor
<
T
>
;
template
<
typename
T
>
using
CeilFunctor
=
phi
::
funcs
::
CeilFunctor
<
T
>
;
template
<
typename
T
>
using
ZeroGradFunctor
=
phi
::
funcs
::
ZeroGradFunctor
<
T
>
;
// exp(x) = e^x
// exp(x) = e^x
template
<
typename
T
>
template
<
typename
T
>
struct
ExpFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
ExpFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -391,46 +406,6 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
...
@@ -391,46 +406,6 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
// ceil(x) = ceiling(x)
template
<
typename
T
>
struct
CeilFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
ceil
();
}
};
template
<
typename
T
>
struct
ZeroGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
static_cast
<
T
>
(
0
)
*
out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kNoDeps
;
}
};
// floor(x) = flooring(x)
template
<
typename
T
>
struct
FloorFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
floor
();
}
};
// round(x) = [x]
template
<
typename
T
>
struct
RoundFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
round
();
}
};
// reciprocal(x) = 1 / x
// reciprocal(x) = 1 / x
template
<
typename
T
>
template
<
typename
T
>
struct
ReciprocalFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
ReciprocalFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -509,51 +484,6 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
...
@@ -509,51 +484,6 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
// HardSwish = min(max(0, x+3), 6) * x / 6
template
<
typename
T
>
struct
HardSwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
float
scale
;
float
offset
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
},
{
"scale"
,
&
scale
},
{
"offset"
,
&
offset
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
x
+
static_cast
<
T
>
(
offset
))
.
cwiseMax
(
static_cast
<
T
>
(
0
))
.
cwiseMin
(
static_cast
<
T
>
(
threshold
))
*
x
/
static_cast
<
T
>
(
scale
);
}
};
template
<
typename
T
>
struct
HardSwishGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
float
scale
;
float
offset
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
},
{
"scale"
,
&
scale
},
{
"offset"
,
&
offset
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
tmp
=
((
x
+
static_cast
<
T
>
(
offset
))
<
static_cast
<
T
>
(
threshold
))
.
template
cast
<
T
>();
dx
.
device
(
d
)
=
dout
*
(((
x
+
static_cast
<
T
>
(
offset
))
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
*
(
static_cast
<
T
>
(
2
)
*
x
+
static_cast
<
T
>
(
offset
))
/
static_cast
<
T
>
(
scale
)
*
tmp
+
static_cast
<
T
>
(
1
)
*
(
static_cast
<
T
>
(
1
)
-
tmp
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// For numerical stability, using the following formula instead of softplus(x) =
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
...
@@ -776,35 +706,6 @@ struct CELUGradFunctor : public BaseActivationFunctor<T> {
...
@@ -776,35 +706,6 @@ struct CELUGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template
<
typename
T
>
struct
PowFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
factor
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"factor"
,
&
factor
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
pow
(
static_cast
<
T
>
(
factor
));
}
};
template
<
typename
T
>
struct
PowGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
factor
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"factor"
,
&
factor
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
factor
)
*
x
.
pow
(
static_cast
<
T
>
(
factor
)
-
static_cast
<
T
>
(
1
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
LogitFunctor
{
struct
LogitFunctor
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
P
>
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
P
>
...
@@ -870,39 +771,6 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
...
@@ -870,39 +771,6 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
SwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
/
(
static_cast
<
T
>
(
1
)
+
(
static_cast
<
T
>
(
-
beta
)
*
x
).
exp
());
}
};
template
<
typename
T
>
struct
SwishGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
fake_out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
(
static_cast
<
T
>
(
-
beta
)
*
x
).
exp
());
auto
out
=
x
*
temp1
;
auto
temp2
=
temp1
*
(
static_cast
<
T
>
(
1
)
-
(
static_cast
<
T
>
(
beta
)
*
out
));
dx
.
device
(
d
)
=
dout
*
((
static_cast
<
T
>
(
beta
)
*
out
)
+
temp2
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
AbsGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
AbsGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
template
<
typename
Device
>
...
@@ -1267,110 +1135,6 @@ class RsqrtDoubleGradKernel
...
@@ -1267,110 +1135,6 @@ class RsqrtDoubleGradKernel
}
}
};
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
PowKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
framework
::
Tensor
*
X
=
nullptr
;
framework
::
Tensor
*
Out
=
nullptr
;
ExtractActivationTensor
(
context
,
&
X
,
&
Out
);
Out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"Pow"
));
auto
out
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
Out
,
"Output"
,
"Out"
,
"Pow"
));
auto
*
place
=
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Functor
functor
;
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
context
.
Attr
<
float
>
(
attr
.
first
);
}
// get FactorTensor
auto
*
factor_tensor
=
context
.
HasInput
(
"FactorTensor"
)
?
context
.
Input
<
framework
::
Tensor
>
(
"FactorTensor"
)
:
nullptr
;
if
(
factor_tensor
)
{
auto
*
factor_data
=
factor_tensor
->
data
<
float
>
();
framework
::
Tensor
cpu_factor_tensor
;
if
(
platform
::
is_gpu_place
(
factor_tensor
->
place
()))
{
framework
::
TensorCopySync
(
*
factor_tensor
,
platform
::
CPUPlace
(),
&
cpu_factor_tensor
);
factor_data
=
cpu_factor_tensor
.
data
<
float
>
();
}
auto
factor
=
std
::
vector
<
float
>
(
factor_data
,
factor_data
+
factor_tensor
->
numel
());
PADDLE_ENFORCE_EQ
(
factor
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The shape of factor(tensor) must be [1] rather than %d"
,
factor
.
size
()));
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
factor
[
0
];
}
}
functor
(
*
place
,
x
,
out
);
}
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
PowGradKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
framework
::
Tensor
*
X
,
*
Out
,
*
dOut
;
framework
::
Tensor
*
dX
=
nullptr
;
X
=
Out
=
dOut
=
nullptr
;
ExtractActivationGradTensor
<
Functor
::
FwdDeps
()
>
(
context
,
&
X
,
&
Out
,
&
dOut
,
&
dX
);
dX
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Input"
,
"Out@GRAD"
,
"PowGrad"
));
auto
out
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
Out
,
"Input"
,
"Out"
,
"PowGrad"
));
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"X@GRAD"
,
"PowGrad"
));
auto
x
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"PowGrad"
));
auto
*
place
=
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Functor
functor
;
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
context
.
Attr
<
float
>
(
attr
.
first
);
}
// get FactorTensor
auto
*
factor_tensor
=
context
.
HasInput
(
"FactorTensor"
)
?
context
.
Input
<
framework
::
LoDTensor
>
(
"FactorTensor"
)
:
nullptr
;
if
(
factor_tensor
)
{
auto
*
factor_data
=
factor_tensor
->
data
<
float
>
();
framework
::
Tensor
cpu_factor_tensor
;
if
(
platform
::
is_gpu_place
(
factor_tensor
->
place
()))
{
framework
::
TensorCopySync
(
*
factor_tensor
,
platform
::
CPUPlace
(),
&
cpu_factor_tensor
);
factor_data
=
cpu_factor_tensor
.
data
<
float
>
();
}
auto
factor
=
std
::
vector
<
float
>
(
factor_data
,
factor_data
+
factor_tensor
->
numel
());
PADDLE_ENFORCE_EQ
(
factor
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The shape of factor(tensor) must be [1] rather than %d"
,
factor
.
size
()));
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
factor
[
0
];
}
}
functor
(
*
place
,
x
,
out
,
dout
,
dx
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
LogitKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LogitKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -1418,15 +1182,10 @@ class LogitGradKernel : public framework::OpKernel<T> {
...
@@ -1418,15 +1182,10 @@ class LogitGradKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
#define FOR_EACH_ACTIVATION_OP(__macro) \
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(mish, Mish, MishFunctor, MishGradFunctor);
__macro(mish, Mish, MishFunctor, MishGradFunctor); \
__macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);
paddle/fluid/operators/activation_op.kps
浏览文件 @
be5918e0
...
@@ -20,51 +20,6 @@ limitations under the License. */
...
@@ -20,51 +20,6 @@ limitations under the License. */
namespace paddle {
namespace paddle {
namespace operators {
namespace operators {
template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// ceil(x) = ceil(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(ceil(x));
}
};
template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// floor(x) = floor(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(floor(x));
}
};
template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// round(x) = round(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(round(x));
}
};
// GradFunctor for ceil, floor and round
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(0.0f);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kNoDeps;
}
};
template <typename T>
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
T one = static_cast<T>(1.0f);
...
@@ -395,50 +350,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
...
@@ -395,50 +350,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
}
}
};
};
template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// swish(x) = x / (1 + exp(-beta * x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
return static_cast<T>(x / (one + exp(-b * x)));
}
};
template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType temp1 = one / (one + exp(-b * x));
MPType out = x * temp1;
MPType temp2 = b * out;
MPType temp3 = temp1 * (one - temp2);
return static_cast<T>(dout * (temp2 + temp3));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
struct CudaMishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename details::MPTypeTrait<T>::Type;
...
@@ -488,58 +399,6 @@ struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
...
@@ -488,58 +399,6 @@ struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
};
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// hard_swish(x) = 0, when x <= -offset
// x , when x >= threshold - offset
// x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold);
T temp = x + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < t ? temp_max : t;
return temp_min * x / static_cast<T>(scale);
}
};
template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
T two = static_cast<T>(2.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T o = static_cast<T>(offset);
T s = static_cast<T>(scale);
T temp1 = static_cast<T>(x + o > zero);
T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
using CT = typename details::MPTypeTrait<T>::Type;
...
@@ -684,6 +543,20 @@ USE_PHI_FUNCTOR(CudaLog)
...
@@ -684,6 +543,20 @@ USE_PHI_FUNCTOR(CudaLog)
USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog1p)
USE_PHI_FUNCTOR(CudaLog1p)
USE_PHI_FUNCTOR(CudaSwish)
USE_PHI_FUNCTOR(CudaHardSwish)
template <typename T>
using CudaRoundFunctor = phi::funcs::CudaRoundFunctor<T>;
template <typename T>
using CudaFloorFunctor = phi::funcs::CudaFloorFunctor<T>;
template <typename T>
using CudaCeilFunctor = phi::funcs::CudaCeilFunctor<T>;
template <typename T>
using CudaZeroGradFunctor = phi::funcs::CudaZeroGradFunctor<T>;
template <typename T>
template <typename T>
using CudaELUGradNegativeAlphaFunctor =
using CudaELUGradNegativeAlphaFunctor =
...
@@ -813,23 +686,6 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -813,23 +686,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::SquareGradGradFunctor<int64_t>>);
ops::SquareGradGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================================================================== */
/* ========================== pow register ============================ */
REGISTER_OP_CUDA_KERNEL(
pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
pow_grad,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
ops::PowGradKernel<plat::CUDADeviceContext,
ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== logit register ============================ */
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL(
...
@@ -889,9 +745,6 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -889,9 +745,6 @@ REGISTER_OP_CUDA_KERNEL(
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
CudaSoftShrinkGradFunctor); \
CudaSoftShrinkGradFunctor); \
__macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \
__macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \
CudaReciprocalGradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
...
@@ -903,10 +756,7 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -903,10 +756,7 @@ REGISTER_OP_CUDA_KERNEL(
CudaTanhShrinkGradFunctor); \
CudaTanhShrinkGradFunctor); \
__macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \
__macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \
CudaHardShrinkGradFunctor); \
CudaHardShrinkGradFunctor); \
__macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \
__macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor);
__macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \
__macro(hard_swish, HardSwish, CudaHardSwishFunctor, \
CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
#ifdef PADDLE_WITH_XPU_KP
#ifdef PADDLE_WITH_XPU_KP
...
...
paddle/phi/kernels/activation_grad_kernel.h
浏览文件 @
be5918e0
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/unary.h"
...
@@ -50,6 +51,11 @@ namespace phi {
...
@@ -50,6 +51,11 @@ namespace phi {
const DenseTensor& dout, \
const DenseTensor& dout, \
DenseTensor* dx);
DenseTensor* dx);
#define DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(name) \
template <typename T, typename Context> \
void name##GradKernel( \
const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx);
#define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(name, attr) \
#define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(name, attr) \
template <typename T, typename Context> \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
void name##GradKernel(const Context& dev_ctx, \
...
@@ -143,6 +149,22 @@ void LogDoubleGradKernel(const Context& dev_ctx,
...
@@ -143,6 +149,22 @@ void LogDoubleGradKernel(const Context& dev_ctx,
DenseTensor
*
dx
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
);
DenseTensor
*
ddout
);
template
<
typename
T
,
typename
Context
>
void
HardSwishGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
float
threshold
,
float
scale
,
float
offset
,
DenseTensor
*
dx
);
template
<
typename
T
,
typename
Context
>
void
PowGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
Scalar
&
factor
,
DenseTensor
*
dx
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
);
...
@@ -166,10 +188,15 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu);
...
@@ -166,10 +188,15 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sigmoid
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sigmoid
);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP
(
Round
);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP
(
Floor
);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP
(
Ceil
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
LeakyRelu
,
alpha
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
LeakyRelu
,
alpha
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
ThresholdedRelu
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
ThresholdedRelu
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
SoftShrink
,
lambda
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
SoftShrink
,
lambda
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Swish
,
beta
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
t_min
,
t_max
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
t_min
,
t_max
);
...
...
paddle/phi/kernels/activation_kernel.h
浏览文件 @
be5918e0
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/unary.h"
...
@@ -60,13 +61,32 @@ DECLARE_ACTIVATION_KERNEL(Log)
...
@@ -60,13 +61,32 @@ DECLARE_ACTIVATION_KERNEL(Log)
DECLARE_ACTIVATION_KERNEL
(
Log2
)
DECLARE_ACTIVATION_KERNEL
(
Log2
)
DECLARE_ACTIVATION_KERNEL
(
Log10
)
DECLARE_ACTIVATION_KERNEL
(
Log10
)
DECLARE_ACTIVATION_KERNEL
(
Log1p
)
DECLARE_ACTIVATION_KERNEL
(
Log1p
)
DECLARE_ACTIVATION_KERNEL
(
Round
)
DECLARE_ACTIVATION_KERNEL
(
Floor
)
DECLARE_ACTIVATION_KERNEL
(
Ceil
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
threshold
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
threshold
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
lambda
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
lambda
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
threshold
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
threshold
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Elu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Elu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
Swish
,
beta
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
t_min
,
t_max
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
t_min
,
t_max
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
slope
,
offset
)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
slope
,
offset
)
template
<
typename
T
,
typename
Context
>
void
HardSwishKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
threshold
,
float
scale
,
float
offset
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
PowKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
factor
,
DenseTensor
*
out
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/cpu/activation_grad_kernel.cc
浏览文件 @
be5918e0
...
@@ -107,6 +107,15 @@ namespace phi {
...
@@ -107,6 +107,15 @@ namespace phi {
dev_ctx, nullptr, &out, &dout, dx, functor); \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
}
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel( \
const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { \
funcs::functor_class<T> functor; \
ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, nullptr, nullptr, &dout, dx, functor); \
}
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
,
CosGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
,
CosGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
,
TanGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
,
TanGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
,
AcosGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
,
AcosGradFunctor
);
...
@@ -130,6 +139,10 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
...
@@ -130,6 +139,10 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
,
TanhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
,
TanhGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sigmoid
,
SigmoidGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sigmoid
,
SigmoidGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP
(
Round
,
ZeroGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP
(
Floor
,
ZeroGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP
(
Ceil
,
ZeroGradFunctor
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
LeakyRelu
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
LeakyRelu
,
LeakyReluGradFunctor
,
LeakyReluGradFunctor
,
alpha
);
alpha
);
...
@@ -142,6 +155,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
...
@@ -142,6 +155,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
HardShrinkGradFunctor
,
HardShrinkGradFunctor
,
threshold
);
threshold
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Swish
,
SwishGradFunctor
,
beta
);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
BReluGradFunctor
,
BReluGradFunctor
,
...
@@ -183,6 +197,23 @@ void EluGradKernel(const Context& dev_ctx,
...
@@ -183,6 +197,23 @@ void EluGradKernel(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
HardSwishGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
float
threshold
,
float
scale
,
float
offset
,
DenseTensor
*
dx
)
{
funcs
::
HardSwishGradFunctor
<
T
>
functor
;
auto
attrs
=
functor
.
GetAttrs
();
*
(
attrs
[
0
].
second
)
=
threshold
;
*
(
attrs
[
1
].
second
)
=
scale
;
*
(
attrs
[
2
].
second
)
=
offset
;
ActivationGradImpl
<
T
,
Context
,
funcs
::
HardSwishGradFunctor
<
T
>>
(
dev_ctx
,
&
x
,
nullptr
,
&
dout
,
dx
,
functor
);
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
...
@@ -242,3 +273,17 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
...
@@ -242,3 +273,17 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log10_grad
,
Log10GradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log10_grad
,
Log10GradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log1p_grad
,
Log1pGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log1p_grad
,
Log1pGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
log_double_grad
,
LogDoubleGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
log_double_grad
,
LogDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
hard_swish_grad
,
HardSwishGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
swish_grad
,
SwishGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
round_grad
,
RoundGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
floor_grad
,
FloorGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
ceil_grad
,
CeilGradKernel
)
PD_REGISTER_KERNEL
(
pow_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
PowGradKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/cpu/activation_kernel.cc
浏览文件 @
be5918e0
...
@@ -78,6 +78,9 @@ DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor)
...
@@ -78,6 +78,9 @@ DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log2
,
Log2Functor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log2
,
Log2Functor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log10
,
Log10Functor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log10
,
Log10Functor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log1p
,
Log1pFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log1p
,
Log1pFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Round
,
RoundFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Floor
,
FloorFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Ceil
,
CeilFunctor
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
...
@@ -86,6 +89,7 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
...
@@ -86,6 +89,7 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
HardShrinkFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
HardShrink
,
HardShrinkFunctor
,
threshold
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
SoftShrinkFunctor
,
lambda
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
SoftShrinkFunctor
,
lambda
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
ELUFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
ELUFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Swish
,
SwishFunctor
,
beta
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
BReluFunctor
,
t_min
,
t_max
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
BReluFunctor
,
t_min
,
t_max
)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
...
@@ -93,6 +97,22 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
...
@@ -93,6 +97,22 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
slope
,
slope
,
offset
)
offset
)
template
<
typename
T
,
typename
Context
>
void
HardSwishKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
threshold
,
float
scale
,
float
offset
,
DenseTensor
*
out
)
{
funcs
::
HardSwishFunctor
<
T
>
functor
;
auto
attrs
=
functor
.
GetAttrs
();
*
(
attrs
[
0
].
second
)
=
threshold
;
*
(
attrs
[
1
].
second
)
=
scale
;
*
(
attrs
[
2
].
second
)
=
offset
;
ActivationImpl
<
T
,
Context
,
funcs
::
HardSwishFunctor
<
T
>>
(
dev_ctx
,
x
,
out
,
functor
);
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
relu
,
CPU
,
ALL_LAYOUT
,
phi
::
ReluKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
relu
,
CPU
,
ALL_LAYOUT
,
phi
::
ReluKernel
,
float
,
double
)
{}
...
@@ -126,3 +146,10 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
...
@@ -126,3 +146,10 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
log2
,
Log2Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log2
,
Log2Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log10
,
Log10Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log10
,
Log10Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log1p
,
Log1pKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log1p
,
Log1pKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
hard_swish
,
HardSwishKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
swish
,
SwishKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
round
,
RoundKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_KERNEL
(
pow
,
CPU
,
ALL_LAYOUT
,
phi
::
PowKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
be5918e0
...
@@ -1350,6 +1350,165 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
...
@@ -1350,6 +1350,165 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
// HardSwish = min(max(0, x+3), 6) * x / 6
template
<
typename
T
>
struct
HardSwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
float
scale
;
float
offset
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
},
{
"scale"
,
&
scale
},
{
"offset"
,
&
offset
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
x
+
static_cast
<
T
>
(
offset
))
.
cwiseMax
(
static_cast
<
T
>
(
0
))
.
cwiseMin
(
static_cast
<
T
>
(
threshold
))
*
x
/
static_cast
<
T
>
(
scale
);
}
};
template
<
typename
T
>
struct
HardSwishGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
threshold
;
float
scale
;
float
offset
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
},
{
"scale"
,
&
scale
},
{
"offset"
,
&
offset
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
tmp
=
((
x
+
static_cast
<
T
>
(
offset
))
<
static_cast
<
T
>
(
threshold
))
.
template
cast
<
T
>();
dx
.
device
(
d
)
=
dout
*
(((
x
+
static_cast
<
T
>
(
offset
))
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
*
(
static_cast
<
T
>
(
2
)
*
x
+
static_cast
<
T
>
(
offset
))
/
static_cast
<
T
>
(
scale
)
*
tmp
+
static_cast
<
T
>
(
1
)
*
(
static_cast
<
T
>
(
1
)
-
tmp
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
SwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
/
(
static_cast
<
T
>
(
1
)
+
(
static_cast
<
T
>
(
-
beta
)
*
x
).
exp
());
}
};
template
<
typename
T
>
struct
SwishGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
fake_out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
(
static_cast
<
T
>
(
-
beta
)
*
x
).
exp
());
auto
out
=
x
*
temp1
;
auto
temp2
=
temp1
*
(
static_cast
<
T
>
(
1
)
-
(
static_cast
<
T
>
(
beta
)
*
out
));
dx
.
device
(
d
)
=
dout
*
((
static_cast
<
T
>
(
beta
)
*
out
)
+
temp2
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template
<
typename
T
>
struct
PowFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
factor
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"factor"
,
&
factor
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
pow
(
static_cast
<
T
>
(
factor
));
}
};
template
<
typename
T
>
struct
PowGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
factor
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"factor"
,
&
factor
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
factor
)
*
x
.
pow
(
static_cast
<
T
>
(
factor
)
-
static_cast
<
T
>
(
1
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// floor(x) = flooring(x)
template
<
typename
T
>
struct
FloorFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
floor
();
}
};
// round(x) = [x]
template
<
typename
T
>
struct
RoundFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
round
();
}
};
// ceil(x) = ceiling(x)
template
<
typename
T
>
struct
CeilFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
ceil
();
}
};
template
<
typename
T
>
struct
ZeroGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
static_cast
<
T
>
(
0
)
*
out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kNoDeps
;
}
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template
<
typename
T
>
template
<
typename
T
>
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
@@ -2190,6 +2349,147 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
...
@@ -2190,6 +2349,147 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
};
template
<
typename
T
>
struct
CudaSwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
float
beta
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
}
// swish(x) = x / (1 + exp(-beta * x))
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
b
=
static_cast
<
MPType
>
(
beta
);
return
static_cast
<
T
>
(
x
/
(
one
+
exp
(
-
b
*
x
)));
}
};
template
<
typename
T
>
struct
CudaSwishGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
float
beta
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
}};
}
// dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
__device__
__forceinline__
T
operator
()(
const
T
arg_dout
,
const
T
arg_x
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
arg_dout
);
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
MPType
b
=
static_cast
<
MPType
>
(
beta
);
MPType
temp1
=
one
/
(
one
+
exp
(
-
b
*
x
));
MPType
out
=
x
*
temp1
;
MPType
temp2
=
b
*
out
;
MPType
temp3
=
temp1
*
(
one
-
temp2
);
return
static_cast
<
T
>
(
dout
*
(
temp2
+
temp3
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaHardSwishFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
float
threshold
;
float
scale
;
float
offset
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
},
{
"scale"
,
&
scale
},
{
"offset"
,
&
offset
}};
}
// hard_swish(x) = 0, when x <= -offset
// x , when x >= threshold - offset
// x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
T
t
=
static_cast
<
T
>
(
threshold
);
T
temp
=
x
+
static_cast
<
T
>
(
offset
);
T
temp_max
=
temp
>
zero
?
temp
:
zero
;
T
temp_min
=
temp_max
<
t
?
temp_max
:
t
;
return
temp_min
*
x
/
static_cast
<
T
>
(
scale
);
}
};
template
<
typename
T
>
struct
CudaHardSwishGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
T
one
=
static_cast
<
T
>
(
1.0
f
);
T
two
=
static_cast
<
T
>
(
2.0
f
);
float
threshold
;
float
scale
;
float
offset
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"threshold"
,
&
threshold
},
{
"scale"
,
&
scale
},
{
"offset"
,
&
offset
}};
}
// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
x
)
const
{
T
o
=
static_cast
<
T
>
(
offset
);
T
s
=
static_cast
<
T
>
(
scale
);
T
temp1
=
static_cast
<
T
>
(
x
+
o
>
zero
);
T
temp2
=
static_cast
<
T
>
(
x
+
o
<
static_cast
<
T
>
(
threshold
));
return
dout
*
(
temp1
*
temp2
*
(
two
*
x
+
o
)
/
s
+
one
-
temp2
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaCeilFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// ceil(x) = ceil(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
ceil
(
x
));
}
};
template
<
typename
T
>
struct
CudaFloorFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// floor(x) = floor(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
floor
(
x
));
}
};
template
<
typename
T
>
struct
CudaRoundFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// round(x) = round(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
round
(
x
));
}
};
// GradFunctor for ceil, floor and round
template
<
typename
T
>
struct
CudaZeroGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
static_cast
<
T
>
(
0.0
f
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kNoDeps
;
}
};
#endif
#endif
}
// namespace funcs
}
// namespace funcs
...
...
paddle/phi/kernels/gpu/activation_grad_kernel.cu
浏览文件 @
be5918e0
...
@@ -159,10 +159,23 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
...
@@ -159,10 +159,23 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, nullptr, &out, &dout, dx, functor); \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
}
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel( \
const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { \
funcs::functor_class<T> functor; \
ActivationGradGPUImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, nullptr, nullptr, &dout, dx, functor); \
}
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Relu
,
CudaReluGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Relu
,
CudaReluGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
,
CudaTanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
,
CudaTanhGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sigmoid
,
CudaSigmoidGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Sigmoid
,
CudaSigmoidGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP
(
Round
,
CudaZeroGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP
(
Floor
,
CudaZeroGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP
(
Ceil
,
CudaZeroGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
,
CudaCosGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
,
CudaCosGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
,
CudaTanGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
,
CudaTanGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
,
CudaAcosGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
,
CudaAcosGradFunctor
);
...
@@ -194,6 +207,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
...
@@ -194,6 +207,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
HardShrink
,
CudaHardShrinkGradFunctor
,
CudaHardShrinkGradFunctor
,
threshold
);
threshold
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Swish
,
CudaSwishGradFunctor
,
beta
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
CudaBReluGradFunctor
,
CudaBReluGradFunctor
,
...
@@ -227,6 +243,23 @@ void EluGradKernel(const Context& dev_ctx,
...
@@ -227,6 +243,23 @@ void EluGradKernel(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
HardSwishGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
float
threshold
,
float
scale
,
float
offset
,
DenseTensor
*
dx
)
{
funcs
::
CudaHardSwishGradFunctor
<
T
>
functor
;
auto
attrs
=
functor
.
GetAttrs
();
*
(
attrs
[
0
].
second
)
=
threshold
;
*
(
attrs
[
1
].
second
)
=
scale
;
*
(
attrs
[
2
].
second
)
=
offset
;
ActivationGradGPUImpl
<
T
,
Context
,
funcs
::
CudaHardSwishGradFunctor
<
T
>>
(
dev_ctx
,
&
x
,
nullptr
,
&
dout
,
dx
,
functor
);
}
}
// namespace phi
}
// namespace phi
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
...
@@ -315,3 +348,18 @@ PD_REGISTER_KERNEL(log_double_grad,
...
@@ -315,3 +348,18 @@ PD_REGISTER_KERNEL(log_double_grad,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
hard_swish_grad
,
HardSwishGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
swish_grad
,
SwishGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
round_grad
,
RoundGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
floor_grad
,
FloorGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
ceil_grad
,
CeilGradKernel
)
PD_REGISTER_KERNEL
(
pow_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
PowGradKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/activation_kernel.cu
浏览文件 @
be5918e0
...
@@ -97,6 +97,9 @@ DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor)
...
@@ -97,6 +97,9 @@ DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log2
,
CudaLog2Functor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log2
,
CudaLog2Functor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log10
,
CudaLog10Functor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log10
,
CudaLog10Functor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log1p
,
CudaLog1pFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log1p
,
CudaLog1pFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Round
,
CudaRoundFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Floor
,
CudaFloorFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Ceil
,
CudaCeilFunctor
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
...
@@ -107,6 +110,7 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
...
@@ -107,6 +110,7 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
threshold
)
threshold
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
CudaSoftShrinkFunctor
,
lambda
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
SoftShrink
,
CudaSoftShrinkFunctor
,
lambda
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
CudaELUFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Elu
,
CudaELUFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
Swish
,
CudaSwishFunctor
,
beta
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
BRelu
,
CudaBReluFunctor
,
t_min
,
t_max
)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS
(
HardSigmoid
,
...
@@ -114,6 +118,22 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
...
@@ -114,6 +118,22 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
slope
,
slope
,
offset
)
offset
)
template
<
typename
T
,
typename
Context
>
void
HardSwishKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
threshold
,
float
scale
,
float
offset
,
DenseTensor
*
out
)
{
funcs
::
CudaHardSwishFunctor
<
T
>
functor
;
auto
attrs
=
functor
.
GetAttrs
();
*
(
attrs
[
0
].
second
)
=
threshold
;
*
(
attrs
[
1
].
second
)
=
scale
;
*
(
attrs
[
2
].
second
)
=
offset
;
ActivationGPUImpl
<
T
,
Context
,
funcs
::
CudaHardSwishFunctor
<
T
>>
(
dev_ctx
,
x
,
out
,
functor
);
}
}
// namespace phi
}
// namespace phi
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
...
@@ -172,3 +192,17 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
...
@@ -172,3 +192,17 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
log2
,
Log2Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log2
,
Log2Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log10
,
Log10Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log10
,
Log10Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log1p
,
Log1pKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log1p
,
Log1pKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
hard_swish
,
HardSwishKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
swish
,
SwishKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
round
,
RoundKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_KERNEL
(
pow
,
GPU
,
ALL_LAYOUT
,
phi
::
PowKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/impl/activation_grad_impl.h
浏览文件 @
be5918e0
...
@@ -293,4 +293,28 @@ void LogDoubleGradKernel(const Context& dev_ctx,
...
@@ -293,4 +293,28 @@ void LogDoubleGradKernel(const Context& dev_ctx,
functor
(
dev_ctx
,
&
x
,
&
ddx
,
ddout
,
&
dout
,
dx
);
functor
(
dev_ctx
,
&
x
,
&
ddx
,
ddout
,
&
dout
,
dx
);
}
}
template
<
typename
T
,
typename
Context
>
void
PowGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
Scalar
&
factor
,
DenseTensor
*
dx
)
{
PADDLE_ENFORCE_NOT_NULL
(
dx
,
errors
::
NotFound
(
"The output DenseTensor dX can not be nullptr"
));
if
(
dx
)
{
dev_ctx
.
template
Alloc
<
T
>(
dx
);
}
auto
dout_flatten
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
&
dout
,
"Input"
,
"Out@GRAD"
,
"PowGrad"
));
auto
dx_flatten
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dx
,
"Output"
,
"X@GRAD"
,
"PowGrad"
));
auto
x_flatten
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
&
x
,
"Input"
,
"X"
,
"PowGrad"
));
auto
*
place
=
dev_ctx
.
eigen_device
();
phi
::
funcs
::
PowGradFunctor
<
T
>
functor
;
auto
attrs
=
functor
.
GetAttrs
();
*
(
attrs
[
0
].
second
)
=
factor
.
to
<
float
>
();
functor
(
*
place
,
x_flatten
,
nullptr
,
dout_flatten
,
dx_flatten
);
}
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/impl/activation_impl.h
浏览文件 @
be5918e0
...
@@ -47,4 +47,23 @@ void ActivationImpl(const Context& dev_ctx,
...
@@ -47,4 +47,23 @@ void ActivationImpl(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
PowKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
factor
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_NOT_NULL
(
out
,
errors
::
NotFound
(
"Output Out should not be nullptr"
));
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
x_flatten
=
phi
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
&
x
,
"Input"
,
"X"
,
"Activation"
));
auto
out_flatten
=
phi
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
out
,
"Output"
,
"Out"
,
"Activation"
));
auto
*
place
=
dev_ctx
.
eigen_device
();
phi
::
funcs
::
PowFunctor
<
T
>
functor
;
auto
attrs
=
functor
.
GetAttrs
();
*
(
attrs
[
0
].
second
)
=
factor
.
to
<
float
>
();
functor
(
*
place
,
x_flatten
,
out_flatten
);
}
}
// namespace phi
}
// namespace phi
paddle/phi/ops/compat/activation_sig.cc
浏览文件 @
be5918e0
...
@@ -34,6 +34,13 @@ namespace phi {
...
@@ -34,6 +34,13 @@ namespace phi {
{GradVarName("X")}); \
{GradVarName("X")}); \
}
}
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature( \
op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \
}
#define comma ,
#define comma ,
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Cos
,
"cos"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Cos
,
"cos"
,
);
// NOLINT
...
@@ -61,6 +68,11 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
...
@@ -61,6 +68,11 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log2
,
"log2"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log2
,
"log2"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log10
,
"log10"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log10
,
"log10"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log1p
,
"log1p"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log1p
,
"log1p"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
HardSwish
,
"hard_swish"
,
"threshold"
comma
"scale"
comma
"offset"
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Swish
,
"swish"
,
"beta"
);
// NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Relu
,
"relu"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Relu
,
"relu"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Tanh
,
"tanh"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Tanh
,
"tanh"
,
);
// NOLINT
...
@@ -69,6 +81,10 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(HardSigmoid,
...
@@ -69,6 +81,10 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(HardSigmoid,
"hard_sigmoid"
,
"hard_sigmoid"
,
"slope"
comma
"offset"
);
// NOLINT
"slope"
comma
"offset"
);
// NOLINT
DEFINE_ACT_GRAD_NODEP_OP_ARGMAP
(
Round
,
"round"
,
);
// NOLINT
DEFINE_ACT_GRAD_NODEP_OP_ARGMAP
(
Floor
,
"floor"
,
);
// NOLINT
DEFINE_ACT_GRAD_NODEP_OP_ARGMAP
(
Ceil
,
"ceil"
,
);
// NOLINT
KernelSignature
ReluDoubleGradOpArgumentMapping
(
KernelSignature
ReluDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"relu_double_grad"
,
{
"Out"
,
"DDX"
},
{},
{
"DDOut"
});
return
KernelSignature
(
"relu_double_grad"
,
{
"Out"
,
"DDX"
},
{},
{
"DDOut"
});
...
@@ -135,6 +151,26 @@ KernelSignature LogDoubleGradOpArgumentMapping(
...
@@ -135,6 +151,26 @@ KernelSignature LogDoubleGradOpArgumentMapping(
"log_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{},
{
"DX"
,
"DDOut"
});
"log_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{},
{
"DX"
,
"DDOut"
});
}
}
KernelSignature
PowOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasInput
(
"FactorTensor"
))
{
return
KernelSignature
(
"pow"
,
{
"X"
},
{
"FactorTensor"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"pow"
,
{
"X"
},
{
"factor"
},
{
"Out"
});
}
}
KernelSignature
PowGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasInput
(
"FactorTensor"
))
{
return
KernelSignature
(
"pow_grad"
,
{
"X"
,
GradVarName
(
"Out"
)},
{
"FactorTensor"
},
{
GradVarName
(
"X"
)});
}
else
{
return
KernelSignature
(
"pow_grad"
,
{
"X"
,
GradVarName
(
"Out"
)},
{
"factor"
},
{
GradVarName
(
"X"
)});
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
relu_grad_grad
,
relu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
relu_grad_grad
,
relu_double_grad
);
...
@@ -197,3 +233,11 @@ PD_REGISTER_ARG_MAPPING_FN(log_grad_grad, phi::LogDoubleGradOpArgumentMapping);
...
@@ -197,3 +233,11 @@ PD_REGISTER_ARG_MAPPING_FN(log_grad_grad, phi::LogDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN
(
log2_grad
,
phi
::
Log2GradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log2_grad
,
phi
::
Log2GradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log10_grad
,
phi
::
Log10GradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log10_grad
,
phi
::
Log10GradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log1p_grad
,
phi
::
Log1pGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log1p_grad
,
phi
::
Log1pGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
hard_swish_grad
,
phi
::
HardSwishGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
swish_grad
,
phi
::
SwishGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
round_grad
,
phi
::
RoundGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
floor_grad
,
phi
::
FloorGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
ceil_grad
,
phi
::
CeilGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
pow_grad
,
phi
::
PowGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
pow
,
phi
::
PowOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录