未验证 提交 0839bba3 编写于 作者: R Ruibiao Chen 提交者: GitHub

Support priority scheduling for standalone executor (#49275)

* Support priority scheduling for standalone executor

* Add CPU test
上级 0a837cb2
...@@ -318,8 +318,11 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { ...@@ -318,8 +318,11 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
return true; return true;
} }
std::vector<std::string> OperatorDistAttr::fields_{ std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"process_mesh", "impl_type", "impl_idx", "execution_stream"}; "impl_type",
"impl_idx",
"execution_stream",
"scheduling_priority"};
OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) { OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) {
VLOG(4) << "[OperatorDistAttr constructor] op type: " << op_->Type(); VLOG(4) << "[OperatorDistAttr constructor] op type: " << op_->Type();
...@@ -379,6 +382,7 @@ void OperatorDistAttr::initialize() { ...@@ -379,6 +382,7 @@ void OperatorDistAttr::initialize() {
impl_type_ = kDefault; impl_type_ = kDefault;
impl_idx_ = 0; impl_idx_ = 0;
execution_stream_ = kDefault; execution_stream_ = kDefault;
scheduling_priority_ = 0;
} }
void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
...@@ -388,6 +392,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { ...@@ -388,6 +392,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
set_impl_type(dist_attr.impl_type()); set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx()); set_impl_idx(dist_attr.impl_idx());
set_execution_stream(dist_attr.execution_stream()); set_execution_stream(dist_attr.execution_stream());
set_scheduling_priority(dist_attr.scheduling_priority());
set_annotated(dist_attr.annotated()); set_annotated(dist_attr.annotated());
} }
...@@ -667,6 +672,7 @@ std::string OperatorDistAttr::to_string() const { ...@@ -667,6 +672,7 @@ std::string OperatorDistAttr::to_string() const {
str += "impl_type: " + impl_type_ + ", "; str += "impl_type: " + impl_type_ + ", ";
str += "impl_idx: " + std::to_string(impl_idx_) + ", "; str += "impl_idx: " + std::to_string(impl_idx_) + ", ";
str += "execution_stream: " + execution_stream_ + ", "; str += "execution_stream: " + execution_stream_ + ", ";
str += "scheduling_priority: " + std::to_string(scheduling_priority_) + ", ";
str += "annotated: [" + str_join(annotated_) + "], "; str += "annotated: [" + str_join(annotated_) + "], ";
str += "\nprocess_mesh: " + process_mesh_.to_string() + ", "; str += "\nprocess_mesh: " + process_mesh_.to_string() + ", ";
str += "\ninput_dist_attrs: [\n"; str += "\ninput_dist_attrs: [\n";
...@@ -751,6 +757,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) { ...@@ -751,6 +757,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.execution_stream() != rhs.execution_stream()) { if (lhs.execution_stream() != rhs.execution_stream()) {
return false; return false;
} }
if (lhs.scheduling_priority() != rhs.scheduling_priority()) {
return false;
}
for (auto const& item : lhs.input_dist_attrs()) { for (auto const& item : lhs.input_dist_attrs()) {
if (rhs.input_dist_attrs().count(item.first) != 1) { if (rhs.input_dist_attrs().count(item.first) != 1) {
return false; return false;
......
...@@ -213,6 +213,12 @@ class OperatorDistAttr { ...@@ -213,6 +213,12 @@ class OperatorDistAttr {
execution_stream_ = execution_stream; execution_stream_ = execution_stream;
} }
int64_t scheduling_priority() const { return scheduling_priority_; }
void set_scheduling_priority(int64_t scheduling_priority) {
scheduling_priority_ = scheduling_priority;
}
const std::map<std::string, bool>& annotated() const { return annotated_; } const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated); void set_annotated(const std::map<std::string, bool>& annotated);
...@@ -271,6 +277,7 @@ class OperatorDistAttr { ...@@ -271,6 +277,7 @@ class OperatorDistAttr {
std::string impl_type_; std::string impl_type_;
int64_t impl_idx_ = -1; int64_t impl_idx_ = -1;
std::string execution_stream_; std::string execution_stream_;
int64_t scheduling_priority_; // lower value, higher priority, default to 0
std::map<std::string, bool> annotated_; std::map<std::string, bool> annotated_;
}; };
......
...@@ -33,11 +33,6 @@ ...@@ -33,11 +33,6 @@
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
PADDLE_DEFINE_EXPORTED_bool(
new_executor_serial_run,
false,
"Enable serial execution for standalone executor, used for debug.");
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
new_executor_log_memory_stats, new_executor_log_memory_stats,
false, false,
...@@ -118,11 +113,7 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, ...@@ -118,11 +113,7 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
std::function<void()> fn) { std::function<void()> fn) {
// queue_idx=0 : kCpuSync or kGpuSync // queue_idx=0 : kCpuSync or kGpuSync
// queue_idx=1 : kGPUAsync // queue_idx=1 : kGPUAsync
// when serial_run, always make queue_idx=1, so only one thread is used queue_group_->AddTask(op_func_type == OpFuncType::kGpuAsync, std::move(fn));
size_t queue_idx =
(op_func_type == OpFuncType::kGpuAsync || FLAGS_new_executor_serial_run);
VLOG(8) << "Add task: " << queue_idx;
queue_group_->AddTask(queue_idx, std::move(fn));
} }
bool IsCommunicationOp(const std::string& op_name) { bool IsCommunicationOp(const std::string& op_name) {
...@@ -585,6 +576,17 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -585,6 +576,17 @@ void BuildOpFuncList(const platform::Place& place,
op_func_node.execution_stream_ = dist_attr->execution_stream(); op_func_node.execution_stream_ = dist_attr->execution_stream();
} }
if (dist_attr) {
op_func_node.priority_ = dist_attr->scheduling_priority();
} else if (interpreter::IsCommunicationOp(op_type)) {
// NOTE(Ruibiao): Dispatching computation before communication improves
// multi-stream overlap when the time cost of communication less than that
// of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32 training).
op_func_node.priority_ = 1;
}
VLOG(6) << "scheduling priority of " << op_type << " : "
<< op_func_node.priority_;
SingleStreamGuard single_stream_guard(ops[i]); SingleStreamGuard single_stream_guard(ops[i]);
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
......
...@@ -33,6 +33,10 @@ ...@@ -33,6 +33,10 @@
#endif #endif
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
PADDLE_DEFINE_EXPORTED_bool(
new_executor_serial_run,
false,
"Enable serial execution for standalone executor, used for debug.");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace,
false, false,
"Use inplace in new executor"); "Use inplace in new executor");
...@@ -128,6 +132,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -128,6 +132,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
local_scope_ = local_scope; local_scope_ = local_scope;
} }
var_scope_.SetLocalScope(local_scope_); var_scope_.SetLocalScope(local_scope_);
instruction_prority_less = [this](size_t lhs, size_t rhs) {
Priority lhs_prority = vec_instruction_[lhs].GetPriority();
Priority rhs_prority = vec_instruction_[rhs].GetPriority();
if (lhs_prority == rhs_prority) {
return lhs > rhs;
}
return lhs_prority > rhs_prority;
};
} }
InterpreterCore::~InterpreterCore() { InterpreterCore::~InterpreterCore() {
...@@ -516,25 +529,31 @@ void InterpreterCore::BuildOperatorDependences() { ...@@ -516,25 +529,31 @@ void InterpreterCore::BuildOperatorDependences() {
Instruction& cur_instr = vec_instruction_[instr_id]; Instruction& cur_instr = vec_instruction_[instr_id];
const std::set<size_t>& next_instr_ids = downstream_map[instr_id]; const std::set<size_t>& next_instr_ids = downstream_map[instr_id];
if (cur_instr.KernelType() == OpFuncType::kGpuAsync) { if (FLAGS_new_executor_serial_run) {
for (size_t next_instr_id : next_instr_ids) { for (size_t next_instr_id : next_instr_ids) {
if (vec_instruction_[next_instr_id].KernelType() == cur_instr.AddNextInstrInSameThread(next_instr_id);
OpFuncType::kGpuAsync) {
cur_instr.AddNextInstrInSameThread(next_instr_id);
} else {
cur_instr.AddNextInstrInDifferentThread(next_instr_id);
}
} }
} else { } else {
bool has_instr_in_same_thread = false; if (cur_instr.KernelType() == OpFuncType::kGpuAsync) {
for (size_t next_instr_id : next_instr_ids) { for (size_t next_instr_id : next_instr_ids) {
if (!has_instr_in_same_thread && if (vec_instruction_[next_instr_id].KernelType() ==
vec_instruction_[next_instr_id].KernelType() != OpFuncType::kGpuAsync) {
OpFuncType::kGpuAsync) { cur_instr.AddNextInstrInSameThread(next_instr_id);
cur_instr.AddNextInstrInSameThread(next_instr_id); } else {
has_instr_in_same_thread = true; cur_instr.AddNextInstrInDifferentThread(next_instr_id);
} else { }
cur_instr.AddNextInstrInDifferentThread(next_instr_id); }
} else {
bool has_instr_in_same_thread = false;
for (size_t next_instr_id : next_instr_ids) {
if (!has_instr_in_same_thread &&
vec_instruction_[next_instr_id].KernelType() !=
OpFuncType::kGpuAsync) {
cur_instr.AddNextInstrInSameThread(next_instr_id);
has_instr_in_same_thread = true;
} else {
cur_instr.AddNextInstrInDifferentThread(next_instr_id);
}
} }
} }
} }
...@@ -567,12 +586,7 @@ void InterpreterCore::Convert( ...@@ -567,12 +586,7 @@ void InterpreterCore::Convert(
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx]; auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
Priority priority = vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
interpreter::IsCommunicationOp(op_func_node.operator_base_->Type())
? Priority::kLowest
: Priority::kNormal;
vec_instruction_.emplace_back(
op_idx, std::move(op_func_node), *dev_ctx_, priority);
} }
BuildOperatorDependences(); BuildOperatorDependences();
...@@ -938,8 +952,12 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -938,8 +952,12 @@ void InterpreterCore::ExecuteInstructionList(
if (dependecy_count_[i] == 0) { if (dependecy_count_[i] == 0) {
// NOTE(zhiqiu): hot fix for jit input var // NOTE(zhiqiu): hot fix for jit input var
RecordMemcpyD2H(vec_instr.at(i)); RecordMemcpyD2H(vec_instr.at(i));
async_work_queue_->AddTask(vec_instr.at(i).KernelType(), if (FLAGS_new_executor_serial_run) {
[this, i] { RunInstructionAsync(i); }); RunInstructionAsync(i);
} else {
async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[this, i] { RunInstructionAsync(i); });
}
} }
} }
...@@ -965,8 +983,8 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -965,8 +983,8 @@ void InterpreterCore::ExecuteInstructionList(
} }
} }
void InterpreterCore::RunNextInstructions( void InterpreterCore::RunNextInstructions(const Instruction& instr,
const Instruction& instr, std::deque<size_t>* reserved_next_ops) { SchedulingQueue* reserved_next_ops) {
platform::RecordEvent record( platform::RecordEvent record(
"RunNextInstructions", platform::TracerEventType::UserDefined, 10); "RunNextInstructions", platform::TracerEventType::UserDefined, 10);
...@@ -986,21 +1004,21 @@ void InterpreterCore::RunNextInstructions( ...@@ -986,21 +1004,21 @@ void InterpreterCore::RunNextInstructions(
for (size_t next_instr_id : instr.NextInstrsInSameThread()) { for (size_t next_instr_id : instr.NextInstrsInSameThread()) {
if (IsReady(next_instr_id)) { if (IsReady(next_instr_id)) {
if (vec_instruction_[next_instr_id].GetPriority() == Priority::kLowest) { reserved_next_ops->push(next_instr_id);
reserved_next_ops->push_back(next_instr_id);
} else {
reserved_next_ops->push_front(next_instr_id);
}
} }
} }
} }
void InterpreterCore::RunInstructionAsync(size_t instr_id) { void InterpreterCore::RunInstructionAsync(size_t instr_id) {
std::deque<size_t> ready_ops; // NOTE(Ruibiao): Due to the uncertain order in multi-threading asynchronous
ready_ops.push_back(instr_id); // scheduling, the priority order involved cross-thread scheduling is not
// guaranteed. Only Ops scheduled by the same AddTask call have the guarantee
// of priority order.
SchedulingQueue ready_ops(instruction_prority_less);
ready_ops.push(instr_id);
while (!ready_ops.empty()) { while (!ready_ops.empty()) {
instr_id = ready_ops.front(); instr_id = ready_ops.top();
ready_ops.pop_front(); ready_ops.pop();
auto& instr_node = vec_instruction_.at(instr_id); auto& instr_node = vec_instruction_.at(instr_id);
RunInstruction(instr_node); RunInstruction(instr_node);
...@@ -1330,24 +1348,24 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() { ...@@ -1330,24 +1348,24 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() {
}; };
std::vector<size_t> trace_order; std::vector<size_t> trace_order;
std::deque<size_t> ready_ops; SchedulingQueue ready_ops(instruction_prority_less);
for (size_t instr_id = 0; instr_id < dependecy_count_.size(); ++instr_id) { for (size_t instr_id = 0; instr_id < dependecy_count_.size(); ++instr_id) {
if (dependecy_count_[instr_id] == 0) { if (dependecy_count_[instr_id] == 0) {
ready_ops.push_back(instr_id); ready_ops.push(instr_id);
} }
} }
while (!ready_ops.empty()) { while (!ready_ops.empty()) {
auto now_id = ready_ops.front(); size_t now_id = ready_ops.top();
ready_ops.pop_front(); ready_ops.pop();
trace_order.push_back(now_id); trace_order.push_back(now_id);
auto next_op_set = op_downstream_map[now_id]; auto next_op_set = op_downstream_map[now_id];
for (size_t next_op_id : next_op_set) { for (size_t next_op_id : next_op_set) {
if (IsReady(next_op_id)) { if (IsReady(next_op_id)) {
ready_ops.push_back(next_op_id); ready_ops.push(next_op_id);
} }
} }
} }
......
...@@ -78,6 +78,10 @@ class InterpreterCore { ...@@ -78,6 +78,10 @@ class InterpreterCore {
const platform::Place& GetPlace() const { return place_; } const platform::Place& GetPlace() const { return place_; }
private: private:
using InstructionPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t, std::vector<size_t>, InstructionPriorityLess>;
// build graph // build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes); void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences(); void BuildOperatorDependences();
...@@ -97,7 +101,7 @@ class InterpreterCore { ...@@ -97,7 +101,7 @@ class InterpreterCore {
void RunInstructionAsync(size_t instr_id); void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node); void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id, void RunNextInstructions(const Instruction& instr_id,
std::deque<size_t>* reserved_next_ops); SchedulingQueue* reserved_next_ops);
void RunOperator(const Instruction& instr_node); void RunOperator(const Instruction& instr_node);
// Trace // Trace
void TraceInstructionList(const std::vector<Instruction>& vec_instr); void TraceInstructionList(const std::vector<Instruction>& vec_instr);
...@@ -170,6 +174,8 @@ class InterpreterCore { ...@@ -170,6 +174,8 @@ class InterpreterCore {
// used for Trace // used for Trace
int64_t sync_op_num_{-1}; int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_; std::vector<size_t> trace_execute_order_;
InstructionPriorityLess instruction_prority_less;
}; };
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
......
...@@ -670,13 +670,11 @@ void VariableScope::CheckExist(const std::string& name) const { ...@@ -670,13 +670,11 @@ void VariableScope::CheckExist(const std::string& name) const {
Instruction::Instruction(size_t id, Instruction::Instruction(size_t id,
OpFuncNode&& op_func_node, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx)
const Priority priority)
: is_artificial_(op_func_node.operator_base_->Type() == "depend"), : is_artificial_(op_func_node.operator_base_->Type() == "depend"),
id_(id), id_(id),
op_func_node_(op_func_node), op_func_node_(op_func_node),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx) {
priority_(priority) {
PADDLE_ENFORCE_GE(id, PADDLE_ENFORCE_GE(id,
0, 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
......
...@@ -32,6 +32,8 @@ namespace framework { ...@@ -32,6 +32,8 @@ namespace framework {
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>; using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using Priority = int64_t;
constexpr const char* kCoalesceTensor = "coalesce_tensor"; constexpr const char* kCoalesceTensor = "coalesce_tensor";
// stream types // stream types
...@@ -42,8 +44,6 @@ constexpr const char* kH2DStream = "H2DStream"; ...@@ -42,8 +44,6 @@ constexpr const char* kH2DStream = "H2DStream";
constexpr int kEmptyVarIndex = 0; constexpr int kEmptyVarIndex = 0;
enum class Priority { kLowest, kNormal };
class InterpretercoreInferShapeContext : public InferShapeContext { class InterpretercoreInferShapeContext : public InferShapeContext {
public: public:
InterpretercoreInferShapeContext(const OperatorBase& op, InterpretercoreInferShapeContext(const OperatorBase& op,
...@@ -263,29 +263,30 @@ enum class OpFuncType { ...@@ -263,29 +263,30 @@ enum class OpFuncType {
class RuntimeInferShapeContext; class RuntimeInferShapeContext;
struct OpFuncNode { struct OpFuncNode {
// TODO(zhiqiu): Better make it unique_ptr // fit for phi kernel
std::shared_ptr<OperatorBase> operator_base_; phi::Kernel* phi_kernel_{nullptr}; // not owned
std::string execution_stream_{kDefaultStream}; platform::DeviceContext* dev_ctx_; // not owned
std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index;
std::map<int, int> inplace_back_map; std::map<int, int> inplace_back_map;
OpKernelComputeFunc kernel_func_; std::map<std::string, std::vector<int>> input_index;
platform::DeviceContext* dev_ctx_; // not owned std::map<std::string, std::vector<int>> output_index;
// fit for phi kernel // TODO(zhiqiu): Better make it unique_ptr
phi::Kernel* phi_kernel_{nullptr}; // not owned std::shared_ptr<OperatorBase> operator_base_;
std::string execution_stream_{kDefaultStream};
OpFuncType type_; OpFuncType type_;
OpKernelComputeFunc kernel_func_;
Priority priority_{0}; // lower value, higher priority
}; };
class Instruction { class Instruction {
public: public:
Instruction(size_t id, Instruction(size_t id,
OpFuncNode&& op_func_node, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx);
const Priority priority);
bool IsArtificial() const { return is_artificial_; } bool IsArtificial() const { return is_artificial_; }
...@@ -368,7 +369,7 @@ class Instruction { ...@@ -368,7 +369,7 @@ class Instruction {
void ClearInplace(); void ClearInplace();
Priority GetPriority() const { return priority_; } Priority GetPriority() const { return op_func_node_.priority_; }
private: private:
bool is_artificial_; // Instruction is artificial means that it is only used bool is_artificial_; // Instruction is artificial means that it is only used
...@@ -384,7 +385,6 @@ class Instruction { ...@@ -384,7 +385,6 @@ class Instruction {
OpFuncNode op_func_node_; OpFuncNode op_func_node_;
const platform::DeviceContext& dev_ctx_; // not owned const platform::DeviceContext& dev_ctx_; // not owned
const Priority priority_;
std::shared_ptr<RuntimeContext> runtime_ctx_; std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_; std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
......
...@@ -226,6 +226,9 @@ void BindAutoParallel(py::module *m) { ...@@ -226,6 +226,9 @@ void BindAutoParallel(py::module *m) {
.def_property("execution_stream", .def_property("execution_stream",
&OperatorDistAttr::execution_stream, &OperatorDistAttr::execution_stream,
&OperatorDistAttr::set_execution_stream) &OperatorDistAttr::set_execution_stream)
.def_property("scheduling_priority",
&OperatorDistAttr::scheduling_priority,
&OperatorDistAttr::set_scheduling_priority)
.def_property("annotated", .def_property("annotated",
&OperatorDistAttr::annotated, &OperatorDistAttr::annotated,
&OperatorDistAttr::set_annotated) &OperatorDistAttr::set_annotated)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle import static
paddle.enable_static()
class TestOpPriority(unittest.TestCase):
def test_op_priority(self):
# In this test case, x and y share the same data,
# which is initialized to 0. The shared data is
# read and wrote by two concurrent Ops increment(x)
# and increment(y). In case of Op sequential scheduling,
# the result of increment(x) would be 1 while that of
# increment(y) would be 2. However, increment(y) is
# set to a higher priority than increment(x), so the
# result of increment(y) would be 1.
program = static.Program()
with static.program_guard(program):
x = paddle.zeros(shape=[1], dtype='int32')
block = program.global_block()
y = block.create_var(dtype='int32')
block.append_op(
type='share_data', inputs={'X': x.name}, outputs={'Out': y.name}
)
paddle.increment(x)
block.ops[-1].dist_attr.scheduling_priority = 1
paddle.increment(y)
block.ops[-1].dist_attr.scheduling_priority = -1
# Note that the priority order involved cross-thread scheduling
# is not guaranteed in standalone executor. As fetch(y)
# is scheduled in the different thread from increment(x),
# they are not scheduled in priority order. To make sure that
# fetch(y) is scheduled before increment(x) in priority order,
# we tricky enable serial_run here.
paddle.framework.set_flags({'FLAGS_new_executor_serial_run': 1})
exe = static.Executor()
# Currently, priority scheduling is not supported in the first
# step that builds Op list by running kernel. Remove the first
# run here when static-build without kernel running is supported.
result = exe.run(program, fetch_list=[y])
result = exe.run(program, fetch_list=[y])
self.assertEqual(result[0], 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册