未验证 提交 4974fdfd 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Add source interceptor and test (#41122)

上级 7c555f4e
...@@ -13,7 +13,7 @@ endif() ...@@ -13,7 +13,7 @@ endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog) cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc interceptor.cc cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc interceptor.cc
compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc compute_interceptor.cc amplifier_interceptor.cc source_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS}) op_registry executor_gc_helper gflags glog ${BRPC_DEPS})
...@@ -25,6 +25,7 @@ if(WITH_DISTRIBUTE) ...@@ -25,6 +25,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Amplifier);
......
...@@ -164,7 +164,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -164,7 +164,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
if (up_id == -1) return; if (up_id == -1) return;
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS); reply_msg.set_message_type(DATA_IS_USELESS);
Send(up_id, reply_msg); Send(up_id, reply_msg);
} }
} }
...@@ -247,7 +247,7 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -247,7 +247,7 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id()); IncreaseReady(msg.src_id());
Run(); Run();
} else if (msg.message_type() == DATE_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
DecreaseBuff(msg.src_id()); DecreaseBuff(msg.src_id());
Run(); Run();
} else if (msg.message_type() == STOP) { } else if (msg.message_type() == STOP) {
......
...@@ -20,9 +20,10 @@ option cc_enable_arenas = true; ...@@ -20,9 +20,10 @@ option cc_enable_arenas = true;
enum MessageType { enum MessageType {
STOP = 1; // STOP an Interceptor STOP = 1; // STOP an Interceptor
DATA_IS_READY = 2; // upstream data is ready DATA_IS_READY = 2; // upstream data is ready
DATE_IS_USELESS = 3; // downstream has used the data DATA_IS_USELESS = 3; // downstream has used the data
ERR = 4; // current Interceptor encounters error ERR = 4; // current Interceptor encounters error
RESET = 5; // reset the status RESET = 5; // reset the status
START = 6;
} }
message InterceptorMessage { message InterceptorMessage {
...@@ -30,6 +31,7 @@ message InterceptorMessage { ...@@ -30,6 +31,7 @@ message InterceptorMessage {
optional int64 dst_id = 2 [ default = 0 ]; optional int64 dst_id = 2 [ default = 0 ];
optional MessageType message_type = 3 [ default = RESET ]; optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ]; optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ];
} }
message InterceptorResponse { optional bool rst = 1 [ default = false ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/source_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
SourceInterceptor::SourceInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) {
// prepare the downstream running status
for (const auto& down : node->downstream()) {
downstream_step_.emplace(down.first, 0);
}
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void SourceInterceptor::SendDataReadyToDownStream(int64_t downstream_id) {
int64_t micro_step = downstream_step_.at(downstream_id);
if (micro_step >= max_run_times_) {
return;
}
int64_t scope_idx = micro_step % max_run_times_;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_idx);
Send(downstream_id, ready_msg);
downstream_step_.at(downstream_id) = micro_step + 1;
}
void SourceInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == START) {
// start run in a new step, reset the previous running status
for (const auto& down : downstream_step_) {
downstream_step_.at(down.first) = 0;
SendDataReadyToDownStream(down.first);
}
} else if (msg.message_type() == DATA_IS_USELESS) {
SendDataReadyToDownStream(msg.src_id());
}
}
REGISTER_INTERCEPTOR(Source, SourceInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/*
* Source interceptor
* There is only one source in the runtime graph
* Take charge of:
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class SourceInterceptor : public Interceptor {
public:
SourceInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void SendDataReadyToDownStream(int64_t down_id);
void Run(const InterceptorMessage& msg);
int64_t max_run_times_;
// downstream_id->cur_step
std::map<int64_t, int64_t> downstream_step_;
};
} // namespace distributed
} // namespace paddle
...@@ -4,6 +4,9 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet ...@@ -4,6 +4,9 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(source_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(source_interceptor_test SRCS source_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS})
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class FakeInterceptor : public Interceptor {
public:
FakeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
step_ = 0;
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
std::cout << "FakeInterceptor run in scope " << msg.scope_idx()
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(-1, reply);
step_++;
if (step_ == node_->max_run_times()) {
carrier_->WakeUp();
}
}
}
private:
int64_t step_;
};
TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{-1, 0}, {0, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(-1, 1);
carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(-1);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册