未验证 提交 d92a75be 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #10550 from Yancey1989/overlap_send_op

overlap send ops and backward ops
......@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
......@@ -26,7 +26,7 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
......@@ -28,6 +29,10 @@
#include <string>
#include <vector>
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");
namespace paddle {
namespace framework {
namespace details {
......@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
}
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
OpDesc *send_op) const {
if (send_op == nullptr) {
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send_vars") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
}
return send_vars;
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
}
return recv_vars;
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
}
......@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool {
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};
if (op.Type() == "split" || op.Type() == "split_byref") {
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
}
return false;
return checker(op.OutputArgumentNames(), send_vars) ||
checker(op.InputArgumentNames(), recv_vars);
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
......@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
// Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program);
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
......@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
// append send op if program is distributed trainer main program.
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateSendOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ !=
......@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
PrintGraphviz(*graph, sout);
VLOG(10) << sout.str();
std::ofstream fout(FLAGS_ssa_graph_path);
PrintGraphviz(*graph, fout);
}
return std::unique_ptr<SSAGraph>(graph);
......@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id);
}
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
......@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var;
}
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
op->AddInput(dep_var);
}
}
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0];
auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0
result->ops_.emplace_back(new SendOpHandle(op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send_vars");
} else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send_vars") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]");
}
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0);
}
......
......@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const;
......@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const;
......
......@@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const platform::Place &place,
const std::string &name)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
place_(place),
name_(name) {}
void SendOpHandle::RunImpl() {
void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK
continue;
}
......@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
op_->Run(*tmp_scope, place_);
}
std::string SendOpHandle::Name() const { return "send"; }
std::string RPCOpHandle::Name() const { return name_; }
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -27,9 +27,9 @@ namespace paddle {
namespace framework {
namespace details {
struct SendOpHandle : public OpHandleBase {
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place, const std::string& name);
std::string Name() const override;
......@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
const std::string name_;
};
} // namespace details
......
......@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.InEnum(
{static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward),
......
......@@ -24,6 +24,7 @@ enum class OpRole {
kForward = 0x0000,
kBackward = 0x0001,
kOptimize = 0x0002,
kRPC = 0x0003,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
......
......@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes();
int count = 0;
size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name();
++count;
......@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS();
int count = 0;
size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name();
++count;
......
......@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL)
......@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()
op_library(cross_entropy_op DEPS cross_entropy)
......
......@@ -25,6 +25,21 @@ namespace paddle {
namespace operators {
namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
}
bool RPCClient::AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
......@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
});
req_count_++;
return true;
......@@ -249,8 +263,9 @@ bool RPCClient::Proceed() {
delete c;
return true;
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
......@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch;
return ch;
}
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
......@@ -35,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace operators {
......@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class RPCClient {
public:
RPCClient() {}
static RPCClient* GetInstance();
bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
......@@ -191,11 +197,17 @@ class RPCClient {
private:
bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private:
grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0;
std::atomic<int64_t> req_count_{0};
std::mutex mutex_;
static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
};
} // namespace detail
......
......@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) {
std::string in_var_name("ids");
std::string out_var_name("out");
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
auto client = detail::RPCClient::GetInstance();
client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
class FetchBarrierOp : public framework::OperatorBase {
public:
FetchBarrierOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
};
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddComment(R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
}
};
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
ops::FetchBarrierOpShapeInference);
......@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out",
"(LoDTensor) result "
"to be fetched from parameter server")
......@@ -87,17 +79,6 @@ the parameter server and fetch result back.
}
};
class PrefetchOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class PrefetchOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(prefetch, ops::PrefetchOp,
paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker,
ops::PrefetchOpVarTypeInference,
ops::PrefetchOpShapeInference);
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase {
const platform::Place& place) const override {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_mode = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait());
}
PADDLE_ENFORCE(client_.Wait());
}
private:
mutable detail::RPCClient client_;
};
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -65,6 +70,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
}
};
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
}
};
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
SendBarrier operator
......@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
}
};
class SendBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
}
};
......@@ -98,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker,
ops::SendBarrierOpVarTypeInference,
ops::SendBarrierOpShapeInference);
......@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
Send operator
......@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
}
};
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference,
ops::SendOpShapeInference);
ops::SendOpMaker, ops::SendOpShapeInference);
......@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false, &initialized);
while (!initialized) {
}
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady();
......@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) {
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
const f::VariableNameMap &inputs = {{"X", {"x1"}}};
const f::VariableNameMap &outputs = {{"Out", {"Out"}}};
auto send_op = f::OpRegistry::CreateOp("send", inputs, outputs, attrs);
send_op->Run(scope, place);
auto in_var = scope.Var("x1");
......@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) {
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
{{"Out", {"Out"}}}, attrs);
send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
......
......@@ -20,6 +20,9 @@ namespace operators {
inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) {
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false;
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname);
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddComment(R"DOC(
Send operator
......@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
}
};
class SendVarsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -112,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpVarTypeInference,
ops::SendVarsOpShapeInference);
......@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss);
.value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC);
op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
......
......@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals())
rpc_client_var = default_main_program().global_block().create_var(
name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW)
if not get_vars:
get_vars = []
for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": get_vars,
"RPCClient": rpc_client_var},
attrs={"endpoints": endpoints,
"epmap": epmap})
outputs={"Out": get_vars},
attrs={
"endpoints": endpoints,
"epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
})
return get_vars
......
......@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
def test_transpiler(self):
trainer = self.get_trainer()
pserver, startup = self.get_pserver(self.current_pserver_ep)
self.assertEqual([op.type for op in trainer.global_block().ops],
self.get_expect_trainer_ops())
......@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
optimize_ops, params_grads = self.net_conf()
delete_ops(trainer.global_block(), optimize_ops)
return [op.type for op in trainer.global_block().ops
] + ["split_byref", "send", "concat"]
ops = [op.type for op in trainer.global_block().ops] + [
"split_byref", "send_vars", "send_barrier", "recv", "recv",
"fetch_barrier", "concat"
]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
return ops
def get_trainer(self):
return self._transpiler_instance().get_trainer_program()
......
......@@ -16,8 +16,9 @@ from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler
from ps_dispatcher import HashName, RoundRobin
__all__ = [
"DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler",
"memory_optimize", "release_memory"
"memory_optimize", "release_memory", "HashName", "RoundRobin"
]
......@@ -16,7 +16,7 @@ from __future__ import print_function
import math
import distributed_splitter as splitter
from ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework
from ..framework import Program, default_main_program, \
default_startup_program, \
......@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
class VarBlock:
......@@ -149,13 +151,27 @@ def delete_ops(block, ops):
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
class DistributeTranspiler:
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=splitter.round_robin,
split_method=RoundRobin,
sync_mode=True):
"""
Transpile the program to distributed data-parallelism programs.
......@@ -196,7 +212,7 @@ class DistributeTranspiler:
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert (callable(split_method))
assert (split_method.__bases__[0] == PSDispatcher)
if program is None:
program = default_main_program()
self.origin_program = program
......@@ -209,6 +225,7 @@ class DistributeTranspiler:
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(pserver_endpoints)
# process lookup_table_op
# 1. check all lookup_table_op is distributed
......@@ -268,54 +285,110 @@ class DistributeTranspiler:
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
assert (len(grad_blocks) == len(param_blocks))
# step2: Create new vars for the parameters and gradients blocks and
# add ops to do the split.
grad_var_mapping = self._append_split_op(program, grad_blocks)
param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks)
grad_var_mapping = self._create_vars_from_blocklist(
program, grad_blocks, add_trainer_suffix=self.trainer_num > 1)
grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)]
# step 3: transpile trainer side program, insert recv op and send op.
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs = []
send_outputs = []
for b in grad_blocks: # append by order
varname, block_id, _ = b.split(":")
send_inputs.append(grad_var_mapping[varname][int(block_id)])
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
for i, ep in enumerate(eplist):
param = send_outputs[i]
grad = send_inputs[i]
if not self.param_grad_ep_mapping.has_key(ep):
self.param_grad_ep_mapping[ep] = {"params": [], "grads": []}
self.param_grad_ep_mapping[ep]["params"].append(param)
self.param_grad_ep_mapping[ep]["grads"].append(grad)
rpc_client_var = program.global_block().create_var(
name=RPC_CLIENT_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
[
self.param_grad_ep_mapping.update({
ep: {
"params": [],
"grads": []
}
}) for ep in self.pserver_endpoints
]
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset()
send_vars = []
for orig_varname, splited_vars in grad_var_mapping.items():
eplist = ps_dispatcher.dispatch(splited_vars)
if len(splited_vars) == 1:
orig_varname = splited_vars[0].name
index = find_op_by_output_arg(program.global_block(),
orig_varname)
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[orig_varname]
index = find_op_by_output_arg(program.global_block(),
orig_varname)
self._insert_split_op(program, orig_var, index, splited_vars)
index += 1
else:
AssertionError("Can not insert the send op by original "
"variable name :", orig_varname)
program.global_block().insert_op(
index=index + 1,
type="send_vars",
inputs={"X": splited_vars},
outputs={},
attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for _, var in enumerate(splited_vars):
send_vars.append(var)
# create send_op
if self.sync_mode:
program.global_block().append_op(
type="send",
inputs={"X": send_inputs},
outputs={"Out": send_outputs,
"RPCClient": rpc_client_var},
type="send_barrier",
inputs={},
outputs={},
attrs={
"endpoints": pserver_endpoints,
"epmap": eplist,
"sync_mode": self.sync_mode
"sync_mode": self.sync_mode,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars = []
for _, var in enumerate(send_vars):
recv_vars.append(grad_param_mapping[var])
ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
program.global_block().append_op(
type="recv",
inputs={},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
program.global_block().append_op(
type="fetch_barrier",
inputs={},
outputs={},
attrs={
"endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
continue
......@@ -327,10 +400,8 @@ class DistributeTranspiler:
attrs={"axis": 0})
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist)
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints)
self._replace_lookup_table_op_with_prefetch(program, eplist)
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
......@@ -550,8 +621,7 @@ class DistributeTranspiler:
return s_prog
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist):
def _replace_lookup_table_op_with_prefetch(self, program, eplist):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
......@@ -598,11 +668,11 @@ class DistributeTranspiler:
index=op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={
"Out": self.prefetch_output_vars,
"RPCClient": rpc_client_var
},
attrs={"epmap": eplist})
outputs={"Out": self.prefetch_output_vars},
attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
# insert concat_op
program.global_block().insert_op(
......@@ -622,8 +692,7 @@ class DistributeTranspiler:
# break for loop
break
def _split_table_grad_and_add_send_vars(self, program, rpc_client_var,
pserver_endpoints):
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops = program.global_block().ops
......@@ -643,9 +712,12 @@ class DistributeTranspiler:
index=op_index + 2,
type="send_vars",
inputs={'X': self.table_grad_list},
outputs={"RPCClient": rpc_client_var},
attrs={"sync_send": True,
"epmap": pserver_endpoints})
outputs={},
attrs={
"sync_send": True,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
break
def _create_prefetch_block(self, pserver_index, pserver_program,
......@@ -838,32 +910,13 @@ class DistributeTranspiler:
lod_level=var.lod_level,
persistable=persistable)
def _append_split_op(self, program, gradblocks):
"""
Split variables that need to be split and append respective ops
Args:
program (ProgramDesc): ProgramDesc that gradients blong.
gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks.
Returns:
var_mapping (dict(varname->[new_splitted_variable])):A dict mapping
from original var name to each var split.
"""
add_suffix = False
if self.trainer_num > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=add_suffix)
for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1:
continue
orig_var = program.global_block().vars[varname]
def _insert_split_op(self, program, orig_var, index, splited_vars):
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
program.global_block().append_op(
program.global_block().insert_op(
index=index + 1,
type="split_selected_rows",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
......@@ -872,7 +925,8 @@ class DistributeTranspiler:
sections = []
for v in splited_vars:
sections.append(v.shape[0])
program.global_block().append_op(
program.global_block().insert_op(
index=index + 1,
type="split_byref",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
......@@ -881,7 +935,6 @@ class DistributeTranspiler:
else:
AssertionError("Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]")
return var_mapping
def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
param_shape):
......
......@@ -13,45 +13,66 @@
# limitations under the License.
def hash_name(varlist, pserver_endpoints):
class PSDispatcher(object):
"""
hash variable names to several endpoints.
PSDispatcher is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
def __init__(self, pserver_endpoints):
self._eps = pserver_endpoints
self._step = 0
@property
def eps(self):
return self._eps
def reset(self):
self._step = 0
def dispatch(self, varlist):
"""
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
AssertionError("Interface has not been implemented.")
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
class HashName(PSDispatcher):
"""
Hash variable names to several endpoints
"""
def _hash_block(block_str, total):
def __init__(self, pserver_endpoints):
super(self.__class__, self).__init__(pserver_endpoints)
def _hash_block(self, block_str, total):
return hash(block_str) % total
def dispatch(self, varlist):
eplist = []
for var in varlist:
server_id = _hash_block(var.name(), len(pserver_endpoints))
server_for_param = pserver_endpoints[server_id]
server_id = self._hash_block(var.name(), len(self._eps))
server_for_param = self._eps[server_id]
eplist.append(server_for_param)
return eplist
def round_robin(varlist, pserver_endpoints):
class RoundRobin(PSDispatcher):
"""
Distribute variables to several endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
Distribute variables to serveral endpoints.
"""
assert (len(varlist) >= len(pserver_endpoints))
def __init__(self, pserver_endpoints):
super(self.__class__, self).__init__(pserver_endpoints)
def dispatch(self, varlist):
eplist = []
pserver_idx = 0
for var in varlist:
server_for_param = pserver_endpoints[pserver_idx]
server_for_param = self._eps[self._step]
eplist.append(server_for_param)
pserver_idx += 1
if pserver_idx >= len(pserver_endpoints):
pserver_idx = 0
self._step += 1
if self._step >= len(self._eps):
self._step = 0
return eplist
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册