Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
133a914b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
133a914b
编写于
3月 08, 2021
作者:
Q
Qi Li
提交者:
GitHub
3月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] fix test_dist_op ci test, test=develop (#31468)
上级
f9377965
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
68 addition
and
12 deletion
+68
-12
paddle/fluid/operators/dist_op.cu
paddle/fluid/operators/dist_op.cu
+9
-0
paddle/fluid/operators/math/math_cuda_utils.h
paddle/fluid/operators/math/math_cuda_utils.h
+10
-4
python/paddle/fluid/tests/unittests/dist_test.sh
python/paddle/fluid/tests/unittests/dist_test.sh
+33
-2
python/paddle/fluid/tests/unittests/test_dist_op.py
python/paddle/fluid/tests/unittests/test_dist_op.py
+16
-6
未找到文件。
paddle/fluid/operators/dist_op.cu
浏览文件 @
133a914b
...
...
@@ -15,9 +15,18 @@
#include "paddle/fluid/operators/dist_op.h"
namespace
ops
=
paddle
::
operators
;
#ifdef PADDLE_WITH_HIP
// Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922
// do not support double in HIPCC platform (Eigen3 to be fixed)
REGISTER_OP_CUDA_KERNEL
(
dist
,
ops
::
DistKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
dist_grad
,
ops
::
DistGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
dist
,
ops
::
DistKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
DistKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
dist_grad
,
ops
::
DistGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
DistGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
#endif
paddle/fluid/operators/math/math_cuda_utils.h
浏览文件 @
133a914b
...
...
@@ -214,7 +214,7 @@ __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
template
<
typename
T
>
__inline__
__device__
T
warpReduceMin
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if
__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
#if
defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
min
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
#else
val
=
min
(
val
,
__shfl_xor
(
val
,
mask
,
warpSize
));
...
...
@@ -226,7 +226,7 @@ __inline__ __device__ T warpReduceMin(T val, unsigned lane_mask) {
* threads are less than warpSize.*/
template
<
typename
T
>
__inline__
__device__
T
PartialWarpReduceMin
(
T
val
,
unsigned
lane_mask
)
{
#if
__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
#if
defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
T
warp_val
=
__shfl_sync
(
lane_mask
,
val
,
0
,
warpSize
);
#else
T
warp_val
=
__shfl
(
...
...
@@ -235,7 +235,7 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) {
warp_val
=
val
;
for
(
int
offset
=
HALF_WARP
;
offset
>
0
;
offset
>>=
1
)
#if
__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
#if
defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
warp_val
=
min
(
warp_val
,
__shfl_down_sync
(
lane_mask
,
warp_val
,
offset
,
warpSize
));
#else
...
...
@@ -298,9 +298,15 @@ __inline__ __device__ T PartialBlockReduceMin(T val, unsigned mask) {
__syncthreads
();
shared
[
lane
]
=
PartialWarpReduceMin
(
shared
[
lane
],
mask
);
#if defined(PADDLE_WITH_HIP)
// HIP do not support __syncwarp, using __syncthreads() instead is ok,
// although bringing a few performance decrease.
__syncthreads
();
#else
__syncwarp
();
#endif
#if
__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
#if
defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
__shfl_sync
(
mask
,
shared
[
lane
],
0
,
warpSize
);
#else
val
=
__shfl
(
shared
[
lane
],
0
,
warpSize
);
...
...
python/paddle/fluid/tests/unittests/dist_test.sh
浏览文件 @
133a914b
#!/bin/bash
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
unset
https_proxy http_proxy
export
FLAGS_rpc_disable_reuse_port
=
1
...
...
@@ -50,14 +65,30 @@ do
cat
-n
${
log
}
done
# check CUDA or ROCM env
GPU_SYS_INFO_CMD
=
nvidia-smi
which
${
GPU_SYS_INFO_CMD
}
exit_code
=
$?
if
[[
$exit_code
-ne
0
]]
;
then
GPU_SYS_INFO_CMD
=
rocm-smi
fi
which
${
GPU_SYS_INFO_CMD
}
exit_code
=
$?
if
[[
$exit_code
-ne
0
]]
;
then
echo
"nvidia-smi or rocm-smi faild with
${
exit_code
}
"
exit
${
exit_code
}
fi
#display system context
for
i
in
{
1..2
}
;
do
sleep
3
ps
-aux
netstat
-anlp
if
hash
"
nvidia-smi
"
>
/dev/null
;
then
nvidia-smi
if
hash
"
${
GPU_SYS_INFO_CMD
}
"
>
/dev/null
;
then
${
GPU_SYS_INFO_CMD
}
fi
done
...
...
python/paddle/fluid/tests/unittests/test_dist_op.py
浏览文件 @
133a914b
...
...
@@ -39,9 +39,10 @@ class TestDistOp(OpTest):
self
.
op_type
=
'dist'
self
.
attrs
=
{}
self
.
init_case
()
self
.
init_data_type
()
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
x_shape
).
astype
(
"float64"
),
"Y"
:
np
.
random
.
random
(
self
.
y_shape
).
astype
(
"float64"
)
"X"
:
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
data_type
),
"Y"
:
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
data_type
)
}
self
.
attrs
[
"p"
]
=
self
.
p
...
...
@@ -55,6 +56,10 @@ class TestDistOp(OpTest):
self
.
y_shape
=
(
120
)
self
.
p
=
0.
def
init_data_type
(
self
):
self
.
data_type
=
np
.
float32
if
core
.
is_compiled_with_rocm
(
)
else
np
.
float64
def
calc_gradient
(
self
):
x
=
self
.
inputs
[
"X"
]
y
=
self
.
inputs
[
"Y"
]
...
...
@@ -143,15 +148,20 @@ class TestDistOpCase5(TestDistOp):
class
TestDistAPI
(
unittest
.
TestCase
):
def
init_data_type
(
self
):
self
.
data_type
=
'float32'
if
core
.
is_compiled_with_rocm
(
)
else
'float64'
def
test_api
(
self
):
self
.
init_data_type
()
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
,
startup_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
2
,
3
,
4
,
5
],
dtype
=
'float64'
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
3
,
1
,
5
],
dtype
=
'float64'
)
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
2
,
3
,
4
,
5
],
dtype
=
self
.
data_type
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
3
,
1
,
5
],
dtype
=
self
.
data_type
)
p
=
2
x_i
=
np
.
random
.
random
((
2
,
3
,
4
,
5
)).
astype
(
"float64"
)
y_i
=
np
.
random
.
random
((
3
,
1
,
5
)).
astype
(
"float64"
)
x_i
=
np
.
random
.
random
((
2
,
3
,
4
,
5
)).
astype
(
self
.
data_type
)
y_i
=
np
.
random
.
random
((
3
,
1
,
5
)).
astype
(
self
.
data_type
)
result
=
paddle
.
dist
(
x
,
y
,
p
)
place
=
fluid
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录