未验证 提交 4cade607 编写于 作者: 武毅 提交者: GitHub

Merge pull request #6983 from typhoonzero/fix_sendrecv_ut

Fix sendrecv ut
...@@ -89,6 +89,9 @@ class OperatorBase { ...@@ -89,6 +89,9 @@ class OperatorBase {
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0; virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {}
virtual bool IsNetOp() const { return false; } virtual bool IsNetOp() const { return false; }
virtual bool SupportGPU() const { return false; } virtual bool SupportGPU() const { return false; }
......
...@@ -62,6 +62,8 @@ class SendRecvServerImpl final : public SendRecvService::Service { ...@@ -62,6 +62,8 @@ class SendRecvServerImpl final : public SendRecvService::Service {
const TensorWithName Get() { return this->var_recv_queue_.Pop(); } const TensorWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const TensorWithName &msg) { this->var_recv_queue_.Push(msg); }
private: private:
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<TensorWithName> var_recv_queue_; SimpleBlockQueue<TensorWithName> var_recv_queue_;
......
...@@ -28,6 +28,8 @@ limitations under the License. */ ...@@ -28,6 +28,8 @@ limitations under the License. */
#include "paddle/operators/detail/send_recv_impl.h" #include "paddle/operators/detail/send_recv_impl.h"
#include "paddle/operators/detail/simple_block_queue.h" #include "paddle/operators/detail/simple_block_queue.h"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -39,7 +41,7 @@ void RunServer(Server **rpc_server, ...@@ -39,7 +41,7 @@ void RunServer(Server **rpc_server,
builder.RegisterService(service.get()); builder.RegisterService(service.get());
std::unique_ptr<Server> server(builder.BuildAndStart()); std::unique_ptr<Server> server(builder.BuildAndStart());
*rpc_server = server.get(); *rpc_server = server.get();
LOG(INFO) << "Server listening on " << server_address << std::endl; LOG(INFO) << "Server listening on " << server_address;
server->Wait(); server->Wait();
} }
...@@ -57,7 +59,10 @@ class RecvOp : public framework::OperatorBase { ...@@ -57,7 +59,10 @@ class RecvOp : public framework::OperatorBase {
} }
} }
virtual ~RecvOp() { void Stop() override {
detail::TensorWithName term_msg;
term_msg.first = LISTEN_TERMINATE_MESSAGE;
rpc_service_->Push(term_msg);
rpc_server_->Shutdown(); rpc_server_->Shutdown();
server_thread_->join(); server_thread_->join();
} }
...@@ -83,13 +88,18 @@ class RecvOp : public framework::OperatorBase { ...@@ -83,13 +88,18 @@ class RecvOp : public framework::OperatorBase {
size_t param_count = param_list.size(); size_t param_count = param_list.size();
rpc_service_->Reset(); rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) { bool exit_flag = false;
while (!exit_flag) {
// Get from multiple trainers, we don't care about order in which // Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient. // the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) { for (size_t i = 0; i < param_count * trainer_count; ++i) {
// blocking get one var from client. // blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get(); const detail::TensorWithName &v = rpc_service_->Get();
auto grad_var_name = v.first; auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
exit_flag = true;
break;
}
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name; std::string param_var_name;
if (it != grad_list.end()) { if (it != grad_list.end()) {
...@@ -114,8 +124,11 @@ class RecvOp : public framework::OperatorBase { ...@@ -114,8 +124,11 @@ class RecvOp : public framework::OperatorBase {
auto *tensor = var->GetMutable<framework::LoDTensor>(); auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy // FIXME(typhoonzero): do not copy
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Borrow(dev_place);
framework::CopyFrom(v.second, place, dev_ctx, tensor); framework::CopyFrom(v.second, dev_place, dev_ctx, tensor);
}
if (exit_flag) {
break;
} }
rpc_service_->Reset(); rpc_service_->Reset();
...@@ -123,7 +136,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -123,7 +136,7 @@ class RecvOp : public framework::OperatorBase {
framework::proto::ProgramDesc program_desc; framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(program_str); program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc); framework::ProgramDesc program(program_desc);
framework::Executor executor(place); framework::Executor executor(dev_place);
// Run sub graph to get optimized tensor // Run sub graph to get optimized tensor
try { try {
executor.Run(program, &recv_scope, 0, /*global_block*/ executor.Run(program, &recv_scope, 0, /*global_block*/
......
...@@ -41,9 +41,11 @@ class SendOp : public framework::OperatorBase { ...@@ -41,9 +41,11 @@ class SendOp : public framework::OperatorBase {
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()))); grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())));
} }
} }
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &dev_place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
// TODO(typhoonzero): use async calls to send multiple variable asyncly. // TODO(typhoonzero): use async calls to send multiple variable asyncly.
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
...@@ -54,10 +56,10 @@ class SendOp : public framework::OperatorBase { ...@@ -54,10 +56,10 @@ class SendOp : public framework::OperatorBase {
} }
// TODO(typhoonzero): support async optimization // TODO(typhoonzero): support async optimization
client_map_[epmap[0]]->Wait(); client_map_[epmap[0]]->Wait();
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]); bool ret = client_map_[epmap[i]]->GetVariable(scope, outs[i]);
if (!ret) { if (!ret) {
LOG(ERROR) << "GetVariable error: " << ins[i]; LOG(ERROR) << "GetVariable error: " << outs[i];
} }
} }
} }
...@@ -72,6 +74,8 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -72,6 +74,8 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
SendOpMaker(OpProto *proto, OpAttrChecker *op_checker) SendOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor to be send").AsDuplicable(); AddInput("X", "(Tensor) Input tensor to be send").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to get from server")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Recv operator Recv operator
...@@ -79,11 +83,13 @@ This operator will recv tensor from send_op ...@@ -79,11 +83,13 @@ This operator will recv tensor from send_op
)DOC"); )DOC");
AddAttr<std::vector<std::string>>("endpoints", AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."); "Server endpoints to send variables to.")
.SetDefault({});
AddAttr<std::vector<std::string>>("epmap", AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping"); "variables for mapping")
.SetDefault({});
} }
}; };
......
...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// TODO(typhoonzero): add python bindings for this test as
// a RemoteOptimizer.
#include <unistd.h> #include <unistd.h>
#include <string> #include <string>
#include <thread> #include <thread>
...@@ -86,18 +83,19 @@ void StartServerNet() { ...@@ -86,18 +83,19 @@ void StartServerNet() {
paddle::framework::ProgramDesc program; paddle::framework::ProgramDesc program;
paddle::framework::BlockDesc *block = program.MutableBlock(0); paddle::framework::BlockDesc *block = program.MutableBlock(0);
// X for server side tensors, RX for received tensers, must be of same shape. // X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"x0"}}}, {}, block);
paddle::framework::AttributeMap attrs; paddle::framework::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
attrs.insert({"ParamList", std::vector<std::string>({"x0"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
std::string program_proto; std::string program_proto;
PADDLE_ENFORCE(program.Proto()->SerializeToString(&program_proto)); PADDLE_ENFORCE(program.Proto()->SerializeToString(&program_proto));
attrs.insert({"OptimizeProgram", program_proto}); attrs.insert({"OptimizeProgram", program_proto});
recv_op = paddle::framework::OpRegistry::CreateOp( recv_op = paddle::framework::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}},
"recv", {{"RX", {"x0", "x1"}}}, {{"Out", {"Out"}}}, attrs); {}, attrs);
paddle::platform::CPUDeviceContext ctx(place); recv_op->Run(scope, place);
recv_op->Run(scope, ctx);
} }
TEST(SendRecvOp, CPU) { TEST(SendRecvOp, CPU) {
...@@ -109,25 +107,25 @@ TEST(SendRecvOp, CPU) { ...@@ -109,25 +107,25 @@ TEST(SendRecvOp, CPU) {
InitTensorsInScope(scope, place); InitTensorsInScope(scope, place);
paddle::framework::AttributeMap attrs; paddle::framework::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})});
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})});
auto send_op = paddle::framework::OpRegistry::CreateOp( auto send_op = paddle::framework::OpRegistry::CreateOp(
"send", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, attrs); "send", {{"X", {"x1"}}}, {{"Out", {"x0"}}}, attrs);
paddle::platform::CPUDeviceContext ctx(place); send_op->Run(scope, place);
send_op->Run(scope, ctx);
auto in_var = scope.Var("x0"); auto in_var = scope.Var("x1");
auto tensor = in_var->GetMutable<paddle::framework::LoDTensor>(); auto tensor = in_var->GetMutable<paddle::framework::LoDTensor>();
float *expected = tensor->data<float>(); float *expected = tensor->data<float>();
auto out_var = scope.Var("x0");
auto out_var = scope.Var("Out");
auto target = out_var->GetMutable<paddle::framework::LoDTensor>(); auto target = out_var->GetMutable<paddle::framework::LoDTensor>();
// send fail cause output is none. // x1 * 2 == x0
EXPECT_NE(target->memory_size(), size_t(0)); EXPECT_NE(target->memory_size(), size_t(0));
float *actual = target->data<float>(); float *actual = target->data<float>();
for (int64_t i = 0; i < target->numel(); ++i) { for (int64_t i = 0; i < target->numel(); ++i) {
EXPECT_EQ(expected[i] * 2, actual[i]); EXPECT_EQ(expected[i] * 2, actual[i]);
} }
recv_op.reset(); // dtor can shutdown and join server thread.
recv_op->Stop();
server_thread.join(); server_thread.join();
// recv_op.reset();
} }
...@@ -141,16 +141,18 @@ class DistributeTranspiler: ...@@ -141,16 +141,18 @@ class DistributeTranspiler:
self.param_grad_map = split_method(params_and_grads, pserver_endpoints) self.param_grad_map = split_method(params_and_grads, pserver_endpoints)
send_op_ordered_inputs = [] send_op_ordered_inputs = []
send_op_ordered_outputs = []
epmap = [] epmap = []
for ep, v in self.param_grad_map.iteritems(): for ep, v in self.param_grad_map.iteritems():
send_op_ordered_inputs.extend(v["grads"]) send_op_ordered_inputs.extend(v["grads"])
send_op_ordered_outputs.extend(v["params"])
for i in v["grads"]: for i in v["grads"]:
epmap.append(ep) epmap.append(ep)
send_op = program.global_block().append_op( send_op = program.global_block().append_op(
type="send", type="send",
inputs={"X": send_op_ordered_inputs inputs={"X": send_op_ordered_inputs
}, # inputs is a list of tensors to be send }, # inputs is a list of tensors to be send
outputs={}, outputs={"Out": send_op_ordered_outputs},
attrs={"endpoints": pserver_endpoints, attrs={"endpoints": pserver_endpoints,
"epmap": epmap}) "epmap": epmap})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册