diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index e6edcbfd19752abbb8524c71766a710b03113759..05bab5334a5737a78b984d4f9644886b0d9ff5a7 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -34,40 +34,21 @@ BroadcastOpHandle::BroadcastOpHandle(const std::vector &local_scopes, : local_scopes_(local_scopes), places_(places) {} void BroadcastOpHandle::RunImpl() { - // the input may have dummy var. - std::vector in_var_handle; - for (auto *in : inputs_) { - auto *out_handle = dynamic_cast(in); - if (out_handle) { - in_var_handle.push_back(out_handle); - } - } + // the input and output may have dummy var. + std::vector in_var_handle = GetValidVarHandles(inputs_); + std::vector out_var_handles = GetValidVarHandles(outputs_); + PADDLE_ENFORCE_EQ(in_var_handle.size(), 1, "The number of input should be one."); - - // the output may have dummy var. - std::vector out_var_handles; - for (auto *out : outputs_) { - auto *out_handle = dynamic_cast(out); - if (out_handle) { - out_var_handles.push_back(out_handle); - } - } - PADDLE_ENFORCE_EQ( out_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - // Wait input done, this Wait is asynchronous operation - auto &in_place = in_var_handle[0]->place_; - if (in_var_handle[0]->generated_op_) { - for (auto *out : out_var_handles) { - auto &out_p = out->place_; - in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]); - } - } + // Wait input done, this Wait is asynchronous operationplatform::Place + // &in_place; + WaitEvents(out_var_handles, in_var_handle); - // + auto in_place = in_var_handle[0]->place_; auto in_scope_idx = in_var_handle[0]->scope_idx_; auto in_var = local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); @@ -107,6 +88,29 @@ void BroadcastOpHandle::RunImpl() { } } +void BroadcastOpHandle::WaitEvents( + const std::vector &out_var_handles, + const std::vector &in_var_handle) { + if (in_var_handle[0]->generated_op_) { + for (auto *out : out_var_handles) { + auto &out_p = out->place_; + in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]); + } + } +} + +std::vector BroadcastOpHandle::GetValidVarHandles( + const std::vector &inputs) { + std::vector in_var_handle; + for (auto *in : inputs) { + auto *out_handle = dynamic_cast(in); + if (out_handle) { + in_var_handle.push_back(out_handle); + } + } + return in_var_handle; +} + std::string BroadcastOpHandle::Name() const { return "broadcast"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index b3292422522b64a38a50f39f04e6f0d2e15492dd..e1311aceaf978bd266105c2ac27b137ab361c96c 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -41,6 +41,12 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; + + std::vector GetValidVarHandles( + const std::vector &inputs); + + void WaitEvents(const std::vector &out_var_handles, + const std::vector &in_var_handle); }; } // namespace details diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index ae2bc9899c69bbbff02698cb6e985c8053d44f9b..df55e4dad1fc563d6d8aa3669d6da2cb0e5f2f4b 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector &local_scopes, : local_scopes_(local_scopes), places_(places) {} void GatherOpHandle::RunImpl() { - // the input may have dummy var. - std::vector in_var_handles; - for (auto *in : inputs_) { - auto *in_handle = dynamic_cast(in); - if (in_handle) { - in_var_handles.push_back(in_handle); - } - } + // the input and output may have dummy var. + std::vector in_var_handles = GetValidVarHandles(inputs_); + std::vector out_var_handles = GetValidVarHandles(outputs_); + PADDLE_ENFORCE_EQ( in_var_handles.size(), places_.size(), "The number of output should equal to the number of places."); - - // the output may have dummy var. - std::vector out_var_handles; - for (auto *out : outputs_) { - auto *out_handle = dynamic_cast(out); - if (out_handle) { - out_var_handles.push_back(out_handle); - } - } PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, "The number of output should be one."); @@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() { "The place of input and output should be the same."); // Wait input done, this Wait is asynchronous operation - for (auto *in : in_var_handles) { - if (in->generated_op_) { - in->generated_op_->Wait(dev_ctxes_[in->place_]); - } - } + WaitEvents(in_var_handles); std::vector out_rows; std::vector in_tensors; @@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() { // copy auto dev_ctx = dev_ctxes_[out_place]; - RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] { + RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] { int s = 0, e = 0; for (size_t j = 0; j < in_tensors.size(); ++j) { e += in_tensors[j].dims()[0]; @@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() { }); } +void GatherOpHandle::WaitEvents( + const std::vector &in_var_handles) { + for (auto *in : in_var_handles) { + if (in->generated_op_) { + in->generated_op_->Wait(dev_ctxes_[in->place_]); + } + } +} + +std::vector GatherOpHandle::GetValidVarHandles( + const std::vector &inputs) { + std::vector in_var_handles; + for (auto *in : inputs) { + auto *in_handle = dynamic_cast(in); + if (in_handle) { + in_var_handles.push_back(in_handle); + } + } + return in_var_handles; +} + std::string GatherOpHandle::Name() const { return "gather"; } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index 6c0231f642c05e6b558b7e2518a15e08c816fe4b..b13dc4ceb3ea2c5598ed77c09210493c40a99290 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -41,6 +41,11 @@ struct GatherOpHandle : public OpHandleBase { protected: void RunImpl() override; + + std::vector GetValidVarHandles( + const std::vector &); + + void WaitEvents(const std::vector &in_var_handles); }; } // namespace details