未验证 提交 be4eaba0 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Framework for message and manager part. (#36966)

上级 bf9374c1
proto_library(fleet_executor_desc_proto SRCS fleet_executor_desc.proto)
cc_library(fleet_executor SRCS fleet_executor.cc DEPS fleet_executor_desc_proto)
if(WITH_PYTHON)
py_proto_compile(fleet_executor_desc_py_proto SRCS fleet_executor_desc.proto)
endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc)
else()
set(BRPC_DEPS "")
endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc
interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS})
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
Carrier::Carrier(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
// init
}
Carrier::~Carrier() {
// destroy
}
bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
return true;
}
void Carrier::CreateInterceptors() {
// create each Interceptor
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class Interceptor;
class TaskNode;
class InterceptorMessageServiceImpl;
class Carrier final {
public:
Carrier() = delete;
Carrier(const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
~Carrier();
// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
DISABLE_COPY_AND_ASSIGN(Carrier);
private:
// create each Interceptor
void CreateInterceptors();
// get interceptor based on the interceptor id
Interceptor* GetInterceptor(int64_t interceptor_id);
// interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
};
} // namespace distributed
} // namespace paddle
......@@ -39,5 +39,15 @@ void FleetExecutor::Release() {
// Release
}
std::shared_ptr<Carrier> FleetExecutor::GetCarrier() {
// get carrier
return nullptr;
}
std::shared_ptr<MessageBus> FleetExecutor::GetMessageBus() {
// get message bus
return nullptr;
}
} // namespace distributed
} // namespace paddle
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
......@@ -24,6 +25,8 @@ class ProgramDesc;
namespace distributed {
class RuntimeGraph;
class Carrier;
class MessageBus;
class FleetExecutor final {
public:
......@@ -33,11 +36,15 @@ class FleetExecutor final {
void Init(const paddle::framework::ProgramDesc& program_desc);
void Run();
void Release();
static std::shared_ptr<Carrier> GetCarrier();
static std::shared_ptr<MessageBus> GetMessageBus();
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
static std::shared_ptr<Carrier> global_carrier_;
static std::shared_ptr<MessageBus> global_message_bus_;
};
} // namespace distributed
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
Interceptor::Interceptor(int64_t interceptor_id_, TaskNode* node) {
// init
}
int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return 0;
}
bool Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
return true;
}
void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
}
bool Interceptor::FetchRemoteMailbox() {
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
return true;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <map>
#include <memory>
#include <queue>
#include <thread>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskNode;
class Interceptor {
public:
Interceptor() = delete;
Interceptor(int64_t interceptor_id_, TaskNode* node);
virtual ~Interceptor() = default;
// return the interceptor id
int64_t GetInterceptorId() const;
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);
DISABLE_COPY_AND_ASSIGN(Interceptor);
private:
// pool the local mailbox, parse the Message
void PoolTheMailbox();
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
bool FetchRemoteMailbox();
// interceptor id, handed from above layer
int64_t interceptor_id_;
// node need to be handled by this interceptor
TaskNode* node_;
// mutex to control read/write conflict for remote mailbox
std::mutex remote_mailbox_mutex_;
// interceptor runs PoolTheMailbox() function to poll local mailbox
std::thread interceptor_thread_;
// conditional variable for blocking the thread when
// fetch an empty remote mailbox
std::condition_variable cond_var_;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
std::queue<InterceptorMessage> remote_mailbox_;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std::queue<InterceptorMessage> local_mailbox_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;
enum MessageType {
STOP = 1; // STOP an Interceptor
DATA_IS_READY = 2; // upstream data is ready
DATE_IS_USELESS = 3; // downstream has used the data
ERROR = 4; // current Interceptor encounters error
RESET = 5; // reset the status
}
message InterceptorMessage {
optional int64 src_id = 1 [ default = 0 ];
optional int64 dst_id = 2 [ default = 0 ];
optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ];
}
message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
service TheInterceptorMessageService {
rpc InterceptorMessageService(InterceptorMessage)
returns (InterceptorResponse);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
namespace paddle {
namespace distributed {
void InterceptorMessageServiceImpl::InterceptorMessageService(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
// receive msg
}
} // namespace distributed
} // namespace paddle
#endif
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#pragma once
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
namespace paddle {
namespace distributed {
class InterceptorMessageServiceImpl : public TheInterceptorMessageService {
public:
InterceptorMessageServiceImpl() {}
virtual ~InterceptorMessageServiceImpl() {}
virtual void InterceptorMessageService(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
};
} // namespace distributed
} // namespace paddle
#endif
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
namespace paddle {
namespace distributed {
MessageBus::~MessageBus() {
// destroy
}
bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
// called by Interceptor, send InterceptorMessage to dst
return true;
}
void MessageBus::ListenPort() {
// function keep listen the port and handle the message
}
bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
// check whether the dst is the same rank or different rank with src
return true;
}
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
// send the message inter rank (dst is different rank with src)
return true;
}
#endif
#endif
bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src)
return true;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <thread>
#include <unordered_map>
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#include "brpc/channel.h"
#include "brpc/server.h"
#endif
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class Carrier;
class MessageBus final {
public:
MessageBus() = delete;
explicit MessageBus(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr)
: interceptor_id_to_rank_(interceptor_id_to_rank),
rank_to_addr_(rank_to_addr),
addr_(addr) {}
~MessageBus();
// called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message);
DISABLE_COPY_AND_ASSIGN(MessageBus);
private:
// function keep listen the port and handle the message
void ListenPort();
// check whether the dst is the same rank or different rank with src
bool IsSameRank(int64_t src_id, int64_t dst_id);
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
// send the message inter rank (dst is different rank with src)
bool SendInterRank(const InterceptorMessage& interceptor_message);
#endif
#endif
// send the message intra rank (dst is the same rank with src)
bool SendIntraRank(const InterceptorMessage& interceptor_message);
// handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
// handed by above layer, save the info mapping rank id to addr
std::unordered_map<int64_t, std::string> rank_to_addr_;
// the ip needs to be listened
std::string addr_;
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
// brpc server
brpc::Server server_;
#endif
#endif
// thread keeps listening to the port to receive remote message
// this thread runs ListenPort() function
std::thread listen_port_thread_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace distributed {
class TaskNode final {
public:
TaskNode() = default;
~TaskNode() = default;
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册