未验证 提交 ddf38a3f 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add amplifier interceptor and 1F1B scheduler test (#37755)

上级 c0d5b7ec
...@@ -11,7 +11,7 @@ else() ...@@ -11,7 +11,7 @@ else()
endif() endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper
${BRPC_DEPS}) ${BRPC_DEPS})
...@@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE) ...@@ -19,6 +19,7 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace distributed {
AmplifierInterceptor::AmplifierInterceptor(int64_t interceptor_id,
TaskNode* node)
: ComputeInterceptor(interceptor_id, node) {
run_per_steps_ = node->run_per_steps();
run_at_offset_ = node->run_at_offset();
reply_up_per_steps_ = node->reply_up_per_steps();
send_down_per_steps_ = node->send_down_per_steps();
PADDLE_ENFORCE_GE(
run_per_steps_, 1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but now is %ld", run_per_steps_));
PADDLE_ENFORCE_GE(
run_at_offset_, 0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but now is %ld", run_at_offset_));
PADDLE_ENFORCE_LT(run_at_offset_, run_per_steps_,
platform::errors::InvalidArgument(
"run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
run_at_offset_, run_per_steps_));
PADDLE_ENFORCE_GE(
reply_up_per_steps_, 1,
platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but now is %ld", reply_up_per_steps_));
PADDLE_ENFORCE_GE(send_down_per_steps_, 1,
platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but now is %ld",
send_down_per_steps_));
}
void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
if ((step_ % run_per_steps_) == run_at_offset_) {
ComputeInterceptor::RunOps();
}
}
void AmplifierInterceptor::SendDataReadyToDownStream() {
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
if (step_ % send_down_per_steps_ == 0) {
ComputeInterceptor::SendDataReadyToDownStream();
}
}
void AmplifierInterceptor::ReplyCompletedToUpStream() {
// run multi times, reply one times to upstream, that is
// input one times, output multi times
if (step_ % reply_up_per_steps_ == 0) {
ComputeInterceptor::ReplyCompletedToUpStream();
}
}
REGISTER_INTERCEPTOR(Amplifier, AmplifierInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace paddle {
namespace distributed {
class AmplifierInterceptor : public ComputeInterceptor {
public:
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void RunOps() override;
void SendDataReadyToDownStream() override;
void ReplyCompletedToUpStream() override;
int64_t run_per_steps_{1};
int64_t run_at_offset_{0};
// one input produces multi times output
int64_t reply_up_per_steps_{1};
// one output need multi times input
int64_t send_down_per_steps_{1};
};
} // namespace distributed
} // namespace paddle
...@@ -24,6 +24,7 @@ namespace paddle { ...@@ -24,6 +24,7 @@ namespace paddle {
namespace distributed { namespace distributed {
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph, void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* root_scope,
......
...@@ -160,15 +160,18 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -160,15 +160,18 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
} }
} }
void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
}
}
void ComputeInterceptor::Run() { void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// step_ %= node_->max_run_times(); RunOps();
for (auto op : node_->ops()) {
auto* scope = microbatch_scopes_[step_ % node_->max_run_times()];
op->Run(*scope, place_);
}
++step_; ++step_;
// send to downstream and increase buff used // send to downstream and increase buff used
...@@ -176,7 +179,7 @@ void ComputeInterceptor::Run() { ...@@ -176,7 +179,7 @@ void ComputeInterceptor::Run() {
// reply to upstream and decrease ready data // reply to upstream and decrease ready data
ReplyCompletedToUpStream(); ReplyCompletedToUpStream();
// Try to stop Carrier // Try to stop Carrier
if (step_ % node_->max_run_times() == 0 && is_last_) { if (is_last_ && (step_ % node_->max_run_times() == 0)) {
StopCarrier(); StopCarrier();
} }
} }
......
...@@ -25,6 +25,14 @@ class ComputeInterceptor : public Interceptor { ...@@ -25,6 +25,14 @@ class ComputeInterceptor : public Interceptor {
public: public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node); ComputeInterceptor(int64_t interceptor_id, TaskNode* node);
protected:
virtual void RunOps();
virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream();
int64_t step_{0};
private:
void PrepareDeps(); void PrepareDeps();
void IncreaseReady(int64_t up_id); void IncreaseReady(int64_t up_id);
...@@ -33,19 +41,14 @@ class ComputeInterceptor : public Interceptor { ...@@ -33,19 +41,14 @@ class ComputeInterceptor : public Interceptor {
bool CanWriteOutput(); bool CanWriteOutput();
bool ShouldReset(); bool ShouldReset();
void SendDataReadyToDownStream();
void ReplyCompletedToUpStream();
void Run(); void Run();
void Compute(const InterceptorMessage& msg); void Compute(const InterceptorMessage& msg);
void ReceivedStop(int64_t up_id); void ReceivedStop(int64_t up_id);
void TryStop(); void TryStop();
private:
bool is_source_{false}; bool is_source_{false};
bool is_last_{false}; bool is_last_{false};
int64_t step_{0};
// upstream_id-->(max_ready_size, ready_size) // upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{}; std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
......
...@@ -44,12 +44,22 @@ class TaskNode final { ...@@ -44,12 +44,22 @@ class TaskNode final {
int32_t role() const { return role_; } int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; } int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; } int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::unordered_set<int64_t>& upstream() const { return upstream_; } const std::unordered_set<int64_t>& upstream() const { return upstream_; }
const std::unordered_set<int64_t>& downstream() const { return downstream_; } const std::unordered_set<int64_t>& downstream() const { return downstream_; }
const std::string& type() const { return type_; } const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; } const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; } const std::vector<OperatorBase*>& ops() const { return ops_; }
void SetRunPerSteps(int64_t value) { run_per_steps_ = value; }
void SetRunAtOffset(int64_t value) { run_at_offset_ = value; }
void SetReplyUpPerSteps(int64_t value) { reply_up_per_steps_ = value; }
void SetSendDownPerSteps(int64_t value) { send_down_per_steps_ = value; }
void SetType(const std::string& type) { type_ = type; }
bool AddUpstreamTask(int64_t task_id); bool AddUpstreamTask(int64_t task_id);
bool AddDownstreamTask(int64_t task_id); bool AddDownstreamTask(int64_t task_id);
std::string DebugString() const; std::string DebugString() const;
...@@ -76,6 +86,13 @@ class TaskNode final { ...@@ -76,6 +86,13 @@ class TaskNode final {
int64_t max_run_times_; int64_t max_run_times_;
int64_t max_slot_nums_; int64_t max_slot_nums_;
int64_t run_per_steps_{1};
int64_t run_at_offset_{0};
// one input produces multi times output
int64_t reply_up_per_steps_{1};
// one output need multi times input
int64_t send_down_per_steps_{1};
std::string type_; std::string type_;
}; };
......
...@@ -4,6 +4,12 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet ...@@ -4,6 +4,12 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(interceptor_pipeline_long_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context) cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
void LinkNodes(const std::vector<TaskNode*>& nodes) {
size_t size = nodes.size();
if (size <= 1) return;
{ // i = 0
TaskNode* now = nodes[0];
TaskNode* next = nodes[1];
now->AddDownstreamTask(next->task_id());
}
{ // i = size - 1
TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1];
now->AddUpstreamTask(prev->task_id());
}
for (size_t i = 1; i < size - 1; ++i) {
TaskNode* prev = nodes[i - 1];
TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1];
now->AddUpstreamTask(prev->task_id());
now->AddDownstreamTask(next->task_id());
}
}
TEST(AmplifierInterceptor, Amplifier) {
Carrier& carrier = Carrier::Instance();
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}},
{{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
// a->b->c->d->e->f
LinkNodes({node_a, node_b, node_c, node_d, node_e, node_f});
// LR->b(1:3)->F->B->e(3:1)->U
node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps);
carrier.SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Amplifier", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d));
carrier.SetInterceptor(4, InterceptorFactory::Create("Amplifier", 4, node_e));
carrier.SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
carrier.SetCreatingFlag(false);
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
void LinkNodes(const std::vector<TaskNode*>& nodes) {
size_t size = nodes.size();
if (size <= 1) return;
{ // i = 0
TaskNode* now = nodes[0];
TaskNode* next = nodes[1];
now->AddDownstreamTask(next->task_id());
}
{ // i = size - 1
TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1];
now->AddUpstreamTask(prev->task_id());
}
for (size_t i = 1; i < size - 1; ++i) {
TaskNode* prev = nodes[i - 1];
TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1];
now->AddUpstreamTask(prev->task_id());
now->AddDownstreamTask(next->task_id());
}
}
TEST(AmplifierInterceptor, Amplifier) {
Carrier& carrier = Carrier::Instance();
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, "");
int64_t micro_steps = 3;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
// a->b->c->d
LinkNodes({node_a, node_b, node_c, node_d});
node_a->SetRunPerSteps(micro_steps);
node_d->SetRunPerSteps(micro_steps);
node_d->SetRunAtOffset(micro_steps - 1);
carrier.SetInterceptor(0, InterceptorFactory::Create("Amplifier", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, InterceptorFactory::Create("Amplifier", 3, node_d));
carrier.SetCreatingFlag(false);
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册