提交 47388020 编写于 作者: F fengjiayi

fix bugs

上级 2e320079
...@@ -20,10 +20,24 @@ namespace paddle { ...@@ -20,10 +20,24 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places) {
if (ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p);
}
}
}
#else
DataBalanceOpHandle::DataBalanceOpHandle( DataBalanceOpHandle::DataBalanceOpHandle(
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
#endif
std::string DataBalanceOpHandle::Name() const { return "data balance"; } std::string DataBalanceOpHandle::Name() const { return "data balance"; }
...@@ -104,6 +118,7 @@ void DataBalanceOpHandle::RunImpl() { ...@@ -104,6 +118,7 @@ void DataBalanceOpHandle::RunImpl() {
} }
} }
const auto &balance_plan = GetBalancePlan(device_sizes); const auto &balance_plan = GetBalancePlan(device_sizes);
for (const auto &trans : balance_plan) { for (const auto &trans : balance_plan) {
for (int data_idx = 0; data_idx < data_num; ++data_idx) { for (int data_idx = 0; data_idx < data_num; ++data_idx) {
LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]]; LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]];
......
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,8 +29,14 @@ namespace details { ...@@ -26,8 +29,14 @@ namespace details {
struct DataBalanceOpHandle : public OpHandleBase { struct DataBalanceOpHandle : public OpHandleBase {
public: public:
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes, DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
DataBalanceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> *places)
#endif
std::string Name() const override; std::string Name() const override;
......
...@@ -368,7 +368,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, ...@@ -368,7 +368,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::InsertDataBalanceOp( void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
SSAGraph *result, const std::vector<std::string> &datas) const { SSAGraph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
......
...@@ -60,6 +60,7 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { ...@@ -60,6 +60,7 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) { for (auto &dev_ctx : dev_ctxes_) {
PADDLE_ENFORCE_NOT_NULL(dev_ctx.second);
dev_ctx.second->Wait(); dev_ctx.second->Wait();
} }
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册