未验证 提交 bafe287a 编写于 作者: L LiYuRio 提交者: GitHub

cherry-pick fleet executor from 2.4 (#52896)

* cherry-pick fleet executor from 2.4

* fix test case
上级 a2aa0087
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
...@@ -45,6 +46,65 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -45,6 +46,65 @@ void ComputeInterceptor::PrepareDeps() {
} }
} }
void ComputeInterceptor::DecodeMsgVars(const InterceptorMessage& msg) {
int64_t scope_id = msg.scope_idx();
PADDLE_ENFORCE_LT(scope_id,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
scope_id));
auto* scope = microbatch_scopes_[scope_id];
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
for (const auto& var_iter : msg.vars_list()) {
const std::string& name = var_iter.name();
auto& dev_ctx = *pool.Get(place_);
std::istringstream ss(var_iter.stensor());
auto* var = scope->Var(name);
auto* tensor = var->GetMutable<phi::DenseTensor>();
framework::DeserializeFromStream(ss, tensor, dev_ctx);
VLOG(3) << "Set vars " << name << " with value in scope " << scope_id
<< " with dims " << tensor->dims() << " with dtype "
<< tensor->dtype();
}
}
InterceptorMessage ComputeInterceptor::PrepareVarsMsg() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* scope = microbatch_scopes_[cur_scope_id_];
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_WITH_VARS);
ready_msg.set_scope_idx(cur_scope_id_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
for (auto iter : node_->vars_to_dtype()) {
VarList* vars = ready_msg.add_vars_list();
const auto& var_name = iter.first;
vars->set_name(var_name);
std::ostringstream ss;
auto& dev_ctx = *pool.Get(place_);
auto* var = scope->FindVar(var_name);
PADDLE_ENFORCE(
var,
platform::errors::NotFound(
"Variable %s not exists in scope %ld", var_name, cur_scope_id_));
const auto& tensor = var->Get<phi::DenseTensor>();
framework::SerializeToStream(ss, tensor, dev_ctx);
vars->set_stensor(ss.str());
VLOG(3) << "Prepare vars msg " << var_name << " with dimension "
<< tensor.dims() << " dtype " << tensor.dtype();
}
return ready_msg;
}
void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) {
auto it = in_readys_.find(up_id); auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, PADDLE_ENFORCE_NE(it,
...@@ -105,6 +165,16 @@ bool ComputeInterceptor::IsInputReady() { ...@@ -105,6 +165,16 @@ bool ComputeInterceptor::IsInputReady() {
flag = flag && (ready_size_map.at(i) != 0); flag = flag && (ready_size_map.at(i) != 0);
} }
if (flag) { if (flag) {
for (auto iter : scope_id_to_finish_flag_) {
if (iter.first == i) {
break;
} else if (!iter.second) {
VLOG(3) << "The previous scope is not ready, waiting for the "
"previous scope "
<< iter.first;
return false;
}
}
cur_scope_id_ = i; cur_scope_id_ = i;
return true; return true;
} else { } else {
...@@ -214,11 +284,20 @@ void ComputeInterceptor::RunOps() { ...@@ -214,11 +284,20 @@ void ComputeInterceptor::RunOps() {
void ComputeInterceptor::Run() { void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) { while (IsInputReady() && CanWriteOutput()) {
VLOG(0) << "id=" << GetInterceptorId() VLOG(3) << "id=" << GetInterceptorId()
<< " ComputeInterceptor running in scope " << cur_scope_id_; << " ComputeInterceptor running in scope " << cur_scope_id_;
RunOps(); RunOps();
if (!scope_id_to_finish_flag_.empty()) {
PADDLE_ENFORCE_NE(
scope_id_to_finish_flag_.find(cur_scope_id_),
scope_id_to_finish_flag_.end(),
platform::errors::NotFound(
"Can not find scope %ld in scope_id_to_finish", cur_scope_id_));
scope_id_to_finish_flag_.erase(cur_scope_id_);
}
// send to downstream and increase buff used // send to downstream and increase buff used
SendDataReadyToDownStream(); SendDataReadyToDownStream();
// reply to upstream and decrease ready data // reply to upstream and decrease ready data
...@@ -239,6 +318,20 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -239,6 +318,20 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
<< msg.scope_idx() << " "; << msg.scope_idx() << " ";
DecreaseBuff(msg.src_id()); DecreaseBuff(msg.src_id());
Run(); Run();
} else if (msg.message_type() == DATA_WITH_VARS) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive data_with_vars " << msg.src_id() << " "
<< msg.scope_idx() << " ";
DecodeMsgVars(msg);
IncreaseReady(msg.src_id(), msg.scope_idx());
Run();
} else if (msg.message_type() == START_LOOP) {
VLOG(3) << "Compute interceptor " << interceptor_id_
<< " receive start_loop " << msg.src_id() << " " << msg.scope_idx()
<< " ";
IncreaseReady(msg.src_id(), msg.scope_idx());
scope_id_to_finish_flag_.emplace(msg.scope_idx(), false);
Run();
} }
} }
......
...@@ -47,9 +47,12 @@ class ComputeInterceptor : public Interceptor { ...@@ -47,9 +47,12 @@ class ComputeInterceptor : public Interceptor {
private: private:
void PrepareDeps(); void PrepareDeps();
InterceptorMessage PrepareVarsMsg();
void DecodeMsgVars(const InterceptorMessage& msg);
bool IsInputReady(); bool IsInputReady();
bool CanWriteOutput(); bool CanWriteOutput();
std::map<int64_t, bool> scope_id_to_finish_flag_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h" #include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -38,6 +39,8 @@ void CondInterceptor::PrepareDeps() { ...@@ -38,6 +39,8 @@ void CondInterceptor::PrepareDeps() {
for (const auto& up : upstream) { for (const auto& up : upstream) {
if (id_to_dep_type.at(up.first) == DependType::NORMAL) { if (id_to_dep_type.at(up.first) == DependType::NORMAL) {
normal_in_id_.insert(up.first); normal_in_id_.insert(up.first);
} else if (id_to_dep_type.at(up.first) == DependType::LOOP) {
loop_id_ = up.first;
} }
} }
...@@ -90,6 +93,13 @@ void CondInterceptor::SendDataReady(int64_t down_id) { ...@@ -90,6 +93,13 @@ void CondInterceptor::SendDataReady(int64_t down_id) {
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
void CondInterceptor::SendStartLoop(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(START_LOOP);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_USELESS); ready_msg.set_message_type(DATA_IS_USELESS);
...@@ -104,18 +114,36 @@ void CondInterceptor::Compute() { ...@@ -104,18 +114,36 @@ void CondInterceptor::Compute() {
if (cond) { if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_; VLOG(3) << "Loop again in scope " << cur_scope_id_;
for (auto& down_id : normal_out_id_) { for (auto& down_id : normal_out_id_) {
SendDataReady(down_id); SendStartLoop(down_id);
} }
++num_of_scopes_;
} else { } else {
VLOG(0) << "Finish loop in scope " << cur_scope_id_; VLOG(3) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_); SendDataReady(stop_loop_id_);
} }
} }
void CondInterceptor::Run(const InterceptorMessage& msg) { void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY ||
msg.message_type() == DATA_WITH_VARS) {
if (msg.src_id() == loop_id_) {
--num_of_scopes_;
VLOG(3) << "Receving loop again message from " << msg.src_id()
<< " waiting other " << num_of_scopes_ << " scopes ready";
ready_scope_id_.emplace_back(msg.scope_idx());
if (num_of_scopes_ == 0) {
std::sort(ready_scope_id_.begin(), ready_scope_id_.end());
for (auto scope_id : ready_scope_id_) {
VLOG(3) << "Start a new loop in scope " << scope_id;
cur_scope_id_ = scope_id;
Compute();
}
ready_scope_id_.clear();
}
} else {
cur_scope_id_ = msg.scope_idx(); cur_scope_id_ = msg.scope_idx();
Compute(); Compute();
}
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) { for (auto& up_id : normal_in_id_) {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <iomanip>
#include <queue> #include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...@@ -37,6 +38,7 @@ class CondInterceptor final : public Interceptor { ...@@ -37,6 +38,7 @@ class CondInterceptor final : public Interceptor {
void Compute(); void Compute();
bool GetCondResult(); bool GetCondResult();
void SendDataReady(int64_t down_id); void SendDataReady(int64_t down_id);
void SendStartLoop(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id); void ReplyDataIsUseless(int64_t up_id);
int64_t cur_scope_id_; int64_t cur_scope_id_;
...@@ -44,6 +46,9 @@ class CondInterceptor final : public Interceptor { ...@@ -44,6 +46,9 @@ class CondInterceptor final : public Interceptor {
std::set<int64_t> normal_in_id_; std::set<int64_t> normal_in_id_;
std::set<int64_t> normal_out_id_; std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_; int64_t stop_loop_id_;
int64_t loop_id_;
int64_t num_of_scopes_{0};
std::vector<int64_t> ready_scope_id_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -24,6 +24,13 @@ enum MessageType { ...@@ -24,6 +24,13 @@ enum MessageType {
ERR = 4; // current Interceptor encounters error ERR = 4; // current Interceptor encounters error
RESET = 5; // reset the status RESET = 5; // reset the status
START = 6; START = 6;
DATA_WITH_VARS = 7;
START_LOOP = 8;
}
message VarList {
required string name = 1;
required string stensor = 2;
} }
message InterceptorMessage { message InterceptorMessage {
...@@ -32,6 +39,7 @@ message InterceptorMessage { ...@@ -32,6 +39,7 @@ message InterceptorMessage {
optional MessageType message_type = 3 [ default = RESET ]; optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ]; optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ]; optional int64 scope_idx = 5 [ default = 0 ];
repeated VarList vars_list = 6;
} }
message InterceptorResponse { optional bool rst = 1 [ default = false ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
......
...@@ -45,6 +45,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank) ...@@ -45,6 +45,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
<< ". And the TaskNode's max_run_time and max_slot_num will be set to 1."; << ". And the TaskNode's max_run_time and max_slot_num will be set to 1.";
} }
void TaskNode::SetVarsToDtype(
const std::map<std::string, std::string>& vars_to_dtype) {
vars_to_dtype_ = vars_to_dtype;
}
void TaskNode::SetVarsToShape(
const std::map<std::string, std::vector<int64_t>>& vars_to_shape) {
vars_to_shape_ = vars_to_shape;
}
void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) { void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_ = program; program_ = program;
} }
......
...@@ -116,6 +116,15 @@ class TaskNode final { ...@@ -116,6 +116,15 @@ class TaskNode final {
int64_t buff_size = 1, int64_t buff_size = 1,
DependType type = DependType::NORMAL); DependType type = DependType::NORMAL);
std::string DebugString() const; std::string DebugString() const;
void SetVarsToDtype(const std::map<std::string, std::string>& vars_to_dtype);
const std::map<std::string, std::vector<int64_t>>& vars_to_shape() const {
return vars_to_shape_;
}
const std::map<std::string, std::string>& vars_to_dtype() const {
return vars_to_dtype_;
}
void SetVarsToShape(
const std::map<std::string, std::vector<int64_t>>& vars_to_shape);
private: private:
DISABLE_COPY_AND_ASSIGN(TaskNode); DISABLE_COPY_AND_ASSIGN(TaskNode);
...@@ -148,6 +157,8 @@ class TaskNode final { ...@@ -148,6 +157,8 @@ class TaskNode final {
int64_t send_down_per_steps_{1}; int64_t send_down_per_steps_{1};
std::string type_; std::string type_;
std::map<std::string, std::string> vars_to_dtype_;
std::map<std::string, std::vector<int64_t>> vars_to_shape_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -184,6 +184,8 @@ void BindFleetExecutor(py::module* m) { ...@@ -184,6 +184,8 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_at_offset", &TaskNode::SetRunAtOffset) .def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType) .def("set_type", &TaskNode::SetType)
.def("set_cond_var_name", &TaskNode::SetCondVarName) .def("set_cond_var_name", &TaskNode::SetCondVarName)
.def("set_vars_to_shape", &TaskNode::SetVarsToShape)
.def("set_vars_to_dtype", &TaskNode::SetVarsToDtype)
.def("role", &TaskNode::role) .def("role", &TaskNode::role)
.def("init", [](TaskNode& self) { self.Init(); }) .def("init", [](TaskNode& self) { self.Init(); })
.def("set_program", &TaskNode::SetProgram); .def("set_program", &TaskNode::SetProgram);
......
...@@ -33,6 +33,8 @@ class TaskNode: ...@@ -33,6 +33,8 @@ class TaskNode:
program=None, program=None,
lazy_initialize=False, lazy_initialize=False,
cond_var_name=None, cond_var_name=None,
vars_to_dtype=None,
vars_to_shape=None,
): ):
""" """
:param rank (int): Current rank of the task node. :param rank (int): Current rank of the task node.
...@@ -58,6 +60,8 @@ class TaskNode: ...@@ -58,6 +60,8 @@ class TaskNode:
self.program = program self.program = program
self.lazy_initialize = lazy_initialize self.lazy_initialize = lazy_initialize
self.cond_var_name = cond_var_name self.cond_var_name = cond_var_name
self.vars_to_dtype = vars_to_dtype
self.vars_to_shape = vars_to_shape
self.run_pre_steps = None self.run_pre_steps = None
self.run_at_offset = None self.run_at_offset = None
self.node = None self.node = None
...@@ -101,6 +105,10 @@ class TaskNode: ...@@ -101,6 +105,10 @@ class TaskNode:
self.node.set_run_at_offset(self.run_at_offset) self.node.set_run_at_offset(self.run_at_offset)
if self.cond_var_name: if self.cond_var_name:
self.node.set_cond_var_name(self.cond_var_name) self.node.set_cond_var_name(self.cond_var_name)
if self.vars_to_shape:
self.node.set_vars_to_shape(self.vars_to_shape)
if self.vars_to_dtype:
self.node.set_vars_to_dtype(self.vars_to_dtype)
for up in self.upstreams: for up in self.upstreams:
self.node.add_upstream_task(up[0], up[1], up[2]) self.node.add_upstream_task(up[0], up[1], up[2])
for down in self.downstreams: for down in self.downstreams:
......
...@@ -963,6 +963,7 @@ class Executor: ...@@ -963,6 +963,7 @@ class Executor:
self.ctx_caches = dict() self.ctx_caches = dict()
self.trainer_caches = dict() self.trainer_caches = dict()
self.scope_caches = dict() self.scope_caches = dict()
self.micro_scope_cache = dict()
self.var_caches = dict() self.var_caches = dict()
self.pruned_program_caches = dict() self.pruned_program_caches = dict()
p = core.Place() p = core.Place()
...@@ -1032,6 +1033,12 @@ class Executor: ...@@ -1032,6 +1033,12 @@ class Executor:
def _add_scope_cache(self, scope_cache_key, scope): def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope self.scope_caches[scope_cache_key] = scope
def _add_micro_scopes_cache(self, program_cache_key, micro_scopes: list):
self.micro_scope_cache[program_cache_key] = micro_scopes
def _get_micro_scopes_cache(self, program_cache_key):
return self.micro_scope_cache.get(program_cache_key, None)
# just for testing, will be removed later # just for testing, will be removed later
@lru_cache() @lru_cache()
def _log_force_set_program_cache(self, use_program_cache): def _log_force_set_program_cache(self, use_program_cache):
...@@ -1467,6 +1474,7 @@ class Executor: ...@@ -1467,6 +1474,7 @@ class Executor:
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
with_standalone_executor=self._fleet_executor_with_standalone, with_standalone_executor=self._fleet_executor_with_standalone,
return_numpy=return_numpy,
) )
if "startup_program" in program._pipeline_opt: if "startup_program" in program._pipeline_opt:
program = program._pipeline_opt["startup_program"] program = program._pipeline_opt["startup_program"]
...@@ -2340,13 +2348,25 @@ class Executor: ...@@ -2340,13 +2348,25 @@ class Executor:
fetch_var_name="fetch", fetch_var_name="fetch",
fetch_list=None, fetch_list=None,
with_standalone_executor=False, with_standalone_executor=False,
return_numpy=True,
): ):
cache_key = _get_strong_program_cache_key(program, feed, fetch_list) cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_program = self._get_program_cache(cache_key) cached_program = self._get_program_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key) cached_scope = self._get_scope_cache(cache_key)
micro_cached_scopes = self._get_micro_scopes_cache(cache_key)
fleet_opt = program._pipeline_opt["fleet_opt"]
if cached_scope is None: if cached_scope is None:
cached_scope = global_scope() cached_scope = global_scope()
self._add_scope_cache(cache_key, cached_scope) self._add_scope_cache(cache_key, cached_scope)
if micro_cached_scopes is None:
micro_cached_scopes = []
if (
"inference_generation" in fleet_opt
and fleet_opt["inference_generation"]
):
for _ in range(int(fleet_opt["num_micro_batches"])):
micro_cached_scopes.append(cached_scope.new_scope())
self._add_micro_scopes_cache(cache_key, micro_cached_scopes)
if cached_program is None: if cached_program is None:
assert ( assert (
program._pipeline_opt program._pipeline_opt
...@@ -2424,7 +2444,7 @@ class Executor: ...@@ -2424,7 +2444,7 @@ class Executor:
program=cached_program, program=cached_program,
scope=cached_scope, scope=cached_scope,
fleet_opt=fleet_opt, fleet_opt=fleet_opt,
micro_scope_list=micro_scope_list, micro_scope_list=micro_cached_scopes,
with_standalone_executor=with_standalone_executor, with_standalone_executor=with_standalone_executor,
) )
...@@ -2448,17 +2468,33 @@ class Executor: ...@@ -2448,17 +2468,33 @@ class Executor:
tensor.set(data, self.place) tensor.set(data, self.place)
self._fleet_executor.run(cache_key) self._fleet_executor.run(cache_key)
if "fetch_var" in fleet_opt: if "fetch_var" in fleet_opt:
# If we speed up the generation in evaluation, we need to generate # If we speed up the generation in evaluation, we need to generate
# multiple queries at the same time. Each query will in separate scope in order # multiple queries at the same time. Each query will in separate scope in order
# not mix up. It indicate that final result will in multiple scopes and need to # not mix up. It indicate that final result will in multiple scopes and need to
# fetch each. # fetch each.
result_list = [] result_list = []
for scope in micro_scope_list: for scope in micro_cached_scopes:
for var in fleet_opt["fetch_var"]: scope_result_list = []
tensor = core.get_variable_tensor(scope, var) for varname in fleet_opt["fetch_var"]:
result_list.append(as_numpy(tensor)) tensor = None
try:
tensor = core.get_variable_tensor(scope, varname)
if return_numpy:
tensor = as_numpy(tensor)
except:
var = scope.find_var(varname)
tensor = var.get_lod_tensor_array()
if return_numpy:
tensor = as_numpy(tensor)
else:
tensor = [t for t in tensor]
if tensor:
scope_result_list.append(tensor)
if scope_result_list:
result_list.append(scope_result_list)
return result_list return result_list
if fetch_list: if fetch_list:
......
...@@ -154,6 +154,8 @@ class TestFleetExecutor(unittest.TestCase): ...@@ -154,6 +154,8 @@ class TestFleetExecutor(unittest.TestCase):
node_type="Compute", node_type="Compute",
task_id=3, task_id=3,
program=paddle.static.Program(), program=paddle.static.Program(),
vars_to_dtype={'x': 'float32', 'tmp_1': 'int64'},
vars_to_shape={'x': (1,), 'tmp_1': (1,)},
lazy_initialize=True, lazy_initialize=True,
) )
task_e = TaskNode( task_e = TaskNode(
...@@ -205,7 +207,7 @@ class TestFleetExecutor(unittest.TestCase): ...@@ -205,7 +207,7 @@ class TestFleetExecutor(unittest.TestCase):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
loader.start() loader.start()
res = exe.run(main_program) res = exe.run(main_program)
ref_res = np.full([1], 10, dtype="float32") ref_res = np.full([1, 1], 10, dtype="float32")
for data in res: for data in res:
np.testing.assert_allclose(data, ref_res, rtol=1e-05) np.testing.assert_allclose(data, ref_res, rtol=1e-05)
ref_res = ref_res + 1 ref_res = ref_res + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册