Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
13c99434
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
13c99434
编写于
3月 23, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi]Move log/log2/log10/log1p Kernels to Phi (#40785)
* move activation * fix bugs when run ce
上级
b03ef424
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
332 addition
and
265 deletion
+332
-265
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+9
-1
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+3
-9
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+5
-146
paddle/fluid/operators/activation_op.kps
paddle/fluid/operators/activation_op.kps
+4
-108
paddle/phi/kernels/activation_grad_kernel.h
paddle/phi/kernels/activation_grad_kernel.h
+12
-0
paddle/phi/kernels/activation_kernel.h
paddle/phi/kernels/activation_kernel.h
+4
-0
paddle/phi/kernels/cpu/activation_grad_kernel.cc
paddle/phi/kernels/cpu/activation_grad_kernel.cc
+9
-0
paddle/phi/kernels/cpu/activation_kernel.cc
paddle/phi/kernels/cpu/activation_kernel.cc
+8
-0
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+220
-0
paddle/phi/kernels/gpu/activation_grad_kernel.cu
paddle/phi/kernels/gpu/activation_grad_kernel.cu
+15
-0
paddle/phi/kernels/gpu/activation_kernel.cu
paddle/phi/kernels/gpu/activation_kernel.cu
+9
-1
paddle/phi/kernels/impl/activation_grad_impl.h
paddle/phi/kernels/impl/activation_grad_impl.h
+18
-0
paddle/phi/ops/compat/activation_sig.cc
paddle/phi/ops/compat/activation_sig.cc
+16
-0
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
13c99434
...
...
@@ -1122,7 +1122,15 @@ static void CheckTensorNANOrInf(const std::string& op_type,
bool
OperatorWithKernel
::
SupportsMKLDNN
(
const
proto
::
VarType
::
Type
data_type
)
const
{
auto
&
op_kernels
=
OperatorWithKernel
::
AllOpKernels
().
at
(
type_
);
auto
op_kernel_iter
=
OperatorWithKernel
::
AllOpKernels
().
find
(
type_
);
if
(
op_kernel_iter
==
OperatorWithKernel
::
AllOpKernels
().
end
())
{
VLOG
(
6
)
<<
"Warning: "
<<
type_
<<
" don't find its MKLDNN Kernel in Fluid "
"Registered Kernels. And We don't "
"search its kernels in phi lib, "
"SupportsMKLDNN() return false."
;
return
false
;
}
auto
&
op_kernels
=
op_kernel_iter
->
second
;
return
std
::
any_of
(
op_kernels
.
begin
(),
op_kernels
.
end
(),
[
data_type
](
OpKernelMap
::
const_reference
kern_pair
)
{
return
platform
::
is_cpu_place
(
kern_pair
.
first
.
place_
)
&&
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
13c99434
...
...
@@ -1496,6 +1496,9 @@ REGISTER_ACTIVATION_OP(hard_sigmoid, HardSigmoid, HardSigmoidFunctor,
HardSigmoidGradFunctor
);
REGISTER_ACTIVATION_OP
(
logsigmoid
,
LogSigmoid
,
LogSigmoidFunctor
,
LogSigmoidGradFunctor
);
REGISTER_ACTIVATION_OP
(
log2
,
Log2
,
Log2Functor
,
Log2GradFunctor
);
REGISTER_ACTIVATION_OP
(
log10
,
Log10
,
Log10Functor
,
Log10GradFunctor
);
REGISTER_ACTIVATION_OP
(
log1p
,
Log1p
,
Log1pFunctor
,
Log1pGradFunctor
);
/* ========================== sigmoid register =============================
*/
...
...
@@ -1867,15 +1870,6 @@ REGISTER_OPERATOR(
ops
::
ActivationOpDoubleGrad
<
ops
::
LogGradGradFunctor
<
float
>::
FwdDeps
()
>
,
ops
::
ActivationDoubleGradOpInplaceInferer
);
REGISTER_ACTIVATION_CPU_KERNEL
(
log
,
Log
,
LogFunctor
,
LogGradFunctor
);
REGISTER_OP_CPU_KERNEL
(
log_grad_grad
,
ops
::
LogDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
LogGradGradFunctor
<
float
>>
,
ops
::
LogDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
LogGradGradFunctor
<
double
>>
,
ops
::
LogDoubleGradKernel
<
plat
::
CPUDeviceContext
,
ops
::
LogGradGradFunctor
<
plat
::
float16
>>
);
/* ========================================================================== */
/* ========================== register checkpoint ===========================*/
...
...
paddle/fluid/operators/activation_op.h
浏览文件 @
13c99434
...
...
@@ -281,6 +281,11 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_TRIPLE_GRAD_FUNCTOR
(
Sigmoid
)
USE_PHI_FUNCTOR
(
LogSigmoid
)
USE_PHI_FUNCTOR
(
HardSigmoid
)
USE_PHI_FUNCTOR
(
Log
)
USE_PHI_DOUBLE_GRAD_FUNCTOR
(
Log
)
USE_PHI_FUNCTOR
(
Log2
)
USE_PHI_FUNCTOR
(
Log10
)
USE_PHI_FUNCTOR
(
Log1p
)
template
<
typename
T
>
using
ELUGradNegativeAlphaFunctor
=
phi
::
funcs
::
ELUGradNegativeAlphaFunctor
<
T
>
;
...
...
@@ -448,88 +453,6 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};
// log(x) = natural logarithm of x
template
<
typename
T
>
struct
LogFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
log
();
}
};
template
<
typename
T
>
struct
LogGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
(
static_cast
<
T
>
(
1
)
/
x
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// log2(x) = logarithm to the base 2 of the elements of x
template
<
typename
T
>
struct
Log2Functor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
log
()
/
static_cast
<
T
>
(
log
(
2
));
}
};
// the gradient of log2(x) is 1/(x*ln(2))
template
<
typename
T
>
struct
Log2GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
1
)
/
(
x
*
static_cast
<
T
>
(
log
(
2
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// log10(x) = logarithm to the base 10 of the elements of x
template
<
typename
T
>
struct
Log10Functor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
log
()
/
static_cast
<
T
>
(
log
(
10
));
}
};
// the gradient of log10(x) is 1/(x*ln(10))
template
<
typename
T
>
struct
Log10GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
1
)
/
(
x
*
static_cast
<
T
>
(
log
(
10
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// log1p(x) = natural logarithm of x+1
template
<
typename
T
>
struct
Log1pFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
static_cast
<
T
>
(
1
)
+
x
).
log
();
}
};
template
<
typename
T
>
struct
Log1pGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
(
static_cast
<
T
>
(
1
)
/
(
x
+
static_cast
<
T
>
(
1
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// square(x) = x^2
template
<
typename
T
>
struct
SquareFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
...
@@ -1197,37 +1120,6 @@ class SquareDoubleGradKernel
}
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
LogDoubleGradKernel
:
public
SquareDoubleGradKernel
<
DeviceContext
,
Functor
>
{};
template
<
typename
DeviceContext
,
typename
Functor
>
class
ELUDoubleGradKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
Tensor
*
X
,
*
ddX
,
*
dOut
;
X
=
ddX
=
dOut
=
nullptr
;
framework
::
Tensor
*
dX
,
*
ddOut
;
dX
=
ddOut
=
nullptr
;
ExtractDoubleGradTensorWithInputDOut
(
ctx
,
&
X
,
&
ddX
,
&
dX
,
&
dOut
,
&
ddOut
);
if
(
dX
)
dX
->
mutable_data
<
T
>
(
X
->
dims
(),
ctx
.
GetPlace
());
if
(
ddOut
)
ddOut
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
place
=
ctx
.
template
device_context
<
DeviceContext
>();
Functor
functor
;
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
ctx
.
Attr
<
float
>
(
attr
.
first
);
}
functor
(
place
,
X
,
ddX
,
ddOut
,
dOut
,
dX
);
}
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
CELUDoubleGradKernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
...
...
@@ -1522,36 +1414,6 @@ class LogitGradKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
>
struct
LogGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
framework
::
Tensor
*
X
,
const
framework
::
Tensor
*
ddX
,
framework
::
Tensor
*
ddOut
,
const
framework
::
Tensor
*
dOut
,
framework
::
Tensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"LogGradGrad"
));
auto
x
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"LogGradGrad"
));
// ddout = ddx / x; dx = -(dout / x) * (ddx / x)
// calculate dx first, so ddout can inplace ddx
if
(
dX
)
{
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"LogGradGrad"
));
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"LogGradGrad"
));
dx
.
device
(
*
d
)
=
dout
*
static_cast
<
T
>
(
-
1
)
*
ddx
/
(
x
*
x
);
}
if
(
ddOut
)
{
auto
ddout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"LogGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
1
)
/
x
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -1560,9 +1422,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
__macro(log2, Log2, Log2Functor, Log2GradFunctor); \
__macro(log10, Log10, Log10Functor, Log10GradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
...
...
paddle/fluid/operators/activation_op.kps
浏览文件 @
13c99434
...
...
@@ -131,27 +131,6 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// log(x) = log(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(x));
}
};
template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
// dx = dout / x
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / x;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
// square(x) = x * x
...
...
@@ -220,78 +199,6 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// log1p(x) = log(1 + x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(one + x));
}
};
template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + x)
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (one + x);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// log2(x) = log2(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log2(x));
}
};
template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));
// dx = dout / (x * log(2))
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (x * log_two);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// log10(x) = log10(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log10(x));
}
};
template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));
// dx = dout / (x * log(10))
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (x * log_ten);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
...
...
@@ -773,6 +680,10 @@ USE_PHI_FUNCTOR(CudaELU)
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
USE_PHI_FUNCTOR(CudaLog)
USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog1p)
template <typename T>
using CudaELUGradNegativeAlphaFunctor =
...
...
@@ -975,18 +886,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);
REGISTER_OP_CUDA_KERNEL(
log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<float>>,
ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<double>>,
ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
CudaSoftShrinkGradFunctor); \
...
...
@@ -995,9 +894,6 @@ REGISTER_OP_CUDA_KERNEL(
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \
__macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor); \
__macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor); \
__macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \
__macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \
...
...
paddle/phi/kernels/activation_grad_kernel.h
浏览文件 @
13c99434
...
...
@@ -135,6 +135,14 @@ void SigmoidTripleGradKernel(const Context& dev_ctx,
DenseTensor
*
d_dout
,
DenseTensor
*
d_ddx
);
template
<
typename
T
,
typename
Context
>
void
LogDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Cos
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Tan
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Acos
);
...
...
@@ -149,6 +157,10 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Atanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
TanhShrink
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Silu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
LogSigmoid
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Log
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Log2
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Log10
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX
(
Log1p
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Relu
);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
);
...
...
paddle/phi/kernels/activation_kernel.h
浏览文件 @
13c99434
...
...
@@ -56,6 +56,10 @@ DECLARE_ACTIVATION_KERNEL(TanhShrink)
DECLARE_ACTIVATION_KERNEL
(
Silu
)
DECLARE_ACTIVATION_KERNEL
(
Sigmoid
)
DECLARE_ACTIVATION_KERNEL
(
LogSigmoid
)
DECLARE_ACTIVATION_KERNEL
(
Log
)
DECLARE_ACTIVATION_KERNEL
(
Log2
)
DECLARE_ACTIVATION_KERNEL
(
Log10
)
DECLARE_ACTIVATION_KERNEL
(
Log1p
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
threshold
)
...
...
paddle/phi/kernels/cpu/activation_grad_kernel.cc
浏览文件 @
13c99434
...
...
@@ -121,6 +121,10 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
TanhShrink
,
TanhShrinkGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Silu
,
SiluGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
LogSigmoid
,
LogSigmoidGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log
,
LogGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log2
,
Log2GradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log10
,
Log10GradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log1p
,
Log1pGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Relu
,
ReluGradFunctor
);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT
(
Tanh
,
TanhGradFunctor
);
...
...
@@ -233,3 +237,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_triple_grad
,
SigmoidTripleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
hard_sigmoid_grad
,
HardSigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
logsigmoid_grad
,
LogSigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log_grad
,
LogGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log2_grad
,
Log2GradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log10_grad
,
Log10GradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log1p_grad
,
Log1pGradKernel
)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL
(
log_double_grad
,
LogDoubleGradKernel
)
paddle/phi/kernels/cpu/activation_kernel.cc
浏览文件 @
13c99434
...
...
@@ -74,6 +74,10 @@ DEFINE_CPU_ACTIVATION_KERNEL(TanhShrink, TanhShrinkFunctor)
DEFINE_CPU_ACTIVATION_KERNEL
(
Silu
,
SiluFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Sigmoid
,
SigmoidFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
LogSigmoid
,
LogSigmoidFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log
,
LogFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log2
,
Log2Functor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log10
,
Log10Functor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Log1p
,
Log1pFunctor
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
...
...
@@ -118,3 +122,7 @@ PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
sigmoid
,
SigmoidKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
logsigmoid
,
LogSigmoidKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
hard_sigmoid
,
HardSigmoidKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log
,
LogKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log2
,
Log2Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log10
,
Log10Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log1p
,
Log1pKernel
)
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
13c99434
...
...
@@ -1223,6 +1223,133 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
};
// log(x) = natural logarithm of x
template
<
typename
T
>
struct
LogFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
log
();
}
};
template
<
typename
T
>
struct
LogGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
(
static_cast
<
T
>
(
1
)
/
x
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// log2(x) = logarithm to the base 2 of the elements of x
template
<
typename
T
>
struct
Log2Functor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
log
()
/
static_cast
<
T
>
(
log
(
2
));
}
};
// the gradient of log2(x) is 1/(x*ln(2))
template
<
typename
T
>
struct
Log2GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
1
)
/
(
x
*
static_cast
<
T
>
(
log
(
2
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// log10(x) = logarithm to the base 10 of the elements of x
template
<
typename
T
>
struct
Log10Functor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
log
()
/
static_cast
<
T
>
(
log
(
10
));
}
};
// the gradient of log10(x) is 1/(x*ln(10))
template
<
typename
T
>
struct
Log10GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
1
)
/
(
x
*
static_cast
<
T
>
(
log
(
10
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
// log1p(x) = natural logarithm of x+1
template
<
typename
T
>
struct
Log1pFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
(
static_cast
<
T
>
(
1
)
+
x
).
log
();
}
};
template
<
typename
T
>
struct
Log1pGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
(
static_cast
<
T
>
(
1
)
/
(
x
+
static_cast
<
T
>
(
1
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
LogGradGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
>
void
operator
()(
const
Device
&
dev
,
const
DenseTensor
*
X
,
const
DenseTensor
*
ddX
,
DenseTensor
*
ddOut
,
const
DenseTensor
*
dOut
,
DenseTensor
*
dX
)
const
{
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"LogGradGrad"
));
auto
x
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X"
,
"LogGradGrad"
));
// ddout = ddx / x; dx = -(dout / x) * (ddx / x)
// calculate dx first, so ddout can inplace ddx
if
(
dX
)
{
auto
dout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dOut
,
"Output"
,
"DOut"
,
"LogGradGrad"
));
auto
dx
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
dX
,
"Output"
,
"DX"
,
"LogGradGrad"
));
dx
.
device
(
*
d
)
=
dout
*
static_cast
<
T
>
(
-
1
)
*
ddx
/
(
x
*
x
);
}
if
(
ddOut
)
{
auto
ddout
=
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DDOut"
,
"LogGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
static_cast
<
T
>
(
1
)
/
x
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template
<
typename
T
>
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
...
...
@@ -1970,6 +2097,99 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
};
template
<
typename
T
>
struct
CudaLogFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// log(x) = log(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
log
(
x
));
}
};
template
<
typename
T
>
struct
CudaLogGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
// dx = dout / x
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
x
)
const
{
return
dout
/
x
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaLog1pFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// log1p(x) = log(1 + x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
log
(
one
+
x
));
}
};
template
<
typename
T
>
struct
CudaLog1pGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one
=
static_cast
<
T
>
(
1.0
f
);
// dx = dout / (1 + x)
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
x
)
const
{
return
dout
/
(
one
+
x
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaLog2Functor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// log2(x) = log2(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
log2
(
x
));
}
};
template
<
typename
T
>
struct
CudaLog2GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
T
log_two
=
static_cast
<
T
>
(
log
(
static_cast
<
MPType
>
(
2.0
f
)));
// dx = dout / (x * log(2))
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
x
)
const
{
return
dout
/
(
x
*
log_two
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
template
<
typename
T
>
struct
CudaLog10Functor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// log10(x) = log10(x)
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
arg_x
);
return
static_cast
<
T
>
(
log10
(
x
));
}
};
template
<
typename
T
>
struct
CudaLog10GradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
T
log_ten
=
static_cast
<
T
>
(
log
(
static_cast
<
MPType
>
(
10.0
f
)));
// dx = dout / (x * log(10))
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
x
)
const
{
return
dout
/
(
x
*
log_ten
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepX
;
}
};
#endif
}
// namespace funcs
...
...
paddle/phi/kernels/gpu/activation_grad_kernel.cu
浏览文件 @
13c99434
...
...
@@ -177,6 +177,10 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
TanhShrink
,
CudaTanhShrinkGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Silu
,
CudaSiluGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
LogSigmoid
,
CudaLogSigmoidGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log
,
CudaLogGradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log2
,
CudaLog2GradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log10
,
CudaLog10GradFunctor
);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX
(
Log1p
,
CudaLog1pGradFunctor
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
LeakyRelu
,
CudaLeakyReluGradFunctor
,
...
...
@@ -300,3 +304,14 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
sigmoid_triple_grad
,
SigmoidTripleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
hard_sigmoid_grad
,
HardSigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
logsigmoid_grad
,
LogSigmoidGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log_grad
,
LogGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log2_grad
,
Log2GradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log10_grad
,
Log10GradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
log1p_grad
,
Log1pGradKernel
)
PD_REGISTER_KERNEL
(
log_double_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
LogDoubleGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/activation_kernel.cu
浏览文件 @
13c99434
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/activation_
grad_
impl.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
...
...
@@ -93,6 +93,10 @@ DEFINE_GPU_ACTIVATION_KERNEL(TanhShrink, CudaTanhShrinkFunctor)
DEFINE_GPU_ACTIVATION_KERNEL
(
Silu
,
CudaSiluFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Sigmoid
,
CudaSigmoidFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
LogSigmoid
,
CudaLogSigmoidFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log
,
CudaLogFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log2
,
CudaLog2Functor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log10
,
CudaLog10Functor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Log1p
,
CudaLog1pFunctor
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
...
...
@@ -164,3 +168,7 @@ PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
sigmoid
,
SigmoidKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
logsigmoid
,
LogSigmoidKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
hard_sigmoid
,
HardSigmoidKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log
,
LogKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log2
,
Log2Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log10
,
Log10Kernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
log1p
,
Log1pKernel
)
paddle/phi/kernels/impl/activation_grad_impl.h
浏览文件 @
13c99434
...
...
@@ -275,4 +275,22 @@ void SigmoidTripleGradKernel(const Context& dev_ctx,
d_ddx
);
}
template
<
typename
T
,
typename
Context
>
void
LogDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
ddx
,
DenseTensor
*
dx
,
DenseTensor
*
ddout
)
{
if
(
dx
)
{
dx
->
Resize
(
x
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dx
);
}
if
(
ddout
)
{
dev_ctx
.
template
Alloc
<
T
>(
ddout
);
}
funcs
::
LogGradGradFunctor
<
T
>
functor
;
functor
(
dev_ctx
,
&
x
,
&
ddx
,
ddout
,
&
dout
,
dx
);
}
}
// namespace phi
paddle/phi/ops/compat/activation_sig.cc
浏览文件 @
13c99434
...
...
@@ -57,6 +57,10 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardShrink, "hard_shrink", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
TanhShrink
,
"tanh_shrink"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Silu
,
"silu"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
LogSigmoid
,
"logsigmoid"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log
,
"log"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log2
,
"log2"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log10
,
"log10"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP
(
Log1p
,
"log1p"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Relu
,
"relu"
,
);
// NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP
(
Tanh
,
"tanh"
,
);
// NOLINT
...
...
@@ -125,6 +129,12 @@ KernelSignature EluDoubleGradOpArgumentMapping(
"elu_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{
"alpha"
},
{
"DX"
,
"DDOut"
});
}
KernelSignature
LogDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"log_double_grad"
,
{
"X"
,
"DOut"
,
"DDX"
},
{},
{
"DX"
,
"DDOut"
});
}
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
relu_grad_grad
,
relu_double_grad
);
...
...
@@ -134,6 +144,7 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink, soft_shrink);
PD_REGISTER_BASE_KERNEL_NAME
(
softshrink_grad
,
soft_shrink_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elu_grad_grad
,
elu_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
sigmoid_grad_grad
,
sigmoid_double_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
log_grad_grad
,
log_double_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
cos_grad
,
phi
::
CosGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
tan_grad
,
phi
::
TanGradOpArgumentMapping
);
...
...
@@ -181,3 +192,8 @@ PD_REGISTER_ARG_MAPPING_FN(logsigmoid_grad,
phi
::
LogSigmoidGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
hard_sigmoid_grad
,
phi
::
HardSigmoidGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log_grad
,
phi
::
LogGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log_grad_grad
,
phi
::
LogDoubleGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log2_grad
,
phi
::
Log2GradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log10_grad
,
phi
::
Log10GradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
log1p_grad
,
phi
::
Log1pGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录