提交 62af10d4 编写于 作者: Y Yancey1989

support multiple devices

上级 274df85c
...@@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h ...@@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
...@@ -26,7 +27,7 @@ endif() ...@@ -26,7 +27,7 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, ...@@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
CreateOpOutput(result, op_handle, each_var_name, p, place_id); CreateOpOutput(result, op_handle, each_var_name, p, place_id);
} }
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const { OpDesc *send_op) const {
if (send_op == nullptr) { if (send_op == nullptr) {
...@@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return false; return false;
}; };
if (op.Type() == "split") { if (op.Type() == "split" || op.Type() == "split_byref") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
} else if (op.Type() == "concat") { } else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
...@@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return false; return false;
} }
bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const {
for (auto &name : op.OutputNames()) {
if (name == "RPCClient") {
return true;
}
}
return false;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types; std::unordered_map<std::string, proto::VarType::Type> var_types;
...@@ -133,10 +143,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -133,10 +143,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") { if (IsRPCOp(*op)) {
// append send op if program is distributed trainer main program. // append rpc op if program is distributed trainer main program.
// always use the first device // always use the first device
CreateSendOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) { } else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1); CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
...@@ -203,9 +213,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -203,9 +213,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
std::ostringstream sout; std::string filename = "/tmp/graph";
PrintGraphviz(*graph, sout); std::ofstream fout(filename);
VLOG(10) << sout.str(); PrintGraphviz(*graph, fout);
} }
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
...@@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var; return var;
} }
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result,
const OpDesc &op) const { std::string op_name) const {
for (auto &prev_op : result->ops_) {
if (prev_op->Name() == op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
result->ops_.back().get()->AddInput(dep_var);
}
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0]; auto &p = places_[0];
auto *s = local_scopes_[0]; auto *s = local_scopes_[0];
VLOG(3) << "create rpc op: " << op.Type();
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") {
ConnectOp(result, "send_vars");
} else if (op.Type() == "recv") {
ConnectOp(result, "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, "recv");
} else if (op.Type() == "send" || op.Type() == "send_vars") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in [send,"
"send_vars, send_barrier. recv, fetch_barrier]");
}
// FIXME(wuyi): send op always copy from GPU 0 // FIXME(wuyi): send op always copy from GPU 0
result->ops_.emplace_back(new SendOpHandle(op, s, p)); // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
// Create inputs for output on original place and no ssa output // Create inputs for output on original place and no ssa output
// is created for send op. // is created for send op.
CreateOpHandleIOs(result, op, 0); CreateOpHandleIOs(result, op, 0);
......
...@@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
*/ */
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
bool IsRPCOp(const OpDesc &op) const;
void ConnectOp(SSAGraph *result, std::string op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const; size_t num_places) const;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/rpc_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const platform::Place &place,
const std::string &name)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place),
name_(name) {}
void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
if (in->generated_op_) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
op_->Run(*tmp_scope, place_);
}
std::string RPCOpHandle::Name() const { return name_; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place, const std::string& name);
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
const std::string name_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
...@@ -38,6 +39,7 @@ class Variable { ...@@ -38,6 +39,7 @@ class Variable {
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
std::unique_lock<std::mutex> lock(mutex_);
if (!IsType<T>()) { if (!IsType<T>()) {
holder_.reset(new PlaceholderImpl<T>(new T())); holder_.reset(new PlaceholderImpl<T>(new T()));
} }
...@@ -90,6 +92,7 @@ class Variable { ...@@ -90,6 +92,7 @@ class Variable {
// by its address but not the unreadable name. // by its address but not the unreadable name.
friend class Scope; friend class Scope;
const std::string* name_; const std::string* name_;
std::mutex mutex_;
}; };
} // namespace framework } // namespace framework
......
...@@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
this] { this] {
...@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
this] { this] {
...@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const std::string in_var_name_val = in_var_name; const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name; const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] { time_out, ch, this] {
...@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
} }
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep, ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
...@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
} }
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep, ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
...@@ -243,12 +243,19 @@ bool RPCClient::Proceed() { ...@@ -243,12 +243,19 @@ bool RPCClient::Proceed() {
delete c; delete c;
return true; return true;
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { const std::string& key) {
auto it = channels_.find(ep); VLOG(3) << "this addr: " << this;
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(key);
if (it != channels_.end()) { if (it != channels_.end()) {
VLOG(3) << "find ep: " << ep;
return it->second; return it->second;
} }
VLOG(3) << "can not find ep: " << ep;
for (auto it = channels_.begin(); it != channels_.end(); ++it) {
VLOG(3) << "ep: " << it->first;
}
grpc::ChannelArguments args; grpc::ChannelArguments args;
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
...@@ -257,8 +264,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -257,8 +264,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto ch = auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[key] = ch;
channels_[ep] = ch;
return ch; return ch;
} }
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -190,12 +191,14 @@ class RPCClient { ...@@ -190,12 +191,14 @@ class RPCClient {
private: private:
bool Proceed(); bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep,
const std::string& key);
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_; std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0; int64_t req_count_ = 0;
std::mutex mutex_;
}; };
} // namespace detail } // namespace detail
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.", "Can not find variable '%s' in the scope.",
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase { ...@@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient"); auto client_var_name = Output("RPCClient");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.", "Can not find variable '%s' in the scope.",
client_var_name); client_var_name);
auto* client_var = scope.FindVar(client_var_name); auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.", "Can not find variable '%s' in the scope.",
......
...@@ -20,6 +20,9 @@ namespace operators { ...@@ -20,6 +20,9 @@ namespace operators {
inline bool NeedSend(const framework::Scope& scope, inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) { const std::string& varname) {
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false;
auto* var = scope.FindVar(varname); auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname); varname);
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.", "Can not find variable '%s' in the scope.",
client_var_name); client_var_name);
auto* client_var = scope.FindVar(client_var_name); auto* client_var = scope.FindVar(client_var_name);
VLOG(3) << "client var addr: " << client_var;
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
VLOG(3) << "rpc_client addr: " << rpc_client;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册