diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.cc b/paddle/fluid/distributed/auto_parallel/dist_attr.cc index 5ba9a700e3e29221014979c239af92ca15a74af1..b7bb47b3b859e9a8cc18a9a16bdaa891f41cb0e4 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.cc +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.cc @@ -120,6 +120,9 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { set_is_recompute(dist_attr.is_recompute()); set_execution_stream(dist_attr.execution_stream()); set_stream_priority(dist_attr.stream_priority()); + set_force_record_event(dist_attr.force_record_event()); + set_event_to_record(dist_attr.event_to_record()); + set_events_to_wait(dist_attr.events_to_wait()); set_scheduling_priority(dist_attr.scheduling_priority()); set_annotated(dist_attr.annotated()); } diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.h b/paddle/fluid/distributed/auto_parallel/dist_attr.h index 3d71c0fd75a4c0ddae1e1e2d8bfe9c8f6a342700..347c7fc05dfa02bc94dea9f101f4072606b937ac 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.h +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.h @@ -141,6 +141,26 @@ class OperatorDistAttr { execution_stream_ = execution_stream; } + void set_event_to_record(const std::string& event_name) { + event_to_record_ = event_name; + } + + void set_force_record_event(bool force_record_event) { + force_record_event_ = force_record_event; + } + + void set_events_to_wait(const std::vector& events_to_wait) { + events_to_wait_ = events_to_wait; + } + + bool force_record_event() const { return force_record_event_; } + + const std::string& event_to_record() const { return event_to_record_; } + + const std::vector& events_to_wait() const { + return events_to_wait_; + } + int stream_priority() const { return stream_priority_; } void set_stream_priority(int stream_priority) { @@ -204,6 +224,11 @@ class OperatorDistAttr { void parse_from_string(const std::string& data); + static std::string unique_name(std::string key) { + static std::atomic id_{0}; + return key + "_" + std::to_string(id_++); + } + private: static std::vector fields_; std::map input_dist_attrs_; @@ -214,6 +239,9 @@ class OperatorDistAttr { int64_t impl_idx_ = 0; bool is_recompute_ = false; std::string execution_stream_ = kDefault; + bool force_record_event_ = false; + std::vector events_to_wait_; + std::string event_to_record_ = unique_name("event"); // event_idx int stream_priority_ = 0; // lower value, higher priority int64_t scheduling_priority_ = 0; // lower value, higher priority std::map annotated_; diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index bbb3a66bac23332754bc21b5efbfc3ffe005fee9..be87549ecbfbc17782776ffc6f08be33f0f5d049 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -656,6 +656,10 @@ void BuildOpFuncList(const platform::Place& place, } op_func_node.stream_priority_ = dist_attr->stream_priority(); op_func_node.scheduling_priority_ = dist_attr->scheduling_priority(); + // set mannual event information + op_func_node.force_record_event_ = dist_attr->force_record_event(); + op_func_node.events_to_wait_ = dist_attr->events_to_wait(); + op_func_node.event_to_record_ = dist_attr->event_to_record(); } else { if (interpreter::IsCommunicationOp(op)) { // NOTE(Ruibiao): Dispatching computation before communication improves diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index 3fc9dde363039ac8ce4dab7eaf96a0cc1fc02995..27ac1681a4008564787651370ae0134a57797575 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -97,6 +97,15 @@ void StreamAnalyzer::ConstructEvents(std::vector* instructions) { platform::GenerateDeviceEventFlag()); recorder_instr.AddEventToRecord(device_event, platform::kCUDA /*unused*/); + // It means the event will be waited for other interpreter that the + // event name of a operator is not 'default'. + if (recorder_instr.OpFunc()->force_record_event_ == true && + (*program_force_events_to_wait_) + .count(recorder_instr.OpFunc()->event_to_record_) == 0) { + (*program_force_events_to_wait_)[recorder_instr.OpFunc() + ->event_to_record_] = + recorder_instr.EventToRecord(); + } instr2event.emplace(recorder_instr_id, device_event); } @@ -108,6 +117,65 @@ void StreamAnalyzer::ConstructEvents(std::vector* instructions) { } } } + // NOTE(lizhiyu): The mannual event only support the program_interpreter to + // annalyze the streams across the sub_programs. construct mannual events to + // record + for (auto& instruction : *instructions) { + // create extra event to record + auto op_func_node = instruction.OpFunc(); + if (op_func_node->force_record_event_ && + instruction.EventToRecord() == nullptr) { + auto place = instruction.DeviceContext().GetPlace(); + if (platform::is_gpu_place(place)) { + PADDLE_ENFORCE_NE( + op_func_node->event_to_record_, + "default", + phi::errors::InvalidArgument( + "If the attribute 'force_record_event_' of one " + "operator is 'true', the 'event_to_record_' of this " + "operator can not be 'default'. But the " + "'event_name' of the operator %s is 'default'.", + instruction.OpBase()->Type().c_str())); + PADDLE_ENFORCE_EQ( + (*program_force_events_to_wait_) + .find(op_func_node->event_to_record_), + (*program_force_events_to_wait_).end(), + phi::errors::InvalidArgument( + "The program_force_events_to_wait_ had the event " + "that belongs to the operator : %s before the operator create " + "the event, " + "This is is werid.", + instruction.OpBase()->Type().c_str())); + std::shared_ptr device_event = + std::make_shared(place, + platform::GenerateDeviceEventFlag()); + instruction.AddEventToRecord(device_event, platform::kCUDA /*unused*/); + (*program_force_events_to_wait_)[op_func_node->event_to_record_] = + instruction.EventToRecord(); + VLOG(6) << "Create mannual event: " << op_func_node->event_to_record_ + << " for the operator: " << instruction.OpBase()->Type(); + } + } + // add extra mannual events + if (!(op_func_node->events_to_wait_.empty())) { + for (auto event_name : op_func_node->events_to_wait_) { + PADDLE_ENFORCE_NE( + (*program_force_events_to_wait_).find(event_name), + (*program_force_events_to_wait_).end(), + phi::errors::InvalidArgument( + "The program_force_events_to_wait_ don't have the event %s " + "for the operator: %s to wait. The event should had been " + "created by the operator " + "whose event_to_record_ is %s.", + event_name.c_str(), + instruction.OpBase()->Type().c_str(), + event_name.c_str())); + + instruction.AddEventToWait( + (*program_force_events_to_wait_)[event_name].get()); + } + } + } } DeviceContext* StreamAnalyzer::ParseDeviceContext( diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h index f8ba8103620a24e2a7b64f24605045c45368144f..8f2ee33ca4ed5addcccc3342bfb382760016047f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h @@ -86,6 +86,12 @@ class StreamAnalyzer { void ShareEventInfoFrom(const StreamAnalyzer& src); + void SetForceEventsToWaitInfo( + std::unordered_map>* + program_force_events_to_wait) { + program_force_events_to_wait_ = program_force_events_to_wait; + } + std::shared_ptr< std::map>>> GetEventInfo() const; @@ -114,6 +120,8 @@ class StreamAnalyzer { std::shared_ptr< std::map>>> event_info_; + std::unordered_map>* + program_force_events_to_wait_; // not owned }; /// ======================== /// diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index edce1554b0a90f3472ad4d650b4882420273d48e..bf0c0880f385d77b1318e9a2edcf05e673f8694e 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -167,6 +167,9 @@ struct OpFuncNode { // TODO(zhiqiu): Better make it unique_ptr std::shared_ptr operator_base_{nullptr}; std::string execution_stream_{kDefaultStream}; + bool force_record_event_{false}; + std::vector events_to_wait_; + std::string event_to_record_{"default"}; OpFuncType type_; OpKernelComputeFunc kernel_func_; @@ -212,10 +215,18 @@ class Instruction { events_to_wait_.emplace_back(instr_id, event, waiter_type); } + void AddEventToWait(const EventInter* event_inter) { + events_to_wait_.push_back(*event_inter); + } + const std::vector& EventsToWait() const { return events_to_wait_; } + const std::shared_ptr& EventToRecord() const { + return event_to_record_; + } + void AddNextInstrInDifferentThread(size_t id) { next_instrs_in_different_thread.push_back(id); } diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 43e5301476d9ba54bea0bf422a7275d9f4ac434e..29998e61f011126871fa3c5a5a0c22186b334632 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -615,6 +615,7 @@ void ProgramInterpreter::Convert( vec_instruction_.reserve(op_nums); for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { auto& op_func_node = nodes[op_idx]; + stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); #ifdef PADDLE_WITH_CUDA if (FLAGS_new_executor_use_cuda_graph) { diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index 29ec71059c778d23576a5c1e833d276ef675b4d6..27348d57fcd173d01ace93866a7b4fc5b1623fb8 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -81,6 +81,17 @@ class ProgramInterpreter : public InterpreterBaseImpl { hookfuncs_ = hookfuncs; } + std::unordered_map>* + GetForceEventsToWaitInfo() { + return force_evnets_to_wait_; + } + + void SetForceEventsToWaitInfo( + std::unordered_map>* + force_evnets_to_wait) { + force_evnets_to_wait_ = force_evnets_to_wait; + } + bool IsStaticBuild() const { return static_build_; } private: @@ -162,6 +173,9 @@ class ProgramInterpreter : public InterpreterBaseImpl { ExecutionConfig execution_config_; + std::unordered_map>* + force_evnets_to_wait_; + VariableScope var_scope_; Scope* local_scope_{nullptr}; // not owned diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 7d65fe33818a13579f8c46ac75301f0ee9b085b9..303a4f8478127064bc0d43ca87427bd6842a6e5c 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -41,6 +41,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, Scope* scope) : place_(place), plan_(plan), scope_(scope) { int64_t micro_batch_num = plan_.MicroBatchNum(); + vec_force_events_to_wait_.resize(micro_batch_num); for (int64_t i = 0; i < micro_batch_num; ++i) { micro_batch_scopes_.emplace_back(&scope->NewScope()); } @@ -129,6 +130,14 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, micro_batch_scopes_[micro_batch_id], execution_config)); interpretercores_.back()->SetCopyProgram(program); + + // Note(lizhiyu): Add mannual event info + auto prog_inter = const_cast( + static_cast( + interpretercores_.back()->Impl())); + prog_inter->SetForceEventsToWaitInfo( + &(vec_force_events_to_wait_[micro_batch_id])); + // NOTE(lizhiyu): Now we only check backward subprogram. After static // build strategy is completely, we should // check all the program in the PP strategy. @@ -181,6 +190,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( VLOG(6) << "Run job (" << job_idx << "), type = " << job_type << ", micro_batch_id =" << job->MicroBatchId(); + // Note(sonder): Share build results don't work for new IR now. if (type_to_first_id.count(job_type) != 0 && !FLAGS_enable_new_ir_in_executor) { diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index 1da628fe27bb7900b06bc71484dc6effcfa8842d..bec52add981bfdcbda93b4ba6d14123c6a329df4 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -52,6 +52,9 @@ class StandaloneExecutor { Scope* scope_; std::vector fetch_var_names_; + + std::vector>> + vec_force_events_to_wait_; }; } // namespace framework diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index bdd467fbefa8bb0d4cc71792b8cd6156278bedef..0619ad319af9a55be1c22e40e48895225821c3ab 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -441,6 +441,16 @@ void BindAutoParallel(py::module *m) { .def_property("scheduling_priority", &OperatorDistAttr::scheduling_priority, &OperatorDistAttr::set_scheduling_priority) + .def_property("force_record_event", + &OperatorDistAttr::force_record_event, + &OperatorDistAttr::set_force_record_event) + .def_property("events_to_wait", + &OperatorDistAttr::events_to_wait, + &OperatorDistAttr::set_events_to_wait, + pybind11::return_value_policy::reference) + .def_property("event_to_record", + &OperatorDistAttr::event_to_record, + &OperatorDistAttr::set_event_to_record) .def_property("annotated", &OperatorDistAttr::annotated, &OperatorDistAttr::set_annotated) diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index ba502a043ab35323bd8a4cfdfc863123a08af504..89c9c65ba176172eae1efe34ed121cf0ed527cd7 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -525,3 +525,20 @@ def _program_for_fthenb_and_1f1b(program): # It MUST return in this order return [lr_prog, fwd_prog, bwd_prog, opt_prog] + + +def _add_event_dependency(recorder_op_desc, waiter_op_desc): + ''' + Add the extra event dependcy of the two operators. + This function mainly aims for the cross-programs in pipeline parallelism, + especial for the 'send_v2' 'recv_v2' etc. + ''' + if not recorder_op_desc.dist_attr.force_record_event: + recorder_op_desc.dist_attr.force_record_event = True + # NOTE(lizhiyu): Here is the copy of 'waiter_op_desc.dist_attr.events_to_wait' not the reference, + # because the type of 'events_to_wait' is 'const vector&' while the type of + # 'waiter_wait_list' is python list. + waiter_wait_list = waiter_op_desc.dist_attr.events_to_wait + if recorder_op_desc.dist_attr.event_to_record not in waiter_wait_list: + waiter_wait_list.append(recorder_op_desc.dist_attr.event_to_record) + waiter_op_desc.dist_attr.events_to_wait = waiter_wait_list diff --git a/test/standalone_executor/CMakeLists.txt b/test/standalone_executor/CMakeLists.txt index a2b55191d53fe14ed816634e6e6c60189a7c7e4b..1e351d176bb154ba91b5240f03d9675086015aa6 100644 --- a/test/standalone_executor/CMakeLists.txt +++ b/test/standalone_executor/CMakeLists.txt @@ -2,6 +2,7 @@ file( GLOB TEST_INTERP_CASES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +list(REMOVE_ITEM TEST_INTERP_CASES "test_standalone_custom_event.py") string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") foreach(target ${TEST_INTERP_CASES}) @@ -31,8 +32,8 @@ py_test_modules( # These UTs are to temporarily test static build for standalone_executor, will be removed after static build is enabled by default. set(STATIC_BUILD_TESTS test_standalone_controlflow test_standalone_cuda_graph_multi_stream - test_standalone_custom_stream test_standalone_executor - test_standalone_multiply_write) + test_standalone_custom_stream test_standalone_custom_event + test_standalone_executor test_standalone_multiply_write) foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) py_test_modules( diff --git a/test/standalone_executor/test_standalone_custom_event.py b/test/standalone_executor/test_standalone_custom_event.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9fe7a3197b6f5d9afc3bedc52e26621b153799 --- /dev/null +++ b/test/standalone_executor/test_standalone_custom_event.py @@ -0,0 +1,203 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.distributed.passes.pass_utils import ( + _add_event_dependency, + get_skip_gc_vars, + split_program, +) +from paddle.fluid import core +from paddle.fluid.executor import _add_feed_fetch_ops, _StandaloneExecutor + +paddle.enable_static() + + +def build_program(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + # data -> [matmul] -> out ->[add] -> add_out + with paddle.static.device_guard('gpu'): + data = paddle.ones([1024, 2048], dtype='float32', name='data') + weight = paddle.randn([2048, 2048], name='weight') # gpu + matmul_out = paddle.matmul(data, weight, name='matmul_out') # gpus + bias = paddle.ones([1024, 2048], dtype='float32', name='bias') + add_out = paddle.add(matmul_out, bias, name='add_out') + # add_out -> [sub] -> sub_out -> [tanh] -> tanh_out + sub_out = paddle.subtract(add_out, data, name='sub_out') + tanh_out = paddle.tanh(sub_out, name='tanh_out') + bias_1 = paddle.add(bias, sub_out, name='bias_1') + out_before = paddle.tanh(bias_1, name='out_before') + out_last = paddle.subtract(tanh_out, data, name='out_last') + out_last2 = paddle.matmul(out_last, weight, name="matmul_2_out") + + out = paddle.add(out_before, out_last2, name='out') + mean = paddle.mean(out, name='mean_out') + + return main_program, startup_program, [mean] + + +class TestMannulEvent(unittest.TestCase): + """ + fill_constant(def) gaussian_random(def) + | | | | + | | matmul_v2(s1) fill_constant(def) + | | | | | | + | | elementwise_add(s1) | + | | | | + | elementwise_sub(s1) | + | | | | + | tanh(s1) elementwise_add(s1) + | | | + elementwise_sub(s1) tanh(s1) + | | + matmul_v2(s1) | + | | ---split prog---- + elementwise_add(s2) + | + reduce_mean(s2) + """ + + def setUp(self): + self.steps = 3 + self.place_desc = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + self.place = core.Place() + self.place.set_place(self.place_desc) + + def set_custom_stream(self, prog): + op_index_for_stream1 = [2, 4, 5, 6, 7, 8, 9, 10] + op_index_for_stream2 = [11, 12] + ops = prog.global_block().ops + for op_index in op_index_for_stream1: + ops[op_index].dist_attr.execution_stream = "s1" + ops[op_index].dist_attr.stream_priority = 0 + for op_index in op_index_for_stream2: + ops[op_index].dist_attr.execution_stream = "s2" + ops[op_index].dist_attr.stream_priority = -1 + + def split_program(self, prog, apply_mannual_event=False): + # split two subprograms + waiter_recorder_events_map = {11: [8, 10]} + prog_block = prog.global_block() + ops = prog_block.ops + if apply_mannual_event: + for waiter, recorders in waiter_recorder_events_map.items(): + for recorder in recorders: + _add_event_dependency(ops[recorder].desc, ops[waiter].desc) + main_progs, _, _ = split_program(prog, [11]) + return main_progs + + def create_standalone_exe(self, main_progs, startup_progs, fetch_list): + micro_batch_num = 1 + micro_batch_id = 0 + job_list = [] + prog_num = len(main_progs) + fetch_op_num = len(fetch_list) + skip_gc_vars = get_skip_gc_vars(main_progs) + + if prog_num == 1: # single prog + main_progs[0] = _add_feed_fetch_ops( + main_progs[0], + [], + fetch_list, + "feed", + "fetch", + use_fetch_v2=True, + ) + op_num = len(main_progs[0].block(0).ops) + fetch_op_indics = list(range(op_num - fetch_op_num, op_num)) + else: + main_progs[-1] = _add_feed_fetch_ops( + main_progs[-1], + [], + fetch_list, + "feed", + "fetch", + use_fetch_v2=True, + ) + op_num = len(main_progs[-1].block(0).ops) + fetch_op_indics = list(range(op_num - fetch_op_num, op_num)) + + # create jobs + for program_id in range(prog_num): + job = core.Job(f"prog_{program_id}") + job.set_skip_gc_vars(skip_gc_vars[program_id]) + # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch + if program_id == prog_num - 1: + for i in range(fetch_op_num): + job.set_col_attr_for_fetch_op( + fetch_op_indics[i], + i * micro_batch_num + micro_batch_id, + ) + job_list.append(job) + + type_to_program = {} + for program_id in range(prog_num): + type_to_program[f"prog_{program_id}"] = main_progs[program_id].desc + + plan = core.Plan(job_list, type_to_program) + scope = core.Scope() + main_exe = _StandaloneExecutor(self.place, plan, scope) + return main_exe + + def run_program( + self, + apply_custom_stream=False, + split_prog=False, + apply_mannual_event=False, + ): + paddle.seed(2022) + main_program, startup_program, fetch_list = build_program() + self.assertEqual(len(startup_program.global_block().ops), 0) + + if apply_custom_stream: + self.set_custom_stream(main_program) + main_progs = [main_program] + startup_progs = [startup_program] + if apply_custom_stream and split_prog: + main_progs = self.split_program(main_program, apply_mannual_event) + outs = [] + exe = self.create_standalone_exe(main_progs, startup_progs, fetch_list) + for i in range(self.steps): + outs.append(exe.run(feed_names=[])) + return outs + + def test_result(self): + if not core.is_compiled_with_cuda(): + return + + baselines = self.run_program() + stream_outs = self.run_program(apply_custom_stream=True) + split_outs = self.run_program(apply_custom_stream=True, split_prog=True) + mannual_outs = self.run_program( + apply_custom_stream=True, split_prog=True, apply_mannual_event=True + ) + for bl, out0, out1, out2 in zip( + baselines, stream_outs, split_outs, mannual_outs + ): + self.assertEqual(bl[0], out0[0]) + self.assertEqual(bl[0], out2[0]) + # self.assertNotEqual(bl[0], out1[0]) + + +if __name__ == "__main__": + unittest.main()