未验证 提交 76b32187 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix cuda dev_ctx by event, test=develop (#20553)

上级 bd99df71
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h"
#include <algorithm> #include <algorithm>
#include <utility>
#include "dgc/dgc.h" #include "dgc/dgc.h"
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
...@@ -41,11 +42,30 @@ SparseAllReduceOpHandle::SparseAllReduceOpHandle( ...@@ -41,11 +42,30 @@ SparseAllReduceOpHandle::SparseAllReduceOpHandle(
} }
} }
void SparseAllReduceOpHandle::WaitInputVarGenerated() {
#ifdef PADDLE_WITH_CUDA
for (auto &p : dev_ctxes_) {
if (platform::is_gpu_place(p.first)) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto *compute_dev_ctx =
platform::DeviceContextPool::Instance().GetByPlace(
platform::CUDAPlace(dev_id));
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(p.second);
if (compute_dev_ctx->stream() != dev_ctx->stream()) {
auto &event = events_.at(dev_id);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(event, compute_dev_ctx->stream()));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(dev_ctx->stream(), event, 0));
}
}
}
#endif
}
void SparseAllReduceOpHandle::RunImplEncoded() { void SparseAllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs()); auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs()); auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -87,6 +107,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -87,6 +107,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
PADDLE_ENFORCE(nranks_ > 1); PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
std::vector<memory::AllocationPtr> allocations;
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i]; auto &place = places_[i];
auto &in = *ins[i]; auto &in = *ins[i];
...@@ -104,7 +126,6 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -104,7 +126,6 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
int dev_id = boost::get<platform::CUDAPlace>(place).device; int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false); auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false);
auto &nccl_ctx = nccl_ctxs->at(dev_id); auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto *dev_ctx = nccl_ctxs->DevCtx(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
...@@ -112,8 +133,9 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -112,8 +133,9 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
// dgc use ncclAllGather to get all the encoded data // dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks. // so the buffer need nranks.
int buf_size = nranks_ * encode_size; int buf_size = nranks_ * encode_size;
auto tmp_ious_data = memory::Alloc(*dev_ctx, buf_size); auto tmp_ious_data = memory::Alloc(place, buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr()); void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
allocations.emplace_back(std::move(tmp_ious_data));
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size << ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
...@@ -126,6 +148,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -126,6 +148,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
}); });
} }
WaitInputVarGenerated();
NCCLAllReduceFunc(all_reduce_calls); NCCLAllReduceFunc(all_reduce_calls);
} }
......
...@@ -36,6 +36,8 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { ...@@ -36,6 +36,8 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
bool is_encoded = false, int nranks = -1); bool is_encoded = false, int nranks = -1);
std::string Name() const override; std::string Name() const override;
void WaitInputVarGenerated() override;
protected: protected:
void RunImpl() override; void RunImpl() override;
int GetKValue(const std::string &grad_name); int GetKValue(const std::string &grad_name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册