Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cced930b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cced930b
编写于
2月 23, 2021
作者:
Q
Qi Li
提交者:
GitHub
2月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid operators for rocm (part1), test=develop (#31077)
上级
99fd9815
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
142 addition
and
36 deletion
+142
-36
paddle/fluid/operators/controlflow/conditional_block_op.h
paddle/fluid/operators/controlflow/conditional_block_op.h
+1
-1
paddle/fluid/operators/controlflow/get_places_op.cc
paddle/fluid/operators/controlflow/get_places_op.cc
+2
-2
paddle/fluid/operators/controlflow/while_op_helper.cc
paddle/fluid/operators/controlflow/while_op_helper.cc
+1
-1
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+5
-3
paddle/fluid/operators/detection/bbox_util.cu.h
paddle/fluid/operators/detection/bbox_util.cu.h
+19
-2
paddle/fluid/operators/detection/collect_fpn_proposals_op.cu
paddle/fluid/operators/detection/collect_fpn_proposals_op.cu
+35
-5
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu
.../fluid/operators/detection/distribute_fpn_proposals_op.cu
+28
-4
paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu
paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu
+0
-1
paddle/fluid/operators/detection/target_assign_op.h
paddle/fluid/operators/detection/target_assign_op.h
+2
-2
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+1
-1
paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc
...fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc
+5
-2
paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc
paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc
+2
-2
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
+5
-2
paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc
paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc
+2
-2
paddle/fluid/operators/distributed/parameter_prefetch.cc
paddle/fluid/operators/distributed/parameter_prefetch.cc
+1
-1
paddle/fluid/operators/distributed/sendrecvop_utils.cc
paddle/fluid/operators/distributed/sendrecvop_utils.cc
+1
-1
paddle/fluid/operators/distributed/variable_response.cc
paddle/fluid/operators/distributed/variable_response.cc
+3
-3
paddle/fluid/operators/metrics/accuracy_op.cu
paddle/fluid/operators/metrics/accuracy_op.cu
+12
-1
paddle/fluid/operators/metrics/auc_op.cu
paddle/fluid/operators/metrics/auc_op.cu
+17
-0
未找到文件。
paddle/fluid/operators/controlflow/conditional_block_op.h
浏览文件 @
cced930b
...
...
@@ -73,7 +73,7 @@ class ConditionalOp : public framework::OperatorBase {
ips
[
0
]
->
numel
()));
bool
res
=
false
;
if
(
platform
::
is_gpu_place
(
ips
[
0
]
->
place
()))
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
framework
::
LoDTensor
cpu_tensor
;
framework
::
TensorCopy
(
*
ips
[
0
],
platform
::
CPUPlace
(),
&
cpu_tensor
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
ips
[
0
]
->
place
())
->
Wait
();
...
...
paddle/fluid/operators/controlflow/get_places_op.cc
浏览文件 @
cced930b
...
...
@@ -26,7 +26,7 @@ namespace imperative {
class
OpBase
;
}
// namespace imperative
}
// namespace paddle
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/gpu_info.h"
#endif
...
...
@@ -34,7 +34,7 @@ namespace paddle {
namespace
operators
{
static
size_t
CUDADevCount
()
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return
platform
::
GetCUDADeviceCount
();
#else
return
0UL
;
...
...
paddle/fluid/operators/controlflow/while_op_helper.cc
浏览文件 @
cced930b
...
...
@@ -223,7 +223,7 @@ bool GetCondData(const framework::LoDTensor &cond) {
}
// when platform::is_gpu_place(cond.place()) is true
std
::
unique_ptr
<
framework
::
LoDTensor
>
cpu_cond
{
new
framework
::
LoDTensor
()};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
framework
::
TensorCopySync
(
cond
,
platform
::
CPUPlace
(),
cpu_cond
.
get
());
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
...
...
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
cced930b
...
...
@@ -40,10 +40,12 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo
detection_library
(
sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu
)
detection_library
(
retinanet_detection_output_op SRCS retinanet_detection_output_op.cc
)
if
(
WITH_GPU
)
if
(
WITH_GPU
OR WITH_ROCM
)
set
(
TMPDEPS memory
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
LESS 11.0
)
set
(
TMPDEPS memory cub
)
if
(
WITH_GPU
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
LESS 11.0
)
set
(
TMPDEPS memory cub
)
endif
()
endif
()
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS
${
TMPDEPS
}
)
detection_library
(
generate_proposals_v2_op SRCS generate_proposals_v2_op.cc generate_proposals_v2_op.cu DEPS
${
TMPDEPS
}
)
...
...
paddle/fluid/operators/detection/bbox_util.cu.h
浏览文件 @
cced930b
...
...
@@ -16,10 +16,16 @@ limitations under the License. */
#include <cfloat>
#include <string>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#include "paddle/fluid/platform/miopen_helper.h"
#endif
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
...
...
@@ -58,16 +64,27 @@ static void SortDescending(const platform::CUDADeviceContext &ctx,
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
nullptr
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
#else
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
nullptr
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
#endif
// Allocate temporary storage
auto
place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
// Run sorting operation
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
#else
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
#endif
}
template
<
typename
T
>
...
...
paddle/fluid/operators/detection/collect_fpn_proposals_op.cu
浏览文件 @
cced930b
...
...
@@ -9,8 +9,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#i
nclude <paddle/fluid/memory/allocation/allocator.h>
#i
fdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include <paddle/fluid/memory/allocation/allocator.h>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
...
...
@@ -135,17 +141,29 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
nullptr
,
temp_storage_bytes
,
concat_scores
.
data
<
T
>
(),
keys_out
,
idx_in
,
idx_out
,
total_roi_num
);
#else
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
nullptr
,
temp_storage_bytes
,
concat_scores
.
data
<
T
>
(),
keys_out
,
idx_in
,
idx_out
,
total_roi_num
);
#endif
// Allocate temporary storage
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
// sort score to get corresponding index
// Run sorting operation
// sort score to get corresponding index
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
concat_scores
.
data
<
T
>
(),
keys_out
,
idx_in
,
idx_out
,
total_roi_num
);
#else
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
concat_scores
.
data
<
T
>
(),
keys_out
,
idx_in
,
idx_out
,
total_roi_num
);
#endif
index_out_t
.
Resize
({
real_post_num
});
Tensor
sorted_rois
;
sorted_rois
.
mutable_data
<
T
>
({
real_post_num
,
kBBoxSize
},
dev_ctx
.
GetPlace
());
...
...
@@ -167,17 +185,29 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
out_id_t
.
mutable_data
<
int
>
({
real_post_num
},
dev_ctx
.
GetPlace
());
// Determine temporary device storage requirements
temp_storage_bytes
=
0
;
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
sorted_batch_id
.
data
<
int
>
(),
out_id_data
,
batch_idx_in
,
index_out_t
.
data
<
int
>
(),
real_post_num
);
#else
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
sorted_batch_id
.
data
<
int
>
(),
out_id_data
,
batch_idx_in
,
index_out_t
.
data
<
int
>
(),
real_post_num
);
#endif
// Allocate temporary storage
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
// sort batch_id to get corresponding index
// Run sorting operation
// sort batch_id to get corresponding index
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
sorted_batch_id
.
data
<
int
>
(),
out_id_data
,
batch_idx_in
,
index_out_t
.
data
<
int
>
(),
real_post_num
);
#else
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
sorted_batch_id
.
data
<
int
>
(),
out_id_data
,
batch_idx_in
,
index_out_t
.
data
<
int
>
(),
real_post_num
);
#endif
GPUGather
<
T
>
(
dev_ctx
,
sorted_rois
,
index_out_t
,
fpn_rois
);
...
...
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu
浏览文件 @
cced930b
...
...
@@ -12,8 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#i
nclude <paddle/fluid/memory/allocation/allocator.h>
#i
fdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include <paddle/fluid/memory/allocation/allocator.h>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
...
...
@@ -143,24 +149,42 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
#else
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
#endif
// Allocate temporary storage
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
// sort target level to get corresponding index
// Run sorting operation
// sort target level to get corresponding index
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
#else
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
#endif
int
*
restore_idx_data
=
restore_index
->
mutable_data
<
int
>
({
roi_num
,
1
},
dev_ctx
.
GetPlace
());
// sort current index to get restore index
// sort current index to get restore index
#ifdef PADDLE_WITH_HIP
hipcub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
idx_out
,
keys_out
,
idx_in
,
restore_idx_data
,
roi_num
);
#else
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
idx_out
,
keys_out
,
idx_in
,
restore_idx_data
,
roi_num
);
#endif
int
start
=
0
;
auto
multi_rois_num
=
ctx
.
MultiOutput
<
Tensor
>
(
"MultiLevelRoIsNum"
);
...
...
paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu
浏览文件 @
cced930b
...
...
@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
paddle/fluid/operators/detection/target_assign_op.h
浏览文件 @
cced930b
...
...
@@ -107,7 +107,7 @@ class TargetAssignKernel : public framework::OpKernel<T> {
int64_t
k
=
x
->
dims
()[
2
];
auto
x_lod
=
x
->
lod
().
back
();
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
|| defined(PADDLE_WITH_HIP)
size_t
*
x_lod_data
=
x_lod
.
MutableData
(
ctx
.
GetPlace
());
#else
size_t
*
x_lod_data
=
x_lod
.
data
();
...
...
@@ -129,7 +129,7 @@ class TargetAssignKernel : public framework::OpKernel<T> {
"TargetAssignOp input(NegIndices) needs 1 level of LoD"
));
const
int
*
neg_idx_data
=
neg_indices
->
data
<
int
>
();
auto
neg_lod
=
neg_indices
->
lod
().
back
();
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
|| defined(PADDLE_WITH_HIP)
size_t
*
neg_lod_data
=
neg_lod
.
MutableData
(
ctx
.
GetPlace
());
#else
size_t
*
neg_lod_data
=
neg_lod
.
data
();
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
cced930b
...
...
@@ -61,7 +61,7 @@ cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library
(
parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory
)
cc_library
(
communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv generator
)
cc_test
(
communicator_test SRCS communicator_test.cc DEPS communicator
)
if
(
WITH_GPU
)
if
(
WITH_GPU
OR WITH_ROCM
)
cc_test
(
collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor
${
RPC_DEPS
}
selected_rows_functor scope math_function
)
...
...
paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc
浏览文件 @
cced930b
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_NCCL
#include <nccl.h>
#endif
#ifdef PADDLE_WITH_RCCL
#include <rccl.h>
#endif
#include <sys/time.h>
#include <limits>
#include <memory>
...
...
@@ -144,7 +147,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
->
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
payload
.
reset
(
new
TensorPayload
(
GetSelectedRowsPayload
(
var
,
ctx
,
request
)));
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
->
set_type
(
::
sendrecv
::
NCCL_ID
);
const
ncclUniqueId
&
uid
=
var
->
Get
<
ncclUniqueId
>
();
...
...
@@ -172,7 +175,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
());
}
else
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
IOBufWriter
::
AppendZeroCopy
(
name
,
iobuf
,
::
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
(),
...
...
paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc
浏览文件 @
cced930b
...
...
@@ -159,7 +159,7 @@ void RunTestLodTensor(platform::Place place) {
TEST
(
LodTensor
,
Run
)
{
platform
::
CPUPlace
place
;
RunTestLodTensor
(
place
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform
::
CUDAPlace
gpu
(
0
);
RunTestLodTensor
(
gpu
);
#endif
...
...
@@ -168,7 +168,7 @@ TEST(LodTensor, Run) {
TEST
(
SelectedRows
,
Run
)
{
platform
::
CPUPlace
place
;
RunSerdeTestSelectedRows
(
place
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform
::
CUDAPlace
gpu
;
RunSerdeTestSelectedRows
(
gpu
);
#endif
...
...
paddle/fluid/operators/distributed/grpc/grpc_serde.cc
浏览文件 @
cced930b
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_NCCL
#include <nccl.h>
#endif
#ifdef PADDLE_WITH_RCCL
#include <rccl.h>
#endif
#include <limits>
#include <memory>
#include "grpcpp/impl/codegen/byte_buffer.h"
...
...
@@ -75,7 +78,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
.
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
payload
=
new
TensorPayload
(
GetSelectedRowsPayload
(
var
,
ctx
,
&
request
));
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
.
set_type
(
::
sendrecv
::
NCCL_ID
);
#endif
...
...
@@ -91,7 +94,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e
.
WriteRawBytes
(
std
::
string
(
header
.
data
(),
header
.
size
()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
e
.
WriteVarlengthBeginning
(
VarMsg
::
kSerializedFieldNumber
,
NCCL_UNIQUE_ID_BYTES
);
...
...
paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc
浏览文件 @
cced930b
...
...
@@ -206,7 +206,7 @@ TEST(LodTensor, Run) {
platform
::
CPUPlace
place
;
RunTestLodTensor
(
place
);
RunTestLodTensor
(
place
,
1
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform
::
CUDAPlace
gpu
(
0
);
RunTestLodTensor
(
gpu
);
RunTestLodTensor
(
gpu
,
1
);
...
...
@@ -217,7 +217,7 @@ TEST(SelectedRows, Run) {
platform
::
CPUPlace
place
;
RunSerdeTestSelectedRows
(
place
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform
::
CUDAPlace
gpu
;
RunSerdeTestSelectedRows
(
gpu
);
#endif
...
...
paddle/fluid/operators/distributed/parameter_prefetch.cc
浏览文件 @
cced930b
...
...
@@ -281,7 +281,7 @@ void prefetchs(const std::vector<std::string> &id_var_names,
}
}
}
else
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std
::
vector
<
float
>
ids_value_vec
(
ids_size
*
vec_dim_1
);
for
(
auto
idx
=
0
;
idx
<
static_cast
<
int
>
(
ids_size
);
idx
++
)
{
const
auto
&
id
=
ids
[
idx
];
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.cc
浏览文件 @
cced930b
...
...
@@ -39,7 +39,7 @@ using VarMsg = sendrecv::VariableMessage;
static
TensorPayload
GetCommunicationAllocationFromTensor
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
tensor
)
{
if
(
is_gpu_place
(
ctx
.
GetPlace
()))
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
tensor
.
place
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Please run in gpu place."
));
...
...
paddle/fluid/operators/distributed/variable_response.cc
浏览文件 @
cced930b
...
...
@@ -33,7 +33,7 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
int
total_written
=
0
;
if
(
platform
::
is_gpu_place
(
place
))
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto
&
gpu_dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
platform
::
CPUPlace
cpu
;
...
...
@@ -62,7 +62,7 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
gpu_dev_ctx
.
Wait
();
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Unexpected branch, please compile with
PADDLE_WITH_CUDA
"
));
"Unexpected branch, please compile with
WITH_GPU or WITH_ROCM
"
));
#endif
return
true
;
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
...
...
@@ -221,7 +221,7 @@ bool VariableResponse::ProcSerializedField(
platform
::
errors
::
PreconditionNotMet
(
"meta info should be got first!"
));
if
(
meta_
.
type
()
==
sendrecv
::
NCCL_ID
)
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto
*
var
=
scope_
->
FindVar
(
meta_
.
varname
());
if
(
var
!=
nullptr
)
{
ncclUniqueId
*
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
...
...
paddle/fluid/operators/metrics/accuracy_op.cu
浏览文件 @
cced930b
...
...
@@ -43,8 +43,19 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
total
[
threadIdx
.
x
]
=
count
;
__syncthreads
();
// reduce the count with init value 0, and output accuracy.
// reduce the count with init value 0, and output accuracy.
#ifdef PADDLE_WITH_CUDA
int
result
=
thrust
::
reduce
(
thrust
::
device
,
total
,
total
+
BlockSize
,
0
);
#else
// HIP thrust::reduce not support __device__
for
(
int
s
=
BlockSize
/
2
;
s
>
0
;
s
>>=
1
)
{
if
(
threadIdx
.
x
<
s
)
{
total
[
threadIdx
.
x
]
+=
total
[
threadIdx
.
x
+
s
];
}
__syncthreads
();
}
int
result
=
total
[
0
];
#endif
if
(
threadIdx
.
x
==
0
)
{
*
correct_data
=
result
;
*
accuracy
=
static_cast
<
float
>
(
result
)
/
static_cast
<
float
>
(
N
);
...
...
paddle/fluid/operators/metrics/auc_op.cu
浏览文件 @
cced930b
...
...
@@ -130,6 +130,7 @@ class AucCUDAKernel : public framework::OpKernel<T> {
auto
*
pos_in_data
=
stat_pos_in_tensor
->
data
<
int64_t
>
();
auto
*
stat_neg_in_tensor
=
ctx
.
Input
<
Tensor
>
(
"StatNeg"
);
auto
*
neg_in_data
=
stat_neg_in_tensor
->
data
<
int64_t
>
();
#ifdef PADDLE_WITH_CUDA
if
(
stat_pos_in_tensor
!=
stat_pos
)
{
cudaMemcpy
(
origin_stat_pos
,
pos_in_data
,
((
1
+
slide_steps
)
*
(
num_thresholds
+
1
)
+
...
...
@@ -144,6 +145,22 @@ class AucCUDAKernel : public framework::OpKernel<T> {
sizeof
(
int64_t
),
cudaMemcpyDeviceToDevice
);
}
#else
if
(
stat_pos_in_tensor
!=
stat_pos
)
{
hipMemcpy
(
origin_stat_pos
,
pos_in_data
,
((
1
+
slide_steps
)
*
(
num_thresholds
+
1
)
+
(
slide_steps
>
0
?
1
:
0
))
*
sizeof
(
int64_t
),
hipMemcpyDeviceToDevice
);
}
if
(
stat_neg_in_tensor
!=
stat_neg
)
{
hipMemcpy
(
origin_stat_neg
,
neg_in_data
,
((
1
+
slide_steps
)
*
(
num_thresholds
+
1
)
+
(
slide_steps
>
0
?
1
:
0
))
*
sizeof
(
int64_t
),
hipMemcpyDeviceToDevice
);
}
#endif
statAuc
(
ctx
,
label
,
predict
,
num_thresholds
,
slide_steps
,
origin_stat_pos
,
origin_stat_neg
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录