提交 62af10d4 编写于 作者: Y Yancey1989

support multiple devices

上级 274df85c
......@@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
......@@ -26,7 +27,7 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h"
......@@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
}
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const {
if (send_op == nullptr) {
......@@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return false;
};
if (op.Type() == "split") {
if (op.Type() == "split" || op.Type() == "split_byref") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
......@@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return false;
}
bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const {
for (auto &name : op.OutputNames()) {
if (name == "RPCClient") {
return true;
}
}
return false;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types;
......@@ -133,10 +143,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
// append send op if program is distributed trainer main program.
if (IsRPCOp(*op)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateSendOp(&result, *op);
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
......@@ -203,9 +213,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
PrintGraphviz(*graph, sout);
VLOG(10) << sout.str();
std::string filename = "/tmp/graph";
std::ofstream fout(filename);
PrintGraphviz(*graph, fout);
}
return std::unique_ptr<SSAGraph>(graph);
......@@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var;
}
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
const OpDesc &op) const {
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result,
std::string op_name) const {
for (auto &prev_op : result->ops_) {
if (prev_op->Name() == op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
result->ops_.back().get()->AddInput(dep_var);
}
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0];
auto *s = local_scopes_[0];
VLOG(3) << "create rpc op: " << op.Type();
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") {
ConnectOp(result, "send_vars");
} else if (op.Type() == "recv") {
ConnectOp(result, "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, "recv");
} else if (op.Type() == "send" || op.Type() == "send_vars") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in [send,"
"send_vars, send_barrier. recv, fetch_barrier]");
}
// FIXME(wuyi): send op always copy from GPU 0
result->ops_.emplace_back(new SendOpHandle(op, s, p));
// result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(result, op, 0);
......
......@@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
bool IsRPCOp(const OpDesc &op) const;
void ConnectOp(SSAGraph *result, std::string op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/rpc_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const platform::Place &place,
const std::string &name)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place),
name_(name) {}
void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
if (in->generated_op_) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
op_->Run(*tmp_scope, place_);
}
std::string RPCOpHandle::Name() const { return name_; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place, const std::string& name);
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
const std::string name_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <typeindex>
#include <typeinfo>
......@@ -38,6 +39,7 @@ class Variable {
template <typename T>
T* GetMutable() {
std::unique_lock<std::mutex> lock(mutex_);
if (!IsType<T>()) {
holder_.reset(new PlaceholderImpl<T>(new T()));
}
......@@ -90,6 +92,7 @@ class Variable {
// by its address but not the unreadable name.
friend class Scope;
const std::string* name_;
std::mutex mutex_;
};
} // namespace framework
......
......@@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
this] {
......@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
this] {
......@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {
......@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
}
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
const auto ch = GetChannel(ep, ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
......@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
}
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
const auto ch = GetChannel(ep, ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
......@@ -243,12 +243,19 @@ bool RPCClient::Proceed() {
delete c;
return true;
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto it = channels_.find(ep);
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
const std::string& key) {
VLOG(3) << "this addr: " << this;
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(key);
if (it != channels_.end()) {
VLOG(3) << "find ep: " << ep;
return it->second;
}
VLOG(3) << "can not find ep: " << ep;
for (auto it = channels_.begin(); it != channels_.end(); ++it) {
VLOG(3) << "ep: " << it->first;
}
grpc::ChannelArguments args;
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
......@@ -257,8 +264,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch;
channels_[key] = ch;
return ch;
}
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
......@@ -190,12 +191,14 @@ class RPCClient {
private:
bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep,
const std::string& key);
private:
grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0;
std::mutex mutex_;
};
} // namespace detail
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase {
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase {
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
......
......@@ -20,6 +20,9 @@ namespace operators {
inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) {
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false;
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname);
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
VLOG(3) << "client var addr: " << client_var;
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
VLOG(3) << "rpc_client addr: " << rpc_client;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册