diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 600c47ad5f2b368634789705be6d3c55656b1638..1bcd8412eb2d618b923bcd0557d118af62271f4a 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -3,7 +3,6 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) @@ -27,7 +26,7 @@ endif() cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 14b73b368117b4816e1aeee8bb5c73f64257c91e..25711e0e47fd9ce878f9945bcfa68ae0112137ab 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -19,7 +19,6 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" -#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" @@ -141,7 +140,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( return checker(op.OutputArgumentNames(), send_vars) || checker(op.InputArgumentNames(), recv_vars); - return false; } bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const { @@ -471,17 +469,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, ConnectOp(result, result->ops_.back().get(), "send_barrier"); } else if (op.Type() == "fetch_barrier") { ConnectOp(result, result->ops_.back().get(), "recv"); - } else if (op.Type() == "send" || op.Type() == "send_vars") { + } else if (op.Type() == "send_vars") { // do nothing } else { PADDLE_THROW( - "rpc op should be in [send," + "rpc op should be in [" "send_vars, send_barrier. recv, fetch_barrier]"); } - // FIXME(wuyi): send op always copy from GPU 0 - // Create inputs for output on original place and no ssa output - // is created for send op. + // TODO(Yancey1989): schedule rpc op on different place may + // increate throughput CreateOpHandleIOs(result, op, 0); } diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 03f53421b1d3cbe3d455d39e657b723b724f70c0..7f4da4c01de1010467d839ee5490c5e0d02d8c24 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -31,6 +31,7 @@ void RPCOpHandle::RunImpl() { // Wait input done for (auto *in : inputs_) { auto &p = static_cast(in)->place_; + // FIXME(Yancey1989): need a better solution instead of use DebugString() if (in->DebugString() == "dummy") { // HACK continue; } diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc deleted file mode 100644 index 7109659dd7001f91e7674ac7bebbe3a59794cfc0..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/send_op_handle.h" - -namespace paddle { -namespace framework { -namespace details { - -SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope, - const platform::Place &place) - : op_(framework::OpRegistry::CreateOp(op_desc)), - local_scope_(local_scope), - place_(place) {} - -void SendOpHandle::RunImpl() { - // TODO(wuyi): need further analysis whether wait VarDummyHandle. - // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - if (in->DebugString() == "dummy") { // HACK - continue; - } - if (in->generated_op_) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); - } - } - auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); - // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead - // lock. - op_->Run(*tmp_scope, place_); -} - -std::string SendOpHandle::Name() const { return "send"; } -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h deleted file mode 100644 index 2f78811fad50642b5e45776c41910df6f4cc48f6..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/send_op_handle.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/details/op_handle_base.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/scope.h" - -namespace paddle { -namespace framework { -namespace details { - -struct SendOpHandle : public OpHandleBase { - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const platform::Place& place); - - std::string Name() const override; - - // Delay and buffer nccl_all_reduce together can significantly increase - // performance. Disable this feature by returning false. - bool IsMultiDeviceTransfer() override { return false; }; - - protected: - void RunImpl() override; - - private: - std::unique_ptr op_; - const Scope* local_scope_; - const platform::Place& place_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 387e06bca6e477dabbc073763344e560715738e4..e7f87ab6f8b889a7319abb48e7f418bb87d1ca21 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -39,6 +39,7 @@ class Variable { template T* GetMutable() { + // TODO(Yancey1989): need to make Variable completely thread-safe. std::unique_lock lock(mutex_); if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); diff --git a/paddle/fluid/inference/analysis/device.h b/paddle/fluid/inference/analysis/device.h index 9fad445edecf09a45551d8db9ce530329037cebc..585c9923291e5f9cb6e50dbc4bcd28c374191048 100644 --- a/paddle/fluid/inference/analysis/device.h +++ b/paddle/fluid/inference/analysis/device.h @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#pragma once - namespace paddle { namespace inference { namespace analysis { diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 51f0d2a7427a3923f038b3d85057fa0e5c8cf6a8..4c9c7be40c143c748c12dc08c22a09ea590366a2 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -249,6 +249,7 @@ bool RPCClient::Proceed() { return true; } std::shared_ptr RPCClient::GetChannel(const std::string& ep) { + // TODO(Yancey1989): make grpc client completely thread-safe std::unique_lock lock(mutex_); auto it = channels_.find(ep); if (it != channels_.end()) { diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 1255ed4c49bbbd8c743d18c4fc1fedd6fc34ae0b..d416ba1e1fda4e7803d0b5a00cb2f7b26ce215b8 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -38,7 +38,7 @@ class RecvOp : public framework::OperatorBase { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); auto client_var_name = Output("RPCClient"); - int sync_recv = Attr("sync_recv"); + int sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -55,7 +55,7 @@ class RecvOp : public framework::OperatorBase { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - if (sync_recv) { + if (sync_mode) { PADDLE_ENFORCE(rpc_client->Wait()); } } @@ -78,7 +78,7 @@ This operator can get variables from server side. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); - AddAttr("sync_recv", + AddAttr("sync_mode", "(int, default 0)" "sync recv or async recv.") .SetDefault(0); diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 72a02f24a339ba7d36dbf58a0479e4b4e681cab3..a9de5419faadba82f92913526999d22dd4c64f3e 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -360,19 +360,6 @@ class DistributeTranspiler: ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) - #program.global_block().append_op( - # type="recv", - # inputs={}, - # outputs={"Out": recv_vars, - # "RPCClient": rpc_client_var}, - # attrs={"epmap": eplist}) - - #program.global_block().append_op( - # type="fetch_barrier", - # inputs={}, - # outputs={"RPCClient": rpc_client_var}, - # attrs={"endpoints": pserver_endpoints}) - for i, ep in enumerate(eplist): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py index 9ba3bf82161c2f105f61e87239c6f3f5477f515d..d6a68677527deb09ace0e3a23cbc093d6d7b4349 100644 --- a/python/paddle/fluid/transpiler/ps_dispatcher.py +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -41,7 +41,7 @@ class PSDispatcher(object): class HashName(PSDispatcher): """ - Hash variable names to servral endpoints + Hash variable names to several endpoints """ def __init__(self, pserver_endpoints):