未验证 提交 b2706b0c 编写于 作者: L LiYuRio 提交者: GitHub

add cond interceptor (#50019)

上级 8a69292b
......@@ -36,6 +36,7 @@ cc_library(
interceptor.cc
compute_interceptor.cc
amplifier_interceptor.cc
cond_interceptor.cc
source_interceptor.cc
sink_interceptor.cc
message_service.cc
......@@ -66,6 +67,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
......
......@@ -33,6 +33,7 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond);
void Carrier::Init(
int64_t rank,
......@@ -96,29 +97,30 @@ void Carrier::CopyParameters(
int microbatch_id,
const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);
std::map<std::string, int> inference_root_scope_var_map;
for (auto var_name : inference_root_scope_vars) {
inference_root_scope_var_map.insert({var_name, 1});
}
for (auto& var : global_block.AllVars()) {
std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end();
if (force_root) {
VLOG(4) << var_name << " will be forced to be created in the root scope.";
}
if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
for (size_t i = 0; i < program.Size(); ++i) {
for (auto& var : program.Block(i).AllVars()) {
std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end();
if (force_root) {
VLOG(4) << var_name
<< " will be forced to be created in the root scope.";
}
if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
}
}
}
......
......@@ -125,6 +125,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_;
......@@ -152,6 +153,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS);
reply_msg.set_scope_idx(cur_scope_id_);
Send(up_id, reply_msg);
}
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void CondInterceptor::PrepareDeps() {
auto& upstream = node_->upstream();
auto& downstream = node_->downstream();
auto& id_to_dep_type = node_->id_to_dep_type();
for (const auto& up : upstream) {
if (id_to_dep_type.at(up.first) == DependType::NORMAL) {
normal_in_id_.insert(up.first);
}
}
for (const auto& down : downstream) {
if (id_to_dep_type.at(down.first) == DependType::NORMAL) {
normal_out_id_.insert(down.first);
} else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) {
stop_loop_id_ = down.first;
}
}
}
bool CondInterceptor::GetCondResult() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* cond_var =
microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var());
PADDLE_ENFORCE(cond_var,
platform::errors::NotFound(
"Condition variable %s not exists in scope %ld",
node_->cond_var(),
cur_scope_id_));
const auto& cond_tensor = cond_var->Get<phi::DenseTensor>();
bool res = false;
if (platform::is_gpu_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_USELESS);
ready_msg.set_scope_idx(cur_scope_id_);
Send(up_id, ready_msg);
}
void CondInterceptor::Compute() {
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_;
for (auto& down_id : normal_out_id_) {
SendDataReady(down_id);
}
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ready_queue_.push(msg.scope_idx());
Compute();
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) {
ReplyDataIsUseless(up_id);
}
// Gc the variable in while block
int64_t scope_id = msg.scope_idx();
if (gc_) {
VLOG(3) << "Release vars in while block in scope " << scope_id;
framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id],
node_->while_block_vars(),
gc_.get());
}
}
}
}
REGISTER_INTERCEPTOR(Cond, CondInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class CondInterceptor final : public Interceptor {
public:
CondInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void PrepareDeps();
void Run(const InterceptorMessage& msg);
void Compute();
bool GetCondResult();
void SendDataReady(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_;
std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_;
};
} // namespace distributed
} // namespace paddle
......@@ -66,12 +66,11 @@ void FleetExecutor::Init(
"Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto task_node : task_nodes) {
for (auto op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
......@@ -107,6 +106,25 @@ void FleetExecutor::Init(
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars;
std::vector<std::string> vars_in_parent;
std::vector<std::string> vars_in_sub;
for (auto& var : program_desc.Block(0).AllVars()) {
vars_in_parent.emplace_back(var->Name());
}
for (auto& var : program_desc.Block(1).AllVars()) {
vars_in_sub.emplace_back(var->Name());
}
std::sort(vars_in_parent.begin(), vars_in_parent.end());
std::sort(vars_in_sub.begin(), vars_in_sub.end());
std::set_difference(vars_in_sub.begin(),
vars_in_sub.end(),
vars_in_parent.begin(),
vars_in_parent.end(),
std::back_inserter(while_block_vars));
task_node->SetWhileBlockVars(while_block_vars);
}
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
}
......
......@@ -141,13 +141,19 @@ TaskNode::TaskNode(int32_t role,
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) {
bool TaskNode::AddUpstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = upstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second;
}
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) {
bool TaskNode::AddDownstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = downstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second;
}
......
......@@ -14,8 +14,10 @@
#pragma once
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -29,6 +31,8 @@ class OpDesc;
} // namespace framework
namespace distributed {
enum class DependType { NORMAL, LOOP, STOP_LOOP };
class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
......@@ -61,6 +65,7 @@ class TaskNode final {
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
......@@ -74,6 +79,7 @@ class TaskNode final {
int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::string& cond_var() const { return cond_var_; }
const std::unordered_map<int64_t, int64_t>& upstream() const {
return upstream_;
}
......@@ -86,11 +92,20 @@ class TaskNode final {
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_;
}
const std::unordered_map<int64_t, DependType> id_to_dep_type() const {
return id_to_dep_type_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const {
return unused_vars_;
}
const std::vector<std::string> while_block_vars() const {
return while_block_vars_;
}
void SetCondVarName(const std::string& cond_var_name) {
cond_var_ = cond_var_name;
}
void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value);
......@@ -101,10 +116,17 @@ class TaskNode final {
unused_vars) {
unused_vars_ = unused_vars;
}
void SetWhileBlockVars(const std::vector<std::string>& vars) {
while_block_vars_ = vars;
}
// upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddUpstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
bool AddDownstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
std::string DebugString() const;
private:
......@@ -115,10 +137,15 @@ class TaskNode final {
// task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_;
// task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_;
std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_;
std::vector<std::string> while_block_vars_;
int32_t role_;
int64_t rank_;
......
......@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace paddle {
namespace pybind {
using paddle::distributed::DependType;
using paddle::distributed::DistModel;
using paddle::distributed::DistModelConfig;
using paddle::distributed::DistModelDataBuf;
......@@ -164,6 +165,11 @@ void BindFleetExecutor(py::module* m) {
.def(
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::enum_<DependType>(*m, "DependType")
.value("NORMAL", DependType::NORMAL)
.value("LOOP", DependType::LOOP)
.value("STOP_LOOP", DependType::STOP_LOOP);
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
......@@ -183,6 +189,7 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_pre_steps", &TaskNode::SetRunPerSteps)
.def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType)
.def("set_cond_var_name", &TaskNode::SetCondVarName)
.def("role", &TaskNode::role)
.def("init", [](TaskNode& self) { self.Init(); })
.def("set_program", &TaskNode::SetProgram);
......
......@@ -33,6 +33,7 @@ class TaskNode:
ops=None,
program=None,
lazy_initialize=False,
cond_var_name=None,
):
"""
:param rank (int): Current rank of the task node.
......@@ -44,6 +45,7 @@ class TaskNode:
:param ops (list): A list of op.desc to init the task node. (Will be removed in the future)
:param program (Program): An instance of Program to init the task node.
:param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later.
:param cond_var_name (string): Indicate the cond var name of while.
"""
assert (ops is not None) ^ (
program is not None
......@@ -58,6 +60,7 @@ class TaskNode:
self.node_type = node_type
self.program = program
self.lazy_initialize = lazy_initialize
self.cond_var_name = cond_var_name
self.run_pre_steps = None
self.run_at_offset = None
self.node = None
......@@ -93,10 +96,12 @@ class TaskNode:
self.node.set_run_pre_steps(self.run_pre_steps)
if self.run_at_offset:
self.node.set_run_at_offset(self.run_at_offset)
if self.cond_var_name:
self.node.set_cond_var_name(self.cond_var_name)
for up in self.upstreams:
self.node.add_upstream_task(up[0], up[1])
self.node.add_upstream_task(up[0], up[1], up[2])
for down in self.downstreams:
self.node.add_downstream_task(down[0], down[1])
self.node.add_downstream_task(down[0], down[1], down[2])
self.lazy_initialize = False
return self.node
......@@ -124,17 +129,21 @@ class TaskNode:
else:
self.node.set_run_at_offset(offset)
def add_upstream_task(self, upstream, buffer_size=2):
def add_upstream_task(
self, upstream, buffer_size=2, depend_type=core.DependType.NORMAL
):
if self.lazy_initialize:
self.upstreams.append((upstream, buffer_size))
self.upstreams.append((upstream, buffer_size, depend_type))
else:
self.node.add_upstream_task(upstream, buffer_size)
self.node.add_upstream_task(upstream, buffer_size, depend_type)
def add_downstream_task(self, downstream, buffer_size=2):
def add_downstream_task(
self, downstream, buffer_size=2, depend_type=core.DependType.NORMAL
):
if self.lazy_initialize:
self.downstreams.append((downstream, buffer_size))
self.downstreams.append((downstream, buffer_size, depend_type))
else:
self.node.add_downstream_task(downstream, buffer_size)
self.node.add_downstream_task(downstream, buffer_size, depend_type)
def task_id(self):
return self.id
......
......@@ -100,6 +100,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node)
list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_run)
list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_tensor)
list(REMOVE_ITEM TEST_OPS test_fleet_executor_cond_interceptor)
endif()
list(REMOVE_ITEM TEST_OPS test_deprecated_decorator)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid.core as core
from paddle.distributed.fleet.fleet_executor_utils import TaskNode
paddle.enable_static()
def cond(i, ten):
return i < ten
def body(i, ten):
i = i + 1
return [i, ten]
class TestFleetExecutor(unittest.TestCase):
def test_cond_interceptor(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.full(
shape=[1], fill_value=0, dtype='int64'
) # loop counter
ten = paddle.full(
shape=[1], fill_value=10, dtype='int64'
) # loop length
i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
program_a = paddle.static.Program()
program_b = paddle.static.Program()
for var_name in main_program.block(0).vars:
if var_name != "_generated_var_0":
var = main_program.block(0).var(var_name)
program_a.block(0).create_var(
name=var_name,
shape=var.shape,
dtype=var.dtype,
stop_gradient=var.stop_gradient,
)
program_b.block(0).create_var(
name=var_name,
shape=var.shape,
dtype=var.dtype,
stop_gradient=var.stop_gradient,
)
for op in main_program.block(0).ops:
if op.type != "while":
program_a.block(0).append_op(
type=op.type,
inputs=op.desc.inputs(),
outputs=op.desc.outputs(),
attrs=op.all_attrs(),
)
for var_name in main_program.block(1).vars:
var = main_program.block(1).var(var_name)
program_b.block(0).create_var(
name=var_name,
shape=var.shape,
dtype=var.dtype,
stop_gradient=var.stop_gradient,
)
for op in main_program.block(1).ops:
program_b.block(0).append_op(
type=op.type,
inputs=op.desc.inputs(),
outputs=op.desc.outputs(),
attrs=op.all_attrs(),
)
cond_var_name = "tmp_0"
num_micro_batches = 3
task_a = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=0,
program=program_a,
lazy_initialize=True,
)
task_b = TaskNode(
0,
num_micro_batches,
0,
node_type="Cond",
task_id=1,
program=paddle.static.Program(),
cond_var_name=cond_var_name,
lazy_initialize=True,
)
task_c = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=2,
program=program_b,
lazy_initialize=True,
)
task_d = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=3,
program=paddle.static.Program(),
lazy_initialize=True,
)
task_e = TaskNode(
0,
num_micro_batches,
0,
node_type="Compute",
task_id=4,
program=paddle.static.Program(),
lazy_initialize=True,
)
task_a.add_downstream_task(task_b.task_id(), 2)
task_b.add_upstream_task(task_a.task_id(), 2)
task_b.add_downstream_task(task_c.task_id(), 100)
task_c.add_upstream_task(task_b.task_id(), 100)
task_c.add_downstream_task(task_d.task_id(), 2)
task_d.add_upstream_task(task_c.task_id(), 2)
task_d.add_downstream_task(task_b.task_id(), 100, core.DependType.LOOP)
task_b.add_upstream_task(task_d.task_id(), 100, core.DependType.LOOP)
task_b.add_downstream_task(
task_e.task_id(), 100, core.DependType.STOP_LOOP
)
task_e.add_upstream_task(
task_b.task_id(), 100, core.DependType.STOP_LOOP
)
main_program._pipeline_opt = {
"fleet_opt": {
'tasks': [task_a, task_b, task_c, task_d, task_e],
'task_id_to_rank': {
task_a.task_id(): 0,
task_b.task_id(): 0,
task_c.task_id(): 0,
task_d.task_id(): 0,
task_e.task_id(): 0,
},
'num_micro_batches': num_micro_batches,
},
}
place = paddle.fluid.CUDAPlace(0)
exe = paddle.fluid.Executor(place)
exe.run(main_program)
if __name__ == "__main__":
unittest.main()
......@@ -31,9 +31,15 @@ class TestFleetExecutorTaskNode(unittest.TestCase):
self.assertEqual(task_node_1.task_id(), 1)
self.assertEqual(task_node_2.task_id(), 2)
self.assertTrue(
task_node_0.add_downstream_task(task_node_1.task_id(), 1)
task_node_0.add_downstream_task(
task_node_1.task_id(), 1, core.DependType.NORMAL
)
)
self.assertTrue(
task_node_1.add_upstream_task(
task_node_0.task_id(), 1, core.DependType.NORMAL
)
)
self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id(), 1))
def test_lazy_task_node(self):
program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册