Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b666fd3c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b666fd3c
编写于
9月 16, 2021
作者:
G
Guoxia Wang
提交者:
GitHub
9月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support l2_normalize float16 (#35776)
* support fp16 dtype
上级
7babb3d2
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
79 addition
and
65 deletion
+79
-65
paddle/fluid/operators/norm_op.cu
paddle/fluid/operators/norm_op.cu
+35
-22
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+18
-40
python/paddle/fluid/tests/unittests/test_norm_op.py
python/paddle/fluid/tests/unittests/test_norm_op.py
+26
-3
未找到文件。
paddle/fluid/operators/norm_op.cu
浏览文件 @
b666fd3c
...
@@ -20,11 +20,17 @@ limitations under the License. */
...
@@ -20,11 +20,17 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
namespace
cub
=
hipcub
;
#endif
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/norm_op.h"
#include "paddle/fluid/operators/norm_op.h"
#include "paddle/fluid/platform/bfloat16.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
__device__
__forceinline__
platform
::
float16
square_root
(
platform
::
float16
x
)
{
return
static_cast
<
platform
::
float16
>
(
sqrtf
(
static_cast
<
float
>
(
x
)));
}
__device__
__forceinline__
float
square_root
(
float
x
)
{
return
sqrtf
(
x
);
}
__device__
__forceinline__
float
square_root
(
float
x
)
{
return
sqrtf
(
x
);
}
__device__
__forceinline__
double
square_root
(
double
x
)
{
return
sqrt
(
x
);
}
__device__
__forceinline__
double
square_root
(
double
x
)
{
return
sqrt
(
x
);
}
...
@@ -33,28 +39,29 @@ template <typename T, int BlockDim>
...
@@ -33,28 +39,29 @@ template <typename T, int BlockDim>
__global__
void
Normalize
(
const
T
*
x
,
const
int
pre
,
__global__
void
Normalize
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
axis_n
,
// dim in axis
const
int
post
,
const
T
eps
,
T
*
y
,
T
*
out_norm
)
{
const
int
post
,
const
T
eps
,
T
*
y
,
T
*
out_norm
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
T
sum
=
0.0
;
M
T
sum
=
0.0
;
__shared__
T
norm
;
__shared__
M
T
norm
;
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
T
x_ij
=
x
[
base
+
j
*
post
]
;
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
base
+
j
*
post
])
;
sum
+=
x_ij
*
x_ij
;
sum
+=
x_ij
*
x_ij
;
}
}
T
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
M
T
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
norm
=
square_root
(
reduce_result
+
eps
);
norm
=
square_root
(
reduce_result
+
static_cast
<
MT
>
(
eps
)
);
out_norm
[
i
]
=
norm
;
out_norm
[
i
]
=
static_cast
<
T
>
(
norm
)
;
}
}
__syncthreads
();
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
int
index
=
base
+
j
*
post
;
const
int
index
=
base
+
j
*
post
;
y
[
index
]
=
x
[
index
]
/
norm
;
y
[
index
]
=
static_cast
<
T
>
((
static_cast
<
MT
>
(
x
[
index
])
/
norm
))
;
}
}
}
}
}
}
...
@@ -109,34 +116,36 @@ template <typename T, int BlockDim>
...
@@ -109,34 +116,36 @@ template <typename T, int BlockDim>
__global__
void
NormalizeGradient
(
const
T
*
x
,
const
T
*
x_norm
,
const
T
*
y_grad
,
__global__
void
NormalizeGradient
(
const
T
*
x
,
const
T
*
x_norm
,
const
T
*
y_grad
,
const
int
pre
,
const
int
axis_n
,
const
int
pre
,
const
int
axis_n
,
const
int
post
,
T
*
x_grad
)
{
const
int
post
,
T
*
x_grad
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage_sum
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage_sum
;
int
num
=
pre
*
post
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
T
sum
=
0.0
;
M
T
sum
=
0.0
;
__shared__
T
row_sum
;
__shared__
M
T
row_sum
;
__shared__
T
row_sqrt_norm
;
__shared__
M
T
row_sqrt_norm
;
__shared__
T
row_norm
;
__shared__
M
T
row_norm
;
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
int
index
=
base
+
j
*
post
;
sum
+=
x
[
index
]
*
y_grad
[
index
]
;
sum
+=
static_cast
<
MT
>
(
x
[
index
])
*
static_cast
<
MT
>
(
y_grad
[
index
])
;
}
}
T
reduce_result
=
BlockReduce
(
temp_storage_sum
).
Sum
(
sum
);
M
T
reduce_result
=
BlockReduce
(
temp_storage_sum
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
row_sum
=
reduce_result
;
row_sum
=
reduce_result
;
row_sqrt_norm
=
x_norm
[
i
]
;
row_sqrt_norm
=
static_cast
<
MT
>
(
x_norm
[
i
])
;
row_norm
=
row_sqrt_norm
*
row_sqrt_norm
;
row_norm
=
row_sqrt_norm
*
row_sqrt_norm
;
}
}
__syncthreads
();
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
int
index
=
base
+
j
*
post
;
const
T
x_ij
=
x
[
index
];
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
index
]);
const
T
dy_ij
=
y_grad
[
index
];
const
MT
dy_ij
=
static_cast
<
MT
>
(
y_grad
[
index
]);
x_grad
[
index
]
=
(
dy_ij
-
x_ij
*
row_sum
/
row_norm
)
/
row_sqrt_norm
;
x_grad
[
index
]
=
static_cast
<
T
>
((
dy_ij
-
x_ij
*
row_sum
/
row_norm
)
/
row_sqrt_norm
);
}
}
}
}
}
}
...
@@ -181,7 +190,11 @@ class NormGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -181,7 +190,11 @@ class NormGradCUDAKernel : public framework::OpKernel<T> {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
norm
,
ops
::
NormCUDAKernel
<
CUDA
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
norm
,
ops
::
NormCUDAKernel
<
CUDA
,
paddle
::
platform
::
float16
>
,
ops
::
NormCUDAKernel
<
CUDA
,
float
>
,
ops
::
NormCUDAKernel
<
CUDA
,
double
>
);
ops
::
NormCUDAKernel
<
CUDA
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
norm_grad
,
ops
::
NormGradCUDAKernel
<
CUDA
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
norm_grad
,
ops
::
NormGradCUDAKernel
<
CUDA
,
paddle
::
platform
::
float16
>
,
ops
::
NormGradCUDAKernel
<
CUDA
,
float
>
,
ops
::
NormGradCUDAKernel
<
CUDA
,
double
>
);
ops
::
NormGradCUDAKernel
<
CUDA
,
double
>
);
python/paddle/fluid/layers/nn.py
浏览文件 @
b666fd3c
...
@@ -5041,7 +5041,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
...
@@ -5041,7 +5041,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
slice along dimension `axis`.
slice along dimension `axis`.
Args:
Args:
x(Variable|list): The input tensor could be N-D tensor, and the input data type could be float32 or float64.
x(Variable|list): The input tensor could be N-D tensor, and the input data type could be float
16, float
32 or float64.
axis(int): The axis on which to apply normalization. If `axis < 0`, \
axis(int): The axis on which to apply normalization. If `axis < 0`, \
the dimension to normalization is rank(X) + axis. -1 is the
the dimension to normalization is rank(X) + axis. -1 is the
last dimension.
last dimension.
...
@@ -5055,50 +5055,28 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
...
@@ -5055,50 +5055,28 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
Examples:
Examples:
.. code-block:: python
.. code-block:: python
:name: code-example1
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
import paddle
paddle.enable_static()
input = fluid.data(name="input", shape=[2,3])
output = fluid.layers.l2_normalize(x=input,axis=0)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.random.rand(2,3).astype("float32")
print(input_data)
# [[0.5171216 0.12704141 0.56018186]
X = paddle.randn(shape=[3, 5], dtype='float64')
# [0.93251234 0.5382788 0.81709313]]
out = paddle.fluid.layers.l2_normalize(X, axis=-1)
print(out.numpy())
output_data = exe.run(fluid.default_main_program(),
# [[ 0.21558504 0.56360189 0.47466096 0.46269539 -0.44326736]
feed={"input":input_data},
# [-0.70602414 -0.52745777 0.37771788 -0.2804768 -0.04449922]
fetch_list=[output],
# [-0.33972208 -0.43014923 0.31772556 0.76617881 -0.10761525]]
return_numpy=True)
print(output_data)
# [array([[0.48496857, 0.22970329, 0.56545246],
# [0.8745316 , 0.9732607 , 0.82478094]], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = fluid.layers.l2_normalize(x=input, axis=-1)
print(output.numpy())
# [[0.66907585 0.16437206 0.7247892 ]
# [0.6899054 0.3982376 0.6045142 ]]
"""
"""
if len(x.shape) == 1:
if len(x.shape) == 1:
axis = 0
axis = 0
check_variable_and_dtype(x, "X", ("float32", "float64"), "norm")
if in_dygraph_mode():
_, out = _C_ops.norm(x, 'axis', 1
if axis is None else axis, 'epsilon', epsilon)
return out
check_variable_and_dtype(x, "X", ("float16", "float32", "float64"), "norm")
helper = LayerHelper("l2_normalize", **locals())
helper = LayerHelper("l2_normalize", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
...
...
python/paddle/fluid/tests/unittests/test_norm_op.py
浏览文件 @
b666fd3c
...
@@ -33,23 +33,27 @@ class TestNormOp(OpTest):
...
@@ -33,23 +33,27 @@ class TestNormOp(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"norm"
self
.
op_type
=
"norm"
self
.
init_test_case
()
self
.
init_test_case
()
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float64"
)
self
.
init_dtype
()
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
y
,
norm
=
l2_norm
(
x
,
self
.
axis
,
self
.
epsilon
)
y
,
norm
=
l2_norm
(
x
,
self
.
axis
,
self
.
epsilon
)
self
.
inputs
=
{
'X'
:
x
}
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
,
'axis'
:
self
.
axis
}
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
,
'axis'
:
self
.
axis
}
self
.
outputs
=
{
'Out'
:
y
,
'Norm'
:
norm
}
self
.
outputs
=
{
'Out'
:
y
,
'Norm'
:
norm
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
atol
=
1e-5
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
shape
=
[
2
,
3
,
4
,
5
]
self
.
shape
=
[
2
,
3
,
4
,
5
]
self
.
axis
=
1
self
.
axis
=
1
self
.
epsilon
=
1e-8
self
.
epsilon
=
1e-8
def
init_dtype
(
self
):
self
.
dtype
=
"float64"
class
TestNormOp2
(
TestNormOp
):
class
TestNormOp2
(
TestNormOp
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
...
@@ -89,6 +93,25 @@ class TestNormOp5(TestNormOp):
...
@@ -89,6 +93,25 @@ class TestNormOp5(TestNormOp):
pass
pass
class
TestNormOp6
(
TestNormOp
):
def
init_dtype
(
self
):
self
.
dtype
=
"float32"
@
unittest
.
skipIf
(
not
fluid
.
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestNormOp7
(
TestNormOp
):
def
init_dtype
(
self
):
self
.
dtype
=
"float16"
def
test_check_output
(
self
):
self
.
check_output_with_place
(
fluid
.
core
.
CUDAPlace
(
0
),
atol
=
5e-2
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
fluid
.
core
.
CUDAPlace
(
0
),
[
'X'
],
'Out'
,
max_relative_error
=
0.05
)
@
skip_check_grad_ci
(
reason
=
"skip check grad for test mode."
)
@
skip_check_grad_ci
(
reason
=
"skip check grad for test mode."
)
class
TestNormTestOp
(
OpTest
):
class
TestNormTestOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录