diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 7a371af510b8050aec3708d82923c707fd9d3a90..77e94e998c4db14cac9c4b2cb3136f1a6b37d5c6 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,6 +5,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) @@ -72,7 +73,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) diff --git a/paddle/fluid/framework/details/fetch_barrier_op_handle.cc b/paddle/fluid/framework/details/fetch_barrier_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..019ecfbb61028537692c8fdeb874c6c490f75430 --- /dev/null +++ b/paddle/fluid/framework/details/fetch_barrier_op_handle.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h" + +#include + +namespace paddle { +namespace framework { +namespace details { +FetchBarrierOpHandle::FetchBarrierOpHandle( + ir::Node *node, const std::vector &local_scopes, + const std::vector &places) + // fetch_barrier op always run on place0, but output on all places. + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(*node->Op())), + local_scopes_(local_scopes), + places_(places), + run_scope_(local_scopes[0]), + place_(places[0]) { + for (auto &p : places) { + this->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); + } +} + +bool FetchBarrierOpHandle::IsMultiDeviceTransfer() { + // override IsMultiDeviceTransfer to return true + return true; +} + +void FetchBarrierOpHandle::RunImpl() { + WaitInputVarGenerated(place_); + + auto run_func = [this]() { + op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get(), place_); + }; + + if (is_lock_and_record_event_free_) { + run_func(); + } else { + this->RunAndRecordEvent(run_func); + } +} + +bool FetchBarrierOpHandle::NeedWait(VarHandleBase *in_var) { + bool need_wait = + in_var && in_var->GeneratedOp() && + in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_.at(place_); + return need_wait; +} + +std::string FetchBarrierOpHandle::Name() const { return op_->Type(); } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_barrier_op_handle.h b/paddle/fluid/framework/details/fetch_barrier_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..b4d12785e0345c887f179bc53c8446dc1438f889 --- /dev/null +++ b/paddle/fluid/framework/details/fetch_barrier_op_handle.h @@ -0,0 +1,61 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +namespace details { + +// **NOTE**: fetch_barrier op is special it outputs all recved variables on +// all places if there are multiple places, must init with +// multiple dev_ctxes_ !!!! + +struct FetchBarrierOpHandle : public OpHandleBase { + public: + FetchBarrierOpHandle(ir::Node *node, const std::vector &local_scopes, + const std::vector &places); + + bool IsMultiDeviceTransfer() override; + + std::string Name() const override; + + protected: + void RunImpl() override; + + bool NeedWait(VarHandleBase *in_var) override; + + private: + std::unique_ptr op_; + std::vector local_scopes_; + std::vector places_; + Scope *run_scope_; + platform::Place place_; + + bool is_lock_and_record_event_free_{false}; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 125dbf746c3880e142af4d4bffd3ccda8654c0a1..253cf5b4a8221ad6a1f0c70f2bebccb589a5668e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h" #include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" @@ -851,9 +852,17 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", node->Op()->Type()); - result->Get(kGraphOps).emplace_back(new RPCOpHandle( - result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], - node->Op()->Type(), places_[op_dev_id])); + + // Create fetch_barrier op handle to enable output on all devices. + // **NOTE** fetch_barrier should output variables list same as recv op does. + if (node->Op()->Type() == "fetch_barrier") { + result->Get(kGraphOps).emplace_back(new FetchBarrierOpHandle( + result->CreateOpNode(node->Op()), local_scopes_, places_)); + } else { + result->Get(kGraphOps).emplace_back(new RPCOpHandle( + result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], + node->Op()->Type(), places_[op_dev_id])); + } if (node->Op()->Type() == "send") { CreateOpHandleIOs(result, node, op_dev_id); diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 158da6f606f3f5a7062a4aaed7cf7e3fe71c817a..413b14961631b3459e0d05af685ad1c5395844c2 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -55,7 +55,7 @@ void OpHandleBase::Run(bool use_cuda) { if (out_var_handle) { int dev_id = boost::get(out_var_handle->place()).device; - out_var_handle->SetGenerateEvent(events_[dev_id]); + out_var_handle->SetGenerateEvent(events_.at(dev_id)); } } } else { @@ -71,7 +71,7 @@ void OpHandleBase::Run(bool use_cuda) { "The place of input(%s) is not consistent with the " "place of current op(%s).", out_var_handle->Name(), Name()); - out_var_handle->SetGenerateEvent(events_[dev_id]); + out_var_handle->SetGenerateEvent(events_.at(dev_id)); } } }