未验证 提交 627d7a64 编写于 作者: G gongweibao 提交者: GitHub

Clean `sendop` `recv` operator. (#11309)

上级 fa29ef0b
......@@ -89,7 +89,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send_vars") {
if (op->Type() == "send") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
......@@ -468,17 +468,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send_vars");
ConnectOp(result, result->ops_.back().get(), "send");
} else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send_vars") {
} else if (op.Type() == "send") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]");
"send, send_barrier. recv, fetch_barrier]");
}
// TODO(Yancey1989): schedule rpc op on different place may
......
......@@ -189,16 +189,14 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(recv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......@@ -208,15 +206,14 @@ if(WITH_DISTRIBUTE)
# listen_and_serv_op sum_op executor SERIAL)
if(WITH_GPU)
set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op
listen_and_serv_op executor SERIAL)
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op executor SERIAL)
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc)
set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()
op_library(cross_entropy_op DEPS cross_entropy)
......
......@@ -78,9 +78,15 @@ This operator can get variables from server side.
}
};
class RecvOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker);
REGISTER_OPERATOR(recv, ops::RecvOp, paddle::framework::EmptyGradOpMaker,
ops::RecvOpMaker, ops::RecvOpShapeInference);
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
......@@ -36,12 +35,9 @@ class SendOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> endpoints =
Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......@@ -55,32 +51,14 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
rpc_client->Wait();
if (sync_mode) {
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
rpc_client->Wait();
}
if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) {
VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
}
rpc_client->Wait();
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(2) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
if (sync_send) {
rpc_client->Wait();
}
}
......@@ -89,26 +67,22 @@ class SendOp : public framework::OperatorBase {
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server")
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddComment(R"DOC(
Send operator
This operator will send tensor to recv_op at the parameter server.
This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
// TODO(typhoonzero): remove this attr generate de-duplicated vector from
// epmap when initializing.
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({});
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
.SetDefault({"127.0.0.1:6164"});
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
class SendVarsOp : public framework::OperatorBase {
public:
SendVarsOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_send");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
rpc_client->Wait();
}
}
};
class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddComment(R"DOC(
Send operator
This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
AddAttr<int>("sync_send",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({"127.0.0.1:6164"});
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpShapeInference);
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import delete_ops
......@@ -54,10 +55,10 @@ class TestDistTranspiler(TranspilerTest):
delete_ops(trainer.global_block(), optimize_ops)
ops = [op.type for op in trainer.global_block().ops] + [
"split_byref", "send_vars", "send_barrier", "recv", "recv",
"split_byref", "send", "send_barrier", "recv", "recv",
"fetch_barrier", "concat"
]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
ops.insert(ops.index("elementwise_add_grad") + 1, "send")
return ops
......
......@@ -59,9 +59,9 @@ class TestSimpleDistTranspiler(TranspilerTest):
delete_ops(trainer.global_block(), optimize_ops)
ops = [op.type for op in trainer.global_block().ops] + [
"send_vars", "send_barrier", "recv", "recv", "fetch_barrier"
"send", "send_barrier", "recv", "recv", "fetch_barrier"
]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
ops.insert(ops.index("elementwise_add_grad") + 1, "send")
return ops
def _transpiler_instance(self):
......
......@@ -24,9 +24,9 @@ Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
3. modify trainer program add split_op to each grad variable.
4. append send_op to send splited variables to server and fetch
params(splited blocks or origin param) from server.
5. append concat_op to merge splited blocks to update local weights.
4. append send_op to send splited variables to server and
5. add recv_op to fetch params(splited blocks or origin param) from server.
6. append concat_op to merge splited blocks to update local weights.
Steps to transpile pserver:
1. create new program for parameter server.
......@@ -317,7 +317,7 @@ class DistributeTranspiler:
program.global_block().insert_op(
index=index + 1,
type="send_vars",
type="send",
inputs={"X": splited_vars},
outputs={},
attrs={
......@@ -678,7 +678,7 @@ class DistributeTranspiler:
break
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# 2. add split_ids_op and send_op to send gradient to pservers
# there should only be one table_name
all_ops = program.global_block().ops
table_grad_name = grad_var_name(self.table_name)
......@@ -695,11 +695,11 @@ class DistributeTranspiler:
outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op(
index=op_index + 2,
type="send_vars",
type="send",
inputs={'X': self.trainer_side_table_grad_list},
outputs={},
attrs={
"sync_send": True,
"sync_mode": True,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册