提交 82d2903b 编写于 作者: C chengduozh

Fix fast ParallelExe bug

test=develop
上级 dcfb6875
......@@ -49,6 +49,8 @@ struct VarHandleBase {
void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) {
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr",
this->Node()->Name());
pending_ops_.insert(out);
node_->outputs.push_back(node);
}
......
......@@ -299,6 +299,12 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
ParallelExecutor::~ParallelExecutor() {
const auto dev_ctxs =
platform::DeviceContextPool::Instance().GetAllDeviceContexts();
for (auto &dev_ctx : dev_ctxs) {
dev_ctx->Wait();
}
if (member_->own_local_scope_) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
Scope *local_scope = member_->local_scopes_[i];
......
......@@ -35,6 +35,16 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
return it->second.get();
}
const std::vector<const DeviceContext*>
DeviceContextPool::GetAllDeviceContexts() const {
std::vector<const DeviceContext*> all_device_ctx;
all_device_ctx.reserve(device_contexts_.size());
for (auto& dev_ctx : device_contexts_) {
all_device_ctx.emplace_back(dev_ctx.second.get());
}
return all_device_ctx;
}
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
......
......@@ -217,6 +217,9 @@ class DeviceContextPool {
/*! \brief Return handle of single device context. */
platform::DeviceContext* Get(const platform::Place& place);
/*! \brief Return all the device contexts. */
const std::vector<const DeviceContext*> GetAllDeviceContexts() const;
template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
const Place& place) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册