Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9c406531
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c406531
编写于
4月 28, 2023
作者:
iSerendipity
提交者:
GitHub
4月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Hackathon No.52】为 Paddle dist 算子实现 float16 数据类型支持 (#50915)
上级
64adfe7a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
116 addition
and
56 deletion
+116
-56
paddle/phi/kernels/dist_grad_kernel.cc
paddle/phi/kernels/dist_grad_kernel.cc
+7
-2
paddle/phi/kernels/funcs/math_cuda_utils.h
paddle/phi/kernels/funcs/math_cuda_utils.h
+10
-17
paddle/phi/kernels/gpu/dist_kernel.cu
paddle/phi/kernels/gpu/dist_kernel.cu
+51
-33
python/paddle/fluid/tests/unittests/test_dist_op.py
python/paddle/fluid/tests/unittests/test_dist_op.py
+40
-0
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+8
-4
未找到文件。
paddle/phi/kernels/dist_grad_kernel.cc
浏览文件 @
9c406531
...
...
@@ -98,6 +98,11 @@ PD_REGISTER_KERNEL(
dist_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
DistGradKernel
,
float
,
double
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
dist_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
DistGradKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
dist_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
DistGradKernel
,
phi
::
dtype
::
float16
,
float
,
double
)
{}
#endif
paddle/phi/kernels/funcs/math_cuda_utils.h
浏览文件 @
9c406531
...
...
@@ -23,6 +23,9 @@ limitations under the License. */
#include <algorithm>
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/common/data_type.h"
namespace
phi
{
namespace
funcs
{
...
...
@@ -170,11 +173,7 @@ struct KeyValuePair<half> {
template
<
typename
T
>
__inline__
__device__
T
WarpReduceSum
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
+=
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
);
#else
val
+=
__shfl_xor
(
val
,
mask
,
warpSize
);
#endif
val
+=
phi
::
backends
::
gpu
::
CudaShuffleXorSync
(
lane_mask
,
val
,
mask
);
return
val
;
}
...
...
@@ -243,11 +242,8 @@ __inline__ __device__ T BlockReduceSumV2(T *val) {
template
<
typename
T
>
__inline__
__device__
T
WarpReduceMax
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
max
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
#else
val
=
max
(
val
,
__shfl_xor
(
val
,
mask
,
warpSize
));
#endif
val
=
std
::
max
(
val
,
phi
::
backends
::
gpu
::
CudaShuffleXorSync
(
lane_mask
,
val
,
mask
));
return
val
;
}
...
...
@@ -265,11 +261,8 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) {
template
<
typename
T
>
__inline__
__device__
T
WarpReduceMin
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
min
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
#else
val
=
min
(
val
,
__shfl_xor
(
val
,
mask
,
warpSize
));
#endif
val
=
std
::
min
(
val
,
phi
::
backends
::
gpu
::
CudaShuffleXorSync
(
lane_mask
,
val
,
mask
));
return
val
;
}
...
...
@@ -310,7 +303,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) {
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
-
1e10
f
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
std
::
numeric_limits
<
T
>::
min
()
;
val
=
WarpReduceMax
(
val
,
mask
);
return
val
;
...
...
@@ -358,7 +351,7 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) {
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
1e10
f
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
std
::
numeric_limits
<
T
>::
max
()
;
val
=
WarpReduceMin
(
val
,
mask
);
return
val
;
...
...
paddle/phi/kernels/gpu/dist_kernel.cu
浏览文件 @
9c406531
...
...
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_kernel.h"
#include <algorithm>
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
...
...
@@ -24,47 +27,56 @@ namespace phi {
#define FULL_MASK 0xffffffff
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
ZeroOrderFunctor
{
public:
__device__
T
operator
()(
const
T
&
x
,
const
T
&
y
)
const
{
return
static_cast
<
T
>
((
x
-
y
)
!=
0
);
HOSTDEVICE
explicit
inline
ZeroOrderFunctor
()
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
,
const
Tx
&
y
)
const
{
return
static_cast
<
Ty
>
(
x
!=
y
);
}
};
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
OtherOrderFunctor
{
explicit
OtherOrderFunctor
(
const
T
&
p_order
)
:
p_order_
(
p_order
)
{}
__device__
T
operator
()(
const
T
&
x
,
const
T
&
y
)
const
{
return
static_cast
<
T
>
(
pow
(
abs
(
x
-
y
),
p_order_
));
HOSTDEVICE
explicit
inline
OtherOrderFunctor
(
const
Ty
&
_p_order
)
:
p_order
(
_p_order
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
,
const
Tx
&
y
)
const
{
return
static_cast
<
Ty
>
(
pow
(
abs
(
static_cast
<
Ty
>
(
x
)
-
static_cast
<
Ty
>
(
y
)),
p_order
));
}
private:
T
p_order_
;
T
y
p_order
;
};
template
<
typename
T
>
template
<
typename
T
x
,
typename
Ty
=
Tx
>
struct
PowFunctor
{
explicit
PowFunctor
(
const
T
&
p_order
)
:
p_order_
(
p_order
)
{}
HOSTDEVICE
inline
T
operator
()(
const
T
x
)
const
{
return
static_cast
<
T
>
(
pow
(
x
,
p_order_
));
HOSTDEVICE
explicit
inline
PowFunctor
(
const
Ty
&
_p_order
)
:
p_order
(
_p_order
)
{}
HOSTDEVICE
inline
Tx
operator
()(
const
Tx
x
)
const
{
return
static_cast
<
Tx
>
(
pow
(
static_cast
<
Ty
>
(
x
),
p_order
));
}
T
p_order_
;
private:
Ty
p_order
;
};
template
<
typename
T
,
typename
Functor
>
__global__
void
ReduceSumWithSubtract
(
const
T
*
x
,
const
T
*
y
,
T
*
out
,
int64_t
N
,
Functor
func
)
{
T
sum_val
=
0
;
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MT
sum_val
(
0.0
);
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
sum_val
+=
func
(
x
[
i
],
y
[
i
]
);
sum_val
+=
static_cast
<
MT
>
(
func
(
x
[
i
],
y
[
i
])
);
}
__syncthreads
();
sum_val
=
phi
::
funcs
::
BlockReduceSum
<
T
>
(
sum_val
,
FULL_MASK
);
sum_val
=
phi
::
funcs
::
BlockReduceSum
<
M
T
>
(
sum_val
,
FULL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
out
[
blockIdx
.
x
]
=
s
um_val
;
out
[
blockIdx
.
x
]
=
s
tatic_cast
<
T
>
(
sum_val
)
;
}
}
...
...
@@ -73,10 +85,10 @@ __global__ void ReduceMaxWithSubtract(const T* x,
const
T
*
y
,
T
*
out
,
int64_t
N
)
{
T
max_val
=
-
1e10
f
;
T
max_val
=
std
::
numeric_limits
<
T
>::
min
()
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
max_val
=
max
(
max_val
,
abs
(
x
[
i
]
-
y
[
i
]));
max_val
=
std
::
max
(
max_val
,
abs
(
x
[
i
]
-
y
[
i
]));
}
__syncthreads
();
...
...
@@ -91,10 +103,10 @@ __global__ void ReduceMinWithSubtract(const T* x,
const
T
*
y
,
T
*
out
,
int64_t
N
)
{
T
min_val
=
1e10
f
;
T
min_val
=
std
::
numeric_limits
<
T
>::
max
()
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
min_val
=
min
(
min_val
,
abs
(
x
[
i
]
-
y
[
i
]));
min_val
=
std
::
min
(
min_val
,
abs
(
x
[
i
]
-
y
[
i
]));
}
__syncthreads
();
...
...
@@ -110,6 +122,7 @@ void DistKernel(const Context& dev_ctx,
const
DenseTensor
&
y
,
float
p
,
DenseTensor
*
out
)
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
DenseTensor
intermediate
;
const
T
*
x_ptr
=
x
.
data
<
T
>
();
const
T
*
y_ptr
=
y
.
data
<
T
>
();
...
...
@@ -131,9 +144,8 @@ void DistKernel(const Context& dev_ctx,
ReduceSumWithSubtract
<
T
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
x_ptr
,
y_ptr
,
i_ptr
,
n
,
ZeroOrderFunctor
<
T
>
());
phi
::
funcs
::
ReduceKernel
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dev_ctx
,
intermediate
,
out
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_axis
);
phi
::
funcs
::
ReduceKernel
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
MT
>>
(
dev_ctx
,
intermediate
,
out
,
kps
::
IdentityFunctor
<
MT
>
(),
reduce_axis
);
}
else
if
(
p
==
INFINITY
)
{
ReduceMaxWithSubtract
<
T
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
...
...
@@ -150,19 +162,19 @@ void DistKernel(const Context& dev_ctx,
dev_ctx
,
intermediate
,
out
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_axis
);
}
else
{
T
p_order
=
static_cast
<
T
>
(
p
);
MT
p_order
=
static_cast
<
M
T
>
(
p
);
ReduceSumWithSubtract
<
T
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
x_ptr
,
y_ptr
,
i_ptr
,
n
,
OtherOrderFunctor
<
T
>
(
p_order
));
phi
::
funcs
::
ReduceKernel
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dev_ctx
,
intermediate
,
out
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_axis
);
x_ptr
,
y_ptr
,
i_ptr
,
n
,
OtherOrderFunctor
<
T
,
MT
>
(
p_order
));
phi
::
funcs
::
ReduceKernel
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
M
T
>>
(
dev_ctx
,
intermediate
,
out
,
kps
::
IdentityFunctor
<
M
T
>
(),
reduce_axis
);
const
DenseTensor
*
tmp_norm
=
out
;
std
::
vector
<
const
DenseTensor
*>
ins
=
{
tmp_norm
};
std
::
vector
<
DenseTensor
*>
outs
=
{
out
};
T
p_order_
=
static_cast
<
T
>
(
1.
/
p_order
);
MT
p_order_
=
static_cast
<
MT
>
(
static_cast
<
MT
>
(
1.
)
/
p_order
);
phi
::
funcs
::
ElementwiseKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
PowFunctor
<
T
>
(
p_order_
));
dev_ctx
,
ins
,
&
outs
,
PowFunctor
<
T
,
MT
>
(
p_order_
));
}
}
else
{
...
...
@@ -173,4 +185,10 @@ void DistKernel(const Context& dev_ctx,
}
// namespace phi
PD_REGISTER_KERNEL
(
dist
,
GPU
,
ALL_LAYOUT
,
phi
::
DistKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
dist
,
GPU
,
ALL_LAYOUT
,
phi
::
DistKernel
,
phi
::
dtype
::
float16
,
float
,
double
)
{}
python/paddle/fluid/tests/unittests/test_dist_op.py
浏览文件 @
9c406531
...
...
@@ -158,6 +158,46 @@ class TestDistOpCase5(TestDistOp):
self
.
p
=
1.5
class
TestDistFP16Op
(
OpTest
):
def
init_data_type
(
self
):
self
.
data_type
=
'float16'
class
TestDistFP16OpCase1
(
TestDistFP16Op
):
def
init_case
(
self
):
self
.
x_shape
=
(
3
,
5
,
5
,
6
)
self
.
y_shape
=
(
5
,
5
,
6
)
self
.
p
=
1.0
class
TestDistFP16OpCase2
(
TestDistFP16Op
):
def
init_case
(
self
):
self
.
x_shape
=
(
10
,
10
)
self
.
y_shape
=
(
4
,
10
,
10
)
self
.
p
=
2.0
class
TestDistFP16OpCase3
(
TestDistFP16Op
):
def
init_case
(
self
):
self
.
x_shape
=
(
15
,
10
)
self
.
y_shape
=
(
15
,
10
)
self
.
p
=
float
(
"inf"
)
class
TestDistFP16OpCase4
(
TestDistFP16Op
):
def
init_case
(
self
):
self
.
x_shape
=
(
2
,
3
,
4
,
5
,
8
)
self
.
y_shape
=
(
3
,
1
,
5
,
8
)
self
.
p
=
float
(
"-inf"
)
class
TestDistFP16OpCase5
(
TestDistFP16Op
):
def
init_case
(
self
):
self
.
x_shape
=
(
4
,
1
,
4
,
8
)
self
.
y_shape
=
(
2
,
2
,
1
,
4
,
4
,
8
)
self
.
p
=
1.5
class
TestDistAPI
(
unittest
.
TestCase
):
def
init_data_type
(
self
):
self
.
data_type
=
(
...
...
python/paddle/tensor/linalg.py
浏览文件 @
9c406531
...
...
@@ -675,8 +675,8 @@ def dist(x, y, p=2, name=None):
||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}}
Args:
x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
x (Tensor): 1-D to 6-D Tensor, its data type is float
16, float
32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float
16, float
32 or float64.
p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
...
...
@@ -706,8 +706,12 @@ def dist(x, y, p=2, name=None):
if
in_dygraph_mode
():
return
_C_ops
.
dist
(
x
,
y
,
p
)
check_variable_and_dtype
(
x
,
'dtype'
,
[
'float32'
,
'float64'
],
'dist'
)
check_variable_and_dtype
(
y
,
'dtype'
,
[
'float32'
,
'float64'
],
'dist'
)
check_variable_and_dtype
(
x
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'dist'
)
check_variable_and_dtype
(
y
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'dist'
)
check_type
(
p
,
'p'
,
(
float
,
int
),
'dist'
)
helper
=
LayerHelper
(
"dist"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录