Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
da7d2f29
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
da7d2f29
编写于
10月 20, 2022
作者:
S
sneaxiy
提交者:
GitHub
10月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-pick][Release/2.4] support pure bfloat16 for more ops
support pure bfloat16 for more ops
上级
c894d91d
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
160 addition
and
96 deletion
+160
-96
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
+11
-2
paddle/fluid/platform/device/gpu/gpu_primitives.h
paddle/fluid/platform/device/gpu/gpu_primitives.h
+56
-55
paddle/phi/kernels/empty_kernel.cc
paddle/phi/kernels/empty_kernel.cc
+1
-0
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+5
-3
paddle/phi/kernels/funcs/eigen/broadcast.cu
paddle/phi/kernels/funcs/eigen/broadcast.cu
+1
-0
paddle/phi/kernels/gpu/activation_grad_kernel.cu
paddle/phi/kernels/gpu/activation_grad_kernel.cu
+2
-1
paddle/phi/kernels/gpu/activation_kernel.cu
paddle/phi/kernels/gpu/activation_kernel.cu
+9
-2
paddle/phi/kernels/gpu/adam_kernel.cu
paddle/phi/kernels/gpu/adam_kernel.cu
+4
-2
paddle/phi/kernels/gpu/clip_grad_kernel.cu
paddle/phi/kernels/gpu/clip_grad_kernel.cu
+1
-0
paddle/phi/kernels/gpu/clip_kernel.cu
paddle/phi/kernels/gpu/clip_kernel.cu
+2
-1
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+4
-2
paddle/phi/kernels/gpu/embedding_kernel.cu
paddle/phi/kernels/gpu/embedding_kernel.cu
+2
-1
paddle/phi/kernels/gpu/gelu_grad_kernel.cu
paddle/phi/kernels/gpu/gelu_grad_kernel.cu
+2
-1
paddle/phi/kernels/gpu/gelu_kernel.cu
paddle/phi/kernels/gpu/gelu_kernel.cu
+2
-1
paddle/phi/kernels/gpu/pad3d_grad_kernel.cu
paddle/phi/kernels/gpu/pad3d_grad_kernel.cu
+2
-1
paddle/phi/kernels/gpu/pad3d_kernel.cu
paddle/phi/kernels/gpu/pad3d_kernel.cu
+1
-0
paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu
paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu
+3
-1
paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu
paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu
+8
-2
paddle/phi/kernels/gpu/selu_grad_kernel.cu
paddle/phi/kernels/gpu/selu_grad_kernel.cu
+7
-2
paddle/phi/kernels/gpu/tile_grad_kernel.cu
paddle/phi/kernels/gpu/tile_grad_kernel.cu
+2
-1
paddle/phi/kernels/gpu/where_grad_kernel.cu
paddle/phi/kernels/gpu/where_grad_kernel.cu
+4
-2
paddle/phi/kernels/gpu/where_kernel.cu
paddle/phi/kernels/gpu/where_kernel.cu
+10
-2
paddle/phi/kernels/impl/selu_kernel_impl.h
paddle/phi/kernels/impl/selu_kernel_impl.h
+8
-5
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+8
-6
python/paddle/optimizer/adam.py
python/paddle/optimizer/adam.py
+1
-1
python/paddle/tensor/stat.py
python/paddle/tensor/stat.py
+4
-2
未找到文件。
paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
浏览文件 @
da7d2f29
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
...
@@ -62,6 +63,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
...
@@ -62,6 +63,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
mat_type
=
CUDA_R_16F
;
mat_type
=
CUDA_R_16F
;
}
}
if
(
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
)
{
mat_type
=
CUDA_R_16BF
;
}
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
mat_type
=
CUDA_R_64F
;
mat_type
=
CUDA_R_64F
;
scale_type
=
CUDA_R_64F
;
scale_type
=
CUDA_R_64F
;
...
@@ -352,6 +356,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
...
@@ -352,6 +356,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
mat_type
=
CUDA_R_16F
;
mat_type
=
CUDA_R_16F
;
}
}
if
(
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
)
{
mat_type
=
CUDA_R_16BF
;
}
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
mat_type
=
CUDA_R_64F
;
mat_type
=
CUDA_R_64F
;
scale_type
=
CUDA_R_64F
;
scale_type
=
CUDA_R_64F
;
...
@@ -686,12 +693,14 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -686,12 +693,14 @@ REGISTER_OP_CUDA_KERNEL(
fused_gemm_epilogue
,
fused_gemm_epilogue
,
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
double
>
,
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
double
>
,
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
paddle
::
platform
::
float16
>
);
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
paddle
::
platform
::
float16
>
,
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
fused_gemm_epilogue_grad
,
fused_gemm_epilogue_grad
,
ops
::
FusedGemmEpilogueGradKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
FusedGemmEpilogueGradKernel
<
phi
::
GPUContext
,
float
>
,
ops
::
FusedGemmEpilogueGradKernel
<
phi
::
GPUContext
,
double
>
,
ops
::
FusedGemmEpilogueGradKernel
<
phi
::
GPUContext
,
double
>
,
ops
::
FusedGemmEpilogueGradKernel
<
phi
::
GPUContext
,
ops
::
FusedGemmEpilogueGradKernel
<
phi
::
GPUContext
,
paddle
::
platform
::
float16
>
);
paddle
::
platform
::
float16
>
,
ops
::
FusedGemmEpilogueKernel
<
phi
::
GPUContext
,
paddle
::
platform
::
bfloat16
>
);
#endif
#endif
paddle/fluid/platform/device/gpu/gpu_primitives.h
浏览文件 @
da7d2f29
...
@@ -198,61 +198,6 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr,
...
@@ -198,61 +198,6 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr,
T
value
)
{
T
value
)
{
CudaAtomicAdd
(
arr
+
index
,
value
);
CudaAtomicAdd
(
arr
+
index
,
value
);
}
}
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template
<
typename
T
,
typename
std
::
enable_if
<
!
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
VectorizedAtomicAddPerBlock
(
const
int64_t
len
,
int
tid
,
int
threads_per_block
,
const
T
*
in
,
T
*
out
)
{
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
CudaAtomicAdd
(
&
out
[
i
],
in
[
i
]);
}
}
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
VectorizedAtomicAddPerBlock
(
const
int64_t
len
,
int
tid
,
int
threads_per_block
,
const
T
*
in
,
T
*
out
)
{
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
CudaAtomicAdd
(
&
out
[
i
],
in
[
i
]);
}
#else
int
i
=
0
;
int
loops
=
len
/
2
*
2
;
bool
aligned_half2
=
(
reinterpret_cast
<
std
::
uintptr_t
>
(
out
)
%
sizeof
(
__half2
)
==
0
);
if
(
aligned_half2
)
{
for
(
i
=
tid
*
2
;
i
<
loops
;
i
+=
threads_per_block
*
2
)
{
__half2
value2
;
T
value_1
=
in
[
i
];
T
value_2
=
in
[
i
+
1
];
value2
.
x
=
*
reinterpret_cast
<
__half
*>
(
&
value_1
);
value2
.
y
=
*
reinterpret_cast
<
__half
*>
(
&
value_2
);
atomicAdd
(
reinterpret_cast
<
__half2
*>
(
&
out
[
i
]),
value2
);
}
for
(;
i
<
len
;
i
+=
threads_per_block
)
{
fastAtomicAdd
(
out
,
i
,
len
,
in
[
i
]);
}
}
else
{
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
fastAtomicAdd
(
out
,
i
,
len
,
in
[
i
]);
}
}
#endif
}
#endif
#endif
#endif
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
...
@@ -601,5 +546,61 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
...
@@ -601,5 +546,61 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
}
}
#endif
#endif
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template
<
typename
T
,
typename
std
::
enable_if
<
!
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
VectorizedAtomicAddPerBlock
(
const
int64_t
len
,
int
tid
,
int
threads_per_block
,
const
T
*
in
,
T
*
out
)
{
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
CudaAtomicAdd
(
&
out
[
i
],
in
[
i
]);
}
}
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_same
<
platform
::
float16
,
T
>
::
value
>::
type
*
=
nullptr
>
__device__
__forceinline__
void
VectorizedAtomicAddPerBlock
(
const
int64_t
len
,
int
tid
,
int
threads_per_block
,
const
T
*
in
,
T
*
out
)
{
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
CudaAtomicAdd
(
&
out
[
i
],
in
[
i
]);
}
#else
int
i
=
0
;
int
loops
=
len
/
2
*
2
;
bool
aligned_half2
=
(
reinterpret_cast
<
std
::
uintptr_t
>
(
out
)
%
sizeof
(
__half2
)
==
0
);
if
(
aligned_half2
)
{
for
(
i
=
tid
*
2
;
i
<
loops
;
i
+=
threads_per_block
*
2
)
{
__half2
value2
;
T
value_1
=
in
[
i
];
T
value_2
=
in
[
i
+
1
];
value2
.
x
=
*
reinterpret_cast
<
__half
*>
(
&
value_1
);
value2
.
y
=
*
reinterpret_cast
<
__half
*>
(
&
value_2
);
atomicAdd
(
reinterpret_cast
<
__half2
*>
(
&
out
[
i
]),
value2
);
}
for
(;
i
<
len
;
i
+=
threads_per_block
)
{
fastAtomicAdd
(
out
,
i
,
len
,
in
[
i
]);
}
}
else
{
for
(
int
i
=
tid
;
i
<
len
;
i
+=
threads_per_block
)
{
fastAtomicAdd
(
out
,
i
,
len
,
in
[
i
]);
}
}
#endif
}
#endif
#endif
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
paddle/phi/kernels/empty_kernel.cc
浏览文件 @
da7d2f29
...
@@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(empty,
...
@@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(empty,
int64_t
,
int64_t
,
bool
,
bool
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
phi
::
dtype
::
complex
<
double
>
)
{}
...
...
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
da7d2f29
...
@@ -2169,12 +2169,14 @@ struct CudaSeluFunctor : public BaseActivationFunctor<T> {
...
@@ -2169,12 +2169,14 @@ struct CudaSeluFunctor : public BaseActivationFunctor<T> {
}
}
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
T
res
=
x
;
using
MT
=
if
(
res
<=
zero
)
{
typename
std
::
conditional
<
(
sizeof
(
T
)
>
sizeof
(
float
)),
T
,
float
>::
type
;
MT
res
=
static_cast
<
MT
>
(
x
);
if
(
x
<=
zero
)
{
res
=
alpha
*
expf
(
res
)
-
alpha
;
res
=
alpha
*
expf
(
res
)
-
alpha
;
}
}
res
*=
scale
;
res
*=
scale
;
return
res
;
return
static_cast
<
T
>
(
res
)
;
}
}
private:
private:
...
...
paddle/phi/kernels/funcs/eigen/broadcast.cu
浏览文件 @
da7d2f29
...
@@ -84,6 +84,7 @@ INSTANTIATION(EigenBroadcast, int);
...
@@ -84,6 +84,7 @@ INSTANTIATION(EigenBroadcast, int);
INSTANTIATION
(
EigenBroadcast
,
int64_t
);
INSTANTIATION
(
EigenBroadcast
,
int64_t
);
INSTANTIATION
(
EigenBroadcastGrad
,
bool
);
INSTANTIATION
(
EigenBroadcastGrad
,
bool
);
INSTANTIATION
(
EigenBroadcastGrad
,
float
);
INSTANTIATION
(
EigenBroadcastGrad
,
float
);
INSTANTIATION
(
EigenBroadcastGrad
,
dtype
::
bfloat16
);
INSTANTIATION
(
EigenBroadcastGrad
,
dtype
::
float16
);
INSTANTIATION
(
EigenBroadcastGrad
,
dtype
::
float16
);
INSTANTIATION
(
EigenBroadcastGrad
,
double
);
INSTANTIATION
(
EigenBroadcastGrad
,
double
);
INSTANTIATION
(
EigenBroadcastGrad
,
dtype
::
complex
<
float
>
);
INSTANTIATION
(
EigenBroadcastGrad
,
dtype
::
complex
<
float
>
);
...
...
paddle/phi/kernels/gpu/activation_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -449,4 +449,5 @@ PD_REGISTER_KERNEL(pow_grad,
...
@@ -449,4 +449,5 @@ PD_REGISTER_KERNEL(pow_grad,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/activation_kernel.cu
浏览文件 @
da7d2f29
...
@@ -265,5 +265,12 @@ PD_REGISTER_KERNEL(pow,
...
@@ -265,5 +265,12 @@ PD_REGISTER_KERNEL(pow,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
PD_REGISTER_KERNEL
(
selu
,
GPU
,
ALL_LAYOUT
,
phi
::
SeluKernel
,
float
,
double
)
{}
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
selu
,
GPU
,
ALL_LAYOUT
,
phi
::
SeluKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/adam_kernel.cu
浏览文件 @
da7d2f29
...
@@ -373,7 +373,8 @@ PD_REGISTER_KERNEL(adam,
...
@@ -373,7 +373,8 @@ PD_REGISTER_KERNEL(adam,
phi
::
AdamDenseKernel
,
phi
::
AdamDenseKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{
// Skip beta1_pow, beta2_pow, skip_update data transform
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel
->
InputAt
(
5
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
5
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
6
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
6
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
@@ -386,7 +387,8 @@ PD_REGISTER_KERNEL(merged_adam,
...
@@ -386,7 +387,8 @@ PD_REGISTER_KERNEL(merged_adam,
phi
::
MergedAdamKernel
,
phi
::
MergedAdamKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{
// Skip beta1_pow, beta2_pow data transform
// Skip beta1_pow, beta2_pow data transform
kernel
->
InputAt
(
5
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
5
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
6
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
6
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
...
...
paddle/phi/kernels/gpu/clip_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip_grad,
...
@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip_grad,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/clip_kernel.cu
浏览文件 @
da7d2f29
...
@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip,
...
@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -249,7 +249,8 @@ PD_REGISTER_KERNEL(embedding_grad,
...
@@ -249,7 +249,8 @@ PD_REGISTER_KERNEL(embedding_grad,
phi
::
EmbeddingGradKernel
,
phi
::
EmbeddingGradKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
embedding_sparse_grad
,
PD_REGISTER_KERNEL
(
embedding_sparse_grad
,
GPU
,
GPU
,
...
@@ -257,4 +258,5 @@ PD_REGISTER_KERNEL(embedding_sparse_grad,
...
@@ -257,4 +258,5 @@ PD_REGISTER_KERNEL(embedding_sparse_grad,
phi
::
EmbeddingSparseGradKernel
,
phi
::
EmbeddingSparseGradKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/embedding_kernel.cu
浏览文件 @
da7d2f29
...
@@ -125,4 +125,5 @@ PD_REGISTER_KERNEL(embedding,
...
@@ -125,4 +125,5 @@ PD_REGISTER_KERNEL(embedding,
phi
::
EmbeddingKernel
,
phi
::
EmbeddingKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/gelu_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -99,4 +99,5 @@ PD_REGISTER_KERNEL(gelu_grad,
...
@@ -99,4 +99,5 @@ PD_REGISTER_KERNEL(gelu_grad,
phi
::
GeluGradKernel
,
phi
::
GeluGradKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/gelu_kernel.cu
浏览文件 @
da7d2f29
...
@@ -93,4 +93,5 @@ PD_REGISTER_KERNEL(gelu,
...
@@ -93,4 +93,5 @@ PD_REGISTER_KERNEL(gelu,
phi
::
GeluKernel
,
phi
::
GeluKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/pad3d_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -509,4 +509,5 @@ PD_REGISTER_KERNEL(pad3d_grad,
...
@@ -509,4 +509,5 @@ PD_REGISTER_KERNEL(pad3d_grad,
phi
::
Pad3dGradKernel
,
phi
::
Pad3dGradKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/pad3d_kernel.cu
浏览文件 @
da7d2f29
...
@@ -583,6 +583,7 @@ PD_REGISTER_KERNEL(pad3d,
...
@@ -583,6 +583,7 @@ PD_REGISTER_KERNEL(pad3d,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
Pad3dKernel
,
phi
::
Pad3dKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
float
,
float
,
double
,
double
,
int
,
int
,
...
...
paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(pixel_shuffle_grad,
...
@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(pixel_shuffle_grad,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
PixelShuffleGradKernel
,
phi
::
PixelShuffleGradKernel
,
float
,
float
,
double
)
{}
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu
浏览文件 @
da7d2f29
...
@@ -18,5 +18,11 @@
...
@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h"
#include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h"
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
pixel_shuffle
,
pixel_shuffle
,
GPU
,
ALL_LAYOUT
,
phi
::
PixelShuffleKernel
,
float
,
double
)
{}
GPU
,
ALL_LAYOUT
,
phi
::
PixelShuffleKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/selu_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -18,5 +18,10 @@
...
@@ -18,5 +18,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h"
#include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
selu_grad
,
selu_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
SeluGradKernel
,
float
,
double
)
{}
GPU
,
ALL_LAYOUT
,
phi
::
SeluGradKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/tile_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile_grad,
...
@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile_grad,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/where_grad_kernel.cu
浏览文件 @
da7d2f29
...
@@ -25,10 +25,10 @@ __global__ void WhereGradCUDAKernel(
...
@@ -25,10 +25,10 @@ __global__ void WhereGradCUDAKernel(
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
dx
!=
nullptr
)
{
if
(
dx
!=
nullptr
)
{
dx
[
idx
]
=
cond
[
idx
]
?
dout
[
idx
]
:
0.
;
dx
[
idx
]
=
cond
[
idx
]
?
dout
[
idx
]
:
static_cast
<
T
>
(
0.
)
;
}
}
if
(
dy
!=
nullptr
)
{
if
(
dy
!=
nullptr
)
{
dy
[
idx
]
=
cond
[
idx
]
?
0.
:
dout
[
idx
];
dy
[
idx
]
=
cond
[
idx
]
?
static_cast
<
T
>
(
0.
)
:
dout
[
idx
];
}
}
}
}
}
}
...
@@ -61,6 +61,8 @@ PD_REGISTER_KERNEL(where_grad,
...
@@ -61,6 +61,8 @@ PD_REGISTER_KERNEL(where_grad,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
WhereGradKernel
,
phi
::
WhereGradKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
float
,
float
,
double
,
double
,
int
,
int
,
...
...
paddle/phi/kernels/gpu/where_kernel.cu
浏览文件 @
da7d2f29
...
@@ -45,5 +45,13 @@ void WhereKernel(const Context& ctx,
...
@@ -45,5 +45,13 @@ void WhereKernel(const Context& ctx,
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
where
,
where
,
GPU
,
ALL_LAYOUT
,
phi
::
WhereKernel
,
float
,
double
,
int
,
int64_t
)
{}
GPU
,
ALL_LAYOUT
,
phi
::
WhereKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/impl/selu_kernel_impl.h
浏览文件 @
da7d2f29
...
@@ -57,14 +57,17 @@ struct SeluGradFunctor {
...
@@ -57,14 +57,17 @@ struct SeluGradFunctor {
dx_data_ptr_
(
dx_data_ptr
)
{}
dx_data_ptr_
(
dx_data_ptr
)
{}
HOSTDEVICE
void
operator
()(
size_t
idx
)
const
{
HOSTDEVICE
void
operator
()(
size_t
idx
)
const
{
T
y_ele
=
y_data_ptr_
[
idx
];
using
MT
=
T
dy_ele
=
dy_data_ptr_
[
idx
]
;
typename
std
::
conditional
<
(
sizeof
(
T
)
>
sizeof
(
float
)),
T
,
float
>::
type
;
float
tmp
=
scale_
;
auto
y_ele
=
static_cast
<
MT
>
(
y_data_ptr_
[
idx
]);
auto
dy_ele
=
static_cast
<
MT
>
(
dy_data_ptr_
[
idx
]);
auto
tmp
=
static_cast
<
MT
>
(
scale_
);
if
(
y_ele
<=
0
)
{
if
(
y_ele
<=
0
)
{
tmp
=
y_ele
+
la_
;
tmp
=
y_ele
+
static_cast
<
MT
>
(
la_
)
;
}
}
dx_data_ptr_
[
idx
]
=
dy_ele
*
tmp
;
dx_data_ptr_
[
idx
]
=
static_cast
<
T
>
(
dy_ele
*
tmp
)
;
}
}
const
T
*
y_data_ptr_
;
const
T
*
y_data_ptr_
;
const
T
*
dy_data_ptr_
;
const
T
*
dy_data_ptr_
;
...
...
python/paddle/fluid/clip.py
浏览文件 @
da7d2f29
...
@@ -52,8 +52,9 @@ def _clip_by_global_norm_using_mp_type(*args):
...
@@ -52,8 +52,9 @@ def _clip_by_global_norm_using_mp_type(*args):
def
_cast_to_mp_type_if_enabled
(
x
):
def
_cast_to_mp_type_if_enabled
(
x
):
if
x
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
and
_clip_by_global_norm_using_mp_type
(
if
(
x
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
):
or
x
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
)
and
_clip_by_global_norm_using_mp_type
():
return
x
.
astype
(
core
.
VarDesc
.
VarType
.
FP32
)
return
x
.
astype
(
core
.
VarDesc
.
VarType
.
FP32
)
else
:
else
:
return
x
return
x
...
@@ -65,7 +66,8 @@ def _squared_l2_norm(x):
...
@@ -65,7 +66,8 @@ def _squared_l2_norm(x):
"""
"""
x
=
_cast_to_mp_type_if_enabled
(
x
)
x
=
_cast_to_mp_type_if_enabled
(
x
)
if
core
.
is_compiled_with_xpu
()
or
x
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
core
.
is_compiled_with_xpu
(
)
or
x
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
or
x
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
square
=
layers
.
square
(
x
)
square
=
layers
.
square
(
x
)
sum_square
=
layers
.
reduce_sum
(
square
)
sum_square
=
layers
.
reduce_sum
(
square
)
return
sum_square
return
sum_square
...
@@ -501,7 +503,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
...
@@ -501,7 +503,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad
=
layers
.
get_tensor_from_selected_rows
(
merge_grad
)
merge_grad
=
layers
.
get_tensor_from_selected_rows
(
merge_grad
)
sum_square
=
_squared_l2_norm
(
merge_grad
)
sum_square
=
_squared_l2_norm
(
merge_grad
)
if
sum_square
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
sum_square
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
or
sum_square
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
sum_square_list_fp16
.
append
(
sum_square
)
sum_square_list_fp16
.
append
(
sum_square
)
elif
sum_square
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
elif
sum_square
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
sum_square_list_fp32
.
append
(
sum_square
)
sum_square_list_fp32
.
append
(
sum_square
)
...
@@ -554,8 +556,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
...
@@ -554,8 +556,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue
continue
# TODO(wangxi): use inplace elementwise_mul
# TODO(wangxi): use inplace elementwise_mul
if
need_clip
:
if
need_clip
:
clip_input
=
(
clip_var
.
astype
(
'float16'
)
if
g
.
dtype
clip_input
=
(
clip_var
.
astype
(
g
.
dtype
)
==
core
.
VarDesc
.
VarType
.
FP16
else
clip_var
)
if
clip_var
.
dtype
!=
g
.
dtype
else
clip_var
)
new_grad
=
layers
.
elementwise_mul
(
g
,
clip_input
)
new_grad
=
layers
.
elementwise_mul
(
g
,
clip_input
)
params_and_grads
.
append
((
p
,
new_grad
))
params_and_grads
.
append
((
p
,
new_grad
))
else
:
else
:
...
...
python/paddle/optimizer/adam.py
浏览文件 @
da7d2f29
...
@@ -275,7 +275,7 @@ class Adam(Optimizer):
...
@@ -275,7 +275,7 @@ class Adam(Optimizer):
def
_add_moments_pows
(
self
,
p
):
def
_add_moments_pows
(
self
,
p
):
acc_dtype
=
p
.
dtype
acc_dtype
=
p
.
dtype
if
acc_dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
acc_dtype
==
core
.
VarDesc
.
VarType
.
FP16
or
acc_dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
acc_dtype
=
core
.
VarDesc
.
VarType
.
FP32
acc_dtype
=
core
.
VarDesc
.
VarType
.
FP32
self
.
_add_accumulator
(
self
.
_moment1_acc_str
,
p
,
dtype
=
acc_dtype
)
self
.
_add_accumulator
(
self
.
_moment1_acc_str
,
p
,
dtype
=
acc_dtype
)
self
.
_add_accumulator
(
self
.
_moment2_acc_str
,
p
,
dtype
=
acc_dtype
)
self
.
_add_accumulator
(
self
.
_moment2_acc_str
,
p
,
dtype
=
acc_dtype
)
...
...
python/paddle/tensor/stat.py
浏览文件 @
da7d2f29
...
@@ -159,8 +159,10 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
...
@@ -159,8 +159,10 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
u
=
mean
(
x
,
axis
,
True
,
name
)
u
=
mean
(
x
,
axis
,
True
,
name
)
out
=
paddle
.
sum
((
x
-
u
)
**
2
,
axis
,
keepdim
=
keepdim
,
name
=
name
)
out
=
paddle
.
sum
((
x
-
u
)
**
2
,
axis
,
keepdim
=
keepdim
,
name
=
name
)
n
=
paddle
.
cast
(
paddle
.
numel
(
x
),
x
.
dtype
)
\
dtype
=
x
.
dtype
/
paddle
.
cast
(
paddle
.
numel
(
out
),
x
.
dtype
)
n
=
paddle
.
cast
(
paddle
.
numel
(
x
),
paddle
.
int64
)
\
/
paddle
.
cast
(
paddle
.
numel
(
out
),
paddle
.
int64
)
n
=
n
.
astype
(
dtype
)
if
unbiased
:
if
unbiased
:
one_const
=
paddle
.
ones
([
1
],
x
.
dtype
)
one_const
=
paddle
.
ones
([
1
],
x
.
dtype
)
n
=
where
(
n
>
one_const
,
n
-
1.
,
one_const
)
n
=
where
(
n
>
one_const
,
n
-
1.
,
one_const
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录