Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
eae31856
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eae31856
编写于
7月 05, 2021
作者:
W
WangXi
提交者:
GitHub
7月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add fused elemwise gelu and optimize performance (#33480)
上级
fa5ddfd9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
202 addition
and
77 deletion
+202
-77
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+103
-76
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
+1
-1
paddle/fluid/operators/fused/fused_elemwise_activation_op.h
paddle/fluid/operators/fused/fused_elemwise_activation_op.h
+17
-0
paddle/fluid/operators/math/functors.h
paddle/fluid/operators/math/functors.h
+58
-0
python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py
...luid/tests/unittests/test_fused_elemwise_activation_op.py
+23
-0
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
eae31856
...
@@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
...
@@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
*mod = dividend_copy % divisor; \
*mod = dividend_copy % divisor; \
} while (0)
} while (0)
#define DIVUP(x, y) (((x) + (y)-1) / (y))
#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -2156,10 +2160,10 @@ template <typename T, typename CompoundFunctor, bool BcastY,
...
@@ -2156,10 +2160,10 @@ template <typename T, typename CompoundFunctor, bool BcastY,
static
__global__
void
FusedElemwiseAndActBroadcast1CUDAKernel
(
static
__global__
void
FusedElemwiseAndActBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
int
h
,
int
w
,
CompoundFunctor
compound_functor
,
const
T
*
x
,
const
T
*
y
,
int
h
,
int
w
,
CompoundFunctor
compound_functor
,
T
*
out
,
T
*
intermediate_out
)
{
T
*
out
,
T
*
intermediate_out
)
{
int
j
=
blockIdx
.
x
;
int
i
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
int
j
=
threadIdx
.
x
;
while
(
i
<
h
)
{
while
(
j
<
w
)
{
int
offset
=
i
*
w
+
j
;
int
offset
=
i
*
w
+
j
;
T
y_val
=
BcastY
?
y
[
j
]
:
y
[
offset
];
T
y_val
=
BcastY
?
y
[
j
]
:
y
[
offset
];
...
@@ -2185,7 +2189,7 @@ static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
...
@@ -2185,7 +2189,7 @@ static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
out
[
offset
]
=
compound_functor
.
GetOut
(
x_val
,
y_val
);
out
[
offset
]
=
compound_functor
.
GetOut
(
x_val
,
y_val
);
}
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
j
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
}
}
}
...
@@ -2196,8 +2200,8 @@ static void FusedElemwiseAndActBroadcast1CUDA(gpuStream_t stream, const T *x,
...
@@ -2196,8 +2200,8 @@ static void FusedElemwiseAndActBroadcast1CUDA(gpuStream_t stream, const T *x,
CompoundFunctor
compound_functor
,
CompoundFunctor
compound_functor
,
int
h
,
int
w
,
T
*
out
,
int
h
,
int
w
,
T
*
out
,
T
*
intermediate_out
)
{
T
*
intermediate_out
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
h
);
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
w
);
int
gird_size
=
w
;
int
gird_size
=
h
;
FusedElemwiseAndActBroadcast1CUDAKernel
<
FusedElemwiseAndActBroadcast1CUDAKernel
<
T
,
CompoundFunctor
,
BcastY
,
KeepIntermediateOut
,
T
,
CompoundFunctor
,
BcastY
,
KeepIntermediateOut
,
SameShapeOfIntermediateOutAndOut
><<<
gird_size
,
block_size
,
0
,
stream
>>>
(
SameShapeOfIntermediateOutAndOut
><<<
gird_size
,
block_size
,
0
,
stream
>>>
(
...
@@ -2585,106 +2589,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
...
@@ -2585,106 +2589,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const
T
*
x
,
const
T
*
y
,
const
T
*
intermediate_out
,
const
T
*
out
,
const
T
*
x
,
const
T
*
y
,
const
T
*
intermediate_out
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
DIntermediate_OP
dintermediate_op
,
T
*
dx
,
T
*
dy
,
T
*
d_intermediate
)
{
DIntermediate_OP
dintermediate_op
,
T
*
dx
,
T
*
dy
,
T
*
d_intermediate
)
{
int
j
=
blockIdx
.
x
;
__shared__
T
sdata
[
BLOCK_Y
][
BLOCK_X
];
int
i
=
threadIdx
.
x
;
size_t
idx
=
threadIdx
.
x
+
BLOCK_X
*
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
size_t
width_stride
=
gridDim
.
x
*
BLOCK_X
;
T
val
(
0
),
inter_val
(
0
);
int64_t
tmp_out_idx
,
x_idx
,
y_idx
;
size_t
full_w
=
ROUNDUP
(
w
,
BLOCK_X
);
T
zero
=
static_cast
<
T
>
(
0
);
T
zero
=
static_cast
<
T
>
(
0
);
do
{
for
(
size_t
j
=
idx
;
j
<
full_w
;
j
+=
width_stride
)
{
int
offset
=
i
*
w
+
j
;
T
val
(
0
),
inter_val
(
0
);
if
(
j
<
w
)
{
for
(
size_t
i
=
threadIdx
.
y
;
i
<
h
;
i
+=
BLOCK_Y
)
{
size_t
offset
=
i
*
w
+
j
;
tmp_out_idx
=
BcastY
?
j
:
offset
;
size_t
tmp_out_idx
=
BcastY
?
j
:
offset
;
y_idx
=
BcastY
?
j
:
offset
;
size_t
y_idx
=
BcastY
?
j
:
offset
;
x_idx
=
BcastY
?
offset
:
j
;
size_t
x_idx
=
BcastY
?
offset
:
j
;
T
x_val
=
(
x
==
nullptr
)
?
zero
:
x
[
x_idx
];
T
x_val
=
(
x
==
nullptr
)
?
zero
:
x
[
x_idx
];
T
y_val
=
(
y
==
nullptr
)
?
zero
:
y
[
y_idx
];
T
y_val
=
(
y
==
nullptr
)
?
zero
:
y
[
y_idx
];
if
(
SameShapeOfIntermediateOutAndOut
)
{
if
(
SameShapeOfIntermediateOutAndOut
)
{
tmp_out_idx
=
offset
;
tmp_out_idx
=
offset
;
}
}
if
(
dx
!=
nullptr
)
{
if
(
dx
!=
nullptr
)
{
T
tmp
=
UseIntermediateOut
T
tmp
=
UseIntermediateOut
?
dx_op
.
UseIntermediateOut
(
x_val
,
y_val
,
?
dx_op
.
UseIntermediateOut
(
x_val
,
y_val
,
intermediate_out
[
tmp_out_idx
],
intermediate_out
[
tmp_out_idx
],
out
[
offset
],
dout
[
offset
])
out
[
offset
],
dout
[
offset
])
:
dx_op
.
Recompute
(
x_val
,
y_val
,
out
[
offset
],
dout
[
offset
]);
:
dx_op
.
Recompute
(
x_val
,
y_val
,
out
[
offset
],
dout
[
offset
]);
if
(
BcastY
)
{
if
(
BcastY
)
{
dx
[
x_idx
]
=
tmp
;
dx
[
x_idx
]
=
tmp
;
}
else
{
}
else
{
val
+=
tmp
;
val
+=
tmp
;
}
}
}
}
if
(
dy
!=
nullptr
)
{
if
(
dy
!=
nullptr
)
{
T
tmp
=
UseIntermediateOut
T
tmp
=
UseIntermediateOut
?
dy_op
.
UseIntermediateOut
(
x_val
,
y_val
,
?
dy_op
.
UseIntermediateOut
(
x_val
,
y_val
,
intermediate_out
[
tmp_out_idx
],
intermediate_out
[
tmp_out_idx
],
out
[
offset
],
dout
[
offset
])
out
[
offset
],
dout
[
offset
])
:
dy_op
.
Recompute
(
x_val
,
y_val
,
out
[
offset
],
dout
[
offset
]);
:
dy_op
.
Recompute
(
x_val
,
y_val
,
out
[
offset
],
dout
[
offset
]);
if
(
BcastY
)
{
if
(
BcastY
)
{
val
+=
tmp
;
val
+=
tmp
;
}
else
{
}
else
{
dy
[
y_idx
]
=
tmp
;
dy
[
y_idx
]
=
tmp
;
}
}
}
}
if
(
d_intermediate
!=
nullptr
)
{
if
(
d_intermediate
!=
nullptr
)
{
T
tmp
=
UseIntermediateOut
T
tmp
=
UseIntermediateOut
?
dintermediate_op
.
UseIntermediateOut
(
?
dintermediate_op
.
UseIntermediateOut
(
y
[
y_idx
],
intermediate_out
[
tmp_out_idx
],
out
[
offset
],
y
[
y_idx
],
intermediate_out
[
tmp_out_idx
],
dout
[
offset
])
out
[
offset
],
dout
[
offset
])
:
dintermediate_op
.
Recompute
(
x_val
,
y_val
,
out
[
offset
],
:
dintermediate_op
.
Recompute
(
x_val
,
y_val
,
out
[
offset
],
dout
[
offset
]);
dout
[
offset
]);
if
(
SameShapeOfIntermediateOutAndOut
)
{
if
(
SameShapeOfIntermediateOutAndOut
)
{
d_intermediate
[
tmp_out_idx
]
=
tmp
;
d_intermediate
[
tmp_out_idx
]
=
tmp
;
}
else
{
}
else
{
inter_val
+=
tmp
;
inter_val
+=
tmp
;
}
}
}
}
}
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
// transpose, for ReduceSum with wrap
}
while
(
i
<
h
);
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
=
val
;
__syncthreads
();
val
=
sdata
[
threadIdx
.
x
][
threadIdx
.
y
];
#pragma unroll
for
(
int
i
=
BLOCK_X
>>
1
;
i
>
0
;
i
>>=
1
)
{
// reduce sum with wrap
val
+=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
val
,
i
);
}
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
size_t
idx_j
=
j
+
threadIdx
.
y
;
if
(
BcastY
)
{
if
(
BcastY
)
{
if
(
dy
)
{
if
(
dy
)
{
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
&&
(
idx_j
<
w
))
dy
[
idx_j
]
=
val
;
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
}
}
}
else
{
}
else
{
if
(
dx
)
{
if
(
dx
)
{
if
(
threadIdx
.
x
==
0
&&
(
idx_j
<
w
))
dx
[
idx_j
]
=
val
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dx
[
j
]
=
val
;
}
}
}
}
}
if
(
!
SameShapeOfIntermediateOutAndOut
)
{
if
(
!
SameShapeOfIntermediateOutAndOut
)
{
if
(
d_intermediate
)
{
if
(
d_intermediate
)
{
inter_val
=
paddle
::
platform
::
reduceSum
(
inter_val
,
tid
,
h
);
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
=
inter_val
;
if
(
threadIdx
.
x
==
0
)
{
__syncthreads
();
d_intermediate
[
j
]
=
inter_val
;
inter_val
=
sdata
[
threadIdx
.
x
][
threadIdx
.
y
];
#pragma unroll
for
(
int
i
=
BLOCK_X
>>
1
;
i
>
0
;
i
>>=
1
)
{
// reduce sum with wrap
inter_val
+=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
inter_val
,
i
);
}
if
(
threadIdx
.
x
==
0
&&
(
idx_j
<
w
))
d_intermediate
[
idx_j
]
=
inter_val
;
}
}
}
}
}
}
// end for
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
DIntermediate_OP
,
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
,
typename
DIntermediate_OP
,
bool
UseIntermediateOut
,
bool
BcastY
,
bool
UseIntermediateOut
,
bool
BcastY
,
bool
SameShapeOfIntermediateOutAndOut
>
bool
SameShapeOfIntermediateOutAndOut
>
static
void
FusedElemwiseAndActGradBroadcast1CUDA
(
static
void
FusedElemwiseAndActGradBroadcast1CUDA
(
gpuStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
intermediate_out
,
const
framework
::
ExecutionContext
&
ctx
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
const
T
*
intermediate_out
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DIntermediate_OP
dintermediate_op
,
T
*
dx
,
T
*
dy
,
T
*
d_intermediate
)
{
DX_OP
dx_op
,
DY_OP
dy_op
,
DIntermediate_OP
dintermediate_op
,
T
*
dx
,
T
*
dy
,
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
h
);
T
*
d_intermediate
)
{
int
gird_size
=
w
;
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
dim3
blocks
(
BLOCK_X
,
BLOCK_Y
);
int
max_gpu_threads
=
ctx
.
cuda_device_context
().
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_gpu_threads
/
(
BLOCK_X
*
BLOCK_Y
),
1
);
int
theory_block
=
(
w
+
BLOCK_X
-
1
)
/
BLOCK_X
;
dim3
grids
(
std
::
min
(
theory_block
,
max_blocks
));
FusedElemwiseAndActGradBroadcast1CUDAKernel
<
FusedElemwiseAndActGradBroadcast1CUDAKernel
<
T
,
DX_OP
,
DY_OP
,
DIntermediate_OP
,
UseIntermediateOut
,
BcastY
,
T
,
DX_OP
,
DY_OP
,
DIntermediate_OP
,
UseIntermediateOut
,
BcastY
,
SameShapeOfIntermediateOutAndOut
><<<
g
ird_size
,
block_size
,
0
,
stream
>>>
(
SameShapeOfIntermediateOutAndOut
><<<
g
rids
,
blocks
,
0
,
stream
>>>
(
x
,
y
,
intermediate_out
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dintermediate_op
,
x
,
y
,
intermediate_out
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dintermediate_op
,
dx
,
dy
,
d_intermediate
);
dx
,
dy
,
d_intermediate
);
}
}
...
@@ -2836,7 +2863,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
...
@@ -2836,7 +2863,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast1CUDA
<
T
,
DX_OP
,
DY_OP
,
DIntermediate_OP
,
FusedElemwiseAndActGradBroadcast1CUDA
<
T
,
DX_OP
,
DY_OP
,
DIntermediate_OP
,
UseIntermediateOut
,
BcastY
,
UseIntermediateOut
,
BcastY
,
SameShapeOfIntermediateOutAndOut
>
(
SameShapeOfIntermediateOutAndOut
>
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
()
,
x_data
,
y_data
,
ctx
,
x_data
,
y_data
,
intermediate_out
==
nullptr
?
nullptr
:
intermediate_out
->
data
<
T
>
(),
intermediate_out
==
nullptr
?
nullptr
:
intermediate_out
->
data
<
T
>
(),
out
->
data
<
T
>
(),
dout
->
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
dintermediate_op
,
out
->
data
<
T
>
(),
dout
->
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
dintermediate_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
...
...
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
浏览文件 @
eae31856
...
@@ -69,7 +69,7 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
...
@@ -69,7 +69,7 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
functors
.
size
(),
2
));
functors
.
size
(),
2
));
static
std
::
unordered_set
<
std
::
string
>
unary_fun
=
{
"scale"
,
"relu"
,
"tanh"
,
static
std
::
unordered_set
<
std
::
string
>
unary_fun
=
{
"scale"
,
"relu"
,
"tanh"
,
"sigmoid"
};
"sigmoid"
,
"gelu"
};
static
std
::
unordered_set
<
std
::
string
>
binary_fun
=
{
"elementwise_add"
,
static
std
::
unordered_set
<
std
::
string
>
binary_fun
=
{
"elementwise_add"
,
"elementwise_mul"
};
"elementwise_mul"
};
...
...
paddle/fluid/operators/fused/fused_elemwise_activation_op.h
浏览文件 @
eae31856
...
@@ -275,6 +275,13 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
...
@@ -275,6 +275,13 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
paddle
::
operators
::
math
::
SigmoidFunctor
<
T
>>
(
paddle
::
operators
::
math
::
SigmoidFunctor
<
T
>>
(
ctx
,
paddle
::
operators
::
math
::
MulFunctor
<
T
>
(),
ctx
,
paddle
::
operators
::
math
::
MulFunctor
<
T
>
(),
paddle
::
operators
::
math
::
SigmoidFunctor
<
T
>
(),
in_x
,
in_y
,
outputs
);
paddle
::
operators
::
math
::
SigmoidFunctor
<
T
>
(),
in_x
,
in_y
,
outputs
);
}
else
if
(
funcs_str
==
"gelu,elementwise_add"
)
{
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
GeluFunctor
<
T
>
,
paddle
::
operators
::
math
::
AddFunctor
<
T
>>
(
ctx
,
paddle
::
operators
::
math
::
GeluFunctor
<
T
>
(),
paddle
::
operators
::
math
::
AddFunctor
<
T
>
(),
in_x
,
in_y
,
outputs
);
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s has not been implemented."
,
funcs_str
));
"%s has not been implemented."
,
funcs_str
));
...
@@ -374,6 +381,16 @@ static void RunGradFunctors(
...
@@ -374,6 +381,16 @@ static void RunGradFunctors(
paddle
::
operators
::
math
::
SigmoidFunctor
<
T
>
(),
paddle
::
operators
::
math
::
SigmoidFunctor
<
T
>
(),
paddle
::
operators
::
math
::
SigmoidGradFunctor
<
T
>
(),
in_x
,
in_y
,
in_out
,
paddle
::
operators
::
math
::
SigmoidGradFunctor
<
T
>
(),
in_x
,
in_y
,
in_out
,
in_intermediate_out
,
in_out_grad
,
x_grad
,
y_grad
,
d_intermediate_out
);
in_intermediate_out
,
in_out_grad
,
x_grad
,
y_grad
,
d_intermediate_out
);
}
else
if
(
funcs_str
==
"gelu_grad,elementwise_add_grad"
)
{
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors
<
DeviceContext
,
T
,
paddle
::
operators
::
math
::
GeluGradFunctor
<
T
>
,
paddle
::
operators
::
math
::
AddFunctor
<
T
>
,
paddle
::
operators
::
math
::
AddGradFunctor
<
T
>
,
InPlace
>
(
ctx
,
paddle
::
operators
::
math
::
GeluGradFunctor
<
T
>
(),
paddle
::
operators
::
math
::
AddFunctor
<
T
>
(),
paddle
::
operators
::
math
::
AddGradFunctor
<
T
>
(),
in_x
,
in_y
,
in_out
,
in_intermediate_out
,
in_out_grad
,
x_grad
,
y_grad
,
d_intermediate_out
);
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s has not been implemented."
,
funcs_str
));
"%s has not been implemented."
,
funcs_str
));
...
...
paddle/fluid/operators/math/functors.h
浏览文件 @
eae31856
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -130,6 +131,63 @@ struct SigmoidGradFunctor {
...
@@ -130,6 +131,63 @@ struct SigmoidGradFunctor {
}
}
};
};
template
<
typename
T
>
struct
GeluFunctor
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
inline
HOSTDEVICE
T
operator
()(
T
x
)
{
// this function is tanh approximation of gelu
// actual gelu is:
// x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
MT
mx
=
static_cast
<
MT
>
(
x
);
MT
out
=
mx
*
static_cast
<
MT
>
(
0.5
)
*
(
static_cast
<
MT
>
(
1.0
)
+
tanh
(
static_cast
<
MT
>
(
0.79788456
)
*
mx
*
(
static_cast
<
MT
>
(
1
)
+
static_cast
<
MT
>
(
0.044715
)
*
mx
*
mx
)));
return
static_cast
<
T
>
(
out
);
}
};
template
<
typename
T
>
struct
GeluGradFunctor
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
inline
HOSTDEVICE
T
UseX
(
T
x
)
{
MT
mx
=
static_cast
<
MT
>
(
x
);
MT
tanh_out
=
tanh
(
static_cast
<
MT
>
(
0.79788456
)
*
mx
*
(
static_cast
<
MT
>
(
1
)
+
static_cast
<
MT
>
(
0.044715
)
*
mx
*
mx
));
MT
ans
=
static_cast
<
MT
>
(
0.5
)
*
mx
*
((
static_cast
<
MT
>
(
1
)
-
tanh_out
*
tanh_out
)
*
(
static_cast
<
MT
>
(
0.79788456
)
+
static_cast
<
MT
>
(
0.1070322243
)
*
mx
*
mx
))
+
static_cast
<
MT
>
(
0.5
)
*
(
static_cast
<
MT
>
(
1
)
+
tanh_out
);
return
static_cast
<
T
>
(
ans
);
}
inline
HOSTDEVICE
T
UseOut
(
T
x
)
{
MT
mx
=
static_cast
<
MT
>
(
x
);
MT
tanh_out
=
tanh
(
static_cast
<
MT
>
(
0.79788456
)
*
mx
*
(
static_cast
<
MT
>
(
1
)
+
static_cast
<
MT
>
(
0.044715
)
*
mx
*
mx
));
MT
ans
=
static_cast
<
MT
>
(
0.5
)
*
mx
*
((
static_cast
<
MT
>
(
1
)
-
tanh_out
*
tanh_out
)
*
(
static_cast
<
MT
>
(
0.79788456
)
+
static_cast
<
MT
>
(
0.1070322243
)
*
mx
*
mx
))
+
static_cast
<
MT
>
(
0.5
)
*
(
static_cast
<
MT
>
(
1
)
+
tanh_out
);
return
static_cast
<
T
>
(
ans
);
}
inline
HOSTDEVICE
T
UseXAndOut
(
T
x
,
T
out
)
{
MT
mx
=
static_cast
<
MT
>
(
x
);
MT
tanh_out
=
tanh
(
static_cast
<
MT
>
(
0.79788456
)
*
mx
*
(
static_cast
<
MT
>
(
1
)
+
static_cast
<
MT
>
(
0.044715
)
*
mx
*
mx
));
MT
ans
=
static_cast
<
MT
>
(
0.5
)
*
mx
*
((
static_cast
<
MT
>
(
1
)
-
tanh_out
*
tanh_out
)
*
(
static_cast
<
MT
>
(
0.79788456
)
+
static_cast
<
MT
>
(
0.1070322243
)
*
mx
*
mx
))
+
static_cast
<
MT
>
(
0.5
)
*
(
static_cast
<
MT
>
(
1
)
+
tanh_out
);
return
static_cast
<
T
>
(
ans
);
}
};
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py
浏览文件 @
eae31856
...
@@ -305,6 +305,15 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
...
@@ -305,6 +305,15 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
return
y
,
x
,
x
*
scale
,
y_bcast
*
(
x_bcast
*
scale
)
return
y
,
x
,
x
*
scale
,
y_bcast
*
(
x_bcast
*
scale
)
def
gelu_add_func
(
x
,
y
,
x_bcast
,
y_bcast
,
mode
=
0
):
im
=
x_bcast
+
y_bcast
out
=
im
*
0.5
*
(
1.0
+
np
.
tanh
(
0.79788456
*
im
*
(
1
+
0.044715
*
im
*
im
)))
if
mode
==
0
:
return
x
,
y
,
im
,
out
else
:
return
y
,
x
,
im
,
out
scale
=
0.1
scale
=
0.1
scale_add_func
=
partial
(
scale_add_func
,
scale
=
scale
)
scale_add_func
=
partial
(
scale_add_func
,
scale
=
scale
)
add_scale_func
=
partial
(
add_scale_func
,
scale
=
scale
)
add_scale_func
=
partial
(
add_scale_func
,
scale
=
scale
)
...
@@ -316,6 +325,7 @@ for mode in {0, 1}:
...
@@ -316,6 +325,7 @@ for mode in {0, 1}:
mul_scale_func
=
partial
(
mul_scale_func
,
mode
=
mode
)
mul_scale_func
=
partial
(
mul_scale_func
,
mode
=
mode
)
relu_add_func
=
partial
(
relu_add_func
,
mode
=
mode
)
relu_add_func
=
partial
(
relu_add_func
,
mode
=
mode
)
add_relu_func
=
partial
(
add_relu_func
,
mode
=
mode
)
add_relu_func
=
partial
(
add_relu_func
,
mode
=
mode
)
gelu_add_func
=
partial
(
gelu_add_func
,
mode
=
mode
)
for
save_intermediate_out
in
{
True
,
False
}:
for
save_intermediate_out
in
{
True
,
False
}:
suffix
=
(
"_save_intermediate_out"
if
save_intermediate_out
else
""
)
\
suffix
=
(
"_save_intermediate_out"
if
save_intermediate_out
else
""
)
\
...
@@ -343,6 +353,11 @@ for mode in {0, 1}:
...
@@ -343,6 +353,11 @@ for mode in {0, 1}:
'functor_list'
:
[
"elementwise_mul"
,
"scale"
],
'functor_list'
:
[
"elementwise_mul"
,
"scale"
],
'save_intermediate_out'
:
save_intermediate_out
,
'save_intermediate_out'
:
save_intermediate_out
,
})
})
create_test_class
(
'gelu_add'
+
suffix
,
gelu_add_func
,
{
'functor_list'
:
[
"gelu"
,
"elementwise_add"
],
'save_intermediate_out'
:
save_intermediate_out
,
})
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
create_test_class
(
create_test_class
(
'scale_add_fp16'
+
suffix
,
'scale_add_fp16'
+
suffix
,
...
@@ -388,6 +403,14 @@ for mode in {0, 1}:
...
@@ -388,6 +403,14 @@ for mode in {0, 1}:
},
},
dtype
=
np
.
float16
,
dtype
=
np
.
float16
,
grad_chek
=
False
)
grad_chek
=
False
)
create_test_class
(
'gelu_add_fp16'
+
suffix
,
gelu_add_func
,
{
'functor_list'
:
[
"gelu"
,
"elementwise_add"
],
'save_intermediate_out'
:
save_intermediate_out
,
},
dtype
=
np
.
float16
,
grad_chek
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
paddle
import
paddle
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录