From 1d996637e6e5972363306677de9b5659a0624a8f Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 22 Feb 2021 10:57:36 +0800 Subject: [PATCH] [ROCM] update fluid imperative for rocm (part1), test=develop (#31017) * [ROCM] update fluid imperative for rocm (part1), test=develop * [ROCM] update reducer.cc after merge, test=develop * update reducer cmake after merge, test=develop --- paddle/fluid/imperative/CMakeLists.txt | 9 +++++-- paddle/fluid/imperative/all_reduce.cc | 24 +++++++++++++++---- paddle/fluid/imperative/all_reduce.h | 2 +- .../fluid/imperative/gradient_accumulator.cc | 20 ++++++++-------- paddle/fluid/imperative/nccl_context.cc | 18 ++++++++++---- paddle/fluid/imperative/nccl_context.h | 11 +++++++-- paddle/fluid/imperative/reducer.cc | 9 +++---- paddle/fluid/imperative/reducer.cu | 2 +- paddle/fluid/imperative/reducer.h | 3 ++- paddle/fluid/imperative/tests/CMakeLists.txt | 4 ++-- .../imperative/tests/nccl_context_test.cc | 2 +- .../tests/test_gradient_accmulator.cc | 10 ++++---- paddle/fluid/imperative/tests/test_group.cc | 4 ++-- .../fluid/imperative/tests/test_prepare_op.cc | 2 +- paddle/fluid/imperative/tests/test_tracer.cc | 4 ++-- paddle/fluid/imperative/tracer.cc | 2 +- 16 files changed, 82 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 22b30403a62..a24c0ac09c7 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -9,10 +9,15 @@ cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator) cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator) cc_library(imperative_profiler SRCS profiler.cc) if(NOT WIN32) - if(WITH_NCCL) + if(WITH_NCCL OR WITH_RCCL) cc_library(imperative_all_reduce SRCS all_reduce.cc DEPS collective_helper device_context selected_rows tensor) cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce var_type_traits) - nv_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce) + if(WITH_NCCL) + nv_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce) + endif() + if(WITH_RCCL) + hip_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce) + endif() endif() if(WITH_XPU_BKCL) cc_library(bkcl_context SRCS bkcl_context.cc DEPS collective_helper device_context tensor var_type_traits) diff --git a/paddle/fluid/imperative/all_reduce.cc b/paddle/fluid/imperative/all_reduce.cc index 3b018374f4f..b922811b4f1 100644 --- a/paddle/fluid/imperative/all_reduce.cc +++ b/paddle/fluid/imperative/all_reduce.cc @@ -12,11 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/imperative/all_reduce.h" +#ifdef PADDLE_WITH_NCCL #include +#endif + +#ifdef PADDLE_WITH_RCCL +#include +#endif #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" @@ -46,7 +52,7 @@ static const platform::Place &GetVarPlace(const framework::Variable &src) { } static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, - const cudaStream_t stream, + const gpuStream_t stream, const platform::NCCLComm *comm) { const auto &place = src.place(); PADDLE_ENFORCE_EQ( @@ -67,7 +73,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, static void AllReduce(const framework::SelectedRows &src, framework::SelectedRows *dst, const ParallelStrategy &strategy, - const cudaStream_t stream, + const gpuStream_t stream, const platform::NCCLComm *comm) { VLOG(3) << "SelectedRows AllReduce start"; const auto &src_tensor = src.value(); @@ -99,7 +105,11 @@ static void AllReduce(const framework::SelectedRows &src, comm->comm(), stream)); if (!use_calc_stream) { +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } const auto *cpu_rows_num_ptr = rows_num_vector.data(); @@ -176,7 +186,7 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst, platform::DeviceContextPool::Instance().Get(place)); platform::NCCLComm *comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - cudaStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream()); + gpuStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream()); if (src.IsType()) { if (!dst->IsType()) { @@ -199,8 +209,12 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst, AllReduce(src.Get(), tmp_dst.GetMutable(), strategy, stream, comm); - // stream must synchronize to ensure accuracy of the move operation +// stream must synchronize to ensure accuracy of the move operation +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif *dst = std::move(tmp_dst); } #endif diff --git a/paddle/fluid/imperative/all_reduce.h b/paddle/fluid/imperative/all_reduce.h index 2185c19b696..6ef528025b0 100644 --- a/paddle/fluid/imperative/all_reduce.h +++ b/paddle/fluid/imperative/all_reduce.h @@ -14,7 +14,7 @@ #pragma once -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) namespace paddle { namespace framework { diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index ff8494a3888..deb504a1b65 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -99,7 +99,7 @@ class TensorAddFunctor : public boost::static_visitor<> { } #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void operator()(const platform::CUDAPlace& place) { platform::CUDADeviceContext* ctx = dynamic_cast( @@ -186,7 +186,7 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { if (data_type == framework::proto::VarType::FP16) { if (platform::is_gpu_place(place)) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return TensorAddImpl( src_tensor, dst_tensor, place); #else @@ -224,7 +224,7 @@ void SelectedRowsAddToTensor(const framework::Variable& src, return; \ } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (paddle::platform::is_gpu_place(place)) { PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, float); PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, double); @@ -232,7 +232,7 @@ void SelectedRowsAddToTensor(const framework::Variable& src, #endif PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, float); PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, double); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } #endif @@ -267,7 +267,7 @@ static void SelectedRowsAddTensor( return; \ } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(place)) { PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, float); PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, double); @@ -275,7 +275,7 @@ static void SelectedRowsAddTensor( #endif PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, float); PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, double); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } #endif @@ -314,7 +314,7 @@ std::shared_ptr SelectedRowsMerge( return dst_var; \ } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (paddle::platform::is_gpu_place(place)) { PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, float); PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, double); @@ -322,7 +322,7 @@ std::shared_ptr SelectedRowsMerge( #endif PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, float); PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, double); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } #endif @@ -518,7 +518,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr var, } } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (paddle::platform::is_gpu_place(place)) { // sum selected rows firstly for (auto& var_info : tmp_grad_vars_) { @@ -579,7 +579,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr var, // Increase count IncreaseCurCnt(); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } #endif tmp_grad_vars_.clear(); diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 4ec23e4b7d6..eb0135d15e0 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/imperative/nccl_context.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" @@ -31,7 +31,7 @@ class Variable; namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void NCCLParallelContext::BcastNCCLId( std::vector &nccl_ids, // NOLINT @@ -113,9 +113,14 @@ void NCCLParallelContext::WaitCompute(int ring_id) { platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream(); auto event = compute_events_[ring_id].get(); - // compute_stream-->event-->comm_stream +// compute_stream-->event-->comm_stream +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, compute_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(comm_stream, event, 0)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0)); +#endif } void NCCLParallelContext::WaitComm(int ring_id) { @@ -134,9 +139,14 @@ void NCCLParallelContext::WaitComm(int ring_id) { platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream(); auto event = comm_events_[ring_id].get(); - // comm_stream-->event-->compute_stream +// comm_stream-->event-->compute_stream +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, comm_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(compute_stream, event, 0)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0)); +#endif } #endif diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index 1a93f897526..51e5743aebd 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -17,11 +17,18 @@ #include #include -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/cuda_resource_pool.h" +#endif + +#ifdef PADDLE_WITH_NCCL #include "paddle/fluid/platform/dynload/nccl.h" #endif +#ifdef PADDLE_WITH_RCCL +#include "paddle/fluid/platform/dynload/rccl.h" +#endif + #include "paddle/fluid/imperative/parallel_context.h" namespace paddle { @@ -33,7 +40,7 @@ class Variable; namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) class NCCLParallelContext : public ParallelContext { public: explicit NCCLParallelContext(const ParallelStrategy& strategy, diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 2289d6600f5..f8740940d04 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -27,7 +27,8 @@ namespace paddle { namespace imperative { -#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) // div the nranks void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { framework::Tensor *tensor = @@ -37,7 +38,7 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { : dense_contents_.GetMutable(); if (platform::is_gpu_place(tensor->place())) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DivNRanks(tensor, nranks, context); #endif } else if (platform::is_cpu_place(tensor->place())) { @@ -206,7 +207,7 @@ void SplitTensorsWithType( void Group::ConcatTensors(const platform::DeviceContext &context) { auto place = context.GetPlace(); if (platform::is_gpu_place(place)) { -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) ConcatTensorsWithType( static_cast(context), dense_tensors_, &dense_contents_, dtype_); @@ -238,7 +239,7 @@ void Group::ConcatTensors(const platform::DeviceContext &context) { void Group::SplitTensors(const platform::DeviceContext &context) { auto place = context.GetPlace(); if (platform::is_gpu_place(place)) { -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) SplitTensorsWithType( static_cast(context), &dense_contents_, &dense_tensors_, dtype_); diff --git a/paddle/fluid/imperative/reducer.cu b/paddle/fluid/imperative/reducer.cu index 96e1de5b3d1..ca233292b34 100644 --- a/paddle/fluid/imperative/reducer.cu +++ b/paddle/fluid/imperative/reducer.cu @@ -17,7 +17,7 @@ namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks, const platform::DeviceContext &context) { framework::VisitDataTypeSmall( diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 1ac9f155a00..f352ad17fda 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -47,7 +47,8 @@ class VariableWrapper; namespace paddle { namespace imperative { -#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) template struct DivNRanksFunctor { diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index 353c137fbf9..adb560df77c 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -1,7 +1,7 @@ if(WIN32) cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS device_context) else() - if (WITH_NCCL) + if (WITH_NCCL OR WITH_RCCL) cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) endif() if (WITH_XPU_BKCL) @@ -16,6 +16,6 @@ cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info s cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy) -if (WITH_NCCL OR WITH_XPU_BKCL) +if (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL) cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy) endif() diff --git a/paddle/fluid/imperative/tests/nccl_context_test.cc b/paddle/fluid/imperative/tests/nccl_context_test.cc index ab4d4add069..4967df5341d 100644 --- a/paddle/fluid/imperative/tests/nccl_context_test.cc +++ b/paddle/fluid/imperative/tests/nccl_context_test.cc @@ -33,7 +33,7 @@ imperative::ParallelStrategy GetStrategy(int local_rank) { return strategy; } -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void BcastNCCLId(int local_rank, std::vector* nccl_ids) { auto strategy = GetStrategy(local_rank); platform::CUDAPlace gpu(local_rank); diff --git a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc index c394ce07df3..cb4ab2e79cb 100644 --- a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc +++ b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc @@ -53,7 +53,7 @@ int TensorddTest(Place place, T t1, T t2) { sizeof(T) * src_data.size()); paddle::memory::Copy(place, dst_mutable, src_place, dst_data.data(), sizeof(T) * dst_data.size()); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } else { paddle::memory::Copy(place, src_mutable, src_place, src_data.data(), sizeof(T) * src_data.size(), 0); @@ -74,7 +74,7 @@ int TensorddTest(Place place, T t1, T t2) { } TEST(test_add_functor, add_functor) { -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::CUDAPlace gpu_place(0); #endif platform::CPUPlace cpu_place; @@ -88,7 +88,7 @@ TEST(test_add_functor, add_functor) { cpu_res = TensorddTest(cpu_place, static_cast(1.0), static_cast(2.0)); EXPECT_EQ(cpu_res, 0); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) int gpu_res = 1; gpu_res = TensorddTest(gpu_place, 1.0, 0.0); EXPECT_EQ(gpu_res, 0); @@ -107,7 +107,7 @@ TEST(test_add_functor, execption) { platform::CPUPlace cpu_place; ASSERT_ANY_THROW(TensorddTest(cpu_place, 1, 0)); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ASSERT_ANY_THROW(TensorddTest(cuda_pinned_place, 1.0, 0.0)); ASSERT_ANY_THROW(TensorddTest(cuda_pinned_place, static_cast(1.0), @@ -358,7 +358,7 @@ TEST(test_gradient_accumulator, test_unchange_input) { for (auto sort_gradient : {false, true}) { TestGradientAccumulatorTestUnchangeInput(platform::CPUPlace(), sort_gradient); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TestGradientAccumulatorTestUnchangeInput(platform::CUDAPlace(0), sort_gradient); #endif diff --git a/paddle/fluid/imperative/tests/test_group.cc b/paddle/fluid/imperative/tests/test_group.cc index 60814dcb6cc..0c058038968 100644 --- a/paddle/fluid/imperative/tests/test_group.cc +++ b/paddle/fluid/imperative/tests/test_group.cc @@ -73,7 +73,7 @@ void GroupConcatSplit(Place place, size_t size) { } if (std::is_same::value) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) paddle::memory::Copy(place, data, cpu_place, value.data(), sizeof(T) * value.size(), 0); #endif @@ -133,7 +133,7 @@ void GroupConcatSplit(Place place, size_t size) { } } -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) TEST(TestGroup, TestConcatSplit) { platform::CUDAPlace cuda_place(0); platform::CPUPlace cpu_place; diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index ea009a4f5a4..7d6882a4ee7 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -106,7 +106,7 @@ TEST(test_prepare_op, test_get_tensor_from_var) { ASSERT_TRUE(ts != nullptr); } -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(test_prepare_op, test_prepare_data) { std::shared_ptr vin( new imperative::VarBase(false, "vin")); diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index c2ead38e4c1..e3b5ff67036 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -195,7 +195,7 @@ TEST(test_tracer, test_track_backward_input) { ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL); } -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(test_tracer, test_trace_op_with_multi_device_inputs) { // Doing an mul imperative::Tracer tracer; @@ -521,7 +521,7 @@ static void TestVarOpDestructionMain(const platform::Place& place, TEST(test_tracer, test_var_op_destruction) { TestVarOpDestructionMain(platform::CPUPlace()); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TestVarOpDestructionMain(platform::CUDAPlace(0)); #endif } diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 7003e569d19..3c20c1f647a 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -201,7 +201,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, void Tracer::SetExpectedPlace(platform::Place place) { // NOTE(wangxi): set device id before launch device kernel if (platform::is_gpu_place(place)) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::SetDeviceId(BOOST_GET_CONST(platform::CUDAPlace, place).device); #else PADDLE_THROW(platform::errors::PreconditionNotMet( -- GitLab