未验证 提交 a6d2fddb 编写于 作者: R ronnywang 提交者: GitHub

refine structure for cuda and rocm (#37202)

* refine structure for cuda and rocm

* update

* update

* update

* update
上级 9ccb6228
...@@ -17,11 +17,7 @@ limitations under the License. */ ...@@ -17,11 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/framework/operator_kernel_configs.h"
#ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -291,13 +291,9 @@ void AllReduceOpHandle::SyncNCCLAllReduce() { ...@@ -291,13 +291,9 @@ void AllReduceOpHandle::SyncNCCLAllReduce() {
nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_); nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_);
auto &nccl_ctx = nccl_ctxs->at(dev_id); auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); platform::GpuStreamSync(stream);
PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError());
#endif
} }
} }
} }
......
...@@ -33,7 +33,7 @@ class NCCLCommunicator; ...@@ -33,7 +33,7 @@ class NCCLCommunicator;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/framework/details/bkcl_op_handle.h" #include "paddle/fluid/framework/details/bkcl_op_handle.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
......
...@@ -111,7 +111,7 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -111,7 +111,7 @@ void BroadcastOpHandle::BroadcastOneVar(
broadcast_calls.emplace_back( broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] { [send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
send_recv_buffer, numel, static_cast<ncclDataType_t>(type), send_recv_buffer, numel, static_cast<ncclDataType_t>(type),
root_id, nccl_ctx.comm_, nccl_ctx.stream())); root_id, nccl_ctx.comm_, nccl_ctx.stream()));
}); });
......
...@@ -44,7 +44,7 @@ struct BKCLContextMap; ...@@ -44,7 +44,7 @@ struct BKCLContextMap;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif #endif
......
...@@ -95,7 +95,7 @@ struct TestBroadcastOpHandle { ...@@ -95,7 +95,7 @@ struct TestBroadcastOpHandle {
#endif #endif
} else if (use_device_ == p::kCUDA) { } else if (use_device_ == p::kCUDA) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int count = p::GetCUDADeviceCount(); int count = p::GetGPUDeviceCount();
if (count <= 1) { if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA " LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is " "device count is "
......
...@@ -40,7 +40,7 @@ class NCCLCommunicator; ...@@ -40,7 +40,7 @@ class NCCLCommunicator;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif #endif
......
...@@ -49,10 +49,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( ...@@ -49,10 +49,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
platform::CUDADeviceGuard guard( platform::CUDADeviceGuard guard(
BOOST_GET_CONST(platform::CUDAPlace, place).device); BOOST_GET_CONST(platform::CUDAPlace, place).device);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipEventCreateWithFlags(&event_, hipEventDisableTiming)); hipEventCreateWithFlags(&event_, hipEventDisableTiming));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
#endif #endif
PADDLE_ENFORCE_NOT_NULL(event_, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(event_, platform::errors::InvalidArgument(
...@@ -75,9 +75,9 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -75,9 +75,9 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx_->GetPlace()); auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(gpu_place.device); platform::CUDADeviceGuard guard(gpu_place.device);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(event_)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(event_));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event_)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event_));
#endif #endif
} }
#endif #endif
...@@ -160,12 +160,12 @@ void EagerDeletionOpHandle::ClearGarbages( ...@@ -160,12 +160,12 @@ void EagerDeletionOpHandle::ClearGarbages(
reinterpret_cast<StreamGarbageCollector *>(gc_)->stream(); reinterpret_cast<StreamGarbageCollector *>(gc_)->stream();
auto callback_func = [=]() { auto callback_func = [=]() {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event_, compute_stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamWaitEvent(callback_stream, event_, 0)); hipStreamWaitEvent(callback_stream, event_, 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, compute_stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamWaitEvent(callback_stream, event_, 0)); cudaStreamWaitEvent(callback_stream, event_, 0));
#endif #endif
}; };
......
...@@ -55,9 +55,9 @@ FusedAllReduceOpHandle::~FusedAllReduceOpHandle() { ...@@ -55,9 +55,9 @@ FusedAllReduceOpHandle::~FusedAllReduceOpHandle() {
auto destroy_event = [](gpuEvent_t event) { auto destroy_event = [](gpuEvent_t event) {
if (event == nullptr) return; if (event == nullptr) return;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(event)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(event));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event));
#endif #endif
}; };
destroy_event(start_event_); destroy_event(start_event_);
...@@ -87,10 +87,10 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -87,10 +87,10 @@ void FusedAllReduceOpHandle::RunImpl() {
auto create_event = [](gpuEvent_t *event) { auto create_event = [](gpuEvent_t *event) {
if (*event) return; if (*event) return;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipEventCreateWithFlags(event, hipEventDisableTiming)); hipEventCreateWithFlags(event, hipEventDisableTiming));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreateWithFlags(event, cudaEventDisableTiming)); cudaEventCreateWithFlags(event, cudaEventDisableTiming));
#endif #endif
}; };
...@@ -109,12 +109,12 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -109,12 +109,12 @@ void FusedAllReduceOpHandle::RunImpl() {
auto &nccl_ctx = flat_nccl_ctxs->at(gpu_place.device); auto &nccl_ctx = flat_nccl_ctxs->at(gpu_place.device);
nccl_stream = nccl_ctx.stream(); nccl_stream = nccl_ctx.stream();
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(start_event_, compute_stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(start_event_, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamWaitEvent(nccl_stream, start_event_, 0)); hipStreamWaitEvent(nccl_stream, start_event_, 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(start_event_, compute_stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event_, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamWaitEvent(nccl_stream, start_event_, 0)); cudaStreamWaitEvent(nccl_stream, start_event_, 0));
#endif #endif
} else { } else {
...@@ -169,12 +169,12 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -169,12 +169,12 @@ void FusedAllReduceOpHandle::RunImpl() {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (FLAGS_allreduce_record_one_event) { if (FLAGS_allreduce_record_one_event) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(end_event_, nccl_stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(end_event_, nccl_stream));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamWaitEvent(compute_stream, end_event_, 0)); hipStreamWaitEvent(compute_stream, end_event_, 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(end_event_, nccl_stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(end_event_, nccl_stream));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamWaitEvent(compute_stream, end_event_, 0)); cudaStreamWaitEvent(compute_stream, end_event_, 0));
#endif #endif
} }
......
...@@ -35,7 +35,7 @@ class NCCLCommunicator; ...@@ -35,7 +35,7 @@ class NCCLCommunicator;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif #endif
......
...@@ -37,7 +37,7 @@ struct NCCLContextMap; ...@@ -37,7 +37,7 @@ struct NCCLContextMap;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -48,7 +48,7 @@ struct TestGatherOpHandle { ...@@ -48,7 +48,7 @@ struct TestGatherOpHandle {
void InitCtxOnGpu(bool use_gpu) { void InitCtxOnGpu(bool use_gpu) {
if (use_gpu) { if (use_gpu) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int count = p::GetCUDADeviceCount(); int count = p::GetGPUDeviceCount();
if (count <= 1) { if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA " LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is " "device count is "
......
...@@ -35,7 +35,7 @@ class NCCLCommunicator; ...@@ -35,7 +35,7 @@ class NCCLCommunicator;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -40,7 +40,7 @@ static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() { ...@@ -40,7 +40,7 @@ static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
} }
static void InitMultiGPUOpVarMap() { static void InitMultiGPUOpVarMap() {
int dev_count = platform::GetCUDADeviceCount(); int dev_count = platform::GetGPUDeviceCount();
PADDLE_ENFORCE_GT(dev_count, 0, PADDLE_ENFORCE_GT(dev_count, 0,
platform::errors::NotFound( platform::errors::NotFound(
"cuda device must > 0, now dev_count=%d", dev_count)); "cuda device must > 0, now dev_count=%d", dev_count));
...@@ -161,11 +161,11 @@ void TensorCheckerVisitor<platform::CUDADeviceContext>::apply( ...@@ -161,11 +161,11 @@ void TensorCheckerVisitor<platform::CUDADeviceContext>::apply(
op_var)); op_var));
#ifdef __HIPCC__ #ifdef __HIPCC__
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1, hipMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1,
hipMemcpyHostToDevice, dev_ctx->stream())); hipMemcpyHostToDevice, dev_ctx->stream()));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1, cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1,
cudaMemcpyHostToDevice, dev_ctx->stream())); cudaMemcpyHostToDevice, dev_ctx->stream()));
#endif #endif
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/rccl.h" #include "paddle/fluid/platform/dynload/rccl.h"
#endif #endif
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
DECLARE_bool(sync_nccl_allreduce); DECLARE_bool(sync_nccl_allreduce);
...@@ -52,16 +52,16 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -52,16 +52,16 @@ class NCCLOpHandleBase : public OpHandleBase {
virtual ~NCCLOpHandleBase() { virtual ~NCCLOpHandleBase() {
for (auto& ev : inter_events_) { for (auto& ev : inter_events_) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(ev.second)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(ev.second));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(ev.second));
#endif #endif
} }
for (auto& ev : exter_events_) { for (auto& ev : exter_events_) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(ev.second)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(ev.second));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(ev.second));
#endif #endif
} }
} }
...@@ -109,14 +109,14 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -109,14 +109,14 @@ class NCCLOpHandleBase : public OpHandleBase {
platform::SetDeviceId(dev_id); platform::SetDeviceId(dev_id);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventCreateWithFlags( PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(
&inter_events_[dev_id], hipEventDisableTiming)); &inter_events_[dev_id], hipEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventCreateWithFlags( PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(
&exter_events_[dev_id], hipEventDisableTiming)); &exter_events_[dev_id], hipEventDisableTiming));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags( PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(
&inter_events_[dev_id], cudaEventDisableTiming)); &inter_events_[dev_id], cudaEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags( PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(
&exter_events_[dev_id], cudaEventDisableTiming)); &exter_events_[dev_id], cudaEventDisableTiming));
#endif #endif
VLOG(10) << "Create events on dev_id:" << dev_id VLOG(10) << "Create events on dev_id:" << dev_id
...@@ -142,7 +142,7 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -142,7 +142,7 @@ class NCCLOpHandleBase : public OpHandleBase {
<< ", dev_id:" << dev_id << ", dtype:" << datatype << ", dev_id:" << dev_id << ", dtype:" << datatype
<< ", place:" << place; << ", place:" << place;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, datatype, op, comm, stream)); sendbuff, recvbuff, count, datatype, op, comm, stream));
} }
...@@ -192,7 +192,7 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -192,7 +192,7 @@ class NCCLOpHandleBase : public OpHandleBase {
<< ", dtype:" << datatype << ", place:" << place << ", dtype:" << datatype << ", place:" << place
<< ", stream:" << stream; << ", stream:" << stream;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream)); sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream));
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -202,11 +202,7 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -202,11 +202,7 @@ class NCCLOpHandleBase : public OpHandleBase {
#endif #endif
if (FLAGS_sync_nccl_allreduce) { if (FLAGS_sync_nccl_allreduce) {
#ifdef PADDLE_WITH_HIP platform::GpuStreamSync(stream);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
} }
} }
...@@ -230,26 +226,21 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -230,26 +226,21 @@ class NCCLOpHandleBase : public OpHandleBase {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipStreamWaitEvent(stream, inter_events_.at(dev_id), 0); hipStreamWaitEvent(stream, inter_events_.at(dev_id), 0);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, datatype, op, comm, stream)); sendbuff, recvbuff, count, datatype, op, comm, stream));
hipEventRecord(exter_events_.at(dev_id), stream); hipEventRecord(exter_events_.at(dev_id), stream);
if (FLAGS_sync_nccl_allreduce) {
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
}
#else #else
cudaStreamWaitEvent(stream, inter_events_.at(dev_id), 0); cudaStreamWaitEvent(stream, inter_events_.at(dev_id), 0);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, datatype, op, comm, stream)); sendbuff, recvbuff, count, datatype, op, comm, stream));
cudaEventRecord(exter_events_.at(dev_id), stream); cudaEventRecord(exter_events_.at(dev_id), stream);
#endif
if (FLAGS_sync_nccl_allreduce) { if (FLAGS_sync_nccl_allreduce) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); platform::GpuStreamSync(stream);
} }
#endif
} }
void InterBroadCast(platform::Place place, void* sendbuff, size_t count, void InterBroadCast(platform::Place place, void* sendbuff, size_t count,
...@@ -269,7 +260,7 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -269,7 +260,7 @@ class NCCLOpHandleBase : public OpHandleBase {
#else #else
cudaStreamWaitEvent(stream, exter_events_.at(dev_id), 0); cudaStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
#endif #endif
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
sendbuff, count, datatype, 0, comm, stream)); sendbuff, count, datatype, 0, comm, stream));
} }
......
...@@ -35,9 +35,9 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW { ...@@ -35,9 +35,9 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW {
for (auto &ev : events_) { for (auto &ev : events_) {
if (ev.second) { if (ev.second) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(ev.second)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(ev.second));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(ev.second));
#endif #endif
} }
} }
...@@ -50,10 +50,10 @@ void OpHandleBase::InitCUDA() { ...@@ -50,10 +50,10 @@ void OpHandleBase::InitCUDA() {
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p.first).device; int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p.first).device;
platform::SetDeviceId(dev_id); platform::SetDeviceId(dev_id);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipEventCreateWithFlags(&events_[dev_id], hipEventDisableTiming)); hipEventCreateWithFlags(&events_[dev_id], hipEventDisableTiming));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming)); cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
#endif #endif
} }
...@@ -182,9 +182,9 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { ...@@ -182,9 +182,9 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream(); static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
for (auto &ev : events_) { for (auto &ev : events_) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(stream, ev.second, 0)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(stream, ev.second, 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(stream, ev.second, 0)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(stream, ev.second, 0));
#endif #endif
} }
} }
...@@ -221,10 +221,10 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) { ...@@ -221,10 +221,10 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
static_cast<platform::CUDADeviceContext *>(dev_ctxes_.at(place)) static_cast<platform::CUDADeviceContext *>(dev_ctxes_.at(place))
->stream(); ->stream();
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamWaitEvent(stream, in_var_handle->GetEvent(), 0)); hipStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0)); cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#endif #endif
#else #else
...@@ -250,11 +250,7 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) { ...@@ -250,11 +250,7 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(pool.Get(place)) static_cast<platform::CUDADeviceContext *>(pool.Get(place))
->stream(); ->stream();
#ifdef PADDLE_WITH_HIP platform::GpuStreamSync(stream);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not compiled with CUDA.")); "Not compiled with CUDA."));
...@@ -279,10 +275,10 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { ...@@ -279,10 +275,10 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
dev_ctxes_.at(in_var_handle->place())) dev_ctxes_.at(in_var_handle->place()))
->stream(); ->stream();
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipStreamWaitEvent(stream, in_var_handle->GetEvent(), 0)); hipStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0)); cudaStreamWaitEvent(stream, in_var_handle->GetEvent(), 0));
#endif #endif
#else #else
...@@ -319,10 +315,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -319,10 +315,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
auto *cuda_dev_ctx = static_cast<platform::CUDADeviceContext *>(p.second); auto *cuda_dev_ctx = static_cast<platform::CUDADeviceContext *>(p.second);
VLOG(10) << "cudadevicecontext:" << cuda_dev_ctx << ", dev_id:" << dev_id; VLOG(10) << "cudadevicecontext:" << cuda_dev_ctx << ", dev_id:" << dev_id;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipEventRecord(events_.at(dev_id), cuda_dev_ctx->stream())); hipEventRecord(events_.at(dev_id), cuda_dev_ctx->stream()));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventRecord(events_.at(dev_id), cuda_dev_ctx->stream())); cudaEventRecord(events_.at(dev_id), cuda_dev_ctx->stream()));
#endif #endif
} }
......
...@@ -193,7 +193,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -193,7 +193,7 @@ void ReduceOpHandle::RunImpl() {
size_t numel = static_cast<size_t>(lod_tensor.numel()); size_t numel = static_cast<size_t>(lod_tensor.numel());
all_reduce_calls.emplace_back( all_reduce_calls.emplace_back(
[buffer, recvbuffer, type, numel, root_id, &nccl_ctx] { [buffer, recvbuffer, type, numel, root_id, &nccl_ctx] {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
buffer, recvbuffer, numel, static_cast<ncclDataType_t>(type), buffer, recvbuffer, numel, static_cast<ncclDataType_t>(type),
ncclSum, root_id, nccl_ctx.comm_, nccl_ctx.stream())); ncclSum, root_id, nccl_ctx.comm_, nccl_ctx.stream()));
}); });
......
...@@ -41,7 +41,7 @@ struct NCCLContextMap; ...@@ -41,7 +41,7 @@ struct NCCLContextMap;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif #endif
......
...@@ -59,7 +59,7 @@ struct TestReduceOpHandle { ...@@ -59,7 +59,7 @@ struct TestReduceOpHandle {
use_gpu_ = use_gpu; use_gpu_ = use_gpu;
if (use_gpu) { if (use_gpu) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int count = p::GetCUDADeviceCount(); int count = p::GetGPUDeviceCount();
if (count <= 1) { if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA " LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is " "device count is "
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_bool(sync_nccl_allreduce); DECLARE_bool(sync_nccl_allreduce);
...@@ -182,7 +182,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -182,7 +182,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype; << ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_gather_calls.emplace_back([=] { all_gather_calls.emplace_back([=] {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
in_tensor_buf, gather_buff, 2 * k, static_cast<ncclDataType_t>(dtype), in_tensor_buf, gather_buff, 2 * k, static_cast<ncclDataType_t>(dtype),
comm, stream)); comm, stream));
}); });
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/fluid/framework/details/dgc_const_values.h" #include "paddle/fluid/framework/details/dgc_const_values.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -54,7 +54,7 @@ class DeviceContext; ...@@ -54,7 +54,7 @@ class DeviceContext;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -115,7 +115,7 @@ void TestMainLoop() { ...@@ -115,7 +115,7 @@ void TestMainLoop() {
std::vector<platform::Place> places{platform::CPUPlace(), std::vector<platform::Place> places{platform::CPUPlace(),
platform::CUDAPlace(0), platform::CUDAPlace(0),
platform::CUDAPinnedPlace()}; platform::CUDAPinnedPlace()};
if (platform::GetCUDADeviceCount() > 1) { if (platform::GetGPUDeviceCount() > 1) {
places.emplace_back(platform::CUDAPlace(1)); places.emplace_back(platform::CUDAPlace(1));
} }
#else #else
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -40,7 +40,7 @@ limitations under the License. */ ...@@ -40,7 +40,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -397,7 +397,7 @@ class BoxWrapper { ...@@ -397,7 +397,7 @@ class BoxWrapper {
if (nullptr != s_instance_) { if (nullptr != s_instance_) {
VLOG(3) << "Begin InitializeGPU"; VLOG(3) << "Begin InitializeGPU";
std::vector<gpuStream_t*> stream_list; std::vector<gpuStream_t*> stream_list;
for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { for (int i = 0; i < platform::GetGPUDeviceCount(); ++i) {
VLOG(3) << "before get context i[" << i << "]"; VLOG(3) << "before get context i[" << i << "]";
platform::CUDADeviceContext* context = platform::CUDADeviceContext* context =
dynamic_cast<platform::CUDADeviceContext*>( dynamic_cast<platform::CUDADeviceContext*>(
...@@ -416,7 +416,7 @@ class BoxWrapper { ...@@ -416,7 +416,7 @@ class BoxWrapper {
slot_name_omited_in_feedpass_.insert(slot_name); slot_name_omited_in_feedpass_.insert(slot_name);
} }
slot_vector_ = slot_vector; slot_vector_ = slot_vector;
keys_tensor.resize(platform::GetCUDADeviceCount()); keys_tensor.resize(platform::GetGPUDeviceCount());
} }
} }
......
...@@ -740,10 +740,10 @@ void FleetWrapper::PushDenseVarsAsync( ...@@ -740,10 +740,10 @@ void FleetWrapper::PushDenseVarsAsync(
BOOST_GET_CONST(platform::CUDAPlace, place), g_data, BOOST_GET_CONST(platform::CUDAPlace, place), g_data,
sizeof(float) * count, stream); sizeof(float) * count, stream);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event, stream));
hipEventSynchronize(event); hipEventSynchronize(event);
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, stream));
cudaEventSynchronize(event); cudaEventSynchronize(event);
#endif #endif
......
...@@ -35,7 +35,7 @@ limitations under the License. */ ...@@ -35,7 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/type_defs.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h"
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -28,7 +28,7 @@ limitations under the License. */ ...@@ -28,7 +28,7 @@ limitations under the License. */
// #include "cudf/concurrent_unordered_map.cuh.h" // #include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/type_defs.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -347,7 +347,7 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys, ...@@ -347,7 +347,7 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
gpuStream_t streams[stream_num]; gpuStream_t streams[stream_num];
for (int i = 0; i < stream_num; ++i) { for (int i = 0; i < stream_num; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(streams[i]))); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&(streams[i])));
auto d_k_buf = memory::AllocShared(place, chunk_size * sizeof(KeyType)); auto d_k_buf = memory::AllocShared(place, chunk_size * sizeof(KeyType));
auto d_v_buf = memory::AllocShared(place, chunk_size * sizeof(ValType)); auto d_v_buf = memory::AllocShared(place, chunk_size * sizeof(ValType));
d_key_bufs.push_back(d_k_buf); d_key_bufs.push_back(d_k_buf);
...@@ -360,11 +360,11 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys, ...@@ -360,11 +360,11 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
while (cur_len < len) { while (cur_len < len) {
cur_stream = cur_stream % stream_num; cur_stream = cur_stream % stream_num;
int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size; int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(d_key_bufs[cur_stream]->ptr(), h_keys + cur_len, cudaMemcpyAsync(d_key_bufs[cur_stream]->ptr(), h_keys + cur_len,
sizeof(KeyType) * tmp_len, cudaMemcpyHostToDevice, sizeof(KeyType) * tmp_len, cudaMemcpyHostToDevice,
streams[cur_stream])); streams[cur_stream]));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpyAsync(d_val_bufs[cur_stream]->ptr(), h_vals + cur_len, cudaMemcpyAsync(d_val_bufs[cur_stream]->ptr(), h_vals + cur_len,
sizeof(ValType) * tmp_len, cudaMemcpyHostToDevice, sizeof(ValType) * tmp_len, cudaMemcpyHostToDevice,
streams[cur_stream])); streams[cur_stream]));
...@@ -378,7 +378,7 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys, ...@@ -378,7 +378,7 @@ void HeterComm<KeyType, ValType, GradType>::build_ps(int num, KeyType* h_keys,
for (int i = 0; i < stream_num; ++i) { for (int i = 0; i < stream_num; ++i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(streams[i])); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(streams[i]));
} }
} }
...@@ -402,14 +402,14 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, ...@@ -402,14 +402,14 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
GradType* d_merge_grads_ptr = GradType* d_merge_grads_ptr =
reinterpret_cast<GradType*>(d_merge_grads->ptr()); reinterpret_cast<GradType*>(d_merge_grads->ptr());
PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_grads, NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_grads,
d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false));
void* d_buff = NULL; void* d_buff = NULL;
auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes); auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes);
PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false)); d_grads, d_merge_grads_ptr, len, 0, 8 * sizeof(KeyType), stream, false));
temp_storage_bytes = 0; temp_storage_bytes = 0;
...@@ -417,7 +417,7 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, ...@@ -417,7 +417,7 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
auto d_num_runs_out_mem = memory::AllocShared(place, sizeof(int)); auto d_num_runs_out_mem = memory::AllocShared(place, sizeof(int));
int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr()); int* d_num_runs_out = reinterpret_cast<int*>(d_num_runs_out_mem->ptr());
PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceReduce::ReduceByKey( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey(
NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_merge_grads_ptr, NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_merge_grads_ptr,
d_grads, d_num_runs_out, merger_, len, stream, false)); d_grads, d_num_runs_out, merger_, len, stream, false));
...@@ -426,13 +426,13 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num, ...@@ -426,13 +426,13 @@ void HeterComm<KeyType, ValType, GradType>::merge_grad(int gpu_num,
d_temp_storage = memory::AllocShared(place, temp_storage_bytes); d_temp_storage = memory::AllocShared(place, temp_storage_bytes);
} }
PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceReduce::ReduceByKey( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::ReduceByKey(
d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys, d_temp_storage->ptr(), temp_storage_bytes, d_merge_keys_ptr, d_keys,
d_merge_grads_ptr, d_grads, d_num_runs_out, merger_, len, stream, false)); d_merge_grads_ptr, d_grads, d_num_runs_out, merger_, len, stream, false));
cudaMemcpyAsync(&uniq_len, d_num_runs_out, sizeof(int), cudaMemcpyAsync(&uniq_len, d_num_runs_out, sizeof(int),
cudaMemcpyDeviceToHost, stream); cudaMemcpyDeviceToHost, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
} }
template <typename KeyType, typename ValType, typename GradType> template <typename KeyType, typename ValType, typename GradType>
...@@ -461,12 +461,12 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard( ...@@ -461,12 +461,12 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
size_t temp_storage_bytes; size_t temp_storage_bytes;
const int num_bits = 1 + log2i(total_gpu); const int num_bits = 1 + log2i(total_gpu);
PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr, NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr,
d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream));
auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes); auto d_temp_storage = memory::AllocShared(place, temp_storage_bytes);
PADDLE_ENFORCE_CUDA_SUCCESS(cub::DeviceRadixSort::SortPairs( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr, d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr,
d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream)); d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream));
calc_shard_offset<<<grid_size, block_size_, 0, stream>>>(d_shard_index_ptr, calc_shard_offset<<<grid_size, block_size_, 0, stream>>>(d_shard_index_ptr,
...@@ -720,12 +720,12 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad( ...@@ -720,12 +720,12 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
// allgather grad len // allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
(const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt, (const void*)(d_node_len + gpu_num), (void*)d_node_len, 1, ncclInt,
nccl_inner_comm, stream)); nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu, cudaMemcpy(h_node_len, d_node_len, sizeof(int) * total_gpu,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
...@@ -737,15 +737,15 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad( ...@@ -737,15 +737,15 @@ int HeterComm<KeyType, ValType, GradType>::gather_one_node_grad(
storage.alloc(max_size * total_gpu); storage.alloc(max_size * total_gpu);
// allgather keys and grads // allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inner_comm, stream)); d_keys, storage.all_keys, max_size, ncclUint64, nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8, d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inner_comm, stream)); nccl_inner_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
int h_left[total_gpu]; int h_left[total_gpu];
int h_right[total_gpu]; int h_right[total_gpu];
...@@ -802,11 +802,11 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad( ...@@ -802,11 +802,11 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
cudaMemcpy(d_node_len, h_node_len, sizeof(int), cudaMemcpyHostToDevice); cudaMemcpy(d_node_len, h_node_len, sizeof(int), cudaMemcpyHostToDevice);
// allgather grad len // allgather grad len
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
d_node_len, d_node_len, 1, ncclInt, nccl_inter_comm, stream)); d_node_len, d_node_len, 1, ncclInt, nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
cudaMemcpy(h_node_len, d_node_len, sizeof(int) * node_size_, cudaMemcpy(h_node_len, d_node_len, sizeof(int) * node_size_,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
...@@ -818,15 +818,15 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad( ...@@ -818,15 +818,15 @@ int HeterComm<KeyType, ValType, GradType>::gather_multi_node_grad(
storage.alloc(max_size * node_size_); storage.alloc(max_size * node_size_);
// allgather keys and grads // allgather keys and grads
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
d_keys, storage.all_keys, max_size, ncclUint64, nccl_inter_comm, stream)); d_keys, storage.all_keys, max_size, ncclUint64, nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8, d_grads, storage.all_grads, max_size * sizeof(GradType), ncclUint8,
nccl_inter_comm, stream)); nccl_inter_comm, stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
int merge_num = 0; int merge_num = 0;
for (int i = 0; i < node_size_; ++i) { for (int i = 0; i < node_size_; ++i) {
......
...@@ -30,11 +30,11 @@ GPUResource::GPUResource(std::vector<int>& dev_ids, int index) { ...@@ -30,11 +30,11 @@ GPUResource::GPUResource(std::vector<int>& dev_ids, int index) {
remote_streams_.resize(dev_ids_.size()); remote_streams_.resize(dev_ids_.size());
for (size_t i = 0; i < dev_ids_.size(); ++i) { for (size_t i = 0; i < dev_ids_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamCreateWithFlags(&local_streams_[i], cudaStreamNonBlocking)); cudaStreamCreateWithFlags(&local_streams_[i], cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamCreateWithFlags(&comm_streams_[i], cudaStreamNonBlocking)); cudaStreamCreateWithFlags(&comm_streams_[i], cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamCreateWithFlags(&remote_streams_[i], cudaStreamNonBlocking)); cudaStreamCreateWithFlags(&remote_streams_[i], cudaStreamNonBlocking));
} }
} }
...@@ -42,13 +42,13 @@ GPUResource::GPUResource(std::vector<int>& dev_ids, int index) { ...@@ -42,13 +42,13 @@ GPUResource::GPUResource(std::vector<int>& dev_ids, int index) {
GPUResource::~GPUResource() { GPUResource::~GPUResource() {
platform::CUDADeviceGuard guard(dev_id_); platform::CUDADeviceGuard guard(dev_id_);
for (size_t i = 0; i < local_streams_.size(); ++i) { for (size_t i = 0; i < local_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(local_streams_[i])); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(local_streams_[i]));
} }
for (size_t i = 0; i < comm_streams_.size(); ++i) { for (size_t i = 0; i < comm_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(comm_streams_[i])); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(comm_streams_[i]));
} }
for (size_t i = 0; i < remote_streams_.size(); ++i) { for (size_t i = 0; i < remote_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_streams_[i])); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(remote_streams_[i]));
} }
} }
...@@ -58,7 +58,7 @@ void HeterPsResource::enable_p2p() { ...@@ -58,7 +58,7 @@ void HeterPsResource::enable_p2p() {
for (size_t j = 0; j < dev_ids_.size(); ++j) { for (size_t j = 0; j < dev_ids_.size(); ++j) {
if (i != j) { if (i != j) {
int p2p_flag; int p2p_flag;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaDeviceCanAccessPeer(&p2p_flag, dev_ids_[i], dev_ids_[j])); cudaDeviceCanAccessPeer(&p2p_flag, dev_ids_[i], dev_ids_[j]));
if (p2p_flag == 1) { if (p2p_flag == 1) {
cudaError_t ret = cudaDeviceEnablePeerAccess(dev_ids_[j], 0); cudaError_t ret = cudaDeviceEnablePeerAccess(dev_ids_[j], 0);
......
...@@ -22,7 +22,7 @@ bool NCCLWrapper::is_initialized_ = false; ...@@ -22,7 +22,7 @@ bool NCCLWrapper::is_initialized_ = false;
void NCCLWrapper::InitNCCL() { void NCCLWrapper::InitNCCL() {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitRank( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclCommInitRank(
&(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_, &(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_,
nccl_info_.my_global_rank_)); nccl_info_.my_global_rank_));
#endif #endif
...@@ -38,7 +38,7 @@ void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) { ...@@ -38,7 +38,7 @@ void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
NCCLInfo NCCLWrapper::GetNCCLId() { NCCLInfo NCCLWrapper::GetNCCLId() {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_))); platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_)));
#endif #endif
return nccl_info_; return nccl_info_;
...@@ -52,9 +52,9 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank, ...@@ -52,9 +52,9 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
nccl_info_.global_ranks_ = ranks; nccl_info_.global_ranks_ = ranks;
platform::SetDeviceId(local_rank); platform::SetDeviceId(local_rank);
#ifdef PADDLE_WITH_RCCL #ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&(nccl_info_.stream_))); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&(nccl_info_.stream_)));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_))); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_)));
#endif #endif
#endif #endif
return; return;
...@@ -67,7 +67,7 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope, ...@@ -67,7 +67,7 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
auto var = scope.FindVar(name); auto var = scope.FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int32_t total_size = tensor->numel(); int32_t total_size = tensor->numel();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(tensor->data<float>()), total_size, ncclFloat, reinterpret_cast<void*>(tensor->data<float>()), total_size, ncclFloat,
root_rank, nccl_info_.comm_, nccl_info_.stream_)); root_rank, nccl_info_.comm_, nccl_info_.stream_));
#ifdef PADDLE_WITH_RCCL #ifdef PADDLE_WITH_RCCL
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -37,8 +37,8 @@ limitations under the License. */ ...@@ -37,8 +37,8 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_PSCORE #ifdef PADDLE_WITH_PSCORE
...@@ -230,7 +230,7 @@ class PSGPUWrapper { ...@@ -230,7 +230,7 @@ class PSGPUWrapper {
? 1.0 ? 1.0
: config["mf_max_bound"]; : config["mf_max_bound"];
for (size_t i = 0; i < heter_devices_.size(); i++) { for (size_t i = 0; i < heter_devices_.size(); i++) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(heter_devices_[i])); PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i]));
this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, max_bound, this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, max_bound,
learning_rate, initial_g2sum, initial_range); learning_rate, initial_g2sum, initial_range);
this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate,
......
...@@ -83,9 +83,9 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place, ...@@ -83,9 +83,9 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
: GarbageCollector(place, max_memory_size) { : GarbageCollector(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device); platform::CUDADeviceGuard guard(place.device);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream_)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream_));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream_));
callback_manager_.reset( callback_manager_.reset(
new platform::StreamCallbackManager<gpuStream_t>(stream_)); new platform::StreamCallbackManager<gpuStream_t>(stream_));
#endif #endif
...@@ -94,13 +94,8 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place, ...@@ -94,13 +94,8 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
StreamGarbageCollector::~StreamGarbageCollector() { StreamGarbageCollector::~StreamGarbageCollector() {
auto place = BOOST_GET_CONST(platform::CUDAPlace, this->dev_ctx_->GetPlace()); auto place = BOOST_GET_CONST(platform::CUDAPlace, this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device); platform::CUDADeviceGuard guard(place.device);
#ifdef PADDLE_WITH_HIP platform::GpuStreamSync(stream_);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_)); platform::GpuDestroyStream(stream_);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamDestroy(stream_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
#endif
} }
gpuStream_t StreamGarbageCollector::stream() const { return stream_; } gpuStream_t StreamGarbageCollector::stream() const { return stream_; }
......
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,7 +33,7 @@ const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) { ...@@ -33,7 +33,7 @@ const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) {
static std::vector<std::shared_ptr<Generator>> default_cuda_generators; static std::vector<std::shared_ptr<Generator>> default_cuda_generators;
std::call_once(num_devices_init_flag, []() { std::call_once(num_devices_init_flag, []() {
num_cuda_devices = paddle::platform::GetCUDADeviceCount(); num_cuda_devices = paddle::platform::GetGPUDeviceCount();
cuda_device_flags.resize(num_cuda_devices); cuda_device_flags.resize(num_cuda_devices);
default_cuda_generators.resize(num_cuda_devices); default_cuda_generators.resize(num_cuda_devices);
}); });
......
...@@ -51,11 +51,11 @@ void HeterXpuTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -51,11 +51,11 @@ void HeterXpuTrainer::Initialize(const TrainerDesc& trainer_desc,
platform::CUDAPlace place = platform::CUDAPlace(num); platform::CUDAPlace place = platform::CUDAPlace(num);
platform::CUDADeviceGuard guard(place.device); platform::CUDADeviceGuard guard(place.device);
cudaStream_t stream; cudaStream_t stream;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
copy_streams_.push_back(stream); copy_streams_.push_back(stream);
places_.push_back(place); places_.push_back(place);
cudaEvent_t event; cudaEvent_t event;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
events_.push_back(event); events_.push_back(event);
#endif #endif
...@@ -104,7 +104,7 @@ void HeterXpuTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -104,7 +104,7 @@ void HeterXpuTrainer::Initialize(const TrainerDesc& trainer_desc,
// platform::CUDAPlace place = platform::CUDAPlace(num); // platform::CUDAPlace place = platform::CUDAPlace(num);
// platform::CUDADeviceGuard guard(place.device); // platform::CUDADeviceGuard guard(place.device);
// cudaStream_t stream; // cudaStream_t stream;
// PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream)); // PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
// copy_streams_.push_back(stream); // copy_streams_.push_back(stream);
// places_.push_back(place); // places_.push_back(place);
// } // }
...@@ -157,7 +157,7 @@ void HeterXpuTrainer::CreateThreadParam(const ProgramDesc& program, int num) { ...@@ -157,7 +157,7 @@ void HeterXpuTrainer::CreateThreadParam(const ProgramDesc& program, int num) {
} }
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, stream));
cudaEventSynchronize(event); cudaEventSynchronize(event);
#endif #endif
} }
...@@ -287,7 +287,7 @@ void HeterXpuTrainer::InitOtherEnv(const ProgramDesc& main_program) { ...@@ -287,7 +287,7 @@ void HeterXpuTrainer::InitOtherEnv(const ProgramDesc& main_program) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device; auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
platform::CUDADeviceGuard guard(dev_id); platform::CUDADeviceGuard guard(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreateWithFlags(&context->event_, cudaEventDisableTiming)); cudaEventCreateWithFlags(&context->event_, cudaEventDisableTiming));
#endif #endif
object_pool_.Push(context); object_pool_.Push(context);
...@@ -441,7 +441,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request, ...@@ -441,7 +441,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device; auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
platform::CUDADeviceGuard guard(dev_id); platform::CUDADeviceGuard guard(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreateWithFlags(&context->event_, cudaEventDisableTiming)); cudaEventCreateWithFlags(&context->event_, cudaEventDisableTiming));
#endif #endif
} }
...@@ -461,7 +461,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request, ...@@ -461,7 +461,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request,
#endif #endif
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventRecord(context->event_, copy_streams_[context->place_num_])); cudaEventRecord(context->event_, copy_streams_[context->place_num_]));
while (cudaEventQuery(context->event_) != cudaSuccess) { while (cudaEventQuery(context->event_) != cudaSuccess) {
VLOG(3) << "wait for kernel"; VLOG(3) << "wait for kernel";
...@@ -481,7 +481,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request, ...@@ -481,7 +481,7 @@ int HeterXpuTrainer::RunTask(const HeterRequest* request,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto* dev_ctx = static_cast<platform::CUDADeviceContext*>( auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventRecord(context->event_, dev_ctx->stream())); cudaEventRecord(context->event_, dev_ctx->stream()));
// cudaEventSynchronize(context->event_); // cudaEventSynchronize(context->event_);
{ {
......
...@@ -24,12 +24,8 @@ class Node; ...@@ -24,12 +24,8 @@ class Node;
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -15,13 +15,8 @@ ...@@ -15,13 +15,8 @@
#include "paddle/fluid/framework/ir/fuse_bn_add_act_pass.h" #include "paddle/fluid/framework/ir/fuse_bn_add_act_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
template <typename T> template <typename T>
using vec = paddle::framework::Vector<T>; using vec = paddle::framework::Vector<T>;
...@@ -63,7 +63,7 @@ TEST(mixed_vector, GPU_VECTOR) { ...@@ -63,7 +63,7 @@ TEST(mixed_vector, GPU_VECTOR) {
} }
TEST(mixed_vector, MultiGPU) { TEST(mixed_vector, MultiGPU) {
if (paddle::platform::GetCUDADeviceCount() < 2) { if (paddle::platform::GetGPUDeviceCount() < 2) {
LOG(WARNING) << "Skip mixed_vector.MultiGPU since there are not multiple " LOG(WARNING) << "Skip mixed_vector.MultiGPU since there are not multiple "
"GPUs in your machine."; "GPUs in your machine.";
return; return;
......
...@@ -398,13 +398,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -398,13 +398,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
/*For profiling/benchmark only*/ /*For profiling/benchmark only*/
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
instr_node.DeviceContext().Wait(); instr_node.DeviceContext().Wait();
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
VLOG(4) << "Operator(" << op->Type()
<< "): context wait and get last error";
#endif
#if defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError());
VLOG(4) << "Operator(" << op->Type() VLOG(4) << "Operator(" << op->Type()
<< "): context wait and get last error"; << "): context wait and get last error";
#endif #endif
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
...@@ -45,7 +45,7 @@ class ProfilerGuard { ...@@ -45,7 +45,7 @@ class ProfilerGuard {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, place); auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, place);
cost_info_->device_memory_bytes = cost_info_->device_memory_bytes =
platform::RecordedCudaMallocSize(cuda_place.device); platform::RecordedGpuMallocSize(cuda_place.device);
#endif #endif
} }
} }
......
...@@ -1212,14 +1212,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1212,14 +1212,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
/*For profiling/benchmark only*/ /*For profiling/benchmark only*/
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
dev_ctx->Wait(); dev_ctx->Wait();
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADLDE_WITH_ROCM)
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
VLOG(4) << "Operator(" << Type() << "): context wait and get last error";
#endif #endif
#if defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError());
VLOG(4) << "Operator(" << Type() << "): context wait and get last error"; VLOG(4) << "Operator(" << Type() << "): context wait and get last error";
#endif
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
......
...@@ -34,7 +34,7 @@ limitations under the License. */ ...@@ -34,7 +34,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
#include <cudnn.h> #include <cudnn.h>
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
...@@ -30,8 +30,8 @@ ...@@ -30,8 +30,8 @@
#endif #endif
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
#if defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" // NOLINT #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" // NOLINT
#include "paddle/fluid/platform/nccl_helper.h" // NOLINT #include "paddle/fluid/platform/device/gpu/nccl_helper.h" // NOLINT
#endif #endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" // NOLINT #include "paddle/fluid/operators/conv_cudnn_op_cache.h" // NOLINT
#include "paddle/fluid/operators/miopen_rnn_cache.h" #include "paddle/fluid/operators/miopen_rnn_cache.h"
......
...@@ -23,15 +23,15 @@ ...@@ -23,15 +23,15 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h" #include "paddle/fluid/operators/cudnn_rnn_cache.h"
#endif #endif
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
#if defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" // NOLINT #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" // NOLINT
#include "paddle/fluid/platform/nccl_helper.h" // NOLINT #include "paddle/fluid/platform/device/gpu/nccl_helper.h" // NOLINT
#endif #endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" // NOLINT #include "paddle/fluid/operators/conv_cudnn_op_cache.h" // NOLINT
#include "paddle/fluid/operators/miopen_rnn_cache.h" #include "paddle/fluid/operators/miopen_rnn_cache.h"
......
...@@ -28,8 +28,8 @@ ...@@ -28,8 +28,8 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/parallel_context.h" #include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
...@@ -64,7 +64,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, ...@@ -64,7 +64,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
dst->Resize(src.dims()); dst->Resize(src.dims());
auto *dst_ptr = dst->mutable_data(src.place(), src.type()); auto *dst_ptr = dst->mutable_data(src.place(), src.type());
auto nccl_dtype = platform::ToNCCLDataType(src.type()); auto nccl_dtype = platform::ToNCCLDataType(src.type());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
src_ptr, dst_ptr, src.numel(), nccl_dtype, ncclSum, comm->comm(), src_ptr, dst_ptr, src.numel(), nccl_dtype, ncclSum, comm->comm(),
stream)); stream));
} }
...@@ -100,16 +100,12 @@ static void AllReduce(const framework::SelectedRows &src, ...@@ -100,16 +100,12 @@ static void AllReduce(const framework::SelectedRows &src,
if (!use_calc_stream) { if (!use_calc_stream) {
dev_ctx->Wait(); dev_ctx->Wait();
} }
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
gpu_rows_num_ptr + strategy.local_rank_, gpu_rows_num_ptr, 1, ncclInt64, gpu_rows_num_ptr + strategy.local_rank_, gpu_rows_num_ptr, 1, ncclInt64,
comm->comm(), stream)); comm->comm(), stream));
if (!use_calc_stream) { if (!use_calc_stream) {
#ifdef PADDLE_WITH_RCCL platform::GpuStreamSync(stream);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
} }
const auto *cpu_rows_num_ptr = rows_num_vector.data(); const auto *cpu_rows_num_ptr = rows_num_vector.data();
...@@ -146,11 +142,11 @@ static void AllReduce(const framework::SelectedRows &src, ...@@ -146,11 +142,11 @@ static void AllReduce(const framework::SelectedRows &src,
// allgather is used to speed up the allreduce by replacing broadcast. // allgather is used to speed up the allreduce by replacing broadcast.
auto row_sendcount = cpu_rows_num_ptr[0]; auto row_sendcount = cpu_rows_num_ptr[0];
VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce"; VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
src_rows_ptr, dst_rows_ptr, row_sendcount, ncclInt64, comm->comm(), src_rows_ptr, dst_rows_ptr, row_sendcount, ncclInt64, comm->comm(),
stream)); stream));
auto value_sendcount = cpu_rows_num_ptr[0] * feature_size; auto value_sendcount = cpu_rows_num_ptr[0] * feature_size;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather(
src_tensor_ptr, dst_tensor_ptr, value_sendcount, nccl_dtype, src_tensor_ptr, dst_tensor_ptr, value_sendcount, nccl_dtype,
comm->comm(), stream)); comm->comm(), stream));
return; return;
...@@ -158,13 +154,13 @@ static void AllReduce(const framework::SelectedRows &src, ...@@ -158,13 +154,13 @@ static void AllReduce(const framework::SelectedRows &src,
for (int i = 0; i < strategy.nranks_; ++i) { for (int i = 0; i < strategy.nranks_; ++i) {
if (cpu_rows_num_ptr[i] > 0) { if (cpu_rows_num_ptr[i] > 0) {
// 2. Broadcast the rows of SelectedRows // 2. Broadcast the rows of SelectedRows
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBroadcast(
src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i], src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i],
ncclInt64, i, comm->comm(), stream)); ncclInt64, i, comm->comm(), stream));
// 3. Broadcast the tensor data of SelectedRows // 3. Broadcast the tensor data of SelectedRows
auto *dst_tensor_ptr_i = reinterpret_cast<uint8_t *>(dst_tensor_ptr) + auto *dst_tensor_ptr_i = reinterpret_cast<uint8_t *>(dst_tensor_ptr) +
row_offset * feature_size * sizeof_dtype; row_offset * feature_size * sizeof_dtype;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBroadcast(
src_tensor_ptr, dst_tensor_ptr_i, cpu_rows_num_ptr[i] * feature_size, src_tensor_ptr, dst_tensor_ptr_i, cpu_rows_num_ptr[i] * feature_size,
nccl_dtype, i, comm->comm(), stream)); nccl_dtype, i, comm->comm(), stream));
row_offset += cpu_rows_num_ptr[i]; row_offset += cpu_rows_num_ptr[i];
...@@ -209,12 +205,8 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst, ...@@ -209,12 +205,8 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst,
AllReduce(src.Get<framework::SelectedRows>(), AllReduce(src.Get<framework::SelectedRows>(),
tmp_dst.GetMutable<framework::SelectedRows>(), strategy, stream, tmp_dst.GetMutable<framework::SelectedRows>(), strategy, stream,
comm); comm);
// stream must synchronize to ensure accuracy of the move operation // stream must synchronize to ensure accuracy of the move operation
#ifdef PADDLE_WITH_RCCL platform::GpuStreamSync(stream);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
*dst = std::move(tmp_dst); *dst = std::move(tmp_dst);
} }
#endif #endif
......
...@@ -153,11 +153,11 @@ void NCCLParallelContext::WaitCompute(int ring_id) { ...@@ -153,11 +153,11 @@ void NCCLParallelContext::WaitCompute(int ring_id) {
// compute_stream-->event-->comm_stream // compute_stream-->event-->comm_stream
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, compute_stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(comm_stream, event, 0)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(comm_stream, event, 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0));
#endif #endif
} }
...@@ -179,11 +179,11 @@ void NCCLParallelContext::WaitComm(int ring_id) { ...@@ -179,11 +179,11 @@ void NCCLParallelContext::WaitComm(int ring_id) {
// comm_stream-->event-->compute_stream // comm_stream-->event-->compute_stream
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, comm_stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event, comm_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(compute_stream, event, 0)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(compute_stream, event, 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, comm_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0));
#endif #endif
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <vector> #include <vector>
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/cuda_resource_pool.h" #include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h"
#endif #endif
#ifdef PADDLE_WITH_NCCL #ifdef PADDLE_WITH_NCCL
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif #endif
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_bool(run_pten_kernel); DECLARE_bool(run_pten_kernel);
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
...@@ -523,12 +525,8 @@ static void PreparedOpRunPtImpl( ...@@ -523,12 +525,8 @@ static void PreparedOpRunPtImpl(
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
dev_ctx->Wait(); dev_ctx->Wait();
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif
#if defined(PADDLE_WITH_HIP)
PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError());
VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error"; VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif #endif
} }
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "paddle/fluid/inference/api/paddle_pass_builder.h" #include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/inference/utils/table_printer.h" #include "paddle/fluid/inference/utils/table_printer.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#ifdef PADDLE_WITH_TENSORRT #ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
......
...@@ -41,8 +41,8 @@ ...@@ -41,8 +41,8 @@
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/api/ext/op_meta_info.h" #include "paddle/pten/api/ext/op_meta_info.h"
......
...@@ -27,7 +27,7 @@ using paddle::PaddleDType; ...@@ -27,7 +27,7 @@ using paddle::PaddleDType;
void* TensorUtils::CudaMallocPinnedMemory(size_t size) { void* TensorUtils::CudaMallocPinnedMemory(size_t size) {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
void* ptr = nullptr; void* ptr = nullptr;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMallocHost(&ptr, size)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMallocHost(&ptr, size));
return ptr; return ptr;
#else #else
return nullptr; return nullptr;
...@@ -36,7 +36,7 @@ void* TensorUtils::CudaMallocPinnedMemory(size_t size) { ...@@ -36,7 +36,7 @@ void* TensorUtils::CudaMallocPinnedMemory(size_t size) {
void TensorUtils::CudaFreePinnedMemory(void* ptr) { void TensorUtils::CudaFreePinnedMemory(void* ptr) {
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_CUDA_SUCCESS(cudaFreeHost(ptr)); PADDLE_ENFORCE_GPU_SUCCESS(cudaFreeHost(ptr));
#endif #endif
} }
......
...@@ -45,7 +45,7 @@ class DefaultIOConverter : public EngineIOConverter { ...@@ -45,7 +45,7 @@ class DefaultIOConverter : public EngineIOConverter {
"the input max_size. But in's memory_size = %u, max_size = %u.", "the input max_size. But in's memory_size = %u, max_size = %u.",
size, max_size)); size, max_size));
if (is_cpu_place(place)) { if (is_cpu_place(place)) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
out, in.data<float>(), size, cudaMemcpyHostToDevice, *stream_)); out, in.data<float>(), size, cudaMemcpyHostToDevice, *stream_));
} else if (is_gpu_place(place)) { } else if (is_gpu_place(place)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
#include "cuda_runtime_api.h" // NOLINT #include "cuda_runtime_api.h" // NOLINT
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -43,16 +43,16 @@ nvinfer1::Weights DeformableConvPlugin::copyToDevice(const void* hostData, ...@@ -43,16 +43,16 @@ nvinfer1::Weights DeformableConvPlugin::copyToDevice(const void* hostData,
size_t count) { size_t count) {
int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
void* deviceData; void* deviceData;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&deviceData, count * num_bytes)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&deviceData, count * num_bytes));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(deviceData, hostData, count * num_bytes,
deviceData, hostData, count * num_bytes, cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
return nvinfer1::Weights{data_type_, deviceData, int64_t(count)}; return nvinfer1::Weights{data_type_, deviceData, int64_t(count)};
} }
void DeformableConvPlugin::serializeFromDevice( void DeformableConvPlugin::serializeFromDevice(
void** hostBuffer, const nvinfer1::Weights& deviceWeights) const { void** hostBuffer, const nvinfer1::Weights& deviceWeights) const {
int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2); int num_bytes = (data_type_ == nvinfer1::DataType::kFLOAT ? 4 : 2);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpy(static_cast<char*>(*hostBuffer), deviceWeights.values, cudaMemcpy(static_cast<char*>(*hostBuffer), deviceWeights.values,
deviceWeights.count * num_bytes, cudaMemcpyDeviceToHost)); deviceWeights.count * num_bytes, cudaMemcpyDeviceToHost));
hostBuffer += deviceWeights.count * num_bytes; hostBuffer += deviceWeights.count * num_bytes;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -136,7 +136,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs, ...@@ -136,7 +136,7 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]); float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
float* const* h_odatas = reinterpret_cast<float* const*>(outputs); float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]); float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*), output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*),
cudaMemcpyHostToDevice, stream)); cudaMemcpyHostToDevice, stream));
...@@ -263,7 +263,7 @@ int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, ...@@ -263,7 +263,7 @@ int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
float* const* h_odatas = reinterpret_cast<float* const*>(outputs); float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]); float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*), output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(float*),
cudaMemcpyHostToDevice, stream)); cudaMemcpyHostToDevice, stream));
...@@ -279,7 +279,7 @@ int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, ...@@ -279,7 +279,7 @@ int SplitPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
half* const* h_odatas = reinterpret_cast<half* const*>(outputs); half* const* h_odatas = reinterpret_cast<half* const*>(outputs);
half** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]); half** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs[0]);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync( PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(
output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(half*), output_ptrs, h_odatas, d_output_ptrs.size() * sizeof(half*),
cudaMemcpyHostToDevice, stream)); cudaMemcpyHostToDevice, stream));
......
...@@ -85,7 +85,7 @@ bool TRTInt8Calibrator::setBatch( ...@@ -85,7 +85,7 @@ bool TRTInt8Calibrator::setBatch(
engine_name_, it.first)); engine_name_, it.first));
} }
const auto& d = dataptr->second; const auto& d = dataptr->second;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice)); cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice));
} }
......
...@@ -30,13 +30,10 @@ ...@@ -30,13 +30,10 @@
#include "paddle/fluid/memory/allocation/pinned_allocator.h" #include "paddle/fluid/memory/allocation/pinned_allocator.h"
#include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h"
#include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/memory/allocation/thread_local_allocator.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h> #include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include "paddle/fluid/platform/cuda_graph.h"
#else
#include <hip/hip_runtime.h>
#endif #endif
#if CUDA_VERSION >= 10020 #if CUDA_VERSION >= 10020
...@@ -145,8 +142,7 @@ class AllocatorFacadePrivate { ...@@ -145,8 +142,7 @@ class AllocatorFacadePrivate {
"naive_best_fit strategy"; "naive_best_fit strategy";
FLAGS_use_stream_safe_cuda_allocator = false; FLAGS_use_stream_safe_cuda_allocator = false;
} }
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount(); ++dev_id) {
++dev_id) {
InitNaiveBestFitCUDAAllocator(platform::CUDAPlace(dev_id)); InitNaiveBestFitCUDAAllocator(platform::CUDAPlace(dev_id));
} }
InitNaiveBestFitCUDAPinnedAllocator(); InitNaiveBestFitCUDAPinnedAllocator();
...@@ -172,13 +168,13 @@ class AllocatorFacadePrivate { ...@@ -172,13 +168,13 @@ class AllocatorFacadePrivate {
if (FLAGS_use_stream_safe_cuda_allocator) { if (FLAGS_use_stream_safe_cuda_allocator) {
// TODO(Ruibiao): Support multi-stream allocator for other strategies // TODO(Ruibiao): Support multi-stream allocator for other strategies
default_stream_ = nullptr; default_stream_ = nullptr;
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount();
++dev_id) { ++dev_id) {
InitStreamSafeCUDAAllocator(platform::CUDAPlace(dev_id), InitStreamSafeCUDAAllocator(platform::CUDAPlace(dev_id),
default_stream_); default_stream_);
} }
} else { } else {
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount();
++dev_id) { ++dev_id) {
InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id), InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id),
allow_free_idle_chunk_); allow_free_idle_chunk_);
...@@ -208,8 +204,7 @@ class AllocatorFacadePrivate { ...@@ -208,8 +204,7 @@ class AllocatorFacadePrivate {
FLAGS_use_stream_safe_cuda_allocator = false; FLAGS_use_stream_safe_cuda_allocator = false;
} }
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount(); ++dev_id) {
++dev_id) {
InitThreadLocalCUDAAllocator(platform::CUDAPlace(dev_id)); InitThreadLocalCUDAAllocator(platform::CUDAPlace(dev_id));
} }
InitNaiveBestFitCUDAPinnedAllocator(); InitNaiveBestFitCUDAPinnedAllocator();
...@@ -399,10 +394,10 @@ class AllocatorFacadePrivate { ...@@ -399,10 +394,10 @@ class AllocatorFacadePrivate {
CUdevice device; CUdevice device;
int val; int val;
try { try {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId())); paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cuDeviceGetAttribute( paddle::platform::dynload::cuDeviceGetAttribute(
&val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED,
device)); device));
...@@ -476,10 +471,10 @@ class AllocatorFacadePrivate { ...@@ -476,10 +471,10 @@ class AllocatorFacadePrivate {
CUdevice device; CUdevice device;
int val; int val;
try { try {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId())); paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId()));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cuDeviceGetAttribute( paddle::platform::dynload::cuDeviceGetAttribute(
&val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED,
device)); device));
...@@ -599,7 +594,7 @@ class AllocatorFacadePrivate { ...@@ -599,7 +594,7 @@ class AllocatorFacadePrivate {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
system_allocators_[platform::CUDAPinnedPlace()] = system_allocators_[platform::CUDAPinnedPlace()] =
std::make_shared<CPUPinnedAllocator>(); std::make_shared<CPUPinnedAllocator>();
int device_count = platform::GetCUDADeviceCount(); int device_count = platform::GetGPUDeviceCount();
for (int i = 0; i < device_count; ++i) { for (int i = 0; i < device_count; ++i) {
platform::CUDAPlace p(i); platform::CUDAPlace p(i);
system_allocators_[p] = std::make_shared<CUDAAllocator>(p); system_allocators_[p] = std::make_shared<CUDAAllocator>(p);
...@@ -612,7 +607,7 @@ class AllocatorFacadePrivate { ...@@ -612,7 +607,7 @@ class AllocatorFacadePrivate {
std::vector<platform::Place> places; std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int device_count = platform::GetCUDADeviceCount(); int device_count = platform::GetGPUDeviceCount();
for (int dev_id = 0; dev_id < device_count; ++dev_id) { for (int dev_id = 0; dev_id < device_count; ++dev_id) {
places.emplace_back(platform::CUDAPlace(dev_id)); places.emplace_back(platform::CUDAPlace(dev_id));
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "paddle/fluid/memory/allocation/npu_pinned_allocator.h" #include "paddle/fluid/memory/allocation/npu_pinned_allocator.h"
#endif #endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#endif #endif
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_gpu_memory_to_use);
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
#include <string> #include <string>
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -37,8 +37,8 @@ void CUDAAllocator::FreeImpl(Allocation* allocation) { ...@@ -37,8 +37,8 @@ void CUDAAllocator::FreeImpl(Allocation* allocation) {
BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_, BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_,
platform::errors::PermissionDenied( platform::errors::PermissionDenied(
"GPU memory is freed in incorrect device. This may be a bug")); "GPU memory is freed in incorrect device. This may be a bug"));
platform::RecordedCudaFree(allocation->ptr(), allocation->size(), platform::RecordedGpuFree(allocation->ptr(), allocation->size(),
place_.device); place_.device);
delete allocation; delete allocation;
} }
...@@ -46,13 +46,13 @@ Allocation* CUDAAllocator::AllocateImpl(size_t size) { ...@@ -46,13 +46,13 @@ Allocation* CUDAAllocator::AllocateImpl(size_t size) {
std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); }); std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); });
void* ptr; void* ptr;
auto result = platform::RecordedCudaMalloc(&ptr, size, place_.device); auto result = platform::RecordedGpuMalloc(&ptr, size, place_.device);
if (LIKELY(result == gpuSuccess)) { if (LIKELY(result == gpuSuccess)) {
return new Allocation(ptr, size, platform::Place(place_)); return new Allocation(ptr, size, platform::Place(place_));
} }
size_t avail, total, actual_avail, actual_total; size_t avail, total, actual_avail, actual_total;
bool is_limited = platform::RecordedCudaMemGetInfo( bool is_limited = platform::RecordedGpuMemGetInfo(
&avail, &total, &actual_avail, &actual_total, place_.device); &avail, &total, &actual_avail, &actual_total, place_.device);
size_t allocated = total - avail; size_t allocated = total - avail;
......
...@@ -81,10 +81,10 @@ class CUDADeviceContextAllocator : public Allocator { ...@@ -81,10 +81,10 @@ class CUDADeviceContextAllocator : public Allocator {
: place_(place), default_stream_(default_stream) { : place_(place), default_stream_(default_stream) {
platform::CUDADeviceGuard guard(place_.device); platform::CUDADeviceGuard guard(place_.device);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipEventCreateWithFlags(&event_, hipEventDisableTiming)); hipEventCreateWithFlags(&event_, hipEventDisableTiming));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreate(&event_, cudaEventDisableTiming)); cudaEventCreate(&event_, cudaEventDisableTiming));
#endif #endif
} }
...@@ -93,9 +93,9 @@ class CUDADeviceContextAllocator : public Allocator { ...@@ -93,9 +93,9 @@ class CUDADeviceContextAllocator : public Allocator {
if (event_) { if (event_) {
platform::CUDADeviceGuard guard(place_.device); platform::CUDADeviceGuard guard(place_.device);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(event_)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(event_));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event_)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event_));
#endif #endif
} }
} }
...@@ -111,12 +111,11 @@ class CUDADeviceContextAllocator : public Allocator { ...@@ -111,12 +111,11 @@ class CUDADeviceContextAllocator : public Allocator {
new CUDADeviceContextAllocation(memory::Alloc(place_, size)); new CUDADeviceContextAllocation(memory::Alloc(place_, size));
// Wait for the event on stream // Wait for the event on stream
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event_, default_stream_)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, default_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(default_stream_, event_, 0)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(default_stream_, event_, 0));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, default_stream_)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, default_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(default_stream_, event_, 0));
cudaStreamWaitEvent(default_stream_, event_, 0));
#endif #endif
return allocation; return allocation;
} }
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/cuda_driver.h" #include "paddle/fluid/platform/dynload/cuda_driver.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif #endif
#if CUDA_VERSION >= 10020 #if CUDA_VERSION >= 10020
...@@ -49,10 +49,10 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( ...@@ -49,10 +49,10 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator(
// Prepare the access descriptor array indicating where and how the backings // Prepare the access descriptor array indicating where and how the backings
// should be visible. // should be visible.
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount(); ++dev_id) {
if (place.device != dev_id) { if (place.device != dev_id) {
int capable = 0; int capable = 0;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaDeviceCanAccessPeer(&capable, place.device, dev_id)); cudaDeviceCanAccessPeer(&capable, place.device, dev_id));
if (!capable) { if (!capable) {
VLOG(1) << "device(" << place.device VLOG(1) << "device(" << place.device
...@@ -73,10 +73,10 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( ...@@ -73,10 +73,10 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator(
// Get the minimum granularity needed for all devices // Get the minimum granularity needed for all devices
// (the max of the minimum granularity of each participating device) // (the max of the minimum granularity of each participating device)
granularity_ = 0; granularity_ = 0;
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount(); ++dev_id) {
size_t granularity; size_t granularity;
prop.location.id = dev_id; prop.location.id = dev_id;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cuMemGetAllocationGranularity( paddle::platform::dynload::cuMemGetAllocationGranularity(
&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
granularity_ = std::max(granularity, granularity_); granularity_ = std::max(granularity, granularity_);
...@@ -84,7 +84,7 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( ...@@ -84,7 +84,7 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator(
size_t actual_avail, actual_total; size_t actual_avail, actual_total;
paddle::platform::CUDADeviceGuard guard(place.device); paddle::platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total));
virtual_mem_size_ = AlignedSize(actual_total, granularity_); virtual_mem_size_ = AlignedSize(actual_total, granularity_);
...@@ -93,7 +93,7 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( ...@@ -93,7 +93,7 @@ CUDAVirtualMemAllocator::CUDAVirtualMemAllocator(
// GPU, // GPU,
// so the virtual address space size we reserve is equal to the GPU video // so the virtual address space size we reserve is equal to the GPU video
// memory size // memory size
PADDLE_ENFORCE_CUDA_SUCCESS(paddle::platform::dynload::cuMemAddressReserve( PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cuMemAddressReserve(
&virtual_mem_base_, virtual_mem_size_, 0, 0, 0)); &virtual_mem_base_, virtual_mem_size_, 0, 0, 0));
virtual_mem_alloced_offset_ = 0; virtual_mem_alloced_offset_ = 0;
...@@ -123,11 +123,11 @@ void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { ...@@ -123,11 +123,11 @@ void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) {
auto result = auto result =
paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second); paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second);
if (result != CUDA_ERROR_DEINITIALIZED) { if (result != CUDA_ERROR_DEINITIALIZED) {
PADDLE_ENFORCE_CUDA_SUCCESS(result); PADDLE_ENFORCE_GPU_SUCCESS(result);
} }
if (result != CUDA_ERROR_DEINITIALIZED) { if (result != CUDA_ERROR_DEINITIALIZED) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::RecordedCuMemRelease( PADDLE_ENFORCE_GPU_SUCCESS(platform::RecordedGpuMemRelease(
iter->second.first, iter->second.second, place_.device)); iter->second.first, iter->second.second, place_.device));
} }
...@@ -166,12 +166,12 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { ...@@ -166,12 +166,12 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) {
// Create physical memory backing allocation. // Create physical memory backing allocation.
auto result = auto result =
platform::RecordedCuMemCreate(&handle, size, &prop_, 0, place_.device); platform::RecordedGpuMemCreate(&handle, size, &prop_, 0, place_.device);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
if (result == CUDA_ERROR_OUT_OF_MEMORY) { if (result == CUDA_ERROR_OUT_OF_MEMORY) {
size_t actual_avail, actual_total; size_t actual_avail, actual_total;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total));
size_t actual_allocated = actual_total - actual_avail; size_t actual_allocated = actual_total - actual_avail;
PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted(
...@@ -186,7 +186,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { ...@@ -186,7 +186,7 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) {
string::HumanReadableSize(actual_allocated), string::HumanReadableSize(actual_allocated),
string::HumanReadableSize(actual_avail), place_.device)); string::HumanReadableSize(actual_avail), place_.device));
} else { } else {
PADDLE_ENFORCE_CUDA_SUCCESS(result); PADDLE_ENFORCE_GPU_SUCCESS(result);
} }
return nullptr; return nullptr;
} }
...@@ -197,8 +197,8 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { ...@@ -197,8 +197,8 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) {
result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
platform::RecordedCuMemRelease(handle, size, place_.device); platform::RecordedGpuMemRelease(handle, size, place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(result); PADDLE_ENFORCE_GPU_SUCCESS(result);
return nullptr; return nullptr;
} }
...@@ -208,8 +208,8 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { ...@@ -208,8 +208,8 @@ Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) {
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
paddle::platform::dynload::cuMemUnmap(ptr, size); paddle::platform::dynload::cuMemUnmap(ptr, size);
platform::RecordedCuMemRelease(handle, size, place_.device); platform::RecordedGpuMemRelease(handle, size, place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(result); PADDLE_ENFORCE_GPU_SUCCESS(result);
return nullptr; return nullptr;
} }
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
......
...@@ -20,18 +20,18 @@ namespace allocation { ...@@ -20,18 +20,18 @@ namespace allocation {
bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; } bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; }
void CPUPinnedAllocator::FreeImpl(Allocation *allocation) { void CPUPinnedAllocator::FreeImpl(Allocation *allocation) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipHostFree(allocation->ptr())); PADDLE_ENFORCE_GPU_SUCCESS(hipHostFree(allocation->ptr()));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaFreeHost(allocation->ptr())); PADDLE_ENFORCE_GPU_SUCCESS(cudaFreeHost(allocation->ptr()));
#endif #endif
delete allocation; delete allocation;
} }
Allocation *CPUPinnedAllocator::AllocateImpl(size_t size) { Allocation *CPUPinnedAllocator::AllocateImpl(size_t size) {
void *ptr; void *ptr;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipHostMalloc(&ptr, size, hipHostMallocPortable)); PADDLE_ENFORCE_GPU_SUCCESS(hipHostMalloc(&ptr, size, hipHostMallocPortable));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaHostAlloc(&ptr, size, cudaHostAllocPortable)); PADDLE_ENFORCE_GPU_SUCCESS(cudaHostAlloc(&ptr, size, cudaHostAllocPortable));
#endif #endif
return new Allocation(ptr, size, platform::CUDAPinnedPlace()); return new Allocation(ptr, size, platform::CUDAPinnedPlace());
} }
......
...@@ -112,13 +112,13 @@ void StreamSafeCUDAAllocator::CreateEventForAllRecordedStream( ...@@ -112,13 +112,13 @@ void StreamSafeCUDAAllocator::CreateEventForAllRecordedStream(
for (gpuStream_t stream : *recorded_streams) { for (gpuStream_t stream : *recorded_streams) {
gpuEvent_t event; gpuEvent_t event;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, stream));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipEventCreateWithFlags(&event, hipEventDisableTiming)); hipEventCreateWithFlags(&event, hipEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event, stream));
#endif #endif
outstanding_events->emplace_back(event); outstanding_events->emplace_back(event);
VLOG(9) << "Record event " << event << " in stream " << stream; VLOG(9) << "Record event " << event << " in stream " << stream;
...@@ -162,8 +162,8 @@ void StreamSafeCUDAAllocator::ProcessEventsAndFree() { ...@@ -162,8 +162,8 @@ void StreamSafeCUDAAllocator::ProcessEventsAndFree() {
outstanding_events.erase(outstanding_events.begin(), deque_it); outstanding_events.erase(outstanding_events.begin(), deque_it);
break; break;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(err); PADDLE_ENFORCE_GPU_SUCCESS(err);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(*deque_it)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(*deque_it));
#else #else
gpuError_t err = hipEventQuery(*deque_it); gpuError_t err = hipEventQuery(*deque_it);
if (err == hipErrorNotReady) { if (err == hipErrorNotReady) {
...@@ -173,8 +173,8 @@ void StreamSafeCUDAAllocator::ProcessEventsAndFree() { ...@@ -173,8 +173,8 @@ void StreamSafeCUDAAllocator::ProcessEventsAndFree() {
outstanding_events.erase(outstanding_events.begin(), deque_it); outstanding_events.erase(outstanding_events.begin(), deque_it);
break; break;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(err); PADDLE_ENFORCE_GPU_SUCCESS(err);
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(*deque_it)); PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(*deque_it));
#endif #endif
++deque_it; ++deque_it;
} }
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
......
...@@ -25,8 +25,8 @@ limitations under the License. */ ...@@ -25,8 +25,8 @@ limitations under the License. */
#include "paddle/fluid/memory/detail/memory_block.h" #include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
......
...@@ -24,8 +24,8 @@ limitations under the License. */ ...@@ -24,8 +24,8 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL)
......
...@@ -27,9 +27,9 @@ limitations under the License. */ ...@@ -27,9 +27,9 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
...@@ -115,7 +115,7 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) { ...@@ -115,7 +115,7 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) {
if (size <= 0) return nullptr; if (size <= 0) return nullptr;
void* p; void* p;
auto result = platform::RecordedCudaMalloc(&p, size, gpu_id_); auto result = platform::RecordedGpuMalloc(&p, size, gpu_id_);
if (result == gpuSuccess) { if (result == gpuSuccess) {
*index = 0; *index = 0;
...@@ -123,7 +123,7 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) { ...@@ -123,7 +123,7 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) {
return p; return p;
} else { } else {
size_t avail, total, actual_avail, actual_total; size_t avail, total, actual_avail, actual_total;
bool is_limited = platform::RecordedCudaMemGetInfo( bool is_limited = platform::RecordedGpuMemGetInfo(
&avail, &total, &actual_avail, &actual_total, gpu_id_); &avail, &total, &actual_avail, &actual_total, gpu_id_);
size_t allocated = total - avail; size_t allocated = total - avail;
...@@ -166,7 +166,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { ...@@ -166,7 +166,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {
size, gpu_alloc_size_)); size, gpu_alloc_size_));
gpu_alloc_size_ -= size; gpu_alloc_size_ -= size;
platform::RecordedCudaFree(p, size, gpu_id_); platform::RecordedGpuFree(p, size, gpu_id_);
} }
bool GPUAllocator::UseGpu() const { return true; } bool GPUAllocator::UseGpu() const { return true; }
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#endif
DECLARE_bool(use_pinned_memory); DECLARE_bool(use_pinned_memory);
...@@ -77,11 +80,7 @@ TEST(GPUAllocator, AllocFailure) { ...@@ -77,11 +80,7 @@ TEST(GPUAllocator, AllocFailure) {
allocator.Alloc(&index, alloc_size); allocator.Alloc(&index, alloc_size);
ASSERT_TRUE(false); ASSERT_TRUE(false);
} catch (paddle::memory::allocation::BadAlloc&) { } catch (paddle::memory::allocation::BadAlloc&) {
#ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::GpuGetLastError());
PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError());
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError());
#endif
} }
} }
#endif #endif
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
// This unit test is an example comparing the performance between using pinned // This unit test is an example comparing the performance between using pinned
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -53,9 +53,9 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { ...@@ -53,9 +53,9 @@ class StreamSafeCUDAAllocTest : public ::testing::Test {
for (size_t i = 1; i < stream_num_; ++i) { for (size_t i = 1; i < stream_num_; ++i) {
gpuStream_t stream; gpuStream_t stream;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream));
#endif #endif
streams_.emplace_back(stream); streams_.emplace_back(stream);
} }
...@@ -65,10 +65,10 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { ...@@ -65,10 +65,10 @@ class StreamSafeCUDAAllocTest : public ::testing::Test {
std::shared_ptr<Allocation> allocation = std::shared_ptr<Allocation> allocation =
AllocShared(place_, allocation_size, streams_[i]); AllocShared(place_, allocation_size, streams_[i]);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemset(allocation->ptr(), 0, allocation->size())); cudaMemset(allocation->ptr(), 0, allocation->size()));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
hipMemset(allocation->ptr(), 0, allocation->size())); hipMemset(allocation->ptr(), 0, allocation->size()));
#endif #endif
allocations_.emplace_back(allocation); allocations_.emplace_back(allocation);
...@@ -111,13 +111,13 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { ...@@ -111,13 +111,13 @@ class StreamSafeCUDAAllocTest : public ::testing::Test {
// tricky code, the allocations are still accessible even though // tricky code, the allocations are still accessible even though
// allocations_.clear() has been called // allocations_.clear() has been called
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemcpy(host_x.get(), allocations_[i]->ptr(), cudaMemcpy(host_x.get(), allocations_[i]->ptr(),
data_num_ * sizeof(int), cudaMemcpyDeviceToHost)); data_num_ * sizeof(int), cudaMemcpyDeviceToHost));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(host_x.get(), allocations_[i]->ptr(),
hipMemcpy(host_x.get(), allocations_[i]->ptr(), data_num_ * sizeof(int),
data_num_ * sizeof(int), hipMemcpyDeviceToHost)); hipMemcpyDeviceToHost));
#endif #endif
for (int j = 0; j < data_num_; ++j) { for (int j = 0; j < data_num_; ++j) {
EXPECT_TRUE(host_x[j] == (j % thread_num) * stream_num_); EXPECT_TRUE(host_x[j] == (j % thread_num) * stream_num_);
...@@ -127,9 +127,9 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { ...@@ -127,9 +127,9 @@ class StreamSafeCUDAAllocTest : public ::testing::Test {
void TearDown() override { void TearDown() override {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize()); PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceSynchronize()); PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif #endif
for (gpuStream_t stream : streams_) { for (gpuStream_t stream : streams_) {
Release(place_, stream); Release(place_, stream);
...@@ -137,14 +137,14 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { ...@@ -137,14 +137,14 @@ class StreamSafeCUDAAllocTest : public ::testing::Test {
for (size_t i = 1; i < stream_num_; ++i) { for (size_t i = 1; i < stream_num_; ++i) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(streams_[i])); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(streams_[i]));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamDestroy(streams_[i])); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(streams_[i]));
#endif #endif
} }
uint64_t cuda_malloc_size = uint64_t cuda_malloc_size =
platform::RecordedCudaMallocSize(place_.GetDeviceId()); platform::RecordedGpuMallocSize(place_.GetDeviceId());
ASSERT_EQ(cuda_malloc_size, 0) << "Found " << cuda_malloc_size ASSERT_EQ(cuda_malloc_size, 0) << "Found " << cuda_malloc_size
<< " bytes memory that not released yet," << " bytes memory that not released yet,"
<< " there may be a memory leak problem"; << " there may be a memory leak problem";
...@@ -192,11 +192,11 @@ TEST(StreamSafeCUDAAllocRetryTest, RetryTest) { ...@@ -192,11 +192,11 @@ TEST(StreamSafeCUDAAllocRetryTest, RetryTest) {
platform::CUDAPlace place = platform::CUDAPlace(); platform::CUDAPlace place = platform::CUDAPlace();
gpuStream_t stream1, stream2; gpuStream_t stream1, stream2;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream1)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream1));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream2)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream2));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream1)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream1));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream2)); PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream2));
#endif #endif
size_t available_size = platform::GpuAvailableMemToAlloc(); size_t available_size = platform::GpuAvailableMemToAlloc();
// alloc_size < available_size < 2 * alloc_size // alloc_size < available_size < 2 * alloc_size
...@@ -216,9 +216,9 @@ TEST(StreamSafeCUDAAllocRetryTest, RetryTest) { ...@@ -216,9 +216,9 @@ TEST(StreamSafeCUDAAllocRetryTest, RetryTest) {
allocation2.reset(); allocation2.reset();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize()); PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceSynchronize()); PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif #endif
Release(place, stream1); Release(place, stream1);
......
...@@ -14,11 +14,7 @@ ...@@ -14,11 +14,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,11 +14,7 @@ ...@@ -14,11 +14,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -64,13 +60,13 @@ struct CudnnActivationFunctor { ...@@ -64,13 +60,13 @@ struct CudnnActivationFunctor {
x_desc.set(x); x_desc.set(x);
out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation")); out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationForward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(), platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), out_desc.desc(), platform::CudnnDataType<T>::kZero(), out_desc.desc(),
out->mutable_data<T>(ctx_.GetPlace()))); out->mutable_data<T>(ctx_.GetPlace())));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationForward( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationForward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(), platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), out_desc.desc(), platform::CudnnDataType<T>::kZero(), out_desc.desc(),
...@@ -108,14 +104,14 @@ struct CudnnActivationGradFunctor { ...@@ -108,14 +104,14 @@ struct CudnnActivationGradFunctor {
dout_desc.set(dout); dout_desc.set(dout);
dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad")); dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationBackward( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationBackward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(), platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(), dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), dx_desc.desc(), platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
dx->mutable_data<T>(ctx_.GetPlace()))); dx->mutable_data<T>(ctx_.GetPlace())));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationBackward( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationBackward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(), platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(), dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -23,7 +23,7 @@ namespace cub = hipcub; ...@@ -23,7 +23,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
// HIP not support cudnnSpatialTfGridGeneratorForward // HIP not support cudnnSpatialTfGridGeneratorForward
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -108,7 +108,7 @@ class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> { ...@@ -108,7 +108,7 @@ class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> {
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* theta_grad_data = theta_grad->mutable_data<T>(ctx.GetPlace()); T* theta_grad_data = theta_grad->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnSpatialTfGridGeneratorBackward( platform::dynload::cudnnSpatialTfGridGeneratorBackward(
handle, cudnn_st_desc, output_grad_data, theta_grad_data)); handle, cudnn_st_desc, output_grad_data, theta_grad_data));
} }
......
...@@ -18,12 +18,7 @@ limitations under the License. */ ...@@ -18,12 +18,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/affine_grid_op.h" #include "paddle/fluid/operators/affine_grid_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -26,8 +26,8 @@ namespace cub = hipcub; ...@@ -26,8 +26,8 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h" #include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#ifdef __HIPCC__ #ifdef __HIPCC__
namespace rocprim { namespace rocprim {
...@@ -169,7 +169,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, ...@@ -169,7 +169,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream); cu_stream);
} }
PADDLE_ENFORCE_CUDA_SUCCESS(err); PADDLE_ENFORCE_GPU_SUCCESS(err);
Tensor temp_storage; Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes); temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
...@@ -188,7 +188,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, ...@@ -188,7 +188,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
cu_stream); cu_stream);
} }
PADDLE_ENFORCE_CUDA_SUCCESS(err); PADDLE_ENFORCE_GPU_SUCCESS(err);
} }
template <typename T, typename IndType> template <typename T, typename IndType>
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/average_accumulates_op.h" #include "paddle/fluid/operators/average_accumulates_op.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/batch_fc_op.h" #include "paddle/fluid/operators/batch_fc_op.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -197,18 +197,18 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -197,18 +197,18 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// miopenTensorDescriptor_t bn_param_desc_; // miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_; // miopenBatchNormMode_t mode_;
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); // platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_)); // platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else #else
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_; cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_; cudnnBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
#endif #endif
...@@ -251,23 +251,22 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -251,23 +251,22 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( // PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
// data_desc_, CudnnDataType<T>::type, // data_desc_, CudnnDataType<T>::type,
// x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()), // x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// const_cast<int *>(strides.data()))); // const_cast<int *>(strides.data())));
// Note: PERSISTENT not implemented for inference // Note: PERSISTENT not implemented for inference
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor( // platform::dynload::miopenDeriveBNTensorDescriptor(
// bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_)); // bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
// Note: PERSISTENT not implemented for inference // Note: PERSISTENT not implemented for inference
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnDeriveBNTensorDescriptor(
platform::dynload::cudnnDeriveBNTensorDescriptor( bn_param_desc_, data_desc_,
bn_param_desc_, data_desc_, test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
#endif #endif
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
...@@ -341,7 +340,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -341,7 +340,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} }
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardInference( // platform::dynload::miopenBatchNormalizationForwardInference(
// handle, miopenBNSpatial, // handle, miopenBNSpatial,
// const_cast<void *>( // const_cast<void *>(
...@@ -364,7 +363,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -364,7 +363,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// est_var->template data<BatchNormParamType<T>>())), // est_var->template data<BatchNormParamType<T>>())),
// epsilon)); // epsilon));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardInference( platform::dynload::cudnnBatchNormalizationForwardInference(
handle, handle,
// Note: PERSISTENT not implemented for inference // Note: PERSISTENT not implemented for inference
...@@ -426,7 +425,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -426,7 +425,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
"The argument ReserveSpace of batch_norm op is not found.")); "The argument ReserveSpace of batch_norm op is not found."));
// --------------- cudnn batchnorm workspace --------------- // --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload:: platform::dynload::
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
/*handle=*/handle, /*handle=*/handle,
...@@ -440,7 +439,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -440,7 +439,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
/*sizeInBytes=*/&workspace_size)); /*sizeInBytes=*/&workspace_size));
// -------------- cudnn batchnorm reserve space -------------- // -------------- cudnn batchnorm reserve space --------------
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload:: platform::dynload::
cudnnGetBatchNormalizationTrainingExReserveSpaceSize( cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
/*handle=*/handle, /*handle=*/handle,
...@@ -454,7 +453,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -454,7 +453,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ctx.GetPlace(), transformed_x.type(), reserve_space_size); ctx.GetPlace(), transformed_x.type(), reserve_space_size);
workspace_ptr = workspace_tensor.mutable_data( workspace_ptr = workspace_tensor.mutable_data(
ctx.GetPlace(), transformed_x.type(), workspace_size); ctx.GetPlace(), transformed_x.type(), workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTrainingEx( platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle, mode_, CUDNN_BATCHNORM_OPS_BN, CudnnDataType<T>::kOne(), handle, mode_, CUDNN_BATCHNORM_OPS_BN, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, CudnnDataType<T>::kZero(), data_desc_,
...@@ -508,7 +507,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -508,7 +507,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} }
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardTraining( // platform::dynload::miopenBatchNormalizationForwardTraining(
// handle, mode_, const_cast<void *>(static_cast<const void *>( // handle, mode_, const_cast<void *>(static_cast<const void *>(
// CudnnDataType<T>::kOne())), // CudnnDataType<T>::kOne())),
...@@ -537,7 +536,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -537,7 +536,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// static_cast<void *>(saved_variance->template mutable_data< // static_cast<void *>(saved_variance->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace())))); // BatchNormParamType<T>>(ctx.GetPlace()))));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTraining( platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), handle, mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, CudnnDataType<T>::kZero(), data_desc_,
...@@ -568,15 +567,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -568,15 +567,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit. // clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); // platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_)); // platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else #else
// clean when exit. // clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
#endif #endif
} }
...@@ -981,18 +980,18 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -981,18 +980,18 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
// miopenTensorDescriptor_t bn_param_desc_; // miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_; // miopenBatchNormMode_t mode_;
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); // platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_)); // platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else #else
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_; cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_; cudnnBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
#endif #endif
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
...@@ -1022,18 +1021,18 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1022,18 +1021,18 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( // PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
// data_desc_, CudnnDataType<T>::type, // data_desc_, CudnnDataType<T>::type,
// x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()), // x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// const_cast<int *>(strides.data()))); // const_cast<int *>(strides.data())));
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_, // platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
// data_desc_, mode_)); // data_desc_, mode_));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_, platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_)); data_desc_, mode_));
#endif #endif
...@@ -1063,7 +1062,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1063,7 +1062,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor workspace_tensor; Tensor workspace_tensor;
auto reserve_space_size = reserve_space->memory_size(); auto reserve_space_size = reserve_space->memory_size();
// --------------- cudnn batchnorm workspace --------------- // --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload:: platform::dynload::
cudnnGetBatchNormalizationBackwardExWorkspaceSize( cudnnGetBatchNormalizationBackwardExWorkspaceSize(
/*handle=*/dev_ctx.cudnn_handle(), /*handle=*/dev_ctx.cudnn_handle(),
...@@ -1081,7 +1080,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1081,7 +1080,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
workspace_ptr = workspace_tensor.mutable_data( workspace_ptr = workspace_tensor.mutable_data(
ctx.GetPlace(), transformed_x.type(), workspace_size); ctx.GetPlace(), transformed_x.type(), workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackwardEx( platform::dynload::cudnnBatchNormalizationBackwardEx(
/*handle=*/dev_ctx.cudnn_handle(), /*handle=*/dev_ctx.cudnn_handle(),
/*mode=*/mode_, /*mode=*/mode_,
...@@ -1151,7 +1150,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1151,7 +1150,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationBackward( // platform::dynload::miopenBatchNormalizationBackward(
// dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(), // dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(), // CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
...@@ -1166,7 +1165,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1166,7 +1165,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
// ctx.GetPlace()), // ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data)); // epsilon, saved_mean_data, saved_var_data));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackward( platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(), dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
...@@ -1231,15 +1230,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -1231,15 +1230,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit. // clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); // platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS( // PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_)); // platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else #else
// clean when exit. // clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
#endif #endif
} else { } else {
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/bce_loss_op.h" #include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
namespace paddle { namespace paddle {
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include "paddle/fluid/operators/bilateral_slice_op.h" #include "paddle/fluid/operators/bilateral_slice_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/bincount_op.h" #include "paddle/fluid/operators/bincount_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
namespace paddle { namespace paddle {
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include <iostream> #include <iostream>
#include "paddle/fluid/operators/center_loss_op.h" #include "paddle/fluid/operators/center_loss_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -131,27 +131,26 @@ class CholeskyGPUKernel : public framework::OpKernel<T> { ...@@ -131,27 +131,26 @@ class CholeskyGPUKernel : public framework::OpKernel<T> {
int lda, int* info) const { \ int lda, int* info) const { \
auto handle = dev_ctx.cusolver_dn_handle(); \ auto handle = dev_ctx.cusolver_dn_handle(); \
int workspace_size = 0; \ int workspace_size = 0; \
PADDLE_ENFORCE_CUDA_SUCCESS( \ PADDLE_ENFORCE_GPU_SUCCESS( \
platform::dynload::cusolverDn##C##potrf_bufferSize( \ platform::dynload::cusolverDn##C##potrf_bufferSize( \
handle, uplo, n, A, lda, &workspace_size)); \ handle, uplo, n, A, lda, &workspace_size)); \
auto workspace = memory::Alloc(dev_ctx, workspace_size); \ auto workspace = memory::Alloc(dev_ctx, workspace_size); \
T* workspace_ptr = reinterpret_cast<T*>(workspace->ptr()); \ T* workspace_ptr = reinterpret_cast<T*>(workspace->ptr()); \
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDn##C##potrf( \ PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDn##C##potrf( \
handle, uplo, n, A, lda, workspace_ptr, workspace_size, info)); \ handle, uplo, n, A, lda, workspace_ptr, workspace_size, info)); \
} }
FUNC_WITH_TYPES(POTRF_INSTANCE); FUNC_WITH_TYPES(POTRF_INSTANCE);
#if CUDA_VERSION >= 9020 && !defined(_WIN32) #if CUDA_VERSION >= 9020 && !defined(_WIN32)
#define POTRF_BATCH_INSTANCE(T, C) \ #define POTRF_BATCH_INSTANCE(T, C) \
template <> \ template <> \
void CholeskyGPUKernel<T>::PotrfBatched( \ void CholeskyGPUKernel<T>::PotrfBatched( \
const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo, \ const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo, \
int n, T* Aarray[], int lda, int* info_array, int batch_size) const { \ int n, T* Aarray[], int lda, int* info_array, int batch_size) const { \
auto handle = dev_ctx.cusolver_dn_handle(); \ auto handle = dev_ctx.cusolver_dn_handle(); \
PADDLE_ENFORCE_CUDA_SUCCESS( \ PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDn##C##potrfBatched( \
platform::dynload::cusolverDn##C##potrfBatched( \ handle, uplo, n, Aarray, lda, info_array, batch_size)); \
handle, uplo, n, Aarray, lda, info_array, batch_size)); \
} }
FUNC_WITH_TYPES(POTRF_BATCH_INSTANCE); FUNC_WITH_TYPES(POTRF_BATCH_INSTANCE);
......
...@@ -18,9 +18,9 @@ limitations under the License. */ ...@@ -18,9 +18,9 @@ limitations under the License. */
#include "cinn/runtime/cinn_runtime.h" #include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/type_defs.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -45,9 +45,9 @@ void CUDART_CB ReleaseBuffers(void* data) { ...@@ -45,9 +45,9 @@ void CUDART_CB ReleaseBuffers(void* data) {
template <> template <>
void ReleaseResource<platform::CUDADeviceContext>( void ReleaseResource<platform::CUDADeviceContext>(
const std::vector<void*>& resources, void* stream) { const std::vector<void*>& resources, void* stream) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaLaunchHostFunc( PADDLE_ENFORCE_GPU_SUCCESS(cudaLaunchHostFunc(
static_cast<gpuStream_t>(stream), ReleaseScope, resources[0])); static_cast<gpuStream_t>(stream), ReleaseScope, resources[0]));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaLaunchHostFunc( PADDLE_ENFORCE_GPU_SUCCESS(cudaLaunchHostFunc(
static_cast<gpuStream_t>(stream), ReleaseBuffers, resources[1])); static_cast<gpuStream_t>(stream), ReleaseBuffers, resources[1]));
} }
......
...@@ -30,7 +30,7 @@ namespace cub = hipcub; ...@@ -30,7 +30,7 @@ namespace cub = hipcub;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -335,7 +335,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { ...@@ -335,7 +335,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
static_cast<platform::CUDADeviceContext*>( static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(ctx.GetPlace())) platform::DeviceContextPool::Instance().Get(ctx.GetPlace()))
->stream(); ->stream();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
num_classes_per_device_ptr, num_classes_per_device_ptr, num_classes_per_device_ptr, num_classes_per_device_ptr,
num_classes_per_device.numel(), num_classes_per_device.numel(),
platform::ToNCCLDataType(num_classes_per_device.type()), ncclSum, platform::ToNCCLDataType(num_classes_per_device.type()), ncclSum,
...@@ -346,13 +346,13 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { ...@@ -346,13 +346,13 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
// step 2: Determine temporary device storage requirements // step 2: Determine temporary device storage requirements
int num_buffer_ele = std::max(batch_size, num_classes); int num_buffer_ele = std::max(batch_size, num_classes);
size_t cub_sort_temp_store_size = 0; size_t cub_sort_temp_store_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>( PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
nullptr, cub_sort_temp_store_size, nullptr, nullptr, nullptr, nullptr, nullptr, cub_sort_temp_store_size, nullptr, nullptr, nullptr, nullptr,
num_buffer_ele, 0, sizeof(T) * 8, ctx.cuda_device_context().stream()))); num_buffer_ele, 0, sizeof(T) * 8, ctx.cuda_device_context().stream())));
size_t cub_sum_temp_store_size = 0; size_t cub_sum_temp_store_size = 0;
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter_temp(nullptr, 0); NotEqualToPreviousAdjacentIterator<T> unique_counting_iter_temp(nullptr, 0);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
(cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>, (cub::DeviceScan::InclusiveSum<NotEqualToPreviousAdjacentIterator<T>,
T*>( T*>(
nullptr, cub_sum_temp_store_size, unique_counting_iter_temp, nullptr, cub_sum_temp_store_size, unique_counting_iter_temp,
...@@ -360,7 +360,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { ...@@ -360,7 +360,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
size_t cub_scan_temp_store_size = 0; size_t cub_scan_temp_store_size = 0;
ActualNumSampledFunctor<T> actual_num_sampled_op_temp(num_samples); ActualNumSampledFunctor<T> actual_num_sampled_op_temp(num_samples);
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveScan( PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceScan::InclusiveScan(
nullptr, cub_scan_temp_store_size, num_classes_per_device_ptr, nullptr, cub_scan_temp_store_size, num_classes_per_device_ptr,
num_classes_per_device_ptr, actual_num_sampled_op_temp, nranks + 1, num_classes_per_device_ptr, actual_num_sampled_op_temp, nranks + 1,
ctx.cuda_device_context().stream()))); ctx.cuda_device_context().stream())));
...@@ -384,7 +384,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { ...@@ -384,7 +384,7 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
void* cub_temp_storage_ptr = memory_buffer.cub_temp_storage_ptr(); void* cub_temp_storage_ptr = memory_buffer.cub_temp_storage_ptr();
// step 4: Calculate class interval among nranks // step 4: Calculate class interval among nranks
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveSum( PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceScan::InclusiveSum(
cub_temp_storage_ptr, cub_temp_storage_bytes, cub_temp_storage_ptr, cub_temp_storage_bytes,
num_classes_per_device_ptr, class_interval_ptr, nranks + 1, num_classes_per_device_ptr, class_interval_ptr, nranks + 1,
ctx.cuda_device_context().stream()))); ctx.cuda_device_context().stream())));
...@@ -415,13 +415,13 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { ...@@ -415,13 +415,13 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
// step 7: sort class center by ascending, so that positive class center // step 7: sort class center by ascending, so that positive class center
// always be sampled. // always be sampled.
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>( PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr, cub_temp_storage_bytes, cub_sort_keys_ptr, cub_temp_storage_ptr, cub_temp_storage_bytes, cub_sort_keys_ptr,
cub_sort_keys_out_ptr, cub_sort_values_ptr, cub_sort_values_out_ptr, cub_sort_keys_out_ptr, cub_sort_values_ptr, cub_sort_values_out_ptr,
num_classes, 0, sizeof(T) * 8, ctx.cuda_device_context().stream()))); num_classes, 0, sizeof(T) * 8, ctx.cuda_device_context().stream())));
// step 8: sort input label ascending // step 8: sort input label ascending
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>( PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceRadixSort::SortPairs<T, T>(
cub_temp_storage_ptr, cub_temp_storage_bytes, label->data<T>(), cub_temp_storage_ptr, cub_temp_storage_bytes, label->data<T>(),
cub_sort_keys_out_ptr, cub_sort_values_ptr, cub_sort_keys_ptr, cub_sort_keys_out_ptr, cub_sort_values_ptr, cub_sort_keys_ptr,
batch_size, 0, sizeof(T) * 8, ctx.cuda_device_context().stream()))); batch_size, 0, sizeof(T) * 8, ctx.cuda_device_context().stream())));
...@@ -430,8 +430,8 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { ...@@ -430,8 +430,8 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
// label // label
NotEqualToPreviousAdjacentIterator<T> unique_counting_iter( NotEqualToPreviousAdjacentIterator<T> unique_counting_iter(
cub_sort_keys_out_ptr, 0); cub_sort_keys_out_ptr, 0);
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveSum< PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceScan::InclusiveSum<
NotEqualToPreviousAdjacentIterator<T>, T*>( NotEqualToPreviousAdjacentIterator<T>, T*>(
cub_temp_storage_ptr, cub_temp_storage_bytes, unique_counting_iter, cub_temp_storage_ptr, cub_temp_storage_bytes, unique_counting_iter,
cub_sort_values_ptr, batch_size, ctx.cuda_device_context().stream()))); cub_sort_values_ptr, batch_size, ctx.cuda_device_context().stream())));
...@@ -445,13 +445,13 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> { ...@@ -445,13 +445,13 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
// Since maybe num_positive_class_center > num_samples, // Since maybe num_positive_class_center > num_samples,
// we need to ensure all positive class center per device are sampled. // we need to ensure all positive class center per device are sampled.
ActualNumSampledFunctor<T> actual_num_sampled_op(num_samples); ActualNumSampledFunctor<T> actual_num_sampled_op(num_samples);
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveScan( PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceScan::InclusiveScan(
cub_temp_storage_ptr, cub_temp_storage_bytes, bound_value_ptr, cub_temp_storage_ptr, cub_temp_storage_bytes, bound_value_ptr,
num_classes_per_device_ptr, actual_num_sampled_op, nranks + 1, num_classes_per_device_ptr, actual_num_sampled_op, nranks + 1,
ctx.cuda_device_context().stream()))); ctx.cuda_device_context().stream())));
// step 12: Calculate actual sampled class interval among nranks // step 12: Calculate actual sampled class interval among nranks
PADDLE_ENFORCE_CUDA_SUCCESS((cub::DeviceScan::InclusiveSum( PADDLE_ENFORCE_GPU_SUCCESS((cub::DeviceScan::InclusiveSum(
cub_temp_storage_ptr, cub_temp_storage_bytes, cub_temp_storage_ptr, cub_temp_storage_bytes,
num_classes_per_device_ptr, class_interval_ptr, nranks + 1, num_classes_per_device_ptr, class_interval_ptr, nranks + 1,
ctx.cuda_device_context().stream()))); ctx.cuda_device_context().stream())));
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -69,15 +69,11 @@ class AllReduceOpKernel : public framework::OpKernel<T> { ...@@ -69,15 +69,11 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
red_type = ncclMin; red_type = ncclMin;
break; break;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type, sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream)); comm, stream));
if (ctx.Attr<bool>("sync_mode")) { if (ctx.Attr<bool>("sync_mode")) {
#ifdef PADDLE_WITH_RCCL platform::GpuStreamSync(stream);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
} }
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -62,15 +62,15 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> { ...@@ -62,15 +62,15 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
auto recv_buf = out->mutable_data<T>(out_dims, place); auto recv_buf = out->mutable_data<T>(out_dims, place);
size_t offset = 0; size_t offset = 0;
send_numel /= nranks; send_numel /= nranks;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) { for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); send_buf + offset, send_numel, dtype, i, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); recv_buf + offset, send_numel, dtype, i, comm->comm(), stream));
offset += send_numel; offset += send_numel;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -41,13 +41,9 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> { ...@@ -41,13 +41,9 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream(); auto stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
ncclRedOp_t nccl_red_type = ncclSum; ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
#ifdef PADDLE_WITH_RCCL platform::GpuStreamSync(stream);
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with NCCL.")); "PaddlePaddle should compile with NCCL."));
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册