提交 7be79231 编写于 作者: T typhoonzero

wip multi-trainer

上级 1e549563
...@@ -66,6 +66,12 @@ bool RPCClient::GetVariable(const framework::Scope& scope, ...@@ -66,6 +66,12 @@ bool RPCClient::GetVariable(const framework::Scope& scope,
return true; return true;
} }
void RPCClient::Wait() {
ClientContext context;
VoidMessage call_msg, ret_msg;
stub_->Wait(&context, call_msg, &ret_msg);
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -81,6 +81,7 @@ class RPCClient { ...@@ -81,6 +81,7 @@ class RPCClient {
bool SendVariable(const framework::Scope &scope, const std::string &inname); bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope, const std::string &outname); bool GetVariable(const framework::Scope &scope, const std::string &outname);
void Wait();
private: private:
std::unique_ptr<SendRecvService::Stub> stub_; std::unique_ptr<SendRecvService::Stub> stub_;
......
...@@ -76,14 +76,14 @@ class RecvOp : public framework::OperatorBase { ...@@ -76,14 +76,14 @@ class RecvOp : public framework::OperatorBase {
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
// FIXME(typhoonzero): no new scopes for every run. // FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
rpc_service_.SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers"); auto trainer_count = Attr<int>("Trainers");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) { while (true) {
rpc_service_.Start(); rpc_service_->Start();
// Get from multiple trainers, we don't care about order in which // Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient. // the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) { for (size_t i = 0; i < param_count * trainer_count; ++i) {
...@@ -126,6 +126,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -126,6 +126,7 @@ class RecvOp : public framework::OperatorBase {
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
rpc_service_->Done();
// for (size_t i = 0; i < param_count; ++i) { // for (size_t i = 0; i < param_count; ++i) {
// auto *out_var = recv_scope.FindVar(param_list[i]); // auto *out_var = recv_scope.FindVar(param_list[i]);
......
...@@ -34,34 +34,36 @@ class SendOp : public framework::OperatorBase { ...@@ -34,34 +34,36 @@ class SendOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {
// init client when the operator is created at runtime. // init client when the operator is created at runtime.
if (!client_) { std::vector<std::string> endpoints =
std::string endpoint = Attr<std::string>("endpoint"); Attr<std::vector<std::string>>("endpoints");
client_.reset(new detail::RPCClient( for (auto ep : endpoints) {
grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials()))); client_map_[ep].reset(new detail::RPCClient(
// TODO(typhoonzero): how to call InitVariables grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())));
} }
} }
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
// TODO(typhoonzero): currently it's non-blocking, std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
// should block until server responds. // TODO(typhoonzero): use async calls to send multiple variable asyncly.
for (auto in : ins) { for (size_t i = 0; i < ins.size(); ++i) {
bool ret = client_->SendVariable(scope, in); bool ret = client_map_[epmap[i]]->SendVariable(scope, ins[i]);
if (!ret) { if (!ret) {
LOG(ERROR) << "send variable error"; LOG(ERROR) << "send variable error: " << ins[i];
} }
} }
for (auto in : ins) { client_map_[0]->Wait(); // TODO(typhoonzero): support async optimization
bool ret = client_->GetVariable(scope); for (size_t i = 0; i < ins.size(); ++i) {
bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]);
if (!ret) { if (!ret) {
LOG(ERROR) << "GetVariable error"; LOG(ERROR) << "GetVariable error: " << ins[i];
} }
} }
} }
protected: protected:
std::shared_ptr<detail::RPCClient> client_{nullptr}; mutable std::unordered_map<std::string, std::shared_ptr<detail::RPCClient>>
client_map_;
}; };
class SendOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -74,11 +76,13 @@ Recv operator ...@@ -74,11 +76,13 @@ Recv operator
This operator will recv tensor from send_op This operator will recv tensor from send_op
)DOC"); )DOC");
AddAttr<std::string>("endpoint", AddAttr<std::vector<std::string>>("endpoints",
"(string, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"IP address to listen on.") "Server endpoints to send variables to.");
.SetDefault("127.0.0.1:6164") AddAttr<std::vector<std::string>>("epmap",
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); "(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping");
} }
}; };
......
...@@ -145,14 +145,20 @@ class DistributeTranspiler: ...@@ -145,14 +145,20 @@ class DistributeTranspiler:
pserver_endpoints = kwargs["pservers"].split(",") pserver_endpoints = kwargs["pservers"].split(",")
self.param_grad_map = split_method(params_and_grads, pserver_endpoints) self.param_grad_map = split_method(params_and_grads, pserver_endpoints)
for ep in pserver_endpoints: send_op_ordered_inputs = []
# FIXME(typhoonzero): send to different servers can run in parrallel. epmap = []
send_op = program.global_block().append_op( for ep, v in self.param_grad_map.iteritems():
type="send", send_op_ordered_inputs.extend(v["grads"])
inputs={"X": self.param_grad_map[ep]["grads"] for i in v:
}, # inputs is a list of tensors to be send epmap.append(ep)
outputs={},
attrs={"endpoint": ep}) send_op = program.global_block().append_op(
type="send",
inputs={"X": send_op_ordered_inputs
}, # inputs is a list of tensors to be send
outputs={},
attrs={"endpoints": pserver_endpoints,
"epmap": epmap})
def _create_var_for_trainers(self, block, var, trainers): def _create_var_for_trainers(self, block, var, trainers):
var_list = [] var_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册