Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
87cc8d48
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
87cc8d48
编写于
9月 30, 2021
作者:
G
Guoxia Wang
提交者:
GitHub
9月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support fp16 (#35888) (#36191)
上级
dcd17d6b
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
56 addition
and
29 deletion
+56
-29
paddle/fluid/operators/elementwise/elementwise_max_op.cu
paddle/fluid/operators/elementwise/elementwise_max_op.cu
+4
-0
paddle/fluid/operators/elementwise/elementwise_max_op.h
paddle/fluid/operators/elementwise/elementwise_max_op.h
+2
-2
paddle/fluid/operators/p_norm_op.cu
paddle/fluid/operators/p_norm_op.cu
+48
-26
python/paddle/nn/functional/norm.py
python/paddle/nn/functional/norm.py
+2
-1
未找到文件。
paddle/fluid/operators/elementwise/elementwise_max_op.cu
浏览文件 @
87cc8d48
...
@@ -41,12 +41,16 @@ namespace ops = paddle::operators;
...
@@ -41,12 +41,16 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
elementwise_max
,
elementwise_max
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
elementwise_max_grad
,
elementwise_max_grad
,
ops
::
ElementwiseMaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
,
ops
::
ElementwiseMaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseMaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseMaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseMaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_max_op.h
浏览文件 @
87cc8d48
...
@@ -39,14 +39,14 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
...
@@ -39,14 +39,14 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
template
<
typename
T
>
template
<
typename
T
>
struct
MaxGradDx
{
struct
MaxGradDx
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
(
x
>
y
);
return
dout
*
static_cast
<
T
>
(
x
>
y
);
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
MaxGradDy
{
struct
MaxGradDy
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
(
x
<=
y
);
return
dout
*
static_cast
<
T
>
(
x
<=
y
);
}
}
};
};
...
...
paddle/fluid/operators/p_norm_op.cu
浏览文件 @
87cc8d48
...
@@ -20,7 +20,9 @@ limitations under the License. */
...
@@ -20,7 +20,9 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
namespace
cub
=
hipcub
;
#endif
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -30,12 +32,23 @@ __device__ __forceinline__ int sgn(T val) {
...
@@ -30,12 +32,23 @@ __device__ __forceinline__ int sgn(T val) {
return
(
T
(
0
)
<
val
)
-
(
val
<
T
(
0
));
return
(
T
(
0
)
<
val
)
-
(
val
<
T
(
0
));
}
}
__device__
__forceinline__
platform
::
float16
inline_abs
(
platform
::
float16
x
)
{
return
static_cast
<
platform
::
float16
>
(
abs
(
static_cast
<
float
>
(
x
)));
}
__device__
__forceinline__
float
inline_abs
(
float
x
)
{
return
abs
(
x
);
}
__device__
__forceinline__
float
inline_abs
(
float
x
)
{
return
abs
(
x
);
}
__device__
__forceinline__
double
inline_abs
(
double
x
)
{
return
abs
(
x
);
}
__device__
__forceinline__
double
inline_abs
(
double
x
)
{
return
abs
(
x
);
}
__device__
__forceinline__
int
inline_sign
(
platform
::
float16
x
)
{
return
sgn
<
platform
::
float16
>
(
x
);
}
__device__
__forceinline__
int
inline_sign
(
float
x
)
{
return
sgn
<
float
>
(
x
);
}
__device__
__forceinline__
int
inline_sign
(
float
x
)
{
return
sgn
<
float
>
(
x
);
}
__device__
__forceinline__
int
inline_sign
(
double
x
)
{
return
sgn
<
double
>
(
x
);
}
__device__
__forceinline__
int
inline_sign
(
double
x
)
{
return
sgn
<
double
>
(
x
);
}
__device__
__forceinline__
platform
::
float16
inline_pow
(
platform
::
float16
base
,
platform
::
float16
exponent
)
{
return
static_cast
<
platform
::
float16
>
(
pow
(
static_cast
<
float
>
(
base
),
static_cast
<
float
>
(
exponent
)));
}
__device__
__forceinline__
float
inline_pow
(
float
base
,
float
exponent
)
{
__device__
__forceinline__
float
inline_pow
(
float
base
,
float
exponent
)
{
return
pow
(
base
,
exponent
);
return
pow
(
base
,
exponent
);
}
}
...
@@ -47,21 +60,23 @@ template <typename T, int BlockDim>
...
@@ -47,21 +60,23 @@ template <typename T, int BlockDim>
__global__
void
Pnorm
(
const
T
*
x
,
const
int
pre
,
__global__
void
Pnorm
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
axis_n
,
// dim in axis
const
int
post
,
float
porder
,
T
*
out_norm
)
{
const
int
post
,
float
porder
,
T
*
out_norm
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
int
num
=
pre
*
post
;
auto
porder_t
=
static_cast
<
T
>
(
porder
);
auto
porder_t
=
static_cast
<
M
T
>
(
porder
);
auto
porder_inv
=
static_cast
<
T
>
(
1.0
/
porder
);
auto
porder_inv
=
static_cast
<
M
T
>
(
1.0
/
porder
);
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
T
sum
=
0.0
;
MT
sum
=
static_cast
<
MT
>
(
0.0
)
;
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
T
x_ij
=
x
[
base
+
j
*
post
]
;
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
base
+
j
*
post
])
;
sum
+=
inline_pow
(
inline_abs
(
x_ij
),
porder_t
);
sum
+=
inline_pow
(
inline_abs
(
x_ij
),
porder_t
);
}
}
T
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
MT
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
inline_pow
(
reduce_result
,
porder_inv
);
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
static_cast
<
T
>
(
inline_pow
(
reduce_result
,
porder_inv
));
}
}
}
}
...
@@ -69,18 +84,19 @@ template <typename T, int BlockDim>
...
@@ -69,18 +84,19 @@ template <typename T, int BlockDim>
__global__
void
ZeorNorm
(
const
T
*
x
,
const
int
pre
,
__global__
void
ZeorNorm
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
axis_n
,
// dim in axis
const
int
post
,
T
*
out_norm
)
{
const
int
post
,
T
*
out_norm
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
T
sum
=
0.0
;
MT
sum
=
static_cast
<
MT
>
(
0.0
)
;
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
T
x_ij
=
x
[
base
+
j
*
post
]
;
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
base
+
j
*
post
])
;
sum
+=
static_cast
<
T
>
(
x_ij
!=
0
);
sum
+=
static_cast
<
MT
>
(
static_cast
<
double
>
(
x_ij
)
!=
0
);
}
}
T
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
M
T
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
reduce_result
;
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
static_cast
<
T
>
(
reduce_result
)
;
}
}
}
}
...
@@ -172,27 +188,29 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
...
@@ -172,27 +188,29 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
const
float
porder
,
const
int
pre
,
const
float
porder
,
const
int
pre
,
const
int
axis_n
,
const
int
post
,
const
T
eps
,
const
int
axis_n
,
const
int
post
,
const
T
eps
,
T
*
x_grad
)
{
T
*
x_grad
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
int
num
=
pre
*
post
;
int
num
=
pre
*
post
;
auto
porder_grad
=
static_cast
<
T
>
(
porder
-
1.0
f
);
auto
porder_grad
=
static_cast
<
M
T
>
(
porder
-
1.0
f
);
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
__shared__
T
pnorm_i
;
__shared__
M
T
pnorm_i
;
__shared__
T
yout_i
;
__shared__
M
T
yout_i
;
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
pnorm_i
=
x_norm
[
i
]
;
pnorm_i
=
static_cast
<
MT
>
(
x_norm
[
i
])
;
yout_i
=
y_grad
[
i
]
;
yout_i
=
static_cast
<
MT
>
(
y_grad
[
i
])
;
}
}
__syncthreads
();
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
int
index
=
base
+
j
*
post
;
const
T
x_ij
=
inline_abs
(
x
[
index
]);
const
MT
x_ij
=
static_cast
<
MT
>
(
inline_abs
(
x
[
index
]));
x_grad
[
index
]
=
inline_pow
(
x_ij
,
porder_grad
)
/
x_grad
[
index
]
=
static_cast
<
T
>
(
(
inline_pow
(
pnorm_i
,
porder_grad
)
+
eps
)
*
yout_i
*
inline_pow
(
x_ij
,
porder_grad
)
/
inline_sign
(
x
[
index
]);
(
inline_pow
(
pnorm_i
,
porder_grad
)
+
static_cast
<
MT
>
(
eps
))
*
yout_i
*
static_cast
<
MT
>
(
inline_sign
(
x
[
index
])));
}
}
}
}
}
}
...
@@ -216,7 +234,7 @@ __global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
...
@@ -216,7 +234,7 @@ __global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
int
index
=
base
+
j
*
post
;
int
index
=
base
+
j
*
post
;
const
T
x_ij
=
inline_abs
(
x
[
index
]);
const
T
x_ij
=
inline_abs
(
x
[
index
]);
if
(
x_ij
==
pnorm_i
)
{
if
(
x_ij
==
pnorm_i
)
{
x_grad
[
index
]
=
inline_sign
(
x
[
index
]
)
*
yout_i
;
x_grad
[
index
]
=
static_cast
<
T
>
(
inline_sign
(
x
[
index
])
)
*
yout_i
;
}
else
{
}
else
{
x_grad
[
index
]
=
static_cast
<
T
>
(
0
);
x_grad
[
index
]
=
static_cast
<
T
>
(
0
);
}
}
...
@@ -278,7 +296,11 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -278,7 +296,11 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
p_norm
,
ops
::
PnormCUDAKernel
<
CUDA
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
p_norm
,
ops
::
PnormCUDAKernel
<
CUDA
,
paddle
::
platform
::
float16
>
,
ops
::
PnormCUDAKernel
<
CUDA
,
float
>
,
ops
::
PnormCUDAKernel
<
CUDA
,
double
>
);
ops
::
PnormCUDAKernel
<
CUDA
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
p_norm_grad
,
ops
::
PnormGradCUDAKernel
<
CUDA
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
p_norm_grad
,
ops
::
PnormGradCUDAKernel
<
CUDA
,
paddle
::
platform
::
float16
>
,
ops
::
PnormGradCUDAKernel
<
CUDA
,
float
>
,
ops
::
PnormGradCUDAKernel
<
CUDA
,
double
>
);
ops
::
PnormGradCUDAKernel
<
CUDA
,
double
>
);
python/paddle/nn/functional/norm.py
浏览文件 @
87cc8d48
...
@@ -86,7 +86,8 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None):
...
@@ -86,7 +86,8 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None):
check_type
(
p
,
'p'
,
(
float
,
int
),
'normalize'
)
check_type
(
p
,
'p'
,
(
float
,
int
),
'normalize'
)
check_type
(
axis
,
'axis'
,
(
int
),
'normalize'
)
check_type
(
axis
,
'axis'
,
(
int
),
'normalize'
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
],
'normalize'
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'normalize'
)
if
len
(
x
.
shape
)
==
1
and
axis
!=
0
and
axis
!=
-
1
:
if
len
(
x
.
shape
)
==
1
and
axis
!=
0
and
axis
!=
-
1
:
raise
ValueError
(
raise
ValueError
(
"Axis must be 0 or -1 when x is a 1-D tensor, but received axis = {}"
.
"Axis must be 0 or -1 when x is a 1-D tensor, but received axis = {}"
.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录