diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 9a80470c1e7afc15047d89db264ba0f4b9eff0b9..09b86ba18e34be69a79a8cd2a25d543b5643793a 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/operator.h" namespace paddle { namespace distributed { @@ -40,9 +41,22 @@ void ComputeInterceptor::PrepareDeps() { for (auto down_id : downstream) { out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0)); } + + // source compute node, should we add a new SourceInterceptor? + if (upstream.empty()) { + is_source_ = true; + PADDLE_ENFORCE_GT(node_->max_run_times(), 0, + platform::errors::InvalidArgument( + "Source ComputeInterceptor must run at least one " + "times, but now max_run_times=%ld", + node_->max_run_times())); + } } void ComputeInterceptor::IncreaseReady(int64_t up_id) { + // source node has no upstream, data_is_ready is send by carrier or others + if (is_source_ && up_id == -1) return; + auto it = in_readys_.find(up_id); PADDLE_ENFORCE_NE(it, in_readys_.end(), platform::errors::NotFound( @@ -93,6 +107,12 @@ bool ComputeInterceptor::CanWriteOutput() { return true; } +// only source node need reset +bool ComputeInterceptor::ShouldReset() { + if (is_source_ && step_ == node_->max_run_times()) return true; + return false; +} + void ComputeInterceptor::SendDataReadyToDownStream() { for (auto& outs : out_buffs_) { auto down_id = outs.first; @@ -134,9 +154,27 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { } void ComputeInterceptor::Run() { - while (IsInputReady() && CanWriteOutput()) { + // If there is no limit, source interceptor can be executed + // an unlimited number of times. + // Now source node can only run + if (ShouldReset()) { + for (auto& out_buff : out_buffs_) { + // buffer is using + if (out_buff.second.second != 0) return; + } + step_ = 0; // reset + return; + } + + while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; - // TODO(wangxi): add op run + + // step_ %= node_->max_run_times(); + for (auto op : node_->ops()) { + auto* scope = microbatch_scopes_[step_ % node_->max_slot_nums()]; + op->Run(*scope, place_); + } + ++step_; // send to downstream and increase buff used SendDataReadyToDownStream(); @@ -149,7 +187,7 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) { received_stop_ = true; // source node has no upstream, stop is send by carrier or others - if (up_id == -1) return; + if (is_source_ && up_id == -1) return; auto it = in_stops_.find(up_id); PADDLE_ENFORCE_NE(it, in_stops_.end(), diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index a116abeff80bf27f6b1e666cb9800ab2a9c4e505..fd540e81afacae1cedfc410464ddf774b1bc7f27 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -31,6 +31,7 @@ class ComputeInterceptor : public Interceptor { void DecreaseBuff(int64_t down_id); bool IsInputReady(); bool CanWriteOutput(); + bool ShouldReset(); void SendDataReadyToDownStream(); void ReplyCompletedToUpStream(); @@ -43,8 +44,9 @@ class ComputeInterceptor : public Interceptor { void TryStop(); private: - // FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0 + bool is_source_{false}; int64_t step_{0}; + // upstream_id-->(max_ready_size, ready_size) std::map> in_readys_{}; // downstream_id-->(max_buffer_size, used_size) diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 9f74f99ded6c35061a6724036b4e8a0323e4eaef..052c0cc55d550cbae3b4ce7803143ac76cee6c76 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -26,8 +26,12 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/place.h" namespace paddle { +namespace framework { +class Scope; +} namespace distributed { class TaskNode; @@ -64,12 +68,34 @@ class Interceptor { bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT + void SetPlace(const platform::Place& place) { place_ = place; } + + void SetRootScope(framework::Scope* scope) { root_scope_ = scope; } + void SetMiniBatchScope(framework::Scope* scope) { minibatch_scope_ = scope; } + void SetMicroBatchScope(const std::vector& scopes) { + microbatch_scopes_ = scopes; + } + + TaskNode* GetTaskNode() const { return node_; } + DISABLE_COPY_AND_ASSIGN(Interceptor); protected: - TaskNode* GetTaskNode() const { return node_; } + // interceptor id, handed from above layer + int64_t interceptor_id_; + + // node need to be handled by this interceptor + TaskNode* node_; + + // for stop bool stop_{false}; + // for runtime + platform::Place place_; + framework::Scope* root_scope_{nullptr}; + framework::Scope* minibatch_scope_{nullptr}; + std::vector microbatch_scopes_{}; + private: // pool the local mailbox, parse the Message void PoolTheMailbox(); @@ -78,12 +104,6 @@ class Interceptor { // return true if remote mailbox not empty, otherwise return false bool FetchRemoteMailbox(); - // interceptor id, handed from above layer - int64_t interceptor_id_; - - // node need to be handled by this interceptor - TaskNode* node_; - // interceptor handle which process message MsgHandle handle_{nullptr}; diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index ec2ea0c0093cd4bb8b8b8b9a4b5cfd0056dcec23..8f4f9d80c42a581cdb7ca0ac5764e9366e29e7ea 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -48,6 +48,7 @@ class TaskNode final { const std::unordered_set& downstream() const { return downstream_; } const std::string& type() const { return type_; } const paddle::framework::ProgramDesc& program() const { return program_; } + const std::vector& ops() const { return ops_; } bool AddUpstreamTask(int64_t task_id); bool AddDownstreamTask(int64_t task_id); diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt index 7e6d887a2d0ed8afb352bd1102c2aca7482571dd..b0f00d7058476896fe5a9bc5ac1a1b098c14464e 100644 --- a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -1,7 +1,12 @@ set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) + +set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) + +set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context) + if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d9776738f83184bfaadade01dd231712b3b6241 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" + +USE_OP(elementwise_add); +USE_OP(fill_constant); + +namespace paddle { +namespace distributed { + +std::vector GetOps() { + framework::AttributeMap attrs; + attrs["dtype"] = framework::proto::VarType::FP32; + attrs["shape"] = framework::vectorize({2, 3}); + attrs["value"] = 1.0f; + + auto zero_op = framework::OpRegistry::CreateOp("fill_constant", {}, + {{"Out", {"x"}}}, attrs); + + auto op = framework::OpRegistry::CreateOp( + "elementwise_add", {{"X", {"x"}}, {"Y", {"x"}}}, {{"Out", {"out"}}}, + framework::AttributeMap()); + + // NOTE: don't delete + return {zero_op.release(), op.release()}; +} + +framework::Scope* GetScope() { + framework::Scope* scope = new framework::Scope(); + + scope->Var("x")->GetMutable(); + scope->Var("out")->GetMutable(); + return scope; +} + +TEST(ComputeInterceptor, Compute) { + std::vector ops = GetOps(); + framework::Scope* scope = GetScope(); + std::vector scopes = {scope, scope}; + platform::Place place = platform::CPUPlace(); + + MessageBus& msg_bus = MessageBus::Instance(); + msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + + Carrier& carrier = Carrier::Instance(); + + // FIXME: don't delete, otherwise interceptor will use undefined node + TaskNode* node_a = + new TaskNode(0, ops, 0, 0, 2, 2); // role, ops, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); + + // a->b + node_a->AddDownstreamTask(1); + node_b->AddUpstreamTask(0); + + auto* a = carrier.SetInterceptor( + 0, InterceptorFactory::Create("Compute", 0, node_a)); + carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + + a->SetPlace(place); + a->SetMicroBatchScope(scopes); + + carrier.SetCreatingFlag(false); + + // start + InterceptorMessage msg; + msg.set_message_type(DATA_IS_READY); + msg.set_src_id(-1); + msg.set_dst_id(0); + carrier.EnqueueInterceptorMessage(msg); + + // stop + InterceptorMessage stop; + stop.set_message_type(STOP); + stop.set_src_id(-1); + stop.set_dst_id(0); + carrier.EnqueueInterceptorMessage(stop); +} + +} // namespace distributed +} // namespace paddle