Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c3055d23
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c3055d23
编写于
4月 18, 2023
作者:
C
chenxujun
提交者:
GitHub
4月 18, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Hackathon No.60】prelu, clip_by_norm, multi_dot 算子FP16/BF16单测完善 (#52666)
* Add prelu, clip_by_norm, multi_dot tests * Fix code * Fix code
上级
534efcb6
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
187 addition
and
39 deletion
+187
-39
paddle/fluid/operators/math/prelu.cu
paddle/fluid/operators/math/prelu.cu
+4
-0
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
+10
-15
paddle/phi/kernels/gpu/multi_dot_grad_kernel.cu
paddle/phi/kernels/gpu/multi_dot_grad_kernel.cu
+3
-4
paddle/phi/kernels/gpu/multi_dot_kernel.cu
paddle/phi/kernels/gpu/multi_dot_kernel.cu
+9
-5
paddle/phi/kernels/gpu/prelu_grad_kernel.cu
paddle/phi/kernels/gpu/prelu_grad_kernel.cu
+1
-0
paddle/phi/kernels/gpu/prelu_kernel.cu
paddle/phi/kernels/gpu/prelu_kernel.cu
+1
-0
python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py
python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py
+43
-1
python/paddle/fluid/tests/unittests/test_multi_dot_op.py
python/paddle/fluid/tests/unittests/test_multi_dot_op.py
+49
-1
python/paddle/fluid/tests/unittests/test_prelu_op.py
python/paddle/fluid/tests/unittests/test_prelu_op.py
+58
-9
python/paddle/nn/clip.py
python/paddle/nn/clip.py
+3
-1
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+5
-2
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+1
-1
未找到文件。
paddle/fluid/operators/math/prelu.cu
浏览文件 @
c3055d23
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
...
...
@@ -135,14 +136,17 @@ void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
template
class
PreluChannelWiseDirectCUDAFunctor
<
float
>;
template
class
PreluChannelWiseDirectCUDAFunctor
<
platform
::
float16
>;
template
class
PreluChannelWiseDirectCUDAFunctor
<
platform
::
bfloat16
>;
template
class
PreluChannelWiseDirectCUDAFunctor
<
double
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
float
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
platform
::
float16
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
platform
::
bfloat16
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
double
>;
template
class
PreluScalarDirectCUDAFunctor
<
float
>;
template
class
PreluScalarDirectCUDAFunctor
<
platform
::
float16
>;
template
class
PreluScalarDirectCUDAFunctor
<
platform
::
bfloat16
>;
template
class
PreluScalarDirectCUDAFunctor
<
double
>;
}
// namespace math
...
...
paddle/phi/kernels/gpu/clip_by_norm_kernel.cu
浏览文件 @
c3055d23
...
...
@@ -17,7 +17,7 @@
#include <typeinfo>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/
float16
.h"
#include "paddle/phi/common/
amp_type_traits
.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
...
...
@@ -34,7 +34,7 @@ void ClipByNormKernel(const Context& dev_ctx,
return
ClipByNormFunctor
<
float
,
Context
>
(
dev_ctx
,
in
,
max_norm
,
output
);
}
auto
input
=
&
in
;
dev_ctx
.
template
Alloc
<
dtype
::
float16
>(
output
);
dev_ctx
.
template
Alloc
<
T
>(
output
);
PADDLE_ENFORCE_NOT_NULL
(
input
,
phi
::
errors
::
InvalidArgument
(
...
...
@@ -49,20 +49,14 @@ void ClipByNormKernel(const Context& dev_ctx,
auto
*
tmp
=
&
tmp_tensor
;
tmp
->
Resize
({
1
});
dev_ctx
.
template
Alloc
<
float
>(
tmp
);
phi
::
funcs
::
ReduceKernel
<
dtype
::
float16
,
float
,
kps
::
AddFunctor
,
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>>
(
dev_ctx
,
*
input
,
tmp
,
kps
::
SquareFunctor
<
dtype
::
float16
,
float
>
(),
reduce_dims
);
phi
::
funcs
::
ReduceKernel
<
T
,
float
,
kps
::
AddFunctor
,
kps
::
SquareFunctor
<
T
,
float
>>
(
dev_ctx
,
*
input
,
tmp
,
kps
::
SquareFunctor
<
T
,
float
>
(),
reduce_dims
);
auto
tmp_eigen
=
phi
::
EigenVector
<
float
>::
Flatten
(
*
tmp
);
auto
x_norm
=
tmp_eigen
.
sqrt
();
auto
x
=
phi
::
EigenVector
<
dtype
::
float16
>::
Flatten
(
*
input
);
auto
out
=
phi
::
EigenVector
<
dtype
::
float16
>::
Flatten
(
*
output
);
auto
x
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
input
);
auto
out
=
phi
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
*
place
=
dev_ctx
.
eigen_device
();
auto
temp
=
(
x_norm
<=
max_norm
).
template
cast
<
float
>();
...
...
@@ -72,7 +66,7 @@ void ClipByNormKernel(const Context& dev_ctx,
auto
scaling
=
(
temp
+
(
static_cast
<
float
>
(
1
)
-
temp
)
*
max_norm
/
(
x_norm
+
epsilon
))
.
template
cast
<
dtype
::
float16
>();
.
template
cast
<
T
>();
Eigen
::
array
<
int
,
1
>
one_dim
{{
1
}};
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
input
->
numel
());
...
...
@@ -86,4 +80,5 @@ PD_REGISTER_KERNEL(clip_by_norm,
ALL_LAYOUT
,
phi
::
ClipByNormKernel
,
float
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/multi_dot_grad_kernel.cu
浏览文件 @
c3055d23
...
...
@@ -15,16 +15,15 @@ limitations under the License. */
#include "paddle/phi/kernels/multi_dot_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/
float16
.h"
#include "paddle/phi/common/
amp_type_traits
.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h"
using
float16
=
phi
::
dtype
::
float16
;
PD_REGISTER_KERNEL
(
multi_dot_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
MultiDotGradKernel
,
float
,
double
,
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/multi_dot_kernel.cu
浏览文件 @
c3055d23
...
...
@@ -15,11 +15,15 @@ limitations under the License. */
#include "paddle/phi/kernels/multi_dot_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/
float16
.h"
#include "paddle/phi/common/
amp_type_traits
.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h"
using
float16
=
phi
::
dtype
::
float16
;
PD_REGISTER_KERNEL
(
multi_dot
,
GPU
,
ALL_LAYOUT
,
phi
::
MultiDotKernel
,
float
,
double
,
float16
)
{}
PD_REGISTER_KERNEL
(
multi_dot
,
GPU
,
ALL_LAYOUT
,
phi
::
MultiDotKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/prelu_grad_kernel.cu
浏览文件 @
c3055d23
...
...
@@ -189,4 +189,5 @@ PD_REGISTER_KERNEL(prelu_grad,
phi
::
PReluGradKernel
,
float
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
double
)
{}
paddle/phi/kernels/gpu/prelu_kernel.cu
浏览文件 @
c3055d23
...
...
@@ -79,4 +79,5 @@ PD_REGISTER_KERNEL(prelu,
phi
::
PReluKernel
,
float
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
double
)
{}
python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py
浏览文件 @
c3055d23
...
...
@@ -15,7 +15,7 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
from
op
import
Operator
import
paddle
...
...
@@ -102,6 +102,48 @@ class TestClipByNormOpFp16Case3(TestClipByNormOpFp16):
self
.
max_norm
=
1.0
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA or not support bfloat16"
,
)
class
TestClipByNormBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
max_relative_error
=
0.006
self
.
python_api
=
clip
.
clip_by_norm
self
.
init_dtype
()
self
.
initTestCase
()
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
np_dtype
)
input
[
np
.
abs
(
input
)
<
self
.
max_relative_error
]
=
0.5
self
.
op_type
=
"clip_by_norm"
self
.
inputs
=
{
'X'
:
input
,
}
self
.
attrs
=
{}
self
.
attrs
[
'max_norm'
]
=
self
.
max_norm
norm
=
np
.
sqrt
(
np
.
sum
(
np
.
square
(
input
)))
if
norm
>
self
.
max_norm
:
output
=
self
.
max_norm
*
input
/
norm
else
:
output
=
input
self
.
outputs
=
{
'Out'
:
output
}
self
.
inputs
[
'X'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'X'
])
self
.
outputs
[
'Out'
]
=
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])
self
.
place
=
core
.
CUDAPlace
(
0
)
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
initTestCase
(
self
):
self
.
shape
=
(
100
,)
self
.
max_norm
=
1.0
def
init_dtype
(
self
):
self
.
dtype
=
np
.
uint16
self
.
np_dtype
=
np
.
float32
class
TestClipByNormOpWithSelectedRows
(
unittest
.
TestCase
):
def
check_with_place
(
self
,
place
):
self
.
config_test_case
()
...
...
python/paddle/fluid/tests/unittests/test_multi_dot_op.py
浏览文件 @
c3055d23
...
...
@@ -15,10 +15,11 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
from
numpy.linalg
import
multi_dot
import
paddle
from
paddle.fluid
import
core
paddle
.
enable_static
()
...
...
@@ -49,6 +50,53 @@ class TestMultiDotOp(OpTest):
self
.
check_grad
([
'x1'
],
'Out'
)
class
TestMultiDotFP16Op
(
TestMultiDotOp
):
def
get_dtype
(
self
):
return
"float16"
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA or not support bfloat16"
,
)
class
TestMultiDotBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"multi_dot"
self
.
python_api
=
paddle
.
linalg
.
multi_dot
self
.
dtype
=
self
.
get_dtype
()
self
.
get_inputs_and_outputs
()
self
.
place
=
core
.
CUDAPlace
(
0
)
def
get_dtype
(
self
):
self
.
np_dtype
=
"float32"
return
np
.
uint16
def
get_inputs_and_outputs
(
self
):
self
.
A
=
np
.
random
.
random
((
2
,
8
)).
astype
(
self
.
np_dtype
)
self
.
B
=
np
.
random
.
random
((
8
,
4
)).
astype
(
self
.
np_dtype
)
self
.
inputs
=
{
'X'
:
[
(
'x0'
,
convert_float_to_uint16
(
self
.
A
)),
(
'x1'
,
convert_float_to_uint16
(
self
.
B
)),
]
}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
multi_dot
([
self
.
A
,
self
.
B
]))
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'x0'
],
'Out'
,
numeric_grad_delta
=
0.01
)
self
.
check_grad_with_place
(
self
.
place
,
[
'x1'
],
'Out'
,
numeric_grad_delta
=
0.01
)
# (A*B)*C
class
TestMultiDotOp3Mat
(
TestMultiDotOp
):
def
get_inputs_and_outputs
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_prelu_op.py
浏览文件 @
c3055d23
...
...
@@ -15,7 +15,7 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
,
skip_check_grad_ci
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
,
skip_check_grad_ci
import
paddle
import
paddle.nn.functional
as
F
...
...
@@ -174,7 +174,11 @@ class PReluTest(OpTest):
self
.
op_type
=
"prelu"
self
.
python_api
=
prelu_api_wrapper
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
).
astype
(
self
.
dtype
)
if
self
.
dtype
==
np
.
uint16
:
as_type
=
self
.
np_dtype
else
:
as_type
=
self
.
dtype
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
).
astype
(
as_type
)
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
x_np
[
np
.
abs
(
x_np
)
<
0.005
]
=
0.02
...
...
@@ -190,7 +194,7 @@ class PReluTest(OpTest):
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
,
1
,
1
,
self
.
x_shape
[
-
1
]])
else
:
alpha_np
=
np
.
random
.
uniform
(
-
1
,
-
0.5
,
[
1
]
+
self
.
x_shape
[
1
:])
alpha_np
=
alpha_np
.
astype
(
self
.
d
type
)
alpha_np
=
alpha_np
.
astype
(
as_
type
)
self
.
inputs
=
{
'X'
:
x_np
,
'Alpha'
:
alpha_np
}
...
...
@@ -393,18 +397,48 @@ def create_test_fp16_class(
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
)
and
check_grad
:
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Alpha'
],
'Out'
,
max_relative_error
=
max_relative_error
,
)
# Use the default max_relative_error, not use max_relative_error
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Alpha'
],
'Out'
)
cls_name
=
"{}_{}"
.
format
(
parent
.
__name__
,
"Fp16Op"
)
TestPReluFp16Case
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestPReluFp16Case
def
create_test_bf16_class
(
parent
,
check_grad
=
True
,
atol
=
1e-3
,
max_relative_error
=
0.05
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestPReluBF16Op
(
parent
):
def
setUp
(
self
):
super
().
setUp
()
self
.
inputs
[
'X'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'X'
])
self
.
inputs
[
'Alpha'
]
=
convert_float_to_uint16
(
self
.
inputs
[
'Alpha'
])
self
.
outputs
[
'Out'
]
=
convert_float_to_uint16
(
self
.
outputs
[
'Out'
])
def
init_dtype
(
self
):
self
.
dtype
=
np
.
uint16
self
.
np_dtype
=
np
.
float32
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
,
atol
=
atol
)
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
check_grad
:
# Use the default max_relative_error, not use max_relative_error
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Alpha'
],
'Out'
)
cls_name
=
"{}_{}"
.
format
(
parent
.
__name__
,
"BF16Op"
)
TestPReluBF16Op
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestPReluBF16Op
create_test_fp16_class
(
TestModeElt
)
create_test_fp16_class
(
TestModeAllRank3
)
create_test_fp16_class
(
TestModeAllRank6
)
...
...
@@ -420,6 +454,21 @@ create_test_fp16_class(TestModeChannelRank6NHWC)
create_test_fp16_class
(
TestModeElementRank3NHWC
)
create_test_fp16_class
(
TestModeElementRank6NHWC
)
create_test_bf16_class
(
TestModeElt
)
create_test_bf16_class
(
TestModeAllRank3
)
create_test_bf16_class
(
TestModeAllRank6
)
create_test_bf16_class
(
TestModeChannelRank3
)
create_test_bf16_class
(
TestModeChannelRank6
)
create_test_bf16_class
(
TestModeElementRank3
)
create_test_bf16_class
(
TestModeElementRank6
)
create_test_bf16_class
(
TestModeEltNHWC
)
create_test_bf16_class
(
TestModeAllRank3NHWC
)
create_test_bf16_class
(
TestModeAllRank6NHWC
)
create_test_bf16_class
(
TestModeChannelRank3NHWC
)
create_test_bf16_class
(
TestModeChannelRank6NHWC
)
create_test_bf16_class
(
TestModeElementRank3NHWC
)
create_test_bf16_class
(
TestModeElementRank6NHWC
)
def
prelu_t
(
x
,
mode
,
param_attr
=
None
,
name
=
None
,
data_format
=
'NCHW'
):
helper
=
fluid
.
layer_helper
.
LayerHelper
(
'prelu'
,
**
locals
())
...
...
python/paddle/nn/clip.py
浏览文件 @
c3055d23
...
...
@@ -63,7 +63,9 @@ def clip_by_norm(x, max_norm, name=None):
return
_legacy_C_ops
.
clip_by_norm
(
x
,
'max_norm'
,
max_norm
)
helper
=
LayerHelper
(
"clip_by_norm"
,
**
locals
())
check_variable_and_dtype
(
x
,
'X'
,
[
'float32'
,
'float16'
],
'clip_by_norm'
)
check_variable_and_dtype
(
x
,
'X'
,
[
'float16'
,
'float32'
,
'uint16'
],
'clip_by_norm'
)
check_type
(
max_norm
,
'max_norm'
,
(
float
),
'clip_by_norm'
)
if
name
is
None
:
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
c3055d23
...
...
@@ -538,10 +538,13 @@ def prelu(x, weight, data_format="NCHW", name=None):
return
_C_ops
.
prelu
(
x
,
weight
,
data_format
,
mode
)
else
:
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'prelu'
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'prelu'
)
check_variable_and_dtype
(
weight
,
'weight'
,
[
'float16'
,
'float32'
,
'float64'
],
'prelu'
weight
,
'weight'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'prelu'
,
)
helper
=
LayerHelper
(
'prelu'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
...
...
python/paddle/tensor/linalg.py
浏览文件 @
c3055d23
...
...
@@ -2489,7 +2489,7 @@ def multi_dot(x, name=None):
check_variable_and_dtype
(
item
,
'x['
+
str
(
id
)
+
']'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'multi_dot'
,
)
if
item
.
dtype
!=
x
[
0
].
dtype
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录