未验证 提交 5b962bd9 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Interceptor run op (#37623)

上级 b6307742
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" #include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -40,9 +41,22 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -40,9 +41,22 @@ void ComputeInterceptor::PrepareDeps() {
for (auto down_id : downstream) { for (auto down_id : downstream) {
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0)); out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0));
} }
// source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) {
is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(), 0,
platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
}
} }
void ComputeInterceptor::IncreaseReady(int64_t up_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id) {
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) return;
auto it = in_readys_.find(up_id); auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(), PADDLE_ENFORCE_NE(it, in_readys_.end(),
platform::errors::NotFound( platform::errors::NotFound(
...@@ -93,6 +107,12 @@ bool ComputeInterceptor::CanWriteOutput() { ...@@ -93,6 +107,12 @@ bool ComputeInterceptor::CanWriteOutput() {
return true; return true;
} }
// only source node need reset
bool ComputeInterceptor::ShouldReset() {
if (is_source_ && step_ == node_->max_run_times()) return true;
return false;
}
void ComputeInterceptor::SendDataReadyToDownStream() { void ComputeInterceptor::SendDataReadyToDownStream() {
for (auto& outs : out_buffs_) { for (auto& outs : out_buffs_) {
auto down_id = outs.first; auto down_id = outs.first;
...@@ -134,9 +154,27 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -134,9 +154,27 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
} }
void ComputeInterceptor::Run() { void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) { // If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// TODO(wangxi): add op run
// step_ %= node_->max_run_times();
for (auto op : node_->ops()) {
auto* scope = microbatch_scopes_[step_ % node_->max_slot_nums()];
op->Run(*scope, place_);
}
++step_;
// send to downstream and increase buff used // send to downstream and increase buff used
SendDataReadyToDownStream(); SendDataReadyToDownStream();
...@@ -149,7 +187,7 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) { ...@@ -149,7 +187,7 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) {
received_stop_ = true; received_stop_ = true;
// source node has no upstream, stop is send by carrier or others // source node has no upstream, stop is send by carrier or others
if (up_id == -1) return; if (is_source_ && up_id == -1) return;
auto it = in_stops_.find(up_id); auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it, in_stops_.end(), PADDLE_ENFORCE_NE(it, in_stops_.end(),
......
...@@ -31,6 +31,7 @@ class ComputeInterceptor : public Interceptor { ...@@ -31,6 +31,7 @@ class ComputeInterceptor : public Interceptor {
void DecreaseBuff(int64_t down_id); void DecreaseBuff(int64_t down_id);
bool IsInputReady(); bool IsInputReady();
bool CanWriteOutput(); bool CanWriteOutput();
bool ShouldReset();
void SendDataReadyToDownStream(); void SendDataReadyToDownStream();
void ReplyCompletedToUpStream(); void ReplyCompletedToUpStream();
...@@ -43,8 +44,9 @@ class ComputeInterceptor : public Interceptor { ...@@ -43,8 +44,9 @@ class ComputeInterceptor : public Interceptor {
void TryStop(); void TryStop();
private: private:
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0 bool is_source_{false};
int64_t step_{0}; int64_t step_{0};
// upstream_id-->(max_ready_size, ready_size) // upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{}; std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size) // downstream_id-->(max_buffer_size, used_size)
......
...@@ -26,8 +26,12 @@ ...@@ -26,8 +26,12 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace framework {
class Scope;
}
namespace distributed { namespace distributed {
class TaskNode; class TaskNode;
...@@ -64,12 +68,34 @@ class Interceptor { ...@@ -64,12 +68,34 @@ class Interceptor {
bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
void SetPlace(const platform::Place& place) { place_ = place; }
void SetRootScope(framework::Scope* scope) { root_scope_ = scope; }
void SetMiniBatchScope(framework::Scope* scope) { minibatch_scope_ = scope; }
void SetMicroBatchScope(const std::vector<framework::Scope*>& scopes) {
microbatch_scopes_ = scopes;
}
TaskNode* GetTaskNode() const { return node_; }
DISABLE_COPY_AND_ASSIGN(Interceptor); DISABLE_COPY_AND_ASSIGN(Interceptor);
protected: protected:
TaskNode* GetTaskNode() const { return node_; } // interceptor id, handed from above layer
int64_t interceptor_id_;
// node need to be handled by this interceptor
TaskNode* node_;
// for stop
bool stop_{false}; bool stop_{false};
// for runtime
platform::Place place_;
framework::Scope* root_scope_{nullptr};
framework::Scope* minibatch_scope_{nullptr};
std::vector<framework::Scope*> microbatch_scopes_{};
private: private:
// pool the local mailbox, parse the Message // pool the local mailbox, parse the Message
void PoolTheMailbox(); void PoolTheMailbox();
...@@ -78,12 +104,6 @@ class Interceptor { ...@@ -78,12 +104,6 @@ class Interceptor {
// return true if remote mailbox not empty, otherwise return false // return true if remote mailbox not empty, otherwise return false
bool FetchRemoteMailbox(); bool FetchRemoteMailbox();
// interceptor id, handed from above layer
int64_t interceptor_id_;
// node need to be handled by this interceptor
TaskNode* node_;
// interceptor handle which process message // interceptor handle which process message
MsgHandle handle_{nullptr}; MsgHandle handle_{nullptr};
......
...@@ -48,6 +48,7 @@ class TaskNode final { ...@@ -48,6 +48,7 @@ class TaskNode final {
const std::unordered_set<int64_t>& downstream() const { return downstream_; } const std::unordered_set<int64_t>& downstream() const { return downstream_; }
const std::string& type() const { return type_; } const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; } const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; }
bool AddUpstreamTask(int64_t task_id); bool AddUpstreamTask(int64_t task_id);
bool AddDownstreamTask(int64_t task_id); bool AddDownstreamTask(int64_t task_id);
......
set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS}) cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
USE_OP(elementwise_add);
USE_OP(fill_constant);
namespace paddle {
namespace distributed {
std::vector<framework::OperatorBase*> GetOps() {
framework::AttributeMap attrs;
attrs["dtype"] = framework::proto::VarType::FP32;
attrs["shape"] = framework::vectorize<int>({2, 3});
attrs["value"] = 1.0f;
auto zero_op = framework::OpRegistry::CreateOp("fill_constant", {},
{{"Out", {"x"}}}, attrs);
auto op = framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"x"}}, {"Y", {"x"}}}, {{"Out", {"out"}}},
framework::AttributeMap());
// NOTE: don't delete
return {zero_op.release(), op.release()};
}
framework::Scope* GetScope() {
framework::Scope* scope = new framework::Scope();
scope->Var("x")->GetMutable<framework::LoDTensor>();
scope->Var("out")->GetMutable<framework::LoDTensor>();
return scope;
}
TEST(ComputeInterceptor, Compute) {
std::vector<framework::OperatorBase*> ops = GetOps();
framework::Scope* scope = GetScope();
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 2); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
// a->b
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
auto* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
a->SetPlace(place);
a->SetMicroBatchScope(scopes);
carrier.SetCreatingFlag(false);
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
// stop
InterceptorMessage stop;
stop.set_message_type(STOP);
stop.set_src_id(-1);
stop.set_dst_id(0);
carrier.EnqueueInterceptorMessage(stop);
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册