未验证 提交 f5497fd0 编写于 作者: L lzydev 提交者: GitHub

Add attributes to support to analyse the stream across interpreters (#56814)

* fix static_build for pp

* add mannual_event to support streams across progs

* revert static_build.sh

* fix coverage-ci

* modify the method to name events

* change code according to review
上级 04332fa4
...@@ -120,6 +120,9 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { ...@@ -120,6 +120,9 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
set_is_recompute(dist_attr.is_recompute()); set_is_recompute(dist_attr.is_recompute());
set_execution_stream(dist_attr.execution_stream()); set_execution_stream(dist_attr.execution_stream());
set_stream_priority(dist_attr.stream_priority()); set_stream_priority(dist_attr.stream_priority());
set_force_record_event(dist_attr.force_record_event());
set_event_to_record(dist_attr.event_to_record());
set_events_to_wait(dist_attr.events_to_wait());
set_scheduling_priority(dist_attr.scheduling_priority()); set_scheduling_priority(dist_attr.scheduling_priority());
set_annotated(dist_attr.annotated()); set_annotated(dist_attr.annotated());
} }
......
...@@ -141,6 +141,26 @@ class OperatorDistAttr { ...@@ -141,6 +141,26 @@ class OperatorDistAttr {
execution_stream_ = execution_stream; execution_stream_ = execution_stream;
} }
void set_event_to_record(const std::string& event_name) {
event_to_record_ = event_name;
}
void set_force_record_event(bool force_record_event) {
force_record_event_ = force_record_event;
}
void set_events_to_wait(const std::vector<std::string>& events_to_wait) {
events_to_wait_ = events_to_wait;
}
bool force_record_event() const { return force_record_event_; }
const std::string& event_to_record() const { return event_to_record_; }
const std::vector<std::string>& events_to_wait() const {
return events_to_wait_;
}
int stream_priority() const { return stream_priority_; } int stream_priority() const { return stream_priority_; }
void set_stream_priority(int stream_priority) { void set_stream_priority(int stream_priority) {
...@@ -204,6 +224,11 @@ class OperatorDistAttr { ...@@ -204,6 +224,11 @@ class OperatorDistAttr {
void parse_from_string(const std::string& data); void parse_from_string(const std::string& data);
static std::string unique_name(std::string key) {
static std::atomic<int> id_{0};
return key + "_" + std::to_string(id_++);
}
private: private:
static std::vector<std::string> fields_; static std::vector<std::string> fields_;
std::map<std::string, TensorDistAttr> input_dist_attrs_; std::map<std::string, TensorDistAttr> input_dist_attrs_;
...@@ -214,6 +239,9 @@ class OperatorDistAttr { ...@@ -214,6 +239,9 @@ class OperatorDistAttr {
int64_t impl_idx_ = 0; int64_t impl_idx_ = 0;
bool is_recompute_ = false; bool is_recompute_ = false;
std::string execution_stream_ = kDefault; std::string execution_stream_ = kDefault;
bool force_record_event_ = false;
std::vector<std::string> events_to_wait_;
std::string event_to_record_ = unique_name("event"); // event_idx
int stream_priority_ = 0; // lower value, higher priority int stream_priority_ = 0; // lower value, higher priority
int64_t scheduling_priority_ = 0; // lower value, higher priority int64_t scheduling_priority_ = 0; // lower value, higher priority
std::map<std::string, bool> annotated_; std::map<std::string, bool> annotated_;
......
...@@ -656,6 +656,10 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -656,6 +656,10 @@ void BuildOpFuncList(const platform::Place& place,
} }
op_func_node.stream_priority_ = dist_attr->stream_priority(); op_func_node.stream_priority_ = dist_attr->stream_priority();
op_func_node.scheduling_priority_ = dist_attr->scheduling_priority(); op_func_node.scheduling_priority_ = dist_attr->scheduling_priority();
// set mannual event information
op_func_node.force_record_event_ = dist_attr->force_record_event();
op_func_node.events_to_wait_ = dist_attr->events_to_wait();
op_func_node.event_to_record_ = dist_attr->event_to_record();
} else { } else {
if (interpreter::IsCommunicationOp(op)) { if (interpreter::IsCommunicationOp(op)) {
// NOTE(Ruibiao): Dispatching computation before communication improves // NOTE(Ruibiao): Dispatching computation before communication improves
......
...@@ -97,6 +97,15 @@ void StreamAnalyzer::ConstructEvents(std::vector<Instruction>* instructions) { ...@@ -97,6 +97,15 @@ void StreamAnalyzer::ConstructEvents(std::vector<Instruction>* instructions) {
platform::GenerateDeviceEventFlag()); platform::GenerateDeviceEventFlag());
recorder_instr.AddEventToRecord(device_event, recorder_instr.AddEventToRecord(device_event,
platform::kCUDA /*unused*/); platform::kCUDA /*unused*/);
// It means the event will be waited for other interpreter that the
// event name of a operator is not 'default'.
if (recorder_instr.OpFunc()->force_record_event_ == true &&
(*program_force_events_to_wait_)
.count(recorder_instr.OpFunc()->event_to_record_) == 0) {
(*program_force_events_to_wait_)[recorder_instr.OpFunc()
->event_to_record_] =
recorder_instr.EventToRecord();
}
instr2event.emplace(recorder_instr_id, device_event); instr2event.emplace(recorder_instr_id, device_event);
} }
...@@ -108,6 +117,65 @@ void StreamAnalyzer::ConstructEvents(std::vector<Instruction>* instructions) { ...@@ -108,6 +117,65 @@ void StreamAnalyzer::ConstructEvents(std::vector<Instruction>* instructions) {
} }
} }
} }
// NOTE(lizhiyu): The mannual event only support the program_interpreter to
// annalyze the streams across the sub_programs. construct mannual events to
// record
for (auto& instruction : *instructions) {
// create extra event to record
auto op_func_node = instruction.OpFunc();
if (op_func_node->force_record_event_ &&
instruction.EventToRecord() == nullptr) {
auto place = instruction.DeviceContext().GetPlace();
if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE_NE(
op_func_node->event_to_record_,
"default",
phi::errors::InvalidArgument(
"If the attribute 'force_record_event_' of one "
"operator is 'true', the 'event_to_record_' of this "
"operator can not be 'default'. But the "
"'event_name' of the operator %s is 'default'.",
instruction.OpBase()->Type().c_str()));
PADDLE_ENFORCE_EQ(
(*program_force_events_to_wait_)
.find(op_func_node->event_to_record_),
(*program_force_events_to_wait_).end(),
phi::errors::InvalidArgument(
"The program_force_events_to_wait_ had the event "
"that belongs to the operator : %s before the operator create "
"the event, "
"This is is werid.",
instruction.OpBase()->Type().c_str()));
std::shared_ptr<DeviceEvent> device_event =
std::make_shared<DeviceEvent>(place,
platform::GenerateDeviceEventFlag());
instruction.AddEventToRecord(device_event, platform::kCUDA /*unused*/);
(*program_force_events_to_wait_)[op_func_node->event_to_record_] =
instruction.EventToRecord();
VLOG(6) << "Create mannual event: " << op_func_node->event_to_record_
<< " for the operator: " << instruction.OpBase()->Type();
}
}
// add extra mannual events
if (!(op_func_node->events_to_wait_.empty())) {
for (auto event_name : op_func_node->events_to_wait_) {
PADDLE_ENFORCE_NE(
(*program_force_events_to_wait_).find(event_name),
(*program_force_events_to_wait_).end(),
phi::errors::InvalidArgument(
"The program_force_events_to_wait_ don't have the event %s "
"for the operator: %s to wait. The event should had been "
"created by the operator "
"whose event_to_record_ is %s.",
event_name.c_str(),
instruction.OpBase()->Type().c_str(),
event_name.c_str()));
instruction.AddEventToWait(
(*program_force_events_to_wait_)[event_name].get());
}
}
}
} }
DeviceContext* StreamAnalyzer::ParseDeviceContext( DeviceContext* StreamAnalyzer::ParseDeviceContext(
......
...@@ -86,6 +86,12 @@ class StreamAnalyzer { ...@@ -86,6 +86,12 @@ class StreamAnalyzer {
void ShareEventInfoFrom(const StreamAnalyzer& src); void ShareEventInfoFrom(const StreamAnalyzer& src);
void SetForceEventsToWaitInfo(
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
program_force_events_to_wait) {
program_force_events_to_wait_ = program_force_events_to_wait;
}
std::shared_ptr< std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>> std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
GetEventInfo() const; GetEventInfo() const;
...@@ -114,6 +120,8 @@ class StreamAnalyzer { ...@@ -114,6 +120,8 @@ class StreamAnalyzer {
std::shared_ptr< std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>> std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
event_info_; event_info_;
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
program_force_events_to_wait_; // not owned
}; };
/// ======================== /// /// ======================== ///
......
...@@ -167,6 +167,9 @@ struct OpFuncNode { ...@@ -167,6 +167,9 @@ struct OpFuncNode {
// TODO(zhiqiu): Better make it unique_ptr // TODO(zhiqiu): Better make it unique_ptr
std::shared_ptr<OperatorBase> operator_base_{nullptr}; std::shared_ptr<OperatorBase> operator_base_{nullptr};
std::string execution_stream_{kDefaultStream}; std::string execution_stream_{kDefaultStream};
bool force_record_event_{false};
std::vector<std::string> events_to_wait_;
std::string event_to_record_{"default"};
OpFuncType type_; OpFuncType type_;
OpKernelComputeFunc kernel_func_; OpKernelComputeFunc kernel_func_;
...@@ -212,10 +215,18 @@ class Instruction { ...@@ -212,10 +215,18 @@ class Instruction {
events_to_wait_.emplace_back(instr_id, event, waiter_type); events_to_wait_.emplace_back(instr_id, event, waiter_type);
} }
void AddEventToWait(const EventInter* event_inter) {
events_to_wait_.push_back(*event_inter);
}
const std::vector<EventInter>& EventsToWait() const { const std::vector<EventInter>& EventsToWait() const {
return events_to_wait_; return events_to_wait_;
} }
const std::shared_ptr<EventInter>& EventToRecord() const {
return event_to_record_;
}
void AddNextInstrInDifferentThread(size_t id) { void AddNextInstrInDifferentThread(size_t id) {
next_instrs_in_different_thread.push_back(id); next_instrs_in_different_thread.push_back(id);
} }
......
...@@ -615,6 +615,7 @@ void ProgramInterpreter::Convert( ...@@ -615,6 +615,7 @@ void ProgramInterpreter::Convert(
vec_instruction_.reserve(op_nums); vec_instruction_.reserve(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx]; auto& op_func_node = nodes[op_idx];
stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_);
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (FLAGS_new_executor_use_cuda_graph) { if (FLAGS_new_executor_use_cuda_graph) {
......
...@@ -81,6 +81,17 @@ class ProgramInterpreter : public InterpreterBaseImpl { ...@@ -81,6 +81,17 @@ class ProgramInterpreter : public InterpreterBaseImpl {
hookfuncs_ = hookfuncs; hookfuncs_ = hookfuncs;
} }
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
GetForceEventsToWaitInfo() {
return force_evnets_to_wait_;
}
void SetForceEventsToWaitInfo(
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
force_evnets_to_wait) {
force_evnets_to_wait_ = force_evnets_to_wait;
}
bool IsStaticBuild() const { return static_build_; } bool IsStaticBuild() const { return static_build_; }
private: private:
...@@ -162,6 +173,9 @@ class ProgramInterpreter : public InterpreterBaseImpl { ...@@ -162,6 +173,9 @@ class ProgramInterpreter : public InterpreterBaseImpl {
ExecutionConfig execution_config_; ExecutionConfig execution_config_;
std::unordered_map<std::string, std::shared_ptr<EventInter>>*
force_evnets_to_wait_;
VariableScope var_scope_; VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned Scope* local_scope_{nullptr}; // not owned
......
...@@ -41,6 +41,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -41,6 +41,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
Scope* scope) Scope* scope)
: place_(place), plan_(plan), scope_(scope) { : place_(place), plan_(plan), scope_(scope) {
int64_t micro_batch_num = plan_.MicroBatchNum(); int64_t micro_batch_num = plan_.MicroBatchNum();
vec_force_events_to_wait_.resize(micro_batch_num);
for (int64_t i = 0; i < micro_batch_num; ++i) { for (int64_t i = 0; i < micro_batch_num; ++i) {
micro_batch_scopes_.emplace_back(&scope->NewScope()); micro_batch_scopes_.emplace_back(&scope->NewScope());
} }
...@@ -129,6 +130,14 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -129,6 +130,14 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
micro_batch_scopes_[micro_batch_id], micro_batch_scopes_[micro_batch_id],
execution_config)); execution_config));
interpretercores_.back()->SetCopyProgram(program); interpretercores_.back()->SetCopyProgram(program);
// Note(lizhiyu): Add mannual event info
auto prog_inter = const_cast<ProgramInterpreter*>(
static_cast<const ProgramInterpreter*>(
interpretercores_.back()->Impl()));
prog_inter->SetForceEventsToWaitInfo(
&(vec_force_events_to_wait_[micro_batch_id]));
// NOTE(lizhiyu): Now we only check backward subprogram. After static // NOTE(lizhiyu): Now we only check backward subprogram. After static
// build strategy is completely, we should // build strategy is completely, we should
// check all the program in the PP strategy. // check all the program in the PP strategy.
...@@ -181,6 +190,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -181,6 +190,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
VLOG(6) << "Run job (" << job_idx << "), type = " << job_type VLOG(6) << "Run job (" << job_idx << "), type = " << job_type
<< ", micro_batch_id =" << job->MicroBatchId(); << ", micro_batch_id =" << job->MicroBatchId();
// Note(sonder): Share build results don't work for new IR now. // Note(sonder): Share build results don't work for new IR now.
if (type_to_first_id.count(job_type) != 0 && if (type_to_first_id.count(job_type) != 0 &&
!FLAGS_enable_new_ir_in_executor) { !FLAGS_enable_new_ir_in_executor) {
......
...@@ -52,6 +52,9 @@ class StandaloneExecutor { ...@@ -52,6 +52,9 @@ class StandaloneExecutor {
Scope* scope_; Scope* scope_;
std::vector<std::string> fetch_var_names_; std::vector<std::string> fetch_var_names_;
std::vector<std::unordered_map<std::string, std::shared_ptr<EventInter>>>
vec_force_events_to_wait_;
}; };
} // namespace framework } // namespace framework
......
...@@ -441,6 +441,16 @@ void BindAutoParallel(py::module *m) { ...@@ -441,6 +441,16 @@ void BindAutoParallel(py::module *m) {
.def_property("scheduling_priority", .def_property("scheduling_priority",
&OperatorDistAttr::scheduling_priority, &OperatorDistAttr::scheduling_priority,
&OperatorDistAttr::set_scheduling_priority) &OperatorDistAttr::set_scheduling_priority)
.def_property("force_record_event",
&OperatorDistAttr::force_record_event,
&OperatorDistAttr::set_force_record_event)
.def_property("events_to_wait",
&OperatorDistAttr::events_to_wait,
&OperatorDistAttr::set_events_to_wait,
pybind11::return_value_policy::reference)
.def_property("event_to_record",
&OperatorDistAttr::event_to_record,
&OperatorDistAttr::set_event_to_record)
.def_property("annotated", .def_property("annotated",
&OperatorDistAttr::annotated, &OperatorDistAttr::annotated,
&OperatorDistAttr::set_annotated) &OperatorDistAttr::set_annotated)
......
...@@ -525,3 +525,20 @@ def _program_for_fthenb_and_1f1b(program): ...@@ -525,3 +525,20 @@ def _program_for_fthenb_and_1f1b(program):
# It MUST return in this order # It MUST return in this order
return [lr_prog, fwd_prog, bwd_prog, opt_prog] return [lr_prog, fwd_prog, bwd_prog, opt_prog]
def _add_event_dependency(recorder_op_desc, waiter_op_desc):
'''
Add the extra event dependcy of the two operators.
This function mainly aims for the cross-programs in pipeline parallelism,
especial for the 'send_v2' 'recv_v2' etc.
'''
if not recorder_op_desc.dist_attr.force_record_event:
recorder_op_desc.dist_attr.force_record_event = True
# NOTE(lizhiyu): Here is the copy of 'waiter_op_desc.dist_attr.events_to_wait' not the reference,
# because the type of 'events_to_wait' is 'const vector<string>&' while the type of
# 'waiter_wait_list' is python list.
waiter_wait_list = waiter_op_desc.dist_attr.events_to_wait
if recorder_op_desc.dist_attr.event_to_record not in waiter_wait_list:
waiter_wait_list.append(recorder_op_desc.dist_attr.event_to_record)
waiter_op_desc.dist_attr.events_to_wait = waiter_wait_list
...@@ -2,6 +2,7 @@ file( ...@@ -2,6 +2,7 @@ file(
GLOB TEST_INTERP_CASES GLOB TEST_INTERP_CASES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py") "test_*.py")
list(REMOVE_ITEM TEST_INTERP_CASES "test_standalone_custom_event.py")
string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")
foreach(target ${TEST_INTERP_CASES}) foreach(target ${TEST_INTERP_CASES})
...@@ -31,8 +32,8 @@ py_test_modules( ...@@ -31,8 +32,8 @@ py_test_modules(
# These UTs are to temporarily test static build for standalone_executor, will be removed after static build is enabled by default. # These UTs are to temporarily test static build for standalone_executor, will be removed after static build is enabled by default.
set(STATIC_BUILD_TESTS set(STATIC_BUILD_TESTS
test_standalone_controlflow test_standalone_cuda_graph_multi_stream test_standalone_controlflow test_standalone_cuda_graph_multi_stream
test_standalone_custom_stream test_standalone_executor test_standalone_custom_stream test_standalone_custom_event
test_standalone_multiply_write) test_standalone_executor test_standalone_multiply_write)
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
py_test_modules( py_test_modules(
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.distributed.passes.pass_utils import (
_add_event_dependency,
get_skip_gc_vars,
split_program,
)
from paddle.fluid import core
from paddle.fluid.executor import _add_feed_fetch_ops, _StandaloneExecutor
paddle.enable_static()
def build_program():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
# data -> [matmul] -> out ->[add] -> add_out
with paddle.static.device_guard('gpu'):
data = paddle.ones([1024, 2048], dtype='float32', name='data')
weight = paddle.randn([2048, 2048], name='weight') # gpu
matmul_out = paddle.matmul(data, weight, name='matmul_out') # gpus
bias = paddle.ones([1024, 2048], dtype='float32', name='bias')
add_out = paddle.add(matmul_out, bias, name='add_out')
# add_out -> [sub] -> sub_out -> [tanh] -> tanh_out
sub_out = paddle.subtract(add_out, data, name='sub_out')
tanh_out = paddle.tanh(sub_out, name='tanh_out')
bias_1 = paddle.add(bias, sub_out, name='bias_1')
out_before = paddle.tanh(bias_1, name='out_before')
out_last = paddle.subtract(tanh_out, data, name='out_last')
out_last2 = paddle.matmul(out_last, weight, name="matmul_2_out")
out = paddle.add(out_before, out_last2, name='out')
mean = paddle.mean(out, name='mean_out')
return main_program, startup_program, [mean]
class TestMannulEvent(unittest.TestCase):
"""
fill_constant(def) gaussian_random(def)
| | | |
| | matmul_v2(s1) fill_constant(def)
| | | | | |
| | elementwise_add(s1) |
| | | |
| elementwise_sub(s1) |
| | | |
| tanh(s1) elementwise_add(s1)
| | |
elementwise_sub(s1) tanh(s1)
| |
matmul_v2(s1) |
| | ---split prog----
elementwise_add(s2)
|
reduce_mean(s2)
"""
def setUp(self):
self.steps = 3
self.place_desc = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.place = core.Place()
self.place.set_place(self.place_desc)
def set_custom_stream(self, prog):
op_index_for_stream1 = [2, 4, 5, 6, 7, 8, 9, 10]
op_index_for_stream2 = [11, 12]
ops = prog.global_block().ops
for op_index in op_index_for_stream1:
ops[op_index].dist_attr.execution_stream = "s1"
ops[op_index].dist_attr.stream_priority = 0
for op_index in op_index_for_stream2:
ops[op_index].dist_attr.execution_stream = "s2"
ops[op_index].dist_attr.stream_priority = -1
def split_program(self, prog, apply_mannual_event=False):
# split two subprograms
waiter_recorder_events_map = {11: [8, 10]}
prog_block = prog.global_block()
ops = prog_block.ops
if apply_mannual_event:
for waiter, recorders in waiter_recorder_events_map.items():
for recorder in recorders:
_add_event_dependency(ops[recorder].desc, ops[waiter].desc)
main_progs, _, _ = split_program(prog, [11])
return main_progs
def create_standalone_exe(self, main_progs, startup_progs, fetch_list):
micro_batch_num = 1
micro_batch_id = 0
job_list = []
prog_num = len(main_progs)
fetch_op_num = len(fetch_list)
skip_gc_vars = get_skip_gc_vars(main_progs)
if prog_num == 1: # single prog
main_progs[0] = _add_feed_fetch_ops(
main_progs[0],
[],
fetch_list,
"feed",
"fetch",
use_fetch_v2=True,
)
op_num = len(main_progs[0].block(0).ops)
fetch_op_indics = list(range(op_num - fetch_op_num, op_num))
else:
main_progs[-1] = _add_feed_fetch_ops(
main_progs[-1],
[],
fetch_list,
"feed",
"fetch",
use_fetch_v2=True,
)
op_num = len(main_progs[-1].block(0).ops)
fetch_op_indics = list(range(op_num - fetch_op_num, op_num))
# create jobs
for program_id in range(prog_num):
job = core.Job(f"prog_{program_id}")
job.set_skip_gc_vars(skip_gc_vars[program_id])
# Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch
if program_id == prog_num - 1:
for i in range(fetch_op_num):
job.set_col_attr_for_fetch_op(
fetch_op_indics[i],
i * micro_batch_num + micro_batch_id,
)
job_list.append(job)
type_to_program = {}
for program_id in range(prog_num):
type_to_program[f"prog_{program_id}"] = main_progs[program_id].desc
plan = core.Plan(job_list, type_to_program)
scope = core.Scope()
main_exe = _StandaloneExecutor(self.place, plan, scope)
return main_exe
def run_program(
self,
apply_custom_stream=False,
split_prog=False,
apply_mannual_event=False,
):
paddle.seed(2022)
main_program, startup_program, fetch_list = build_program()
self.assertEqual(len(startup_program.global_block().ops), 0)
if apply_custom_stream:
self.set_custom_stream(main_program)
main_progs = [main_program]
startup_progs = [startup_program]
if apply_custom_stream and split_prog:
main_progs = self.split_program(main_program, apply_mannual_event)
outs = []
exe = self.create_standalone_exe(main_progs, startup_progs, fetch_list)
for i in range(self.steps):
outs.append(exe.run(feed_names=[]))
return outs
def test_result(self):
if not core.is_compiled_with_cuda():
return
baselines = self.run_program()
stream_outs = self.run_program(apply_custom_stream=True)
split_outs = self.run_program(apply_custom_stream=True, split_prog=True)
mannual_outs = self.run_program(
apply_custom_stream=True, split_prog=True, apply_mannual_event=True
)
for bl, out0, out1, out2 in zip(
baselines, stream_outs, split_outs, mannual_outs
):
self.assertEqual(bl[0], out0[0])
self.assertEqual(bl[0], out2[0])
# self.assertNotEqual(bl[0], out1[0])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册