提交 7be79231 编写于 作者: T typhoonzero

wip multi-trainer

上级 1e549563
......@@ -66,6 +66,12 @@ bool RPCClient::GetVariable(const framework::Scope& scope,
return true;
}
void RPCClient::Wait() {
ClientContext context;
VoidMessage call_msg, ret_msg;
stub_->Wait(&context, call_msg, &ret_msg);
}
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -81,6 +81,7 @@ class RPCClient {
bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope, const std::string &outname);
void Wait();
private:
std::unique_ptr<SendRecvService::Stub> stub_;
......
......@@ -76,14 +76,14 @@ class RecvOp : public framework::OperatorBase {
const platform::DeviceContext &dev_ctx) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
rpc_service_.SetScope(&recv_scope);
rpc_service_->SetScope(&recv_scope);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
size_t param_count = param_list.size();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) {
rpc_service_.Start();
rpc_service_->Start();
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) {
......@@ -126,6 +126,7 @@ class RecvOp : public framework::OperatorBase {
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
rpc_service_->Done();
// for (size_t i = 0; i < param_count; ++i) {
// auto *out_var = recv_scope.FindVar(param_list[i]);
......
......@@ -34,34 +34,36 @@ class SendOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {
// init client when the operator is created at runtime.
if (!client_) {
std::string endpoint = Attr<std::string>("endpoint");
client_.reset(new detail::RPCClient(
grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials())));
// TODO(typhoonzero): how to call InitVariables
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");
for (auto ep : endpoints) {
client_map_[ep].reset(new detail::RPCClient(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())));
}
}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto ins = Inputs("X");
// TODO(typhoonzero): currently it's non-blocking,
// should block until server responds.
for (auto in : ins) {
bool ret = client_->SendVariable(scope, in);
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
// TODO(typhoonzero): use async calls to send multiple variable asyncly.
for (size_t i = 0; i < ins.size(); ++i) {
bool ret = client_map_[epmap[i]]->SendVariable(scope, ins[i]);
if (!ret) {
LOG(ERROR) << "send variable error";
LOG(ERROR) << "send variable error: " << ins[i];
}
}
for (auto in : ins) {
bool ret = client_->GetVariable(scope);
client_map_[0]->Wait(); // TODO(typhoonzero): support async optimization
for (size_t i = 0; i < ins.size(); ++i) {
bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]);
if (!ret) {
LOG(ERROR) << "GetVariable error";
LOG(ERROR) << "GetVariable error: " << ins[i];
}
}
}
protected:
std::shared_ptr<detail::RPCClient> client_{nullptr};
mutable std::unordered_map<std::string, std::shared_ptr<detail::RPCClient>>
client_map_;
};
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -74,11 +76,13 @@ Recv operator
This operator will recv tensor from send_op
)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.");
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping");
}
};
......
......@@ -145,14 +145,20 @@ class DistributeTranspiler:
pserver_endpoints = kwargs["pservers"].split(",")
self.param_grad_map = split_method(params_and_grads, pserver_endpoints)
for ep in pserver_endpoints:
# FIXME(typhoonzero): send to different servers can run in parrallel.
send_op = program.global_block().append_op(
type="send",
inputs={"X": self.param_grad_map[ep]["grads"]
}, # inputs is a list of tensors to be send
outputs={},
attrs={"endpoint": ep})
send_op_ordered_inputs = []
epmap = []
for ep, v in self.param_grad_map.iteritems():
send_op_ordered_inputs.extend(v["grads"])
for i in v:
epmap.append(ep)
send_op = program.global_block().append_op(
type="send",
inputs={"X": send_op_ordered_inputs
}, # inputs is a list of tensors to be send
outputs={},
attrs={"endpoints": pserver_endpoints,
"epmap": epmap})
def _create_var_for_trainers(self, block, var, trainers):
var_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册