diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 4a41f69411836e1154df23fa3c1d33dd0327ab24..2071477372c9e7ee5a58ae4f01af268b5b014211 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -51,15 +51,11 @@ void MessageBus::Init( #endif ListenPort(); - - std::call_once(once_flag_, []() { - std::atexit([]() { MessageBus::Instance().Release(); }); - }); } bool MessageBus::IsInit() const { return is_init_; } -void MessageBus::Release() { +MessageBus::~MessageBus() { VLOG(3) << "Message bus releases resource."; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index 03bb7ed81a0c78b1645529a88def18cf4922c1ed..5b19a894aa35171a8b672804a9b90e7480db2668 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -50,11 +50,11 @@ class MessageBus final { bool IsInit() const; - void Release(); - // called by Interceptor, send InterceptorMessage to dst bool Send(const InterceptorMessage& interceptor_message); + ~MessageBus(); + DISABLE_COPY_AND_ASSIGN(MessageBus); private: diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index 1d034d510a9cc7c2dd2ce7e76ae8665f600da09f..7e6d887a2d0ed8afb352bd1102c2aca7482571dd 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -2,3 +2,7 @@ set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLA set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) +if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) +set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS}) +endif() diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbbcd647292db3a0a2e81fe1f837bfe1016113de --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" + +namespace paddle { +namespace distributed { + +class PingPongInterceptor : public Interceptor { + public: + PingPongInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); }); + } + + void PingPong(const InterceptorMessage& msg) { + std::cout << GetInterceptorId() << " recv msg, count=" << count_ + << std::endl; + ++count_; + if (count_ == 20 && GetInterceptorId() == 0) { + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(0, stop); + Send(1, stop); + return; + } + + InterceptorMessage resp; + int64_t dst = GetInterceptorId() == 0 ? 1 : 0; + Send(dst, resp); + } + + private: + int count_{0}; +}; + +REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); + +TEST(InterceptorTest, PingPong) { + std::cout << "Ping pong test through brpc" << std::endl; + unsigned int seed = time(0); + // random generated two ports in from 6000 to 9000 + int port0 = 6000 + rand_r(&seed) % 3000; + int port1 = port0 + 1; + + // using socket to check the availability of the port + int server_fd = -1; + server_fd = socket(AF_INET, SOCK_STREAM, 0); + int opt = 1; + linger ling; + ling.l_onoff = 1; + ling.l_linger = 0; + setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)); + setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port0); + while (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) == -1) { + port0++; + address.sin_port = htons(port0); + } + close(server_fd); + + // use another socket to check another port + server_fd = socket(AF_INET, SOCK_STREAM, 0); + setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)); + setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + port1 = port0 + 1; + address.sin_port = htons(port1); + while (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) == -1) { + port1++; + address.sin_port = htons(port1); + } + close(server_fd); + + std::string ip0 = "127.0.0.1:" + std::to_string(port0); + std::string ip1 = "127.0.0.1:" + std::to_string(port1); + std::cout << "ip0: " << ip0 << std::endl; + std::cout << "ip1: " << ip1 << std::endl; + + int pid = fork(); + if (pid == 0) { + MessageBus& msg_bus = MessageBus::Instance(); + msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0); + + Carrier& carrier = Carrier::Instance(); + + Interceptor* a = carrier.SetInterceptor( + 0, InterceptorFactory::Create("PingPong", 0, nullptr)); + carrier.SetCreatingFlag(false); + + InterceptorMessage msg; + a->Send(1, msg); + } else { + MessageBus& msg_bus = MessageBus::Instance(); + msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1); + + Carrier& carrier = Carrier::Instance(); + + carrier.SetInterceptor(1, + InterceptorFactory::Create("PingPong", 1, nullptr)); + carrier.SetCreatingFlag(false); + } +} + +} // namespace distributed +} // namespace paddle