未验证 提交 1d996637 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid imperative for rocm (part1), test=develop (#31017)

* [ROCM] update fluid imperative for rocm (part1), test=develop

* [ROCM] update reducer.cc after merge, test=develop

* update reducer cmake after merge, test=develop
上级 b95eb38b
......@@ -9,10 +9,15 @@ cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32)
if(WITH_NCCL)
if(WITH_NCCL OR WITH_RCCL)
cc_library(imperative_all_reduce SRCS all_reduce.cc DEPS collective_helper device_context selected_rows tensor)
cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce var_type_traits)
nv_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce)
if(WITH_NCCL)
nv_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce)
endif()
if(WITH_RCCL)
hip_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce)
endif()
endif()
if(WITH_XPU_BKCL)
cc_library(bkcl_context SRCS bkcl_context.cc DEPS collective_helper device_context tensor var_type_traits)
......
......@@ -12,11 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_NCCL
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/imperative/all_reduce.h"
#ifdef PADDLE_WITH_NCCL
#include <nccl.h>
#endif
#ifdef PADDLE_WITH_RCCL
#include <rccl.h>
#endif
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
......@@ -46,7 +52,7 @@ static const platform::Place &GetVarPlace(const framework::Variable &src) {
}
static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
const cudaStream_t stream,
const gpuStream_t stream,
const platform::NCCLComm *comm) {
const auto &place = src.place();
PADDLE_ENFORCE_EQ(
......@@ -67,7 +73,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
static void AllReduce(const framework::SelectedRows &src,
framework::SelectedRows *dst,
const ParallelStrategy &strategy,
const cudaStream_t stream,
const gpuStream_t stream,
const platform::NCCLComm *comm) {
VLOG(3) << "SelectedRows AllReduce start";
const auto &src_tensor = src.value();
......@@ -99,7 +105,11 @@ static void AllReduce(const framework::SelectedRows &src,
comm->comm(), stream));
if (!use_calc_stream) {
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
}
const auto *cpu_rows_num_ptr = rows_num_vector.data();
......@@ -176,7 +186,7 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst,
platform::DeviceContextPool::Instance().Get(place));
platform::NCCLComm *comm =
platform::NCCLCommContext::Instance().Get(ring_id, place);
cudaStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream());
gpuStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream());
if (src.IsType<framework::LoDTensor>()) {
if (!dst->IsType<framework::LoDTensor>()) {
......@@ -199,8 +209,12 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst,
AllReduce(src.Get<framework::SelectedRows>(),
tmp_dst.GetMutable<framework::SelectedRows>(), strategy, stream,
comm);
// stream must synchronize to ensure accuracy of the move operation
// stream must synchronize to ensure accuracy of the move operation
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
*dst = std::move(tmp_dst);
}
#endif
......
......@@ -14,7 +14,7 @@
#pragma once
#ifdef PADDLE_WITH_NCCL
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
namespace paddle {
namespace framework {
......
......@@ -99,7 +99,7 @@ class TensorAddFunctor : public boost::static_visitor<> {
}
#endif
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void operator()(const platform::CUDAPlace& place) {
platform::CUDADeviceContext* ctx =
dynamic_cast<platform::CUDADeviceContext*>(
......@@ -186,7 +186,7 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
if (data_type == framework::proto::VarType::FP16) {
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<platform::CUDADeviceContext, platform::float16>(
src_tensor, dst_tensor, place);
#else
......@@ -224,7 +224,7 @@ void SelectedRowsAddToTensor(const framework::Variable& src,
return; \
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (paddle::platform::is_gpu_place(place)) {
PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, float);
PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CUDADeviceContext, double);
......@@ -232,7 +232,7 @@ void SelectedRowsAddToTensor(const framework::Variable& src,
#endif
PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, float);
PADDLE_SELECTED_ROWS_ADD_TO_TENSOR(platform::CPUDeviceContext, double);
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
}
#endif
......@@ -267,7 +267,7 @@ static void SelectedRowsAddTensor(
return; \
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place)) {
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, float);
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, double);
......@@ -275,7 +275,7 @@ static void SelectedRowsAddTensor(
#endif
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, float);
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, double);
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
}
#endif
......@@ -314,7 +314,7 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge(
return dst_var; \
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (paddle::platform::is_gpu_place(place)) {
PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, float);
PADDLE_SELECTED_ROWS_ADD(platform::CUDADeviceContext, double);
......@@ -322,7 +322,7 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge(
#endif
PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, float);
PADDLE_SELECTED_ROWS_ADD(platform::CPUDeviceContext, double);
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
}
#endif
......@@ -518,7 +518,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
}
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (paddle::platform::is_gpu_place(place)) {
// sum selected rows firstly
for (auto& var_info : tmp_grad_vars_) {
......@@ -579,7 +579,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
// Increase count
IncreaseCurCnt();
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
}
#endif
tmp_grad_vars_.clear();
......
......@@ -14,7 +14,7 @@
#include "paddle/fluid/imperative/nccl_context.h"
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/imperative/all_reduce.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
......@@ -31,7 +31,7 @@ class Variable;
namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void NCCLParallelContext::BcastNCCLId(
std::vector<ncclUniqueId> &nccl_ids, // NOLINT
......@@ -113,9 +113,14 @@ void NCCLParallelContext::WaitCompute(int ring_id) {
platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = compute_events_[ring_id].get();
// compute_stream-->event-->comm_stream
// compute_stream-->event-->comm_stream
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(comm_stream, event, 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0));
#endif
}
void NCCLParallelContext::WaitComm(int ring_id) {
......@@ -134,9 +139,14 @@ void NCCLParallelContext::WaitComm(int ring_id) {
platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = comm_events_[ring_id].get();
// comm_stream-->event-->compute_stream
// comm_stream-->event-->compute_stream
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, comm_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(compute_stream, event, 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0));
#endif
}
#endif
......
......@@ -17,11 +17,18 @@
#include <string>
#include <vector>
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/cuda_resource_pool.h"
#endif
#ifdef PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#include "paddle/fluid/imperative/parallel_context.h"
namespace paddle {
......@@ -33,7 +40,7 @@ class Variable;
namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
class NCCLParallelContext : public ParallelContext {
public:
explicit NCCLParallelContext(const ParallelStrategy& strategy,
......
......@@ -27,7 +27,8 @@
namespace paddle {
namespace imperative {
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
// div the nranks
void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
framework::Tensor *tensor =
......@@ -37,7 +38,7 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
: dense_contents_.GetMutable<framework::LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DivNRanks(tensor, nranks, context);
#endif
} else if (platform::is_cpu_place(tensor->place())) {
......@@ -206,7 +207,7 @@ void SplitTensorsWithType<platform::XPUDeviceContext>(
void Group::ConcatTensors(const platform::DeviceContext &context) {
auto place = context.GetPlace();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_NCCL
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ConcatTensorsWithType(
static_cast<const platform::CUDADeviceContext &>(context),
dense_tensors_, &dense_contents_, dtype_);
......@@ -238,7 +239,7 @@ void Group::ConcatTensors(const platform::DeviceContext &context) {
void Group::SplitTensors(const platform::DeviceContext &context) {
auto place = context.GetPlace();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_NCCL
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
SplitTensorsWithType(
static_cast<const platform::CUDADeviceContext &>(context),
&dense_contents_, &dense_tensors_, dtype_);
......
......@@ -17,7 +17,7 @@
namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks,
const platform::DeviceContext &context) {
framework::VisitDataTypeSmall(
......
......@@ -47,7 +47,8 @@ class VariableWrapper;
namespace paddle {
namespace imperative {
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
template <typename T>
struct DivNRanksFunctor {
......
if(WIN32)
cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS device_context)
else()
if (WITH_NCCL)
if (WITH_NCCL OR WITH_RCCL)
cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context)
endif()
if (WITH_XPU_BKCL)
......@@ -16,6 +16,6 @@ cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info s
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy)
if (WITH_NCCL OR WITH_XPU_BKCL)
if (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL)
cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy)
endif()
......@@ -33,7 +33,7 @@ imperative::ParallelStrategy GetStrategy(int local_rank) {
return strategy;
}
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void BcastNCCLId(int local_rank, std::vector<ncclUniqueId>* nccl_ids) {
auto strategy = GetStrategy(local_rank);
platform::CUDAPlace gpu(local_rank);
......
......@@ -53,7 +53,7 @@ int TensorddTest(Place place, T t1, T t2) {
sizeof(T) * src_data.size());
paddle::memory::Copy(place, dst_mutable, src_place, dst_data.data(),
sizeof(T) * dst_data.size());
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else {
paddle::memory::Copy(place, src_mutable, src_place, src_data.data(),
sizeof(T) * src_data.size(), 0);
......@@ -74,7 +74,7 @@ int TensorddTest(Place place, T t1, T t2) {
}
TEST(test_add_functor, add_functor) {
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace gpu_place(0);
#endif
platform::CPUPlace cpu_place;
......@@ -88,7 +88,7 @@ TEST(test_add_functor, add_functor) {
cpu_res = TensorddTest(cpu_place, static_cast<platform::float16>(1.0),
static_cast<platform::float16>(2.0));
EXPECT_EQ(cpu_res, 0);
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int gpu_res = 1;
gpu_res = TensorddTest(gpu_place, 1.0, 0.0);
EXPECT_EQ(gpu_res, 0);
......@@ -107,7 +107,7 @@ TEST(test_add_functor, execption) {
platform::CPUPlace cpu_place;
ASSERT_ANY_THROW(TensorddTest(cpu_place, 1, 0));
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
ASSERT_ANY_THROW(TensorddTest(cuda_pinned_place, 1.0, 0.0));
ASSERT_ANY_THROW(TensorddTest(cuda_pinned_place,
static_cast<platform::float16>(1.0),
......@@ -358,7 +358,7 @@ TEST(test_gradient_accumulator, test_unchange_input) {
for (auto sort_gradient : {false, true}) {
TestGradientAccumulatorTestUnchangeInput(platform::CPUPlace(),
sort_gradient);
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TestGradientAccumulatorTestUnchangeInput(platform::CUDAPlace(0),
sort_gradient);
#endif
......
......@@ -73,7 +73,7 @@ void GroupConcatSplit(Place place, size_t size) {
}
if (std::is_same<Place, platform::CUDAPlace>::value) {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
paddle::memory::Copy(place, data, cpu_place, value.data(),
sizeof(T) * value.size(), 0);
#endif
......@@ -133,7 +133,7 @@ void GroupConcatSplit(Place place, size_t size) {
}
}
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
TEST(TestGroup, TestConcatSplit) {
platform::CUDAPlace cuda_place(0);
platform::CPUPlace cpu_place;
......
......@@ -106,7 +106,7 @@ TEST(test_prepare_op, test_get_tensor_from_var) {
ASSERT_TRUE(ts != nullptr);
}
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(test_prepare_op, test_prepare_data) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
......
......@@ -195,7 +195,7 @@ TEST(test_tracer, test_track_backward_input) {
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL);
}
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
// Doing an mul
imperative::Tracer tracer;
......@@ -521,7 +521,7 @@ static void TestVarOpDestructionMain(const platform::Place& place,
TEST(test_tracer, test_var_op_destruction) {
TestVarOpDestructionMain(platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TestVarOpDestructionMain(platform::CUDAPlace(0));
#endif
}
......
......@@ -201,7 +201,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::SetExpectedPlace(platform::Place place) {
// NOTE(wangxi): set device id before launch device kernel
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::SetDeviceId(BOOST_GET_CONST(platform::CUDAPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册