未验证 提交 2273471d 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Support multi carriers (#38650)

上级 2d2609ea
......@@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier);
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids) {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
......@@ -45,14 +43,12 @@ void Carrier::Init(
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
......@@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
int64_t carrier_id = *GlobalMap<int64_t, int64_t>::Get(dst_id);
return GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(msg);
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
......@@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);
// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap<int64_t, int64_t>::Create(interceptor_id, carrier_id_);
auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
......@@ -220,19 +211,15 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}
void Carrier::CreateInterceptors() {
if (interceptor_ids_.empty()) return;
if (interceptor_id_to_node_.empty()) return;
auto gc = GetGC(place_);
// create each Interceptor
// no auto init since there is no config
for (int64_t interceptor_id : interceptor_ids_) {
const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id);
PADDLE_ENFORCE_NE(
task_node_iter, interceptor_id_to_node_.end(),
platform::errors::NotFound("Can not find task node for interceptor %ld",
interceptor_id));
TaskNode* task_node = task_node_iter->second;
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
......
......@@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl;
class RuntimeGraph;
class MessageBus;
// TODO(liyurui): Add CarrierId instead of std::string
class Carrier final {
public:
explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {}
explicit Carrier(const std::string& carrier_id) : carrier_id_(carrier_id) {}
~Carrier();
void Init(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids);
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank);
void Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
......@@ -109,7 +109,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
int64_t carrier_id_;
std::string carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_;
......
......@@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Release();
for (const auto& carrier_id : carrier_ids_) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
}
}
void FleetExecutor::Init(
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place, const std::vector<TaskNode*>& task_nodes,
const std::string& carrier_id, const framework::ProgramDesc& program_desc,
framework::Scope* scope, const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
PADDLE_ENFORCE_GT(task_nodes.size(), 0,
platform::errors::InvalidArgument(
......@@ -58,19 +59,13 @@ void FleetExecutor::Init(
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids;
std::unordered_set<int64_t> interceptor_ids;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
interceptor_ids.insert(interceptor_id);
}
carrier_id_to_interceptor_ids.emplace(0, interceptor_ids);
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
for (auto& unique_op : ops) {
unique_op.release();
}
......@@ -87,27 +82,23 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Create(item.first, item.first);
}
InitCarrier();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier_ids_.insert(carrier_id);
GlobalVal<std::string>::Set(carrier_id);
// TODO(liyurui): Maybe message bus should be created only once
InitCarrier(carrier);
InitMessageBus();
// Wait for all message bus connected.
msg_bus_->Barrier();
}
void FleetExecutor::InitCarrier() {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Get(item.first);
PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument(
"Carrier has not been created."));
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(), item.second,
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
void FleetExecutor::InitCarrier(Carrier* carrier) {
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
void FleetExecutor::InitMessageBus() {
......@@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() {
}
}
void FleetExecutor::Run() {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Start();
}
void FleetExecutor::Run(const std::string& carrier_id) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Start();
GlobalVal<std::string>::Set(carrier_id);
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
......
......@@ -37,16 +37,17 @@ class FleetExecutor final {
FleetExecutor() = delete;
explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor();
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
void Run(const std::string& carrier_id);
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
void InitMessageBus();
void InitCarrier();
void InitCarrier(Carrier* carrier);
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_;
......@@ -57,6 +58,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
std::unordered_set<std::string> carrier_ids_;
};
} // namespace distributed
......
......@@ -17,6 +17,24 @@
namespace paddle {
namespace distributed {
// TODO(liyurui): Change this file to global.h
template <typename T>
class GlobalVal final {
public:
static T Get() { return *GetPtr(); }
static T Set(T val) {
auto* ptr = GetPtr();
*ptr = val;
return val;
}
private:
static T* GetPtr() {
static T value;
return &value;
}
};
template <typename KeyT, typename ValueT>
class GlobalMap final {
public:
......@@ -26,6 +44,7 @@ class GlobalMap final {
item, platform::errors::NotFound("This value is not in global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
......@@ -37,6 +56,34 @@ class GlobalMap final {
return item;
}
private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::unordered_map<KeyT, std::unique_ptr<ValueT>> id_to_ptr;
return &id_to_ptr[id];
}
};
template <typename KeyT, typename ValueT>
class ThreadSafeGlobalMap final {
public:
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound(
"This value is not in thread safe global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value has already in thread safe global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}
private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::mutex mutex;
......
......@@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
// TODO(liyurui): Remove this hard code.
int64_t carrier_id;
if (request->ctrl_message()) {
carrier_id = 0;
} else {
carrier_id = *GlobalMap<int64_t, int64_t>::Get(request->dst_id());
}
bool flag = GlobalMap<int64_t, Carrier>::Get(carrier_id)
const auto& carrier_id = GlobalVal<std::string>::Get();
bool flag = GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request);
response->set_rst(flag);
}
......
......@@ -35,10 +35,6 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_;
}
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids() const {
return carrier_id_to_interceptor_ids_;
}
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank;
......@@ -47,19 +43,12 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node;
}
void SetCarrierIdToInterceptorIds(
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids) {
carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids;
}
std::string DebugString() const;
private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids_;
};
} // namespace distributed
......
......@@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)
set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS})
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
......
......@@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
......
......@@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor {
};
TEST(ComputeInterceptor, Compute) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class ParcelInterceptor : public Interceptor {
public:
ParcelInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle(
[this](const InterceptorMessage& msg) { PassParcel(msg); });
}
void PassParcel(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
if (count_ == 5 && interceptor_id_ == 0) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
StopCarrier();
return;
}
++count_;
InterceptorMessage new_msg;
if (msg.dst_id() == 3) {
Send(0, new_msg);
} else {
Send(msg.dst_id() + 1, new_msg);
}
}
private:
int count_{0};
};
REGISTER_INTERCEPTOR(Parcel, ParcelInterceptor);
TEST(InterceptorTest, PassTheParcel) {
auto msg_bus = std::make_shared<MessageBus>();
Carrier* carrier_0 = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier_0->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0});
carrier_0->SetMsgBus(msg_bus);
Carrier* carrier_1 = GlobalMap<int64_t, Carrier>::Create(1, 1);
carrier_1->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {1});
carrier_1->SetMsgBus(msg_bus);
Carrier* carrier_2 = GlobalMap<int64_t, Carrier>::Create(2, 2);
carrier_2->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {2});
carrier_2->SetMsgBus(msg_bus);
Carrier* carrier_3 = GlobalMap<int64_t, Carrier>::Create(3, 3);
carrier_3->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {3});
carrier_3->SetMsgBus(msg_bus);
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Interceptor* a = carrier_0->SetInterceptor(
0, InterceptorFactory::Create("Parcel", 0, nullptr));
carrier_1->SetInterceptor(1,
InterceptorFactory::Create("Parcel", 1, nullptr));
carrier_2->SetInterceptor(2,
InterceptorFactory::Create("Parcel", 2, nullptr));
carrier_3->SetInterceptor(3,
InterceptorFactory::Create("Parcel", 3, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
carrier_0->Wait();
}
} // namespace distributed
} // namespace paddle
......@@ -60,8 +60,10 @@ class PingPongInterceptor : public Interceptor {
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
carrier->SetMsgBus(msg_bus);
......
......@@ -106,16 +106,18 @@ TEST(InterceptorTest, PingPong) {
std::cout << "ip1: " << ip1 << std::endl;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0},
{1, 1}};
std::string carrier_id = "0";
int pid = fork();
if (pid == 0) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(carrier_id);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus
carrier->Init(0, interceptor_id_to_rank, {0});
carrier->Init(0, interceptor_id_to_rank);
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->SetMsgBus(msg_bus);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
msg_bus->Barrier();
......@@ -123,10 +125,12 @@ TEST(InterceptorTest, PingPong) {
a->Send(1, msg);
carrier->Wait();
} else {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(carrier_id);
auto msg_bus = std::make_shared<MessageBus>();
carrier->SetMsgBus(msg_bus);
carrier->Init(1, interceptor_id_to_rank, {1});
carrier->Init(1, interceptor_id_to_rank);
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr));
......
......@@ -52,9 +52,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes) {
}
TEST(AmplifierInterceptor, Amplifier) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}},
{0, 1, 2, 3, 4, 5});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
carrier->SetMsgBus(msg_bus);
......
......@@ -70,8 +70,10 @@ void LinkNodes(const std::vector<TaskNode*>& nodes,
}
TEST(AmplifierInterceptor, Amplifier) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {0, 1, 2, 3});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}});
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, ""}}, "");
carrier->SetMsgBus(msg_bus);
......
......@@ -1956,7 +1956,11 @@ class Executor(object):
return ctx
def _prepare_fleet_executor(self, program=None, scope=None, fleet_opt=None):
def _prepare_fleet_executor(self,
carrier_id="",
program=None,
scope=None,
fleet_opt=None):
from ..distributed.fleet.proto import fleet_executor_desc_pb2
assert program, "Program for fleet executor should not be None"
assert fleet_opt, "Configurations for fleet executor should not be None"
......@@ -2014,7 +2018,8 @@ class Executor(object):
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place()
place.set_place(self.place)
fleet_exe.init(program.desc, scope, place, tasks, task_id_to_rank)
fleet_exe.init(carrier_id, program.desc, scope, place, tasks,
task_id_to_rank)
return fleet_exe
def _run_using_fleet_executor(self,
......@@ -2023,6 +2028,7 @@ class Executor(object):
feed_var_name="feed",
fetch_var_name="fetch",
fetch_list=None):
# TODO(liyurui): Change cache strategy for multi carriers
cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key)
......@@ -2088,7 +2094,10 @@ class Executor(object):
fetch_task.set_program(fetch_program)
cached_ctx = self._prepare_fleet_executor(
program=cached_program, scope=cached_scope, fleet_opt=fleet_opt)
cache_key,
program=cached_program,
scope=cached_scope,
fleet_opt=fleet_opt)
self._add_ctx_cache(cache_key, cached_ctx)
if feed:
# NOTE: don't have to traverse programs in task nodes,
......@@ -2107,7 +2116,7 @@ class Executor(object):
lr_sheduler._var_name)
tensor.set(data, self.place)
cached_ctx.run()
cached_ctx.run(cache_key)
if fetch_list:
arr = cached_scope.find_var(fetch_var_name).get_fetch_list()
tensors = arr._move_to_list()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册