未验证 提交 fbd5f689 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #7980 from typhoonzero/grpc_perf_conn_once

Performance enhancement by reuse connection
...@@ -42,28 +42,32 @@ class SendOp : public framework::OperatorBase { ...@@ -42,28 +42,32 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} }
PADDLE_ENFORCE(client_.Wait()); PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep; VLOG(3) << "batch barrier, ep: " << ep;
client_.AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(client_.Wait()); PADDLE_ENFORCE(rpc_client->Wait());
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
PADDLE_ENFORCE(rpc_client->Wait());
PADDLE_ENFORCE(client_.Wait());
} }
private:
mutable detail::RPCClient client_;
}; };
class SendOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -73,6 +77,9 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,6 +77,9 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server") AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
......
...@@ -153,11 +153,18 @@ class DistributeTranspiler: ...@@ -153,11 +153,18 @@ class DistributeTranspiler:
self.param_grad_ep_mapping[ep]["params"].append(param) self.param_grad_ep_mapping[ep]["params"].append(param)
self.param_grad_ep_mapping[ep]["grads"].append(grad) self.param_grad_ep_mapping[ep]["grads"].append(grad)
rpc_client_var = program.global_block().create_var(
name="RPC_CLIENT_VAR",
psersistable=True,
dtype='float32', # dtype and shape is not used in fact
shape=[0])
# create send_op # create send_op
send_op = program.global_block().append_op( send_op = program.global_block().append_op(
type="send", type="send",
inputs={"X": send_inputs}, inputs={"X": send_inputs},
outputs={"Out": send_outputs}, outputs={"Out": send_outputs,
"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints, attrs={"endpoints": pserver_endpoints,
"epmap": eplist}) "epmap": eplist})
# step4 # step4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册