Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b666fd3c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b666fd3c
编写于
9月 16, 2021
作者:
G
Guoxia Wang
提交者:
GitHub
9月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support l2_normalize float16 (#35776)
* support fp16 dtype
上级
7babb3d2
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
79 addition
and
65 deletion
+79
-65
paddle/fluid/operators/norm_op.cu
paddle/fluid/operators/norm_op.cu
+35
-22
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+18
-40
python/paddle/fluid/tests/unittests/test_norm_op.py
python/paddle/fluid/tests/unittests/test_norm_op.py
+26
-3
未找到文件。
paddle/fluid/operators/norm_op.cu
浏览文件 @
b666fd3c
...
...
@@ -20,11 +20,17 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/norm_op.h"
#include "paddle/fluid/platform/bfloat16.h"
namespace
paddle
{
namespace
operators
{
__device__
__forceinline__
platform
::
float16
square_root
(
platform
::
float16
x
)
{
return
static_cast
<
platform
::
float16
>
(
sqrtf
(
static_cast
<
float
>
(
x
)));
}
__device__
__forceinline__
float
square_root
(
float
x
)
{
return
sqrtf
(
x
);
}
__device__
__forceinline__
double
square_root
(
double
x
)
{
return
sqrt
(
x
);
}
...
...
@@ -33,28 +39,29 @@ template <typename T, int BlockDim>
__global__
void
Normalize
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
post
,
const
T
eps
,
T
*
y
,
T
*
out_norm
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
T
sum
=
0.0
;
__shared__
T
norm
;
M
T
sum
=
0.0
;
__shared__
M
T
norm
;
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
T
x_ij
=
x
[
base
+
j
*
post
]
;
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
base
+
j
*
post
])
;
sum
+=
x_ij
*
x_ij
;
}
T
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
M
T
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
{
norm
=
square_root
(
reduce_result
+
eps
);
out_norm
[
i
]
=
norm
;
norm
=
square_root
(
reduce_result
+
static_cast
<
MT
>
(
eps
)
);
out_norm
[
i
]
=
static_cast
<
T
>
(
norm
)
;
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
int
index
=
base
+
j
*
post
;
y
[
index
]
=
x
[
index
]
/
norm
;
y
[
index
]
=
static_cast
<
T
>
((
static_cast
<
MT
>
(
x
[
index
])
/
norm
))
;
}
}
}
...
...
@@ -109,34 +116,36 @@ template <typename T, int BlockDim>
__global__
void
NormalizeGradient
(
const
T
*
x
,
const
T
*
x_norm
,
const
T
*
y_grad
,
const
int
pre
,
const
int
axis_n
,
const
int
post
,
T
*
x_grad
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage_sum
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
T
sum
=
0.0
;
__shared__
T
row_sum
;
__shared__
T
row_sqrt_norm
;
__shared__
T
row_norm
;
M
T
sum
=
0.0
;
__shared__
M
T
row_sum
;
__shared__
M
T
row_sqrt_norm
;
__shared__
M
T
row_norm
;
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
sum
+=
x
[
index
]
*
y_grad
[
index
]
;
sum
+=
static_cast
<
MT
>
(
x
[
index
])
*
static_cast
<
MT
>
(
y_grad
[
index
])
;
}
T
reduce_result
=
BlockReduce
(
temp_storage_sum
).
Sum
(
sum
);
M
T
reduce_result
=
BlockReduce
(
temp_storage_sum
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
{
row_sum
=
reduce_result
;
row_sqrt_norm
=
x_norm
[
i
]
;
row_sqrt_norm
=
static_cast
<
MT
>
(
x_norm
[
i
])
;
row_norm
=
row_sqrt_norm
*
row_sqrt_norm
;
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
const
T
x_ij
=
x
[
index
];
const
T
dy_ij
=
y_grad
[
index
];
x_grad
[
index
]
=
(
dy_ij
-
x_ij
*
row_sum
/
row_norm
)
/
row_sqrt_norm
;
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
index
]);
const
MT
dy_ij
=
static_cast
<
MT
>
(
y_grad
[
index
]);
x_grad
[
index
]
=
static_cast
<
T
>
((
dy_ij
-
x_ij
*
row_sum
/
row_norm
)
/
row_sqrt_norm
);
}
}
}
...
...
@@ -181,7 +190,11 @@ class NormGradCUDAKernel : public framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
norm
,
ops
::
NormCUDAKernel
<
CUDA
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
norm
,
ops
::
NormCUDAKernel
<
CUDA
,
paddle
::
platform
::
float16
>
,
ops
::
NormCUDAKernel
<
CUDA
,
float
>
,
ops
::
NormCUDAKernel
<
CUDA
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
norm_grad
,
ops
::
NormGradCUDAKernel
<
CUDA
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
norm_grad
,
ops
::
NormGradCUDAKernel
<
CUDA
,
paddle
::
platform
::
float16
>
,
ops
::
NormGradCUDAKernel
<
CUDA
,
float
>
,
ops
::
NormGradCUDAKernel
<
CUDA
,
double
>
);
python/paddle/fluid/layers/nn.py
浏览文件 @
b666fd3c
...
...
@@ -5041,7 +5041,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
slice along dimension `axis`.
Args:
x(Variable|list): The input tensor could be N-D tensor, and the input data type could be float32 or float64.
x(Variable|list): The input tensor could be N-D tensor, and the input data type could be float
16, float
32 or float64.
axis(int): The axis on which to apply normalization. If `axis < 0`, \
the dimension to normalization is rank(X) + axis. -1 is the
last dimension.
...
...
@@ -5055,50 +5055,28 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
Examples:
.. code-block:: python
:name: code-example1
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
paddle.enable_static()
input = fluid.data(name="input", shape=[2,3])
output = fluid.layers.l2_normalize(x=input,axis=0)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.random.rand(2,3).astype("float32")
print(input_data)
# [[0.5171216 0.12704141 0.56018186]
# [0.93251234 0.5382788 0.81709313]]
X = paddle.randn(shape=[3, 5], dtype='float64')
out = paddle.fluid.layers.l2_normalize(X, axis=-1)
print(out.numpy())
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data},
fetch_list=[output],
return_numpy=True)
print(output_data)
# [array([[0.48496857, 0.22970329, 0.56545246],
# [0.8745316 , 0.9732607 , 0.82478094]], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = fluid.layers.l2_normalize(x=input, axis=-1)
print(output.numpy())
# [[0.66907585 0.16437206 0.7247892 ]
# [0.6899054 0.3982376 0.6045142 ]]
# [[ 0.21558504 0.56360189 0.47466096 0.46269539 -0.44326736]
# [-0.70602414 -0.52745777 0.37771788 -0.2804768 -0.04449922]
# [-0.33972208 -0.43014923 0.31772556 0.76617881 -0.10761525]]
"""
if len(x.shape) == 1:
axis = 0
check_variable_and_dtype(x, "X", ("float32", "float64"), "norm")
if in_dygraph_mode():
_, out = _C_ops.norm(x, 'axis', 1
if axis is None else axis, 'epsilon', epsilon)
return out
check_variable_and_dtype(x, "X", ("float16", "float32", "float64"), "norm")
helper = LayerHelper("l2_normalize", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
...
...
python/paddle/fluid/tests/unittests/test_norm_op.py
浏览文件 @
b666fd3c
...
...
@@ -33,23 +33,27 @@ class TestNormOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"norm"
self
.
init_test_case
()
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float64"
)
self
.
init_dtype
()
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
y
,
norm
=
l2_norm
(
x
,
self
.
axis
,
self
.
epsilon
)
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
,
'axis'
:
self
.
axis
}
self
.
outputs
=
{
'Out'
:
y
,
'Norm'
:
norm
}
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
atol
=
1e-5
)
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
def
init_test_case
(
self
):
self
.
shape
=
[
2
,
3
,
4
,
5
]
self
.
axis
=
1
self
.
epsilon
=
1e-8
def
init_dtype
(
self
):
self
.
dtype
=
"float64"
class
TestNormOp2
(
TestNormOp
):
def
init_test_case
(
self
):
...
...
@@ -89,6 +93,25 @@ class TestNormOp5(TestNormOp):
pass
class
TestNormOp6
(
TestNormOp
):
def
init_dtype
(
self
):
self
.
dtype
=
"float32"
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestNormOp7
(
TestNormOp
):
def
init_dtype
(
self
):
self
.
dtype
=
"float16"
def
test_check_output
(
self
):
self
.
check_output_with_place
(
fluid
.
core
.
CUDAPlace
(
0
),
atol
=
5e-2
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
fluid
.
core
.
CUDAPlace
(
0
),
[
'X'
],
'Out'
,
max_relative_error
=
0.05
)
@
skip_check_grad_ci
(
reason
=
"skip check grad for test mode."
)
class
TestNormTestOp
(
OpTest
):
def
setUp
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录