Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eca8dcc7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eca8dcc7
编写于
4月 27, 2021
作者:
Z
Zhang Zheng
提交者:
GitHub
4月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify the implementation of activation operation (#32348)
上级
6f6e159a
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
759 addition
and
357 deletion
+759
-357
paddle/fluid/operators/activation_op.cu
paddle/fluid/operators/activation_op.cu
+757
-355
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+2
-2
未找到文件。
paddle/fluid/operators/activation_op.cu
浏览文件 @
eca8dcc7
...
...
@@ -10,382 +10,719 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
float16
=
paddle
::
platform
::
float16
;
template
<
typename
T
>
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
// relu(x) = max(x, 0)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
>
zero
?
args
[
0
]
:
zero
;
}
};
template
<
typename
T
>
struct
CudaVecType
{
using
type
=
T
;
static
constexpr
int
vecsize
=
1
;
struct
CudaReluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
// dx = dout * (out > 0)
// Inputs: args[0], the input dout
// args[1], the input out
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
1
]
>
zero
?
args
[
0
]
:
zero
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
template
<
>
struct
CudaVecType
<
platform
::
float16
>
{
using
type
=
__half2
;
static
constexpr
int
vecsize
=
2
;
template
<
typename
T
>
struct
CudaLeakyReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
// leakyrelu(x) = x > 0 ? x : alpha * x
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
>
zero
?
args
[
0
]
:
static_cast
<
T
>
(
alpha
)
*
args
[
0
];
}
};
template
<
>
struct
CudaVecType
<
float
>
{
using
type
=
float4
;
static
constexpr
int
vecsize
=
4
;
template
<
typename
T
>
struct
CudaLeakyReluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
// dx = dout * (x > 0 ? 1 : alpha)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
1
]
>
zero
?
args
[
0
]
:
static_cast
<
T
>
(
alpha
)
*
args
[
0
];
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
class
BaseGPUFunctor
{
public:
using
ELEMENT_TYPE
=
T
;
struct
CudaSigmoidFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// sigmoid(x) = 1 / (1 + exp(-x))
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
one
/
(
one
+
exp
(
-
x
)));
}
};
using
AttrPair
=
std
::
vector
<
std
::
pair
<
const
char
*
,
float
*>>
;
template
<
typename
T
>
struct
CudaSigmoidGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one
=
static_cast
<
T
>
(
1.0
f
);
// dx = dout * out * (1 - out)
// Inputs: args[0], the input dout
// args[1], the input out
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
*
args
[
1
]
*
(
one
-
args
[
1
]);
}
AttrPair
GetAttrs
()
{
return
AttrPair
()
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
/* ========================================================================== */
template
<
typename
T
>
struct
CudaSiluFunctor
:
public
BaseActivationFunctor
<
T
>
{
// MPType means Compute Type
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// silu(x) = x / (1 + exp(-x))
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
x
/
(
one
+
exp
(
-
x
)));
}
};
/* =========================== relu forward ============================ */
template
<
typename
T
>
class
ReluGPUFunctor
:
public
BaseGPUFunctor
<
T
>
{
private:
T
zero_
;
struct
CudaSiluGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
MPType
temp
=
one
/
(
one
+
exp
(
-
x
));
return
static_cast
<
T
>
(
dout
*
(
temp
*
(
one
+
x
*
(
one
-
temp
))));
}
public:
ReluGPUFunctor
()
{
zero_
=
static_cast
<
T
>
(
0.0
f
);
}
// for relu forward when T is double
__device__
__forceinline__
typename
CudaVecType
<
T
>::
type
Compute
(
const
typename
CudaVecType
<
T
>::
type
in
)
{
// relu forward : out = max(x, 0)
return
in
>
zero_
?
in
:
zero_
;
}
// when num % vecsize != 0 this func will be used
__device__
__forceinline__
T
ComputeRemainder
(
const
T
in
)
{
// relu forward : out = max(x, 0)
return
in
>
zero_
?
in
:
zero_
;
}
};
template
<
>
__device__
__forceinline__
CudaVecType
<
float
>::
type
ReluGPUFunctor
<
float
>::
Compute
(
const
CudaVecType
<
float
>::
type
in
)
{
// relu forward : out = max(in, 0)
return
make_float4
((
in
.
x
>
zero_
)
*
(
in
.
x
),
(
in
.
y
>
zero_
)
*
(
in
.
y
),
(
in
.
z
>
zero_
)
*
(
in
.
z
),
(
in
.
w
>
zero_
)
*
(
in
.
w
));
}
template
<
>
__device__
__forceinline__
CudaVecType
<
float16
>::
type
ReluGPUFunctor
<
float16
>::
Compute
(
const
CudaVecType
<
float16
>::
type
in
)
{
// relu forward : out = max(in, 0)
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const
half2
kzero
=
__float2half2_rn
(
0.0
f
);
return
__hmul2
(
__hgt2
(
in
,
kzero
),
in
);
#else
const
float2
xx
=
__half22float2
(
in
);
return
__floats2half2_rn
((
xx
.
x
>
0.0
f
)
*
static_cast
<
float
>
(
xx
.
x
),
(
xx
.
y
>
0.0
f
)
*
static_cast
<
float
>
(
xx
.
y
));
#endif
}
/* ========================================================================== */
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
/* =========================== relu backward ============================
*/
template
<
typename
T
>
struct
CudaLogSigmoidFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
MPType
zero
=
static_cast
<
MPType
>
(
0.0
f
);
// logsigmoid(x) = log(1 / (1 + exp(-x)))
// For numerical stability,
// logsigmoid(x) =
// - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
temp
=
x
>
zero
?
zero
:
-
x
;
return
static_cast
<
T
>
(
-
temp
-
log
(
exp
(
-
temp
)
+
exp
(
-
x
-
temp
)));
}
};
template
<
typename
T
>
class
ReluGradGPUFunctor
:
public
BaseGPUFunctor
<
T
>
{
private:
T
zero_
;
struct
CudaLogSigmoidGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
MPType
zero
=
static_cast
<
MPType
>
(
0.0
f
);
// dx = dout * exp(-x) / (1 + exp(-x))
// For numerical stability:
// dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
// 0)))
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
MPType
temp1
=
x
>
zero
?
zero
:
-
x
;
MPType
temp2
=
exp
(
-
x
-
temp1
);
return
static_cast
<
T
>
(
dout
*
(
temp2
/
(
exp
(
-
temp1
)
+
temp2
)));
}
public:
ReluGradGPUFunctor
()
{
zero_
=
static_cast
<
T
>
(
0.0
f
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaAtanFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// atan(x) = atan(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
atan
(
x
));
}
};
template
<
typename
T
>
struct
CudaAtanGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one
=
static_cast
<
T
>
(
1.0
f
);
// dx = dout / (1 + x^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
/
(
one
+
args
[
1
]
*
args
[
1
]);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaSoftShrinkFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
// softshrink(x) = x - lambda, if x > lambda;
// x + lambda, if x < -lambda;
// 0, otherwise.
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
T
x
=
args
[
0
];
T
l
=
static_cast
<
T
>
(
lambda
);
T
temp1
=
static_cast
<
T
>
(
x
>
l
);
T
temp2
=
static_cast
<
T
>
(
x
<
-
l
);
return
temp1
*
(
x
-
l
)
+
temp2
*
(
x
+
l
);
}
};
template
<
typename
T
>
struct
CudaSoftShrinkGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
float
lambda
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"lambda"
,
&
lambda
}};
}
// dx = dout, if x > lambda or x < -lambda else 0
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
T
x
=
args
[
1
];
T
l
=
static_cast
<
T
>
(
lambda
);
return
(
x
>=
-
l
&&
x
<=
l
)
?
zero
:
args
[
0
];
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaCeilFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// ceil(x) = ceil(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
ceil
(
x
));
}
};
template
<
typename
T
>
struct
CudaFloorFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// floor(x) = floor(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
floor
(
x
));
}
};
template
<
typename
T
>
struct
CudaRoundFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// round(x) = round(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
round
(
x
));
}
};
// grad functor for ceil, floor and round
template
<
typename
T
>
struct
CudaZeroGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
static_cast
<
T
>
(
0.0
f
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kNoDeps
;
}
};
template
<
typename
T
>
struct
CudaCosFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// cos(x) = cos(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
cos
(
x
));
}
};
template
<
typename
T
>
struct
CudaCosGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// dx = dout * (-sin(x))
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
return
static_cast
<
T
>
(
-
dout
*
sin
(
x
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaSinFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// sin(x) = sin(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
sin
(
x
));
}
};
template
<
typename
T
>
struct
CudaSinGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// dx = dout * cos(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
return
static_cast
<
T
>
(
dout
*
cos
(
x
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaTanFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// tan(x) = tan(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
tan
(
x
));
}
};
template
<
typename
T
>
struct
CudaTanGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// dx = dout / cos(x)^2
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
return
static_cast
<
T
>
(
dout
/
(
cos
(
x
)
*
cos
(
x
)));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
// for relu backward when T is double
__device__
__forceinline__
typename
CudaVecType
<
T
>::
type
Compute
(
const
typename
CudaVecType
<
T
>::
type
out
,
const
typename
CudaVecType
<
T
>::
type
dout
)
{
return
out
>
zero_
?
dout
:
zero_
;
template
<
typename
T
>
struct
CudaAsinFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// asin(x) = asin(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
asin
(
x
));
}
};
// when num % vecsize != 0 this func will be used
__device__
__forceinline__
T
ComputeRemainder
(
const
T
out
,
const
T
dout
)
{
// relu backward : dx = out > 0 ? dout : 0
return
out
>
zero_
?
dout
:
zero_
;
template
<
typename
T
>
struct
CudaAsinGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// dx = dout / sqrt(1 - x^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
return
static_cast
<
T
>
(
dout
/
sqrt
(
one
-
x
*
x
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaAcosFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// acos(x) = acos(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
acos
(
x
));
}
};
template
<
typename
T
>
struct
CudaAcosGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
MPType
one
=
static_cast
<
MPType
>
(
1.0
f
);
// dx = -dout / sqrt(1 - x^2)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
return
static_cast
<
T
>
(
-
dout
/
sqrt
(
one
-
x
*
x
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaCoshFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// cosh(x) = cosh(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
cosh
(
x
));
}
};
template
<
typename
T
>
struct
CudaCoshGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// dx = dout * sinh(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
return
static_cast
<
T
>
(
dout
*
sinh
(
x
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaSinhFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// sinh(x) = sinh(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
sinh
(
x
));
}
};
template
<
typename
T
>
struct
CudaSinhGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// dx = dout * cosh(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
dout
=
static_cast
<
MPType
>
(
args
[
0
]);
MPType
x
=
static_cast
<
MPType
>
(
args
[
1
]);
return
static_cast
<
T
>
(
dout
*
cosh
(
x
));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaTanhFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// tanh(x) = tanh(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
tanh
(
x
));
}
};
template
<
typename
T
>
struct
CudaTanhGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one
=
static_cast
<
T
>
(
1.0
f
);
// dx = dout * (1 - out^2)
// Inputs: args[0], the input dout
// args[1], the input out
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
T
dout
=
static_cast
<
T
>
(
args
[
0
]);
T
out
=
static_cast
<
T
>
(
args
[
1
]);
return
dout
*
(
one
-
out
*
out
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
template
<
>
__device__
__forceinline__
CudaVecType
<
float
>::
type
ReluGradGPUFunctor
<
float
>::
Compute
(
const
CudaVecType
<
float
>::
type
out
,
const
CudaVecType
<
float
>::
type
dout
)
{
// relu backward : dx = out > 0 ? dout : 0;
return
make_float4
((
out
.
x
>
zero_
)
*
(
dout
.
x
),
(
out
.
y
>
zero_
)
*
(
dout
.
y
),
(
out
.
z
>
zero_
)
*
(
dout
.
z
),
(
out
.
w
>
zero_
)
*
(
dout
.
w
));
}
template
<
>
__device__
__forceinline__
CudaVecType
<
float16
>::
type
ReluGradGPUFunctor
<
float16
>::
Compute
(
const
CudaVecType
<
float16
>::
type
out
,
const
CudaVecType
<
float16
>::
type
dout
)
{
// relu backward : dx = out > 0 ? dout : 0;
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const
half2
kzero
=
__float2half2_rn
(
0.0
f
);
return
__hmul2
(
__hgt2
(
out
,
kzero
),
dout
);
#else
const
float2
xx
=
__half22float2
(
out
);
const
float2
yy
=
__half22float2
(
dout
);
return
__floats2half2_rn
((
xx
.
x
>
0.0
f
)
*
static_cast
<
float
>
(
yy
.
x
),
(
xx
.
y
>
0.0
f
)
*
static_cast
<
float
>
(
yy
.
y
));
#endif
}
template
<
typename
T
>
struct
CudaReciprocalFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one
=
static_cast
<
T
>
(
1.0
f
);
// reciprocal(x) = 1 / x
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
one
/
args
[
0
];
}
};
/* ========================================================================== */
/* ======================== leaky relu forward ========================
*/
template
<
typename
T
>
class
LeakyReluGPUFunctor
:
public
BaseGPUFunctor
<
T
>
{
private:
T
zero_
;
float
alpha_
;
struct
CudaReciprocalGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
// dx = -dout * out^2
// Inputs: args[0], the input dout
// args[1], the input out
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
-
args
[
0
]
*
args
[
1
]
*
args
[
1
];
}
public:
LeakyReluGPUFunctor
()
{
zero_
=
static_cast
<
T
>
(
0.0
f
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha_
}};
}
// leakyrelu forward : out = x > 0 ? x : x * alpha
__device__
__forceinline__
typename
CudaVecType
<
T
>::
type
Compute
(
const
typename
CudaVecType
<
T
>::
type
in
)
{
return
in
>
zero_
?
in
:
static_cast
<
T
>
(
alpha_
)
*
in
;
}
__device__
__forceinline__
T
ComputeRemainder
(
const
T
in
)
{
// leakyrelu forward : out = x > 0 ? x : x * alpha
return
in
>
zero_
?
in
:
static_cast
<
T
>
(
alpha_
)
*
in
;
}
};
template
<
>
__device__
__forceinline__
CudaVecType
<
float
>::
type
LeakyReluGPUFunctor
<
float
>::
Compute
(
const
CudaVecType
<
float
>::
type
in
)
{
// leakyrelu forward : out = x > 0 ? x : x * alpha
return
make_float4
((
in
.
x
>
zero_
)
?
(
in
.
x
)
:
(
in
.
x
)
*
alpha_
,
(
in
.
y
>
zero_
)
?
(
in
.
y
)
:
(
in
.
y
)
*
alpha_
,
(
in
.
z
>
zero_
)
?
(
in
.
z
)
:
(
in
.
z
)
*
alpha_
,
(
in
.
w
>
zero_
)
?
(
in
.
w
)
:
(
in
.
w
)
*
alpha_
);
}
template
<
>
__device__
__forceinline__
CudaVecType
<
float16
>::
type
LeakyReluGPUFunctor
<
float16
>::
Compute
(
const
CudaVecType
<
float16
>::
type
in
)
{
// leakyrelu forward : out = x > 0 ? x : x * alpha
const
float2
xx
=
__half22float2
(
in
);
return
__floats2half2_rn
((
xx
.
x
>
0.0
f
)
?
xx
.
x
:
xx
.
x
*
alpha_
,
(
xx
.
y
>
0.0
f
)
?
xx
.
y
:
xx
.
y
*
alpha_
);
}
/* ========================================================================== */
template
<
typename
T
>
struct
CudaExpFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// exp(x) = exp(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
exp
(
x
));
}
};
/* =========================== leaky relu backward =======================
*/
template
<
typename
T
>
class
LeakyReluGradGPUFunctor
:
public
BaseGPUFunctor
<
T
>
{
private:
T
zero_
;
float
alpha_
;
struct
CudaExpGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
// dx = dout * out
// Inputs: args[0], the input dout
// args[1], the input out
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
*
args
[
1
];
}
public:
LeakyReluGradGPUFunctor
()
{
zero_
=
static_cast
<
T
>
(
0.0
f
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha_
}};
template
<
typename
T
>
struct
CudaLogFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// log(x) = log(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
log
(
x
));
}
};
// for leaky relu backward when T is double
__device__
__forceinline__
typename
CudaVecType
<
T
>::
type
Compute
(
const
typename
CudaVecType
<
T
>::
type
in
,
const
typename
CudaVecType
<
T
>::
type
dout
)
{
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return
in
>
zero_
?
dout
:
static_cast
<
T
>
(
alpha_
)
*
dout
;
template
<
typename
T
>
struct
CudaLogGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
// dx = dout / x
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
/
args
[
1
];
}
// when num % vecsize != 0 this func will be used
__device__
__forceinline__
T
ComputeRemainder
(
const
T
in
,
const
T
dout
)
{
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return
in
>
zero_
?
dout
:
static_cast
<
T
>
(
alpha_
)
*
dout
;
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
typename
T
>
struct
CudaSquareFunctor
:
public
BaseActivationFunctor
<
T
>
{
// square(x) = x * x
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
*
args
[
0
];
}
};
template
<
typename
T
>
struct
CudaSquareGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
two
=
static_cast
<
T
>
(
2.0
f
);
// dx = dout * 2 * x
// Inputs: args[0], the input dout
// args[1], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
*
two
*
args
[
1
];
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
template
<
>
__device__
__forceinline__
CudaVecType
<
float
>::
type
LeakyReluGradGPUFunctor
<
float
>::
Compute
(
const
CudaVecType
<
float
>::
type
in
,
const
CudaVecType
<
float
>::
type
dout
)
{
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
return
make_float4
((
in
.
x
>
zero_
)
?
(
dout
.
x
)
:
alpha_
*
(
dout
.
x
),
(
in
.
y
>
zero_
)
?
(
dout
.
y
)
:
alpha_
*
(
dout
.
y
),
(
in
.
z
>
zero_
)
?
(
dout
.
z
)
:
alpha_
*
(
dout
.
z
),
(
in
.
w
>
zero_
)
?
(
dout
.
w
)
:
alpha_
*
(
dout
.
w
));
}
template
<
>
__device__
__forceinline__
CudaVecType
<
float16
>::
type
LeakyReluGradGPUFunctor
<
float16
>::
Compute
(
const
CudaVecType
<
float16
>::
type
in
,
const
CudaVecType
<
float16
>::
type
dout
)
{
// leakyrelu backward : dx = x > 0 ? dout : alpha * dout
const
float2
xx
=
__half22float2
(
in
);
const
float2
yy
=
__half22float2
(
dout
);
return
__floats2half2_rn
((
xx
.
x
>
0.0
f
)
?
yy
.
x
:
alpha_
*
yy
.
x
,
(
xx
.
y
>
0.0
f
)
?
yy
.
y
:
alpha_
*
yy
.
y
);
}
template
<
typename
T
>
struct
CudaSqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// sqrt(x) = sqrt(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
sqrt
(
x
));
}
};
/* ========================================================================== */
template
<
typename
T
>
struct
CudaSqrtGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
one_half
=
static_cast
<
T
>
(
0.5
f
);
// dx = dout * 0.5 / out
// Inputs: args[0], the input dout
// args[1], the input out
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
return
one_half
*
args
[
0
]
/
args
[
1
];
}
template
<
typename
T
,
typename
Functor
>
__global__
void
ActivationGradKernelVec
(
const
T
*
forward_data
,
const
T
*
dout
,
T
*
dx
,
int
num
,
Functor
functor
)
{
using
VecType
=
typename
CudaVecType
<
T
>::
type
;
constexpr
int
vecsize
=
CudaVecType
<
T
>::
vecsize
;
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
loop
=
num
/
vecsize
;
int
tail
=
num
%
vecsize
;
const
VecType
*
in_forward
=
reinterpret_cast
<
const
VecType
*>
(
forward_data
);
const
VecType
*
in_dout
=
reinterpret_cast
<
const
VecType
*>
(
dout
);
VecType
*
out
=
reinterpret_cast
<
VecType
*>
(
dx
);
VecType
forward_vec
,
dout_vec
;
T
in_data
,
dout_data
;
for
(
int
i
=
idx
;
i
<
loop
;
i
+=
stride
)
{
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
forward_vec
=
__ldg
(
in_forward
+
i
);
dout_vec
=
__ldg
(
in_dout
+
i
);
#else
forward_vec
=
in_forward
[
i
];
dout_vec
=
in_dout
[
i
];
#endif
out
[
i
]
=
functor
.
Compute
(
forward_vec
,
dout_vec
);
}
while
(
idx
==
loop
&&
tail
)
{
in_data
=
forward_data
[
num
-
tail
];
dout_data
=
dout
[
num
-
tail
];
dx
[
num
-
tail
]
=
functor
.
ComputeRemainder
(
in_data
,
dout_data
);
--
tail
;
}
}
template
<
typename
T
,
typename
Functor
>
__global__
void
ActivationkernelVec
(
const
T
*
src
,
T
*
dst
,
int
num
,
Functor
functor
)
{
constexpr
int
vecsize
=
CudaVecType
<
T
>::
vecsize
;
using
VecType
=
typename
CudaVecType
<
T
>::
type
;
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
loop
=
num
/
vecsize
;
int
tail
=
num
%
vecsize
;
const
VecType
*
in
=
reinterpret_cast
<
const
VecType
*>
(
src
);
VecType
*
out
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
x_vec
;
for
(
int
i
=
idx
;
i
<
loop
;
i
+=
stride
)
{
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
x_vec
=
__ldg
(
in
+
i
);
#else
x_vec
=
in
[
i
];
#endif
out
[
i
]
=
functor
.
Compute
(
x_vec
);
}
while
(
idx
==
loop
&&
tail
)
{
dst
[
num
-
tail
]
=
functor
.
ComputeRemainder
(
src
[
num
-
tail
]);
--
tail
;
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
template
<
typename
T
>
struct
CudaRsqrtFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// rsqrt(x) = rsqrt(x)
// Inputs: args[0], the input x
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
MPType
x
=
static_cast
<
MPType
>
(
args
[
0
]);
return
static_cast
<
T
>
(
rsqrt
(
x
));
}
};
template
<
typename
T
>
struct
CudaRsqrtGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
minus_one_half
=
static_cast
<
T
>
(
-
0.5
f
);
// dx = dout * -0.5 / out^3
// Inputs: args[0], the input dout
// args[1], the input out
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
T
out
=
args
[
1
];
return
minus_one_half
*
args
[
0
]
*
out
*
out
*
out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
Activation
GPU
Kernel
class
Activation
Cuda
Kernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
c
ontext
)
const
override
{
const
framework
::
Tensor
*
in_
x
=
nullptr
;
void
Compute
(
const
framework
::
ExecutionContext
&
c
tx
)
const
override
{
const
framework
::
Tensor
*
x
=
nullptr
;
framework
::
Tensor
*
out
=
nullptr
;
ExtractActivationTensor
(
context
,
&
in_x
,
&
out
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
int
num
=
in_x
->
numel
();
const
T
*
input_data
=
in_x
->
data
<
T
>
();
T
*
output_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
num
*
sizeof
(
T
)));
int
block
=
512
;
#ifdef __HIPCC__
block
=
256
;
#endif
Functor
functor
;
ExtractActivationTensor
(
ctx
,
&
x
,
&
out
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
x
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out
};
auto
functor
=
Functor
();
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
c
ontext
.
Attr
<
float
>
(
attr
.
first
);
*
attr
.
second
=
c
tx
.
Attr
<
float
>
(
attr
.
first
);
}
constexpr
int
vecsize
=
CudaVecType
<
T
>::
vecsize
;
int
grid
=
max
((
num
/
vecsize
+
block
-
1
)
/
block
,
1
);
auto
stream
=
context
.
cuda_device_context
().
stream
();
ActivationkernelVec
<
T
,
Functor
><<<
grid
,
block
,
0
,
stream
>>>
(
input_data
,
output_data
,
num
,
functor
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
};
template
<
typename
DeviceContext
,
typename
Functor
>
class
ActivationGrad
GPU
Kernel
class
ActivationGrad
Cuda
Kernel
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
c
ontext
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
c
tx
)
const
override
{
const
framework
::
Tensor
*
x
,
*
out
,
*
d_out
;
framework
::
Tensor
*
d_x
=
nullptr
;
x
=
out
=
d_out
=
nullptr
;
ExtractActivationGradTensor
<
Functor
::
FwdDeps
()
>
(
c
ontext
,
&
x
,
&
out
,
&
d_out
,
ExtractActivationGradTensor
<
Functor
::
FwdDeps
()
>
(
c
tx
,
&
x
,
&
out
,
&
d_out
,
&
d_x
);
int
numel
=
d_out
->
numel
();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
dx_data
=
d_x
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
(),
static_cast
<
size_t
>
(
numel
*
sizeof
(
T
)));
auto
*
dout_data
=
d_out
->
data
<
T
>
();
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
functor
=
Functor
();
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
ctx
.
Attr
<
float
>
(
attr
.
first
);
}
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
d_out
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
d_x
};
auto
*
forward_data
=
dout_data
;
if
(
static_cast
<
int
>
(
Functor
::
FwdDeps
())
==
static_cast
<
int
>
(
kDepOut
))
{
// Only need forward output Out
forward_data
=
out
->
data
<
T
>
();
ins
.
push_back
(
out
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
else
if
(
static_cast
<
int
>
(
Functor
::
FwdDeps
())
==
static_cast
<
int
>
(
kDepX
))
{
// Only need forward input X
forward_data
=
x
->
data
<
T
>
();
}
int
block
=
512
;
#ifdef __HIPCC__
block
=
256
;
#endif
Functor
functor
;
auto
attrs
=
functor
.
GetAttrs
();
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
context
.
Attr
<
float
>
(
attr
.
first
);
ins
.
push_back
(
x
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
else
{
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
constexpr
int
vecsize
=
CudaVecType
<
T
>::
vecsize
;
int
grid
=
max
((
numel
/
vecsize
+
block
-
1
)
/
block
,
1
);
auto
stream
=
context
.
cuda_device_context
().
stream
();
ActivationGradKernelVec
<
T
,
Functor
><<<
grid
,
block
,
0
,
stream
>>>
(
forward_data
,
dout_data
,
dx_data
,
numel
,
functor
);
}
};
...
...
@@ -395,12 +732,13 @@ class ActivationGradGPUKernel
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
#define REGISTER_ACTIVATION_
CUDA_KERNEL(act_type, op_name, functor,
\
#define REGISTER_ACTIVATION_
GPU_KERNEL(act_type, op_name, functor,
\
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<double>>, \
act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
...
...
@@ -410,28 +748,28 @@ namespace plat = paddle::platform;
ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_ACTIVATION_OP
(
REGISTER_ACTIVATION_CUDA_KERNEL
);
#define REGISTER_ACTIVATION_
GPU_KERNEL(act_type, op_name, functor,
\
#define REGISTER_ACTIVATION_
CUDA_KERNEL(act_type, op_name, functor,
\
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::Activation
GPUKernel<paddle::platform::CUDADeviceContext,
\
act_type, ops::Activation
CudaKernel<paddle::platform::CUDADeviceContext,
\
ops::functor<float>>, \
ops::Activation
GPUKernel<paddle::platform::CUDADeviceContext,
\
ops::Activation
CudaKernel<paddle::platform::CUDADeviceContext,
\
ops::functor<double>>, \
ops::Activation
GPUKernel<plat::CUDADeviceContext,
\
ops::Activation
CudaKernel<plat::CUDADeviceContext,
\
ops::functor<plat::float16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradGPUKernel<plat::CUDADeviceContext, \
act_type##_grad, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGrad
GPUKernel<plat::CUDADeviceContext,
\
ops::ActivationGrad
CudaKernel<plat::CUDADeviceContext,
\
ops::grad_functor<double>>, \
ops::ActivationGrad
GPUKernel<plat::CUDADeviceContext,
\
ops::ActivationGrad
CudaKernel<plat::CUDADeviceContext,
\
ops::grad_functor<plat::float16>>);
/* ======================== leaky relu register ============================ */
REGISTER_ACTIVATION_
GPU_KERNEL
(
leaky_relu
,
LeakyRelu
,
LeakyReluGPU
Functor
,
LeakyReluGradGPU
Functor
);
REGISTER_ACTIVATION_
CUDA_KERNEL
(
leaky_relu
,
LeakyRelu
,
CudaLeakyRelu
Functor
,
CudaLeakyReluGrad
Functor
);
REGISTER_OP_CUDA_KERNEL
(
leaky_relu_grad_grad
,
...
...
@@ -444,7 +782,7 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* ======================== elu register ============================ */
REGISTER_ACTIVATION_
CUDA
_KERNEL
(
elu
,
ELU
,
ELUFunctor
,
ELUGradFunctor
);
REGISTER_ACTIVATION_
GPU
_KERNEL
(
elu
,
ELU
,
ELUFunctor
,
ELUGradFunctor
);
REGISTER_OP_CUDA_KERNEL
(
elu_grad_grad
,
ops
::
ELUDoubleGradKernel
<
plat
::
CUDADeviceContext
,
...
...
@@ -456,7 +794,8 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== relu register ============================ */
REGISTER_ACTIVATION_GPU_KERNEL
(
relu
,
Relu
,
ReluGPUFunctor
,
ReluGradGPUFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
relu
,
Relu
,
CudaReluFunctor
,
CudaReluGradFunctor
);
REGISTER_OP_CUDA_KERNEL
(
relu_grad_grad
,
...
...
@@ -469,7 +808,8 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== tanh register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL
(
tanh
,
Tanh
,
TanhFunctor
,
TanhGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
tanh
,
Tanh
,
CudaTanhFunctor
,
CudaTanhGradFunctor
);
REGISTER_OP_CUDA_KERNEL
(
tanh_grad_grad
,
...
...
@@ -482,7 +822,8 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_ACTIVATION_CUDA_KERNEL
(
sqrt
,
Sqrt
,
SqrtFunctor
,
SqrtGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
sqrt
,
Sqrt
,
CudaSqrtFunctor
,
CudaSqrtGradFunctor
);
REGISTER_OP_CUDA_KERNEL
(
sqrt_grad_grad
,
...
...
@@ -496,7 +837,8 @@ REGISTER_OP_CUDA_KERNEL(
/* =========================== rsqrt register =============================
*/
REGISTER_ACTIVATION_CUDA_KERNEL
(
rsqrt
,
Rsqrt
,
RsqrtFunctor
,
RsqrtGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
rsqrt
,
Rsqrt
,
CudaRsqrtFunctor
,
CudaRsqrtGradFunctor
);
REGISTER_OP_CUDA_KERNEL
(
rsqrt_grad_grad
,
...
...
@@ -510,24 +852,28 @@ REGISTER_OP_CUDA_KERNEL(
/* =========================== square register ============================ */
REGISTER_OP_CUDA_KERNEL
(
square
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareFunctor
<
float
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareFunctor
<
double
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareFunctor
<
int
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareFunctor
<
int64_t
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareFunctor
<
plat
::
float16
>>
);
square
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareFunctor
<
float
>>
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareFunctor
<
double
>>
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareFunctor
<
int
>>
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareFunctor
<
int64_t
>>
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareFunctor
<
plat
::
float16
>>
);
REGISTER_OP_CUDA_KERNEL
(
square_grad
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareGradFunctor
<
float
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareGradFunctor
<
double
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareGradFunctor
<
int
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareGradFunctor
<
int64_t
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
SquareGradFunctor
<
plat
::
float16
>>
);
square_grad
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareGradFunctor
<
float
>>
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareGradFunctor
<
double
>>
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareGradFunctor
<
int
>>
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareGradFunctor
<
int64_t
>>
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaSquareGradFunctor
<
plat
::
float16
>>
);
REGISTER_OP_CUDA_KERNEL
(
square_grad_grad
,
...
...
@@ -564,27 +910,29 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================== exp register ============================ */
REGISTER_OP_CUDA_KERNEL
(
exp
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpFunctor
<
float
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpFunctor
<
double
>>
,
exp
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaExpFunctor
<
float
>>
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaExpFunctor
<
double
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpFunctor
<
int
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpFunctor
<
int64_t
>>
,
ops
::
ActivationKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpFunctor
<
plat
::
float16
>>
);
ops
::
Activation
Cuda
Kernel
<
plat
::
CUDADeviceContext
,
ops
::
Cuda
ExpFunctor
<
plat
::
float16
>>
);
REGISTER_OP_CUDA_KERNEL
(
exp_grad
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpGradFunctor
<
float
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpGradFunctor
<
double
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpGradFunctor
<
int
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpGradFunctor
<
int64_t
>>
,
ops
::
ActivationGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ExpGradFunctor
<
plat
::
float16
>>
);
exp_grad
,
ops
::
ActivationGrad
Cuda
Kernel
<
plat
::
CUDADeviceContext
,
ops
::
Cuda
ExpGradFunctor
<
float
>>
,
ops
::
ActivationGrad
Cuda
Kernel
<
plat
::
CUDADeviceContext
,
ops
::
Cuda
ExpGradFunctor
<
double
>>
,
ops
::
ActivationGrad
Cuda
Kernel
<
plat
::
CUDADeviceContext
,
ops
::
Cuda
ExpGradFunctor
<
int
>>
,
ops
::
ActivationGrad
Cuda
Kernel
<
plat
::
CUDADeviceContext
,
ops
::
Cuda
ExpGradFunctor
<
int64_t
>>
,
ops
::
ActivationGrad
Cuda
Kernel
<
plat
::
CUDADeviceContext
,
ops
::
Cuda
ExpGradFunctor
<
plat
::
float16
>>
);
/* ========================================================================== */
/* ========================== Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL
(
log
,
Log
,
LogFunctor
,
LogGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
log
,
Log
,
CudaLogFunctor
,
Cuda
LogGradFunctor
);
REGISTER_OP_CUDA_KERNEL
(
log_grad_grad
,
ops
::
LogDoubleGradKernel
<
plat
::
CUDADeviceContext
,
...
...
@@ -594,3 +942,57 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
LogDoubleGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
LogGradGradFunctor
<
plat
::
float16
>>
);
/* ========================================================================== */
REGISTER_ACTIVATION_CUDA_KERNEL
(
sigmoid
,
Sigmoid
,
CudaSigmoidFunctor
,
CudaSigmoidGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
silu
,
Silu
,
CudaSiluFunctor
,
CudaSiluGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
logsigmoid
,
LogSigmoid
,
CudaLogSigmoidFunctor
,
CudaLogSigmoidGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
atan
,
Atan
,
CudaAtanFunctor
,
CudaAtanGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
softshrink
,
SoftShrink
,
CudaSoftShrinkFunctor
,
CudaSoftShrinkGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
ceil
,
Ceil
,
CudaCeilFunctor
,
CudaZeroGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
floor
,
Floor
,
CudaFloorFunctor
,
CudaZeroGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
cos
,
Cos
,
CudaCosFunctor
,
CudaCosGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
tan
,
Tan
,
CudaTanFunctor
,
CudaTanGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
acos
,
Acos
,
CudaAcosFunctor
,
CudaAcosGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
sin
,
Sin
,
CudaSinFunctor
,
CudaSinGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
asin
,
Asin
,
CudaAsinFunctor
,
CudaAsinGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
sinh
,
Sinh
,
CudaSinhFunctor
,
CudaSinhGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
cosh
,
Cosh
,
CudaCoshFunctor
,
CudaCoshGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
round
,
Round
,
CudaRoundFunctor
,
CudaZeroGradFunctor
);
REGISTER_ACTIVATION_CUDA_KERNEL
(
reciprocal
,
Reciprocal
,
CudaReciprocalFunctor
,
CudaReciprocalGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
log1p
,
Log1p
,
Log1pFunctor
,
Log1pGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
log2
,
Log2
,
Log2Functor
,
Log2GradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
log10
,
Log10
,
Log10Functor
,
Log10GradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
brelu
,
BRelu
,
BReluFunctor
,
BReluGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
soft_relu
,
SoftRelu
,
SoftReluFunctor
,
SoftReluGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
stanh
,
STanh
,
STanhFunctor
,
STanhGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
softplus
,
Softplus
,
SoftplusFunctor
,
SoftplusGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
softsign
,
Softsign
,
SoftsignFunctor
,
SoftsignGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
relu6
,
Relu6
,
Relu6Functor
,
Relu6GradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
tanh_shrink
,
TanhShrink
,
TanhShrinkFunctor
,
TanhShrinkGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
hard_shrink
,
HardShrink
,
HardShrinkFunctor
,
HardShrinkGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
hard_sigmoid
,
HardSigmoid
,
HardSigmoidFunctor
,
HardSigmoidGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
swish
,
Swish
,
SwishFunctor
,
SwishGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
thresholded_relu
,
ThresholdedRelu
,
ThresholdedReluFunctor
,
ThresholdedReluGradFunctor
);
REGISTER_ACTIVATION_GPU_KERNEL
(
hard_swish
,
HardSwish
,
HardSwishFunctor
,
HardSwishGradFunctor
);
paddle/fluid/operators/activation_op.h
浏览文件 @
eca8dcc7
...
...
@@ -455,7 +455,7 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
auto
temp1
=
x
<
static_cast
<
T
>
(
threshold
*
-
1.
f
);
auto
temp2
=
x
>
static_cast
<
T
>
(
threshold
);
out
.
device
(
d
)
=
x
*
(
temp1
+
temp2
).
template
cast
<
T
>();
out
.
device
(
d
)
=
x
*
(
temp1
||
temp2
).
template
cast
<
T
>();
}
};
...
...
@@ -472,7 +472,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
x
<
static_cast
<
T
>
(
threshold
*
-
1.
f
);
auto
temp2
=
x
>
static_cast
<
T
>
(
threshold
);
dx
.
device
(
d
)
=
dout
*
(
temp1
+
temp2
).
template
cast
<
T
>();
dx
.
device
(
d
)
=
dout
*
(
temp1
||
temp2
).
template
cast
<
T
>();
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录