未验证 提交 b2706b0c 编写于 作者: L LiYuRio 提交者: GitHub

add cond interceptor (#50019)

上级 8a69292b
...@@ -36,6 +36,7 @@ cc_library( ...@@ -36,6 +36,7 @@ cc_library(
interceptor.cc interceptor.cc
compute_interceptor.cc compute_interceptor.cc
amplifier_interceptor.cc amplifier_interceptor.cc
cond_interceptor.cc
source_interceptor.cc source_interceptor.cc
sink_interceptor.cc sink_interceptor.cc
message_service.cc message_service.cc
...@@ -66,6 +67,8 @@ if(WITH_DISTRIBUTE) ...@@ -66,6 +67,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties( set_source_files_properties(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
......
...@@ -33,6 +33,7 @@ USE_INTERCEPTOR(Source); ...@@ -33,6 +33,7 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink); USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond);
void Carrier::Init( void Carrier::Init(
int64_t rank, int64_t rank,
...@@ -96,29 +97,30 @@ void Carrier::CopyParameters( ...@@ -96,29 +97,30 @@ void Carrier::CopyParameters(
int microbatch_id, int microbatch_id,
const framework::ProgramDesc& program, const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);
std::map<std::string, int> inference_root_scope_var_map; std::map<std::string, int> inference_root_scope_var_map;
for (auto var_name : inference_root_scope_vars) { for (auto var_name : inference_root_scope_vars) {
inference_root_scope_var_map.insert({var_name, 1}); inference_root_scope_var_map.insert({var_name, 1});
} }
for (auto& var : global_block.AllVars()) { for (size_t i = 0; i < program.Size(); ++i) {
std::string var_name = var->Name(); for (auto& var : program.Block(i).AllVars()) {
bool force_root = inference_root_scope_var_map.find(var_name) != std::string var_name = var->Name();
inference_root_scope_var_map.end(); bool force_root = inference_root_scope_var_map.find(var_name) !=
if (force_root) { inference_root_scope_var_map.end();
VLOG(4) << var_name << " will be forced to be created in the root scope."; if (force_root) {
} VLOG(4) << var_name
if ((var->Persistable() || force_root) && microbatch_id == 0) { << " will be forced to be created in the root scope.";
auto* ptr = root_scope_->Var(var->Name()); }
InitializeVariable(ptr, var->GetType()); if ((var->Persistable() || force_root) && microbatch_id == 0) {
VLOG(5) << "Create persistable var: " << var->Name() auto* ptr = root_scope_->Var(var->Name());
<< ", which pointer is " << ptr; InitializeVariable(ptr, var->GetType());
} else if (!var->Persistable()) { VLOG(5) << "Create persistable var: " << var->Name()
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); << ", which pointer is " << ptr;
VLOG(5) << "Create variable " << var->Name() << " for microbatch " } else if (!var->Persistable()) {
<< microbatch_id << ", which pointer is " << ptr << "."; auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
InitializeVariable(ptr, var->GetType()); VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
} }
} }
} }
......
...@@ -125,6 +125,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -125,6 +125,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id << " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_; << " in scope: " << cur_scope_id_;
...@@ -152,6 +153,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -152,6 +153,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS); reply_msg.set_message_type(DATA_IS_USELESS);
reply_msg.set_scope_idx(cur_scope_id_);
Send(up_id, reply_msg); Send(up_id, reply_msg);
} }
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void CondInterceptor::PrepareDeps() {
auto& upstream = node_->upstream();
auto& downstream = node_->downstream();
auto& id_to_dep_type = node_->id_to_dep_type();
for (const auto& up : upstream) {
if (id_to_dep_type.at(up.first) == DependType::NORMAL) {
normal_in_id_.insert(up.first);
}
}
for (const auto& down : downstream) {
if (id_to_dep_type.at(down.first) == DependType::NORMAL) {
normal_out_id_.insert(down.first);
} else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) {
stop_loop_id_ = down.first;
}
}
}
bool CondInterceptor::GetCondResult() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* cond_var =
microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var());
PADDLE_ENFORCE(cond_var,
platform::errors::NotFound(
"Condition variable %s not exists in scope %ld",
node_->cond_var(),
cur_scope_id_));
const auto& cond_tensor = cond_var->Get<phi::DenseTensor>();
bool res = false;
if (platform::is_gpu_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_USELESS);
ready_msg.set_scope_idx(cur_scope_id_);
Send(up_id, ready_msg);
}
void CondInterceptor::Compute() {
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_;
for (auto& down_id : normal_out_id_) {
SendDataReady(down_id);
}
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ready_queue_.push(msg.scope_idx());
Compute();
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) {
ReplyDataIsUseless(up_id);
}
// Gc the variable in while block
int64_t scope_id = msg.scope_idx();
if (gc_) {
VLOG(3) << "Release vars in while block in scope " << scope_id;
framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id],
node_->while_block_vars(),
gc_.get());
}
}
}
}
REGISTER_INTERCEPTOR(Cond, CondInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class CondInterceptor final : public Interceptor {
public:
CondInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void PrepareDeps();
void Run(const InterceptorMessage& msg);
void Compute();
bool GetCondResult();
void SendDataReady(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_;
std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_;
};
} // namespace distributed
} // namespace paddle
...@@ -66,12 +66,11 @@ void FleetExecutor::Init( ...@@ -66,12 +66,11 @@ void FleetExecutor::Init(
"Fleet executor is inited with empty task node")); "Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph // TODO(fleet_exe devs): the unused_vars should be got from run time graph
std::vector<std::unique_ptr<framework::OperatorBase>> ops; std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto task_node : task_nodes) { for (const auto& desc : program_desc.Block(0).AllOps()) {
for (auto op : task_node->ops()) { ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
} }
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
// NOTE: For inference, the vars in inference_root_scope_vars // NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the // shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result. // inf. If they are GCed, it will cause error during ZeroCopy the result.
...@@ -107,6 +106,25 @@ void FleetExecutor::Init( ...@@ -107,6 +106,25 @@ void FleetExecutor::Init(
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars); task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars;
std::vector<std::string> vars_in_parent;
std::vector<std::string> vars_in_sub;
for (auto& var : program_desc.Block(0).AllVars()) {
vars_in_parent.emplace_back(var->Name());
}
for (auto& var : program_desc.Block(1).AllVars()) {
vars_in_sub.emplace_back(var->Name());
}
std::sort(vars_in_parent.begin(), vars_in_parent.end());
std::sort(vars_in_sub.begin(), vars_in_sub.end());
std::set_difference(vars_in_sub.begin(),
vars_in_sub.end(),
vars_in_parent.begin(),
vars_in_parent.end(),
std::back_inserter(while_block_vars));
task_node->SetWhileBlockVars(while_block_vars);
}
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
} }
......
...@@ -141,13 +141,19 @@ TaskNode::TaskNode(int32_t role, ...@@ -141,13 +141,19 @@ TaskNode::TaskNode(int32_t role,
max_run_times_(max_run_times), max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {} max_slot_nums_(max_slot_nums) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) { bool TaskNode::AddUpstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = upstream_.emplace(task_id, buff_size); const auto& ret = upstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second; return ret.second;
} }
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) { bool TaskNode::AddDownstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = downstream_.emplace(task_id, buff_size); const auto& ret = downstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second; return ret.second;
} }
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -29,6 +31,8 @@ class OpDesc; ...@@ -29,6 +31,8 @@ class OpDesc;
} // namespace framework } // namespace framework
namespace distributed { namespace distributed {
enum class DependType { NORMAL, LOOP, STOP_LOOP };
class TaskNode final { class TaskNode final {
public: public:
using OperatorBase = paddle::framework::OperatorBase; using OperatorBase = paddle::framework::OperatorBase;
...@@ -61,6 +65,7 @@ class TaskNode final { ...@@ -61,6 +65,7 @@ class TaskNode final {
int64_t rank, int64_t rank,
int64_t max_run_times, int64_t max_run_times,
int64_t max_slot_nums); int64_t max_slot_nums);
~TaskNode() = default; ~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program); void SetProgram(paddle::framework::ProgramDesc* program);
...@@ -74,6 +79,7 @@ class TaskNode final { ...@@ -74,6 +79,7 @@ class TaskNode final {
int64_t run_at_offset() const { return run_at_offset_; } int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; } int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; } int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::string& cond_var() const { return cond_var_; }
const std::unordered_map<int64_t, int64_t>& upstream() const { const std::unordered_map<int64_t, int64_t>& upstream() const {
return upstream_; return upstream_;
} }
...@@ -86,11 +92,20 @@ class TaskNode final { ...@@ -86,11 +92,20 @@ class TaskNode final {
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const { const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_; return ops_vec_;
} }
const std::unordered_map<int64_t, DependType> id_to_dep_type() const {
return id_to_dep_type_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>& const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const { unused_vars() const {
return unused_vars_; return unused_vars_;
} }
const std::vector<std::string> while_block_vars() const {
return while_block_vars_;
}
void SetCondVarName(const std::string& cond_var_name) {
cond_var_ = cond_var_name;
}
void SetRunPerSteps(int64_t value); void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value); void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value); void SetReplyUpPerSteps(int64_t value);
...@@ -101,10 +116,17 @@ class TaskNode final { ...@@ -101,10 +116,17 @@ class TaskNode final {
unused_vars) { unused_vars) {
unused_vars_ = unused_vars; unused_vars_ = unused_vars;
} }
void SetWhileBlockVars(const std::vector<std::string>& vars) {
while_block_vars_ = vars;
}
// upstream need buffs? // upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1); bool AddUpstreamTask(int64_t task_id,
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1); int64_t buff_size = 1,
DependType type = DependType::NORMAL);
bool AddDownstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
std::string DebugString() const; std::string DebugString() const;
private: private:
...@@ -115,10 +137,15 @@ class TaskNode final { ...@@ -115,10 +137,15 @@ class TaskNode final {
// task_id-->buff_size // task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_; std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_; std::unordered_map<int64_t, int64_t> downstream_;
// task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_; std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>> std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_; unused_vars_;
std::vector<std::string> while_block_vars_;
int32_t role_; int32_t role_;
int64_t rank_; int64_t rank_;
......
...@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> { ...@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using paddle::distributed::DependType;
using paddle::distributed::DistModel; using paddle::distributed::DistModel;
using paddle::distributed::DistModelConfig; using paddle::distributed::DistModelConfig;
using paddle::distributed::DistModelDataBuf; using paddle::distributed::DistModelDataBuf;
...@@ -164,6 +165,11 @@ void BindFleetExecutor(py::module* m) { ...@@ -164,6 +165,11 @@ void BindFleetExecutor(py::module* m) {
.def( .def(
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>()); "run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::enum_<DependType>(*m, "DependType")
.value("NORMAL", DependType::NORMAL)
.value("LOOP", DependType::LOOP)
.value("STOP_LOOP", DependType::STOP_LOOP);
py::class_<TaskNode>(*m, "TaskNode") py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*, .def(py::init<framework::ProgramDesc*,
int64_t, int64_t,
...@@ -183,6 +189,7 @@ void BindFleetExecutor(py::module* m) { ...@@ -183,6 +189,7 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_pre_steps", &TaskNode::SetRunPerSteps) .def("set_run_pre_steps", &TaskNode::SetRunPerSteps)
.def("set_run_at_offset", &TaskNode::SetRunAtOffset) .def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType) .def("set_type", &TaskNode::SetType)
.def("set_cond_var_name", &TaskNode::SetCondVarName)
.def("role", &TaskNode::role) .def("role", &TaskNode::role)
.def("init", [](TaskNode& self) { self.Init(); }) .def("init", [](TaskNode& self) { self.Init(); })
.def("set_program", &TaskNode::SetProgram); .def("set_program", &TaskNode::SetProgram);
......
...@@ -33,6 +33,7 @@ class TaskNode: ...@@ -33,6 +33,7 @@ class TaskNode:
ops=None, ops=None,
program=None, program=None,
lazy_initialize=False, lazy_initialize=False,
cond_var_name=None,
): ):
""" """
:param rank (int): Current rank of the task node. :param rank (int): Current rank of the task node.
...@@ -44,6 +45,7 @@ class TaskNode: ...@@ -44,6 +45,7 @@ class TaskNode:
:param ops (list): A list of op.desc to init the task node. (Will be removed in the future) :param ops (list): A list of op.desc to init the task node. (Will be removed in the future)
:param program (Program): An instance of Program to init the task node. :param program (Program): An instance of Program to init the task node.
:param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later. :param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later.
:param cond_var_name (string): Indicate the cond var name of while.
""" """
assert (ops is not None) ^ ( assert (ops is not None) ^ (
program is not None program is not None
...@@ -58,6 +60,7 @@ class TaskNode: ...@@ -58,6 +60,7 @@ class TaskNode:
self.node_type = node_type self.node_type = node_type
self.program = program self.program = program
self.lazy_initialize = lazy_initialize self.lazy_initialize = lazy_initialize
self.cond_var_name = cond_var_name
self.run_pre_steps = None self.run_pre_steps = None
self.run_at_offset = None self.run_at_offset = None
self.node = None self.node = None
...@@ -93,10 +96,12 @@ class TaskNode: ...@@ -93,10 +96,12 @@ class TaskNode:
self.node.set_run_pre_steps(self.run_pre_steps) self.node.set_run_pre_steps(self.run_pre_steps)
if self.run_at_offset: if self.run_at_offset:
self.node.set_run_at_offset(self.run_at_offset) self.node.set_run_at_offset(self.run_at_offset)
if self.cond_var_name:
self.node.set_cond_var_name(self.cond_var_name)
for up in self.upstreams: for up in self.upstreams:
self.node.add_upstream_task(up[0], up[1]) self.node.add_upstream_task(up[0], up[1], up[2])
for down in self.downstreams: for down in self.downstreams:
self.node.add_downstream_task(down[0], down[1]) self.node.add_downstream_task(down[0], down[1], down[2])
self.lazy_initialize = False self.lazy_initialize = False
return self.node return self.node
...@@ -124,17 +129,21 @@ class TaskNode: ...@@ -124,17 +129,21 @@ class TaskNode:
else: else:
self.node.set_run_at_offset(offset) self.node.set_run_at_offset(offset)
def add_upstream_task(self, upstream, buffer_size=2): def add_upstream_task(
self, upstream, buffer_size=2, depend_type=core.DependType.NORMAL
):
if self.lazy_initialize: if self.lazy_initialize:
self.upstreams.append((upstream, buffer_size)) self.upstreams.append((upstream, buffer_size, depend_type))
else: else:
self.node.add_upstream_task(upstream, buffer_size) self.node.add_upstream_task(upstream, buffer_size, depend_type)
def add_downstream_task(self, downstream, buffer_size=2): def add_downstream_task(
self, downstream, buffer_size=2, depend_type=core.DependType.NORMAL
):
if self.lazy_initialize: if self.lazy_initialize:
self.downstreams.append((downstream, buffer_size)) self.downstreams.append((downstream, buffer_size, depend_type))
else: else:
self.node.add_downstream_task(downstream, buffer_size) self.node.add_downstream_task(downstream, buffer_size, depend_type)
def task_id(self): def task_id(self):
return self.id return self.id
......
...@@ -100,6 +100,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) ...@@ -100,6 +100,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node) list(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node)
list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_run) list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_run)
list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_tensor) list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_tensor)
list(REMOVE_ITEM TEST_OPS test_fleet_executor_cond_interceptor)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_deprecated_decorator) list(REMOVE_ITEM TEST_OPS test_deprecated_decorator)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid.core as core
from paddle.distributed.fleet.fleet_executor_utils import TaskNode
paddle.enable_static()
def cond(i, ten):
return i < ten
def body(i, ten):
i = i + 1
return [i, ten]
class TestFleetExecutor(unittest.TestCase):
def test_cond_interceptor(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.full(
shape=[1], fill_value=0, dtype='int64'
) # loop counter
ten = paddle.full(
shape=[1], fill_value=10, dtype='int64'
) # loop length
i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
program_a = paddle.static.Program()
program_b = paddle.static.Program()
for var_name in main_program.block(0).vars:
if var_name != "_generated_var_0":
var = main_program.block(0).var(var_name)
program_a.block(0).create_var(
name=var_name,
shape=var.shape,
dtype=var.dtype,
stop_gradient=var.stop_gradient,
)
program_b.block(0).create_var(
name=var_name,
shape=var.shape,
dtype=var.dtype,
stop_gradient=var.stop_gradient,
)
for op in main_program.block(0).ops:
if op.type != "while":
program_a.block(0).append_op(
type=op.type,
inputs=op.desc.inputs(),
outputs=op.desc.outputs(),
attrs=op.all_attrs(),
)
for var_name in main_program.block(1).vars:
var = main_program.block(1).var(var_name)
program_b.block(0).create_var(
name=var_name,
shape=var.shape,
dtype=var.dtype,
stop_gradient=var.stop_gradient,
)
for op in main_program.block(1).ops:
program_b.block(0).append_op(
type=op.type,
inputs=op.desc.inputs(),
outputs=op.desc.outputs(),
attrs=op.all_attrs(),
)
cond_var_name = "tmp_0"
num_micro_batches = 3
task_a = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=0,
program=program_a,
lazy_initialize=True,
)
task_b = TaskNode(
0,
num_micro_batches,
0,
node_type="Cond",
task_id=1,
program=paddle.static.Program(),
cond_var_name=cond_var_name,
lazy_initialize=True,
)
task_c = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=2,
program=program_b,
lazy_initialize=True,
)
task_d = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=3,
program=paddle.static.Program(),
lazy_initialize=True,
)
task_e = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=4,
program=paddle.static.Program(),
lazy_initialize=True,
)
task_a.add_downstream_task(task_b.task_id(), 2)
task_b.add_upstream_task(task_a.task_id(), 2)
task_b.add_downstream_task(task_c.task_id(), 100)
task_c.add_upstream_task(task_b.task_id(), 100)
task_c.add_downstream_task(task_d.task_id(), 2)
task_d.add_upstream_task(task_c.task_id(), 2)
task_d.add_downstream_task(task_b.task_id(), 100, core.DependType.LOOP)
task_b.add_upstream_task(task_d.task_id(), 100, core.DependType.LOOP)
task_b.add_downstream_task(
task_e.task_id(), 100, core.DependType.STOP_LOOP
)
task_e.add_upstream_task(
task_b.task_id(), 100, core.DependType.STOP_LOOP
)
main_program._pipeline_opt = {
"fleet_opt": {
'tasks': [task_a, task_b, task_c, task_d, task_e],
'task_id_to_rank': {
task_a.task_id(): 0,
task_b.task_id(): 0,
task_c.task_id(): 0,
task_d.task_id(): 0,
task_e.task_id(): 0,
},
'num_micro_batches': num_micro_batches,
},
}
place = paddle.fluid.CUDAPlace(0)
exe = paddle.fluid.Executor(place)
exe.run(main_program)
if __name__ == "__main__":
unittest.main()
...@@ -31,9 +31,15 @@ class TestFleetExecutorTaskNode(unittest.TestCase): ...@@ -31,9 +31,15 @@ class TestFleetExecutorTaskNode(unittest.TestCase):
self.assertEqual(task_node_1.task_id(), 1) self.assertEqual(task_node_1.task_id(), 1)
self.assertEqual(task_node_2.task_id(), 2) self.assertEqual(task_node_2.task_id(), 2)
self.assertTrue( self.assertTrue(
task_node_0.add_downstream_task(task_node_1.task_id(), 1) task_node_0.add_downstream_task(
task_node_1.task_id(), 1, core.DependType.NORMAL
)
)
self.assertTrue(
task_node_1.add_upstream_task(
task_node_0.task_id(), 1, core.DependType.NORMAL
)
) )
self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id(), 1))
def test_lazy_task_node(self): def test_lazy_task_node(self):
program = paddle.static.Program() program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册