未验证 提交 01e2874a 编写于 作者: S ShenLiang 提交者: GitHub

Support multi-stream communication for dynamic graph distributed (#29525)

* fix fleet for multi-stream

* fix memcpy for ncclid

* use sync to solve move operation
上级 f350aa59
...@@ -16,19 +16,27 @@ ...@@ -16,19 +16,27 @@
#include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/imperative/all_reduce.h"
#include <string>
#include <utility>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static const platform::Place &GetVarPlace(const framework::Variable &src) {
if (src.IsType<framework::LoDTensor>()) {
return src.Get<framework::LoDTensor>().place();
#if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<framework::SelectedRows>()) {
return src.Get<framework::SelectedRows>().value().place();
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get unsupported variable type %s for imperative allreduce, "
"only "
"LoDTensor and SelectedRows are supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}
static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
const ParallelStrategy &strategy, cudaStream_t stream) { const cudaStream_t stream,
const platform::NCCLComm *comm) {
const auto &place = src.place(); const auto &place = src.place();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), true, platform::is_gpu_place(place), true,
...@@ -36,23 +44,20 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, ...@@ -36,23 +44,20 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
"Imperative mode does not support multi-CPU training yet.")); "Imperative mode does not support multi-CPU training yet."));
const void *src_ptr = src.data<void>(); const void *src_ptr = src.data<void>();
dst->Resize(src.dims()); dst->Resize(src.dims());
auto *dst_ptr = dst->mutable_data(src.place(), src.type()); auto *dst_ptr = dst->mutable_data(src.place(), src.type());
auto nccl_dtype = platform::ToNCCLDataType(src.type()); auto nccl_dtype = platform::ToNCCLDataType(src.type());
auto comm = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place))
->nccl_comm();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
src_ptr, dst_ptr, src.numel(), nccl_dtype, ncclSum, comm, stream)); src_ptr, dst_ptr, src.numel(), nccl_dtype, ncclSum, comm->comm(),
stream));
} }
#if NCCL_VERSION_CODE >= 2212 #if NCCL_VERSION_CODE >= 2212
static void AllReduce(const framework::SelectedRows &src, static void AllReduce(const framework::SelectedRows &src,
framework::SelectedRows *dst, framework::SelectedRows *dst,
const ParallelStrategy &strategy, cudaStream_t stream) { const ParallelStrategy &strategy,
const cudaStream_t stream,
const platform::NCCLComm *comm) {
VLOG(3) << "SelectedRows AllReduce start"; VLOG(3) << "SelectedRows AllReduce start";
const auto &src_tensor = src.value(); const auto &src_tensor = src.value();
const auto &place = src_tensor.place(); const auto &place = src_tensor.place();
...@@ -65,7 +70,8 @@ static void AllReduce(const framework::SelectedRows &src, ...@@ -65,7 +70,8 @@ static void AllReduce(const framework::SelectedRows &src,
auto nccl_dtype = platform::ToNCCLDataType(dtype); auto nccl_dtype = platform::ToNCCLDataType(dtype);
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>( auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
auto comm = dev_ctx->nccl_comm();
bool use_calc_stream = (dev_ctx->stream() == stream);
// 1. Gather rows number from all workers. Here use ncclAllGather to do this, // 1. Gather rows number from all workers. Here use ncclAllGather to do this,
// but we can use other ways to implement is in the future // but we can use other ways to implement is in the future
...@@ -74,12 +80,14 @@ static void AllReduce(const framework::SelectedRows &src, ...@@ -74,12 +80,14 @@ static void AllReduce(const framework::SelectedRows &src,
rows_num_vector[strategy.local_rank_] = static_cast<int64_t>(src_rows.size()); rows_num_vector[strategy.local_rank_] = static_cast<int64_t>(src_rows.size());
// CUDAMutableData use CalStream // CUDAMutableData use CalStream
auto *gpu_rows_num_ptr = rows_num_vector.CUDAMutableData(place); auto *gpu_rows_num_ptr = rows_num_vector.CUDAMutableData(place);
if (stream != dev_ctx->stream()) dev_ctx->Wait(); if (!use_calc_stream) {
dev_ctx->Wait();
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
gpu_rows_num_ptr + strategy.local_rank_, gpu_rows_num_ptr, 1, ncclInt64, gpu_rows_num_ptr + strategy.local_rank_, gpu_rows_num_ptr, 1, ncclInt64,
comm, stream)); comm->comm(), stream));
if (stream != dev_ctx->stream()) { if (!use_calc_stream) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
} }
...@@ -108,19 +116,21 @@ static void AllReduce(const framework::SelectedRows &src, ...@@ -108,19 +116,21 @@ static void AllReduce(const framework::SelectedRows &src,
auto sizeof_dtype = framework::SizeOfType(dtype); auto sizeof_dtype = framework::SizeOfType(dtype);
int64_t row_offset = 0; int64_t row_offset = 0;
if (stream != dev_ctx->stream()) dev_ctx->Wait(); if (!use_calc_stream) {
dev_ctx->Wait();
}
for (int i = 0; i < strategy.nranks_; ++i) { for (int i = 0; i < strategy.nranks_; ++i) {
if (cpu_rows_num_ptr[i] > 0) { if (cpu_rows_num_ptr[i] > 0) {
// 2. Broadcast the rows of SelectedRows // 2. Broadcast the rows of SelectedRows
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i], src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i],
ncclInt64, i, comm, stream)); ncclInt64, i, comm->comm(), stream));
// 3. Broadcast the tensor data of SelectedRows // 3. Broadcast the tensor data of SelectedRows
auto *dst_tensor_ptr_i = reinterpret_cast<uint8_t *>(dst_tensor_ptr) + auto *dst_tensor_ptr_i = reinterpret_cast<uint8_t *>(dst_tensor_ptr) +
row_offset * feature_size * sizeof_dtype; row_offset * feature_size * sizeof_dtype;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
src_tensor_ptr, dst_tensor_ptr_i, cpu_rows_num_ptr[i] * feature_size, src_tensor_ptr, dst_tensor_ptr_i, cpu_rows_num_ptr[i] * feature_size,
nccl_dtype, i, comm, stream)); nccl_dtype, i, comm->comm(), stream));
row_offset += cpu_rows_num_ptr[i]; row_offset += cpu_rows_num_ptr[i];
} }
} }
...@@ -133,13 +143,21 @@ static void AllReduce(const framework::SelectedRows &src, ...@@ -133,13 +143,21 @@ static void AllReduce(const framework::SelectedRows &src,
#endif #endif
void AllReduce(const framework::Variable &src, framework::Variable *dst, void AllReduce(const framework::Variable &src, framework::Variable *dst,
const ParallelStrategy &strategy, cudaStream_t stream) { const ParallelStrategy &strategy, int ring_id,
bool use_calc_stream) {
const auto &place = GetVarPlace(src);
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
platform::NCCLComm *comm =
platform::NCCLCommContext::Instance().Get(ring_id, place);
cudaStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream());
if (src.IsType<framework::LoDTensor>()) { if (src.IsType<framework::LoDTensor>()) {
if (!dst->IsType<framework::LoDTensor>()) { if (!dst->IsType<framework::LoDTensor>()) {
dst->Clear(); dst->Clear();
} }
AllReduce(src.Get<framework::LoDTensor>(), AllReduce(src.Get<framework::LoDTensor>(),
dst->GetMutable<framework::LoDTensor>(), strategy, stream); dst->GetMutable<framework::LoDTensor>(), stream, comm);
#if NCCL_VERSION_CODE >= 2212 #if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<framework::SelectedRows>()) { } else if (src.IsType<framework::SelectedRows>()) {
if (&src != dst) { if (&src != dst) {
...@@ -147,13 +165,16 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst, ...@@ -147,13 +165,16 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst,
dst->Clear(); dst->Clear();
} }
AllReduce(src.Get<framework::SelectedRows>(), AllReduce(src.Get<framework::SelectedRows>(),
dst->GetMutable<framework::SelectedRows>(), strategy, stream); dst->GetMutable<framework::SelectedRows>(), strategy, stream,
comm);
} else { } else {
// SelectedRows cannot be allreduce in-place // SelectedRows cannot be allreduce in-place
framework::Variable tmp_dst; framework::Variable tmp_dst;
AllReduce(src.Get<framework::SelectedRows>(), AllReduce(src.Get<framework::SelectedRows>(),
tmp_dst.GetMutable<framework::SelectedRows>(), strategy, tmp_dst.GetMutable<framework::SelectedRows>(), strategy, stream,
stream); comm);
// stream must synchronize to ensure accuracy of the move operation
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
*dst = std::move(tmp_dst); *dst = std::move(tmp_dst);
} }
#endif #endif
...@@ -165,33 +186,9 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst, ...@@ -165,33 +186,9 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst,
} }
} }
static const platform::Place &GetVarPlace(const framework::Variable &src) {
if (src.IsType<framework::LoDTensor>()) {
return src.Get<framework::LoDTensor>().place();
#if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<framework::SelectedRows>()) {
return src.Get<framework::SelectedRows>().value().place();
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get unsupported variable type %s for imperative allreduce, "
"only "
"LoDTensor and SelectedRows are supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}
void AllReduce(const framework::Variable &src, framework::Variable *dst, void AllReduce(const framework::Variable &src, framework::Variable *dst,
const ParallelStrategy &strategy) { const ParallelStrategy &strategy) {
const auto &place = GetVarPlace(src); AllReduce(src, dst, strategy, 0, true);
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), true,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream();
AllReduce(src, dst, strategy, stream);
} }
} // namespace imperative } // namespace imperative
......
...@@ -19,11 +19,17 @@ ...@@ -19,11 +19,17 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <nccl.h> #include <nccl.h>
#include <string>
#include <utility>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -40,7 +46,8 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst, ...@@ -40,7 +46,8 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst,
const ParallelStrategy &strategy); const ParallelStrategy &strategy);
void AllReduce(const framework::Variable &src, framework::Variable *dst, void AllReduce(const framework::Variable &src, framework::Variable *dst,
const ParallelStrategy &strategy, cudaStream_t stream); const ParallelStrategy &strategy, int ring_id,
bool use_calc_stream);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
......
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
void NCCLParallelContext::RecvNCCLID(const std::string &ep, void NCCLParallelContext::RecvNCCLID(
ncclUniqueId *nccl_id) { const std::string &ep,
std::vector<ncclUniqueId> &nccl_ids) { // NOLINT
int nrings = nccl_ids.size();
auto addr = paddle::string::Split(ep, ':'); auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
addr.size(), 2UL, addr.size(), 2UL,
...@@ -85,14 +87,16 @@ void NCCLParallelContext::RecvNCCLID(const std::string &ep, ...@@ -85,14 +87,16 @@ void NCCLParallelContext::RecvNCCLID(const std::string &ep,
} }
VLOG(3) << "recevived the ncclUniqueId"; VLOG(3) << "recevived the ncclUniqueId";
memcpy(nccl_id, buffer, NCCL_UNIQUE_ID_BYTES);
memcpy(&nccl_ids[0], buffer, nrings * NCCL_UNIQUE_ID_BYTES);
VLOG(3) << "closing the socket server: " << ep; VLOG(3) << "closing the socket server: " << ep;
close(server_fd); close(server_fd);
} }
void NCCLParallelContext::SendNCCLID(const std::string &ep, void NCCLParallelContext::SendNCCLID(
ncclUniqueId *nccl_id) { const std::string &ep, const std::vector<ncclUniqueId> &nccl_ids) {
int nrings = nccl_ids.size();
auto addr = paddle::string::Split(ep, ':'); auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
addr.size(), 2UL, addr.size(), 2UL,
...@@ -100,12 +104,12 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep, ...@@ -100,12 +104,12 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep,
"The endpoint should contain host and port, but got %s.", ep)); "The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0]; std::string host = addr[0];
int port = std::stoi(addr[1]); int port = std::stoi(addr[1]);
// struct sockaddr_in address;
int sock = 0; int sock = 0;
struct sockaddr_in serv_addr; struct sockaddr_in serv_addr;
char buffer[1024] = {0}; char buffer[1024] = {0};
memcpy(buffer, nccl_id, NCCL_UNIQUE_ID_BYTES); memcpy(buffer, &nccl_ids[0], nrings * NCCL_UNIQUE_ID_BYTES);
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
PADDLE_THROW(platform::errors::Unavailable("Create socket failed.")); PADDLE_THROW(platform::errors::Unavailable("Create socket failed."));
} }
...@@ -149,40 +153,46 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep, ...@@ -149,40 +153,46 @@ void NCCLParallelContext::SendNCCLID(const std::string &ep,
continue; continue;
} }
VLOG(3) << "sending the ncclUniqueId to " << ep; VLOG(3) << "sending the ncclUniqueId to " << ep;
send(sock, buffer, NCCL_UNIQUE_ID_BYTES, 0); send(sock, buffer, NCCL_UNIQUE_ID_BYTES * nrings, 0);
break; break;
} }
close(sock); close(sock);
} }
void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) { void NCCLParallelContext::BcastNCCLId(
std::vector<ncclUniqueId> &nccl_ids, // NOLINT
int root) {
if (strategy_.local_rank_ == root) { if (strategy_.local_rank_ == root) {
for (auto ep : strategy_.trainer_endpoints_) { for (auto ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) SendNCCLID(ep, nccl_id); if (ep != strategy_.current_endpoint_) SendNCCLID(ep, nccl_ids);
} }
} else { } else {
RecvNCCLID(strategy_.current_endpoint_, nccl_id); RecvNCCLID(strategy_.current_endpoint_, nccl_ids);
} }
} }
void NCCLParallelContext::Init() { void NCCLParallelContext::Init() {
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { std::vector<ncclUniqueId> nccl_ids;
ncclUniqueId nccl_id; nccl_ids.resize(strategy_.nrings_);
if (strategy_.local_rank_ == 0) { if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker // generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_id); for (size_t i = 0; i < nccl_ids.size(); ++i) {
BcastNCCLId(&nccl_id, 0); platform::dynload::ncclGetUniqueId(&nccl_ids[i]);
} else {
BcastNCCLId(&nccl_id, 0);
} }
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; BcastNCCLId(nccl_ids, 0);
} else {
BcastNCCLId(nccl_ids, 0);
}
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) {
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id; << " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id // it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm( platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id); &nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id,
ring_id);
} }
} }
...@@ -193,15 +203,7 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, ...@@ -193,15 +203,7 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src,
platform::is_gpu_place(place_), true, platform::is_gpu_place(place_), true,
platform::errors::Unimplemented( platform::errors::Unimplemented(
"Dynamic graph mode does not support multi-CPU training yet.")); "Dynamic graph mode does not support multi-CPU training yet."));
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place_); AllReduce(src, dst, strategy_, ring_id, use_calc_stream);
cudaStream_t stream = nullptr;
if (use_calc_stream) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place_);
stream = static_cast<platform::CUDADeviceContext *>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
AllReduce(src, dst, strategy_, stream);
} }
paddle::platform::CUDADeviceContext *NCCLParallelContext::GetDeviceContext( paddle::platform::CUDADeviceContext *NCCLParallelContext::GetDeviceContext(
......
...@@ -73,6 +73,8 @@ class ParallelContext { ...@@ -73,6 +73,8 @@ class ParallelContext {
int ring_id) = 0; int ring_id) = 0;
#endif #endif
inline int GetNRings() { return strategy_.nrings_; }
protected: protected:
ParallelStrategy strategy_; ParallelStrategy strategy_;
platform::Place place_; platform::Place place_;
...@@ -87,7 +89,7 @@ class NCCLParallelContext : public ParallelContext { ...@@ -87,7 +89,7 @@ class NCCLParallelContext : public ParallelContext {
~NCCLParallelContext() {} ~NCCLParallelContext() {}
void BcastNCCLId(ncclUniqueId* nccl_id, int root); void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root); // NOLINT
void Init() override; void Init() override;
...@@ -98,9 +100,11 @@ class NCCLParallelContext : public ParallelContext { ...@@ -98,9 +100,11 @@ class NCCLParallelContext : public ParallelContext {
paddle::platform::CUDADeviceContext* GetDeviceContext(int ring_id) override; paddle::platform::CUDADeviceContext* GetDeviceContext(int ring_id) override;
protected: protected:
void RecvNCCLID(const std::string& endpoint, ncclUniqueId* nccl_id); void RecvNCCLID(const std::string& endpoint,
std::vector<ncclUniqueId>& nccl_ids); // NOLINT
void SendNCCLID(const std::string& endpoint, ncclUniqueId* nccl_id); void SendNCCLID(const std::string& endpoint,
const std::vector<ncclUniqueId>& nccl_ids);
}; };
#endif #endif
......
...@@ -68,7 +68,7 @@ void Group::SplitTensors(const platform::CUDADeviceContext &context) { ...@@ -68,7 +68,7 @@ void Group::SplitTensors(const platform::CUDADeviceContext &context) {
std::ostream &operator<<(std::ostream &out, const Group &group) { std::ostream &operator<<(std::ostream &out, const Group &group) {
const auto &vars = group.variable_indices_; const auto &vars = group.variable_indices_;
out << "numul: " << group.all_length_ << " ;is_sparse: " << group.is_sparse_ out << "numel: " << group.all_length_ << " ;is_sparse: " << group.is_sparse_
<< " ;var number: " << vars.size() << "\n"; << " ;var number: " << vars.size() << "\n";
auto begin = vars.begin(); auto begin = vars.begin();
auto end = vars.end(); auto end = vars.end();
...@@ -95,6 +95,7 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, ...@@ -95,6 +95,7 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
parallel_ctx_(parallel_ctx), parallel_ctx_(parallel_ctx),
group_size_limits_(group_size_limits) { group_size_limits_(group_size_limits) {
VLOG(3) << "Start construct the Reducer ..."; VLOG(3) << "Start construct the Reducer ...";
nrings_ = parallel_ctx->GetNRings();
// initialize groups // initialize groups
InitializeGroups(group_indices); InitializeGroups(group_indices);
for (size_t global_var_index = 0; global_var_index < vars_.size(); for (size_t global_var_index = 0; global_var_index < vars_.size();
...@@ -109,11 +110,13 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, ...@@ -109,11 +110,13 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
compute_stream_ = static_cast<platform::CUDADeviceContext *>( compute_stream_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_)) platform::DeviceContextPool::Instance().Get(place_))
->stream(); ->stream();
comm_stream_ = platform::NCCLCommContext::Instance().Get(0, place_)->stream(); for (int i = 0; i < nrings_; ++i) {
// create events comm_streams_.emplace_back(
platform::NCCLCommContext::Instance().Get(i, place_)->stream());
comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
}
CreateGroupEvents(group_indices.size()); CreateGroupEvents(group_indices.size());
comm_enent_ = platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device);
std::call_once(once_flag_, []() { std::call_once(once_flag_, []() {
std::atexit([]() { Reducer::GetInstance()->ReleaseReducer(); }); std::atexit([]() { Reducer::GetInstance()->ReleaseReducer(); });
...@@ -121,20 +124,22 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, ...@@ -121,20 +124,22 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
} }
void Reducer::ReleaseReducer() { void Reducer::ReleaseReducer() {
for (auto &event : events_) { for (auto &event : group_events_) {
event.reset();
}
for (auto &event : comm_events_) {
event.reset(); event.reset();
} }
comm_enent_.reset();
} }
void Reducer::CreateGroupEvents(int group_num) { void Reducer::CreateGroupEvents(int group_num) {
// release old events // release old events
for (auto &event : events_) { for (auto &event : group_events_) {
event.reset(); event.reset();
} }
events_.clear(); group_events_.clear();
events_.resize(group_num); group_events_.resize(group_num);
for (auto &event : events_) { for (auto &event : group_events_) {
event = platform::CudaEventResourcePool::Instance().New( event = platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device); BOOST_GET_CONST(platform::CUDAPlace, place_).device);
} }
...@@ -194,7 +199,7 @@ void Reducer::InitializeDenseGroups( ...@@ -194,7 +199,7 @@ void Reducer::InitializeDenseGroups(
// Each parameter will be initialized according to the group information. // Each parameter will be initialized according to the group information.
// For the sparse parameter, sparse_contents_ in the group directly points // For the sparse parameter, sparse_contents_ in the group directly points
// to the parameter. For dense parameters, first construct an empty Tensor(). // to the parameter. For dense parameters, first construct an empty Tensor().
// Then specify the actual memory in MarkVariableReady. // Then specify the actual memory in MarkDenseVarReady.
void Reducer::InitializeGroups( void Reducer::InitializeGroups(
const std::vector<std::vector<size_t>> &group_indices) { const std::vector<std::vector<size_t>> &group_indices) {
VLOG(3) << "Start initialize groups .."; VLOG(3) << "Start initialize groups ..";
...@@ -218,7 +223,6 @@ void Reducer::InitializeGroups( ...@@ -218,7 +223,6 @@ void Reducer::InitializeGroups(
if (variable_indices_.size() == 1 && if (variable_indices_.size() == 1 &&
is_sparse_gradient_[variable_indices_.front()]) { is_sparse_gradient_[variable_indices_.front()]) {
// process the sparse gradient. one sparse, one group // process the sparse gradient. one sparse, one group
group.sparse_contents_ = first_varbase->MutableGradVar();
group.dtype_ = first_varbase->DataType(); group.dtype_ = first_varbase->DataType();
group.is_sparse_ = true; group.is_sparse_ = true;
} else { } else {
...@@ -232,7 +236,7 @@ void Reducer::InitializeGroups( ...@@ -232,7 +236,7 @@ void Reducer::InitializeGroups(
// map variables to this group by VariableLocator // map variables to this group by VariableLocator
size_t inside_group_index = 0; size_t inside_group_index = 0;
for (const auto var_index : group_indices[group_index]) { for (const auto var_index : variable_indices_) {
variable_locators_[var_index] = VariableLocator{ variable_locators_[var_index] = VariableLocator{
.group_index = group_index, .group_index = group_index,
.inside_group_index = inside_group_index++, .inside_group_index = inside_group_index++,
...@@ -260,7 +264,7 @@ void Reducer::PrepareForBackward() { ...@@ -260,7 +264,7 @@ void Reducer::PrepareForBackward() {
// Add hook function to each leaf node. When the gradient of a leaf node is // Add hook function to each leaf node. When the gradient of a leaf node is
// generated, if it is the sparse parameter, it will directly execute allreduce, // generated, if it is the sparse parameter, it will directly execute allreduce,
// if it is the dense parameter, it will execute three steps: 1, // if it is the dense parameter, it will execute three steps: 1,
// MarkVariableReady. Find the position of the corresponding group // MarkDenseVarReady. Find the position of the corresponding group
// through var_index, share the gradient memory and the group dense_tensors, // through var_index, share the gradient memory and the group dense_tensors,
// the group counter is reduced by 1. 2, MarkGroupReady: When the group // the group counter is reduced by 1. 2, MarkGroupReady: When the group
// counter is 0, it means that allreduce can be emitted, and // counter is 0, it means that allreduce can be emitted, and
...@@ -278,8 +282,11 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) { ...@@ -278,8 +282,11 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) {
if (!group.is_sparse_) { if (!group.is_sparse_) {
// Only dense_contents_ need memory copy // Only dense_contents_ need memory copy
MarkVariableReady(var_index, var_warpper); MarkDenseVarReady(var_index, var_warpper);
} else {
MarkSparseVarReady(var_index, var_warpper);
} }
if (--group.pending_ == 0) { if (--group.pending_ == 0) {
// can start allreduce // can start allreduce
MarkGroupReady(group_index); MarkGroupReady(group_index);
...@@ -290,7 +297,7 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) { ...@@ -290,7 +297,7 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper, size_t var_index) {
} }
} }
void Reducer::MarkVariableReady(size_t var_index, void Reducer::MarkDenseVarReady(size_t var_index,
VariableWrapper *var_warpper) { VariableWrapper *var_warpper) {
const auto &var_locator = variable_locators_[var_index]; const auto &var_locator = variable_locators_[var_index];
auto group_index = var_locator.group_index; auto group_index = var_locator.group_index;
...@@ -303,6 +310,14 @@ void Reducer::MarkVariableReady(size_t var_index, ...@@ -303,6 +310,14 @@ void Reducer::MarkVariableReady(size_t var_index,
{static_cast<int64_t>(length)}); {static_cast<int64_t>(length)});
} }
void Reducer::MarkSparseVarReady(size_t var_index,
VariableWrapper *var_warpper) {
const auto &var_locator = variable_locators_[var_index];
auto group_index = var_locator.group_index;
auto &group = groups_[group_index];
group.sparse_contents_ = var_warpper->MutableVar();
}
void Reducer::MarkGroupReady(size_t group_index) { void Reducer::MarkGroupReady(size_t group_index) {
if (group_index > next_group_) { if (group_index > next_group_) {
VLOG(3) << "It will adjust the order of group in next batch automatically"; VLOG(3) << "It will adjust the order of group in next batch automatically";
...@@ -310,29 +325,35 @@ void Reducer::MarkGroupReady(size_t group_index) { ...@@ -310,29 +325,35 @@ void Reducer::MarkGroupReady(size_t group_index) {
} }
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(events_[group_index].get(), compute_stream_)); cudaEventRecord(group_events_[group_index].get(), compute_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS( for (int i = 0; i < nrings_; ++i) {
cudaStreamWaitEvent(comm_stream_, events_[group_index].get(), 0)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(
comm_streams_[i], group_events_[group_index].get(), 0));
}
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) { ++next_group_) {
auto &group = groups_[next_group_]; auto &group = groups_[next_group_];
int run_order = next_group_ % nrings_;
if (group.is_sparse_) { if (group.is_sparse_) {
VLOG(3) << "sparse group [" << next_group_ << "] start allreduce..."; VLOG(3) << "sparse group [" << next_group_ << "] start allreduce in ring["
parallel_ctx_->AllReduceByStream(*group.sparse_contents_, << run_order << "]";
group.sparse_contents_, 0, false); parallel_ctx_->AllReduceByStream(
*group.sparse_contents_, group.sparse_contents_, run_order, false);
} else { } else {
VLOG(3) << "dense group [" << next_group_ << "] start allreduce..."; VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring["
<< run_order << "]";
// Select common commstream to concat tensors // Select common commstream to concat tensors
// group.dense_tensors ---> group.dense_contents_ // group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(*parallel_ctx_->GetDeviceContext(0)); group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order));
// Start allreduce // Start allreduce
parallel_ctx_->AllReduceByStream(group.dense_contents_, parallel_ctx_->AllReduceByStream(
&(group.dense_contents_), 0, false); group.dense_contents_, &(group.dense_contents_), run_order, false);
// Select common commstream to split tensors // Select common commstream to split tensors
// group.dense_contents_ ---> group.dense_tensors // group.dense_contents_ ---> group.dense_tensors
group.SplitTensors(*parallel_ctx_->GetDeviceContext(0)); group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order));
} }
} }
} }
...@@ -351,9 +372,16 @@ std::vector<std::vector<size_t>> Reducer::RebuildGruops() { ...@@ -351,9 +372,16 @@ std::vector<std::vector<size_t>> Reducer::RebuildGruops() {
} }
void Reducer::FinalizeBackward() { void Reducer::FinalizeBackward() {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(comm_enent_.get(), comm_stream_)); // Must prevent compute_stream_ starting until all comm streams have finished
PADDLE_ENFORCE_CUDA_SUCCESS( for (int i = 0; i < nrings_; ++i) {
cudaStreamWaitEvent(compute_stream_, comm_enent_.get(), 0)); PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(comm_events_[i].get(), comm_streams_[i]));
}
for (int i = 0; i < nrings_; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(compute_stream_, comm_events_[i].get(), 0));
}
if (!has_rebuilt_group_) { if (!has_rebuilt_group_) {
VLOG(3) << "Start rebuilding the groups"; VLOG(3) << "Start rebuilding the groups";
auto rebuild_group_indices = RebuildGruops(); auto rebuild_group_indices = RebuildGruops();
...@@ -362,6 +390,7 @@ void Reducer::FinalizeBackward() { ...@@ -362,6 +390,7 @@ void Reducer::FinalizeBackward() {
CreateGroupEvents(rebuild_group_number); CreateGroupEvents(rebuild_group_number);
InitializeGroups(group_indices_); InitializeGroups(group_indices_);
} }
VLOG(3) << "In the batch, Reducer is finished..."; VLOG(3) << "In the batch, Reducer is finished...";
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -133,7 +134,9 @@ class Reducer { ...@@ -133,7 +134,9 @@ class Reducer {
void AddDistHook(VariableWrapper* var_warpper, size_t var_index); void AddDistHook(VariableWrapper* var_warpper, size_t var_index);
void MarkVariableReady(size_t var_index, VariableWrapper* var_warpper); void MarkDenseVarReady(size_t var_index, VariableWrapper* var_warpper);
void MarkSparseVarReady(size_t var_index, VariableWrapper* var_warpper);
void MarkGroupReady(size_t group_index); void MarkGroupReady(size_t group_index);
...@@ -180,10 +183,11 @@ class Reducer { ...@@ -180,10 +183,11 @@ class Reducer {
std::vector<VariableLocator> variable_locators_; std::vector<VariableLocator> variable_locators_;
// Following variables are to help sync stream // Following variables are to help sync stream
std::vector<std::shared_ptr<platform::CudaEventObject>> events_; std::vector<std::shared_ptr<platform::CudaEventObject>> group_events_;
std::shared_ptr<platform::CudaEventObject> comm_enent_; std::vector<std::shared_ptr<platform::CudaEventObject>> comm_events_;
cudaStream_t compute_stream_; cudaStream_t compute_stream_;
cudaStream_t comm_stream_; std::vector<cudaStream_t> comm_streams_;
int nrings_ = 1;
// Following variables are to help rebuild group // Following variables are to help rebuild group
bool has_rebuilt_group_{false}; bool has_rebuilt_group_{false};
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
namespace imperative = paddle::imperative; namespace imperative = paddle::imperative;
namespace platform = paddle::platform; namespace platform = paddle::platform;
int nrings = 2;
imperative::ParallelStrategy GetStrategy(int local_rank) { imperative::ParallelStrategy GetStrategy(int local_rank) {
std::vector<std::string> eps = {"127.0.0.1:9866", "localhost:9867"}; std::vector<std::string> eps = {"127.0.0.1:9866", "localhost:9867"};
imperative::ParallelStrategy strategy; imperative::ParallelStrategy strategy;
...@@ -26,27 +27,38 @@ imperative::ParallelStrategy GetStrategy(int local_rank) { ...@@ -26,27 +27,38 @@ imperative::ParallelStrategy GetStrategy(int local_rank) {
strategy.current_endpoint_ = eps[local_rank]; strategy.current_endpoint_ = eps[local_rank];
strategy.nranks_ = 2; strategy.nranks_ = 2;
strategy.local_rank_ = local_rank; strategy.local_rank_ = local_rank;
strategy.nrings_ = nrings;
return strategy; return strategy;
} }
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
void BcastNCCLId(int local_rank, ncclUniqueId *nccl_id) { void BcastNCCLId(int local_rank, std::vector<ncclUniqueId>* nccl_ids) {
auto strategy = GetStrategy(local_rank); auto strategy = GetStrategy(local_rank);
platform::CUDAPlace gpu(local_rank); platform::CUDAPlace gpu(local_rank);
imperative::NCCLParallelContext ctx(strategy, gpu); imperative::NCCLParallelContext ctx(strategy, gpu);
ctx.BcastNCCLId(nccl_id, 0); ctx.BcastNCCLId(*nccl_ids, 0);
} }
TEST(BcastNCCLId, Run) { TEST(BcastNCCLId, Run) {
ncclUniqueId nccl_id; std::vector<ncclUniqueId> nccl_ids;
platform::dynload::ncclGetUniqueId(&nccl_id); nccl_ids.resize(nrings);
std::thread t(BcastNCCLId, 0, &nccl_id); for (int i = 0; i < nrings; ++i) {
platform::dynload::ncclGetUniqueId(&nccl_ids[i]);
}
ncclUniqueId recv_nccl_id; std::thread t(BcastNCCLId, 0, &nccl_ids);
BcastNCCLId(1, &recv_nccl_id);
std::vector<ncclUniqueId> recv_nccl_ids;
recv_nccl_ids.resize(nrings);
for (int i = 0; i < nrings; ++i) {
platform::dynload::ncclGetUniqueId(&recv_nccl_ids[i]);
}
BcastNCCLId(1, &recv_nccl_ids);
t.join(); t.join();
EXPECT_EQ(0, std::memcmp(nccl_id.internal, recv_nccl_id.internal, for (int i = 0; i < nrings; ++i) {
NCCL_UNIQUE_ID_BYTES)); EXPECT_EQ(0, std::memcmp(nccl_ids[i].internal, recv_nccl_ids[i].internal,
NCCL_UNIQUE_ID_BYTES));
}
} }
#endif #endif
...@@ -33,7 +33,7 @@ TEST(TestGroup, TestPrintGroupMessage) { ...@@ -33,7 +33,7 @@ TEST(TestGroup, TestPrintGroupMessage) {
std::stringstream stream1, stream2; std::stringstream stream1, stream2;
stream1 << group; stream1 << group;
ASSERT_STREQ(stream1.str().c_str(), ASSERT_STREQ(stream1.str().c_str(),
"numul: 0 ;is_sparse: 0 ;var number: 0\n[]\n"); "numel: 0 ;is_sparse: 0 ;var number: 0\n[]\n");
std::vector<size_t> vars; std::vector<size_t> vars;
size_t vars_num = 102; size_t vars_num = 102;
...@@ -44,7 +44,7 @@ TEST(TestGroup, TestPrintGroupMessage) { ...@@ -44,7 +44,7 @@ TEST(TestGroup, TestPrintGroupMessage) {
group.all_length_ = 102; group.all_length_ = 102;
group.is_sparse_ = false; group.is_sparse_ = false;
std::string head = "numul: 102 ;is_sparse: 0 ;var number: 102\n"; std::string head = "numel: 102 ;is_sparse: 0 ;var number: 102\n";
head = head + "["; head = head + "[";
auto begin = vars.begin(); auto begin = vars.begin();
auto end = vars.end(); auto end = vars.end();
......
...@@ -1261,7 +1261,13 @@ void BindImperative(py::module *m_ptr) { ...@@ -1261,7 +1261,13 @@ void BindImperative(py::module *m_ptr) {
return self.current_endpoint_; return self.current_endpoint_;
}, },
[](imperative::ParallelStrategy &self, [](imperative::ParallelStrategy &self,
const std::string &ep) { self.current_endpoint_ = ep; }); const std::string &ep) { self.current_endpoint_ = ep; })
.def_property(
"nrings",
[](const imperative::ParallelStrategy &self) { return self.nrings_; },
[](imperative::ParallelStrategy &self, int nrings) {
self.nrings_ = nrings;
});
m.def( m.def(
"dygraph_partial_grad", "dygraph_partial_grad",
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import copy import copy
import warnings import warnings
import paddle import paddle
import os
from paddle.fluid.framework import dygraph_only from paddle.fluid.framework import dygraph_only
from paddle.fluid import compiler from paddle.fluid import compiler
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
...@@ -221,6 +222,15 @@ class Fleet(object): ...@@ -221,6 +222,15 @@ class Fleet(object):
warnings.warn( warnings.warn(
"The dygraph parallel environment has been initialized.") "The dygraph parallel environment has been initialized.")
else: else:
# FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
if "FLAGS_nccl_nrings" in os.environ:
warnings.warn(
"You have set the environment variable FLAGS_nccl_nrings "
"outside the program, so the nccl_comm_num in "
"DistributedStrategy will not take effect here.")
else:
os.environ["FLAGS_nccl_nrings"] = str(
self._user_defined_strategy.nccl_comm_num)
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
def is_first_worker(self): def is_first_worker(self):
......
...@@ -166,6 +166,7 @@ def init_parallel_env(): ...@@ -166,6 +166,7 @@ def init_parallel_env():
strategy.local_rank = parallel_env.rank strategy.local_rank = parallel_env.rank
strategy.trainer_endpoints = parallel_env.trainer_endpoints strategy.trainer_endpoints = parallel_env.trainer_endpoints
strategy.current_endpoint = parallel_env.current_endpoint strategy.current_endpoint = parallel_env.current_endpoint
strategy.nrings = parallel_env.nrings
# NOTE(chenweihang): [ why config global place here? ] # NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode, # the dygraph mode will be set to default mode,
......
...@@ -114,6 +114,11 @@ class ParallelEnv(object): ...@@ -114,6 +114,11 @@ class ParallelEnv(object):
self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
"").split(",") "").split(",")
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "") self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
assert self._nrings > 0, \
"nccl_nrings must be an integer greater than 0."
assert self._nrings < 9, \
"nccl_nrings should be less than 9, which is enough in most scenarios."
@property @property
def rank(self): def rank(self):
...@@ -211,6 +216,25 @@ class ParallelEnv(object): ...@@ -211,6 +216,25 @@ class ParallelEnv(object):
""" """
return self._trainer_endpoints return self._trainer_endpoints
@property
def nrings(self):
"""
Nrings of current trainer.
Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.
Examples:
.. code-block:: python
# execute this command in terminal: export FLAGS_nccl_nrings=1
import paddle.distributed as dist
env = dist.ParallelEnv()
print("The nrings is %d" % env.nrings)
# the number of ring is 1
"""
return self._nrings
# [aliases] Compatible with old method names # [aliases] Compatible with old method names
local_rank = rank local_rank = rank
nranks = world_size nranks = world_size
...@@ -397,8 +421,8 @@ class DataParallel(layers.Layer): ...@@ -397,8 +421,8 @@ class DataParallel(layers.Layer):
else: else:
warnings.warn("The program will return to single-card operation. " warnings.warn("The program will return to single-card operation. "
"Please check 1, whether you use spawn or fleetrun " "Please check 1, whether you use spawn or fleetrun "
"to start the program. 2. Whether it is a multi-card " "to start the program. 2, Whether it is a multi-card "
"program. 3. Is the current environment multi-card.") "program. 3, Is the current environment multi-card.")
def init_reducer(self): def init_reducer(self):
layers_param = [] layers_param = []
...@@ -424,7 +448,7 @@ class DataParallel(layers.Layer): ...@@ -424,7 +448,7 @@ class DataParallel(layers.Layer):
if isinstance(sublayer, paddle.nn.layer.common.Embedding): if isinstance(sublayer, paddle.nn.layer.common.Embedding):
return sublayer._sparse return sublayer._sparse
# NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding
# is removed in the future, the judgment will also be removed here. # is removed in the future, the check will also be removed here.
if isinstance(sublayer, paddle.fluid.dygraph.Embedding): if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
return sublayer._is_sparse return sublayer._is_sparse
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册