From e12b6c04a4140995c6832f56c24a38a16e7b579c Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Tue, 1 Nov 2022 16:48:30 +0800 Subject: [PATCH] Support custom stream for standalone executor (#47411) * [Auto Parallel] Improve the c++ dist attr * [Auto Parallel] Modify test_program.py * Support custom stream for standalone executor Co-authored-by: Yulong Ao --- .../distributed/auto_parallel/dist_attr.cc | 12 +- .../distributed/auto_parallel/dist_attr.h | 9 ++ .../new_executor/interpreter/data_transfer.cc | 1 + .../interpreter/dependency_builder.cc | 7 +- .../interpreter/interpreter_util.cc | 116 +++++++++++------- .../interpreter/interpreter_util.h | 24 ++-- .../new_executor/new_executor_defs.h | 26 ++-- .../framework/new_executor/stream_analyzer.cc | 21 ++-- paddle/fluid/framework/op_desc.cc | 4 + paddle/fluid/framework/op_desc.h | 1 + paddle/fluid/pybind/auto_parallel_py.cc | 3 + .../fluid/tests/unittests/CMakeLists.txt | 2 +- .../CMakeLists.txt | 0 .../test_standalone_controlflow.py | 0 .../test_standalone_custom_stream.py | 83 +++++++++++++ .../test_standalone_executor.py | 0 .../test_standalone_multiply_write.py | 0 17 files changed, 224 insertions(+), 85 deletions(-) rename python/paddle/fluid/tests/unittests/{interpreter => standalone_executor}/CMakeLists.txt (100%) rename python/paddle/fluid/tests/unittests/{interpreter => standalone_executor}/test_standalone_controlflow.py (100%) create mode 100644 python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py rename python/paddle/fluid/tests/unittests/{interpreter => standalone_executor}/test_standalone_executor.py (100%) rename python/paddle/fluid/tests/unittests/{interpreter => standalone_executor}/test_standalone_multiply_write.py (100%) diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.cc b/paddle/fluid/distributed/auto_parallel/dist_attr.cc index 57a5b40768a..5b97393864d 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.cc +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.cc @@ -319,7 +319,7 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { } std::vector OperatorDistAttr::fields_{ - "process_mesh", "impl_type", "impl_idx"}; + "process_mesh", "impl_type", "impl_idx", "execution_stream"}; OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) { VLOG(4) << "[OperatorDistAttr constructor] op type: " << op_->Type(); @@ -376,8 +376,9 @@ void OperatorDistAttr::initialize() { output_dist_attrs_[name] = TensorDistAttr(*output); } } - impl_type_ = "default"; + impl_type_ = kDefault; impl_idx_ = 0; + execution_stream_ = kDefault; } void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { @@ -386,9 +387,8 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) { set_process_mesh(dist_attr.process_mesh()); set_impl_type(dist_attr.impl_type()); set_impl_idx(dist_attr.impl_idx()); + set_execution_stream(dist_attr.execution_stream()); set_annotated(dist_attr.annotated()); - impl_type_ = dist_attr.impl_type(); - impl_idx_ = dist_attr.impl_idx(); } void OperatorDistAttr::set_input_dist_attrs( @@ -666,6 +666,7 @@ std::string OperatorDistAttr::to_string() const { } str += "impl_type: " + impl_type_ + ", "; str += "impl_idx: " + std::to_string(impl_idx_) + ", "; + str += "execution_stream: " + execution_stream_ + ", "; str += "annotated: [" + str_join(annotated_) + "], "; str += "\nprocess_mesh: " + process_mesh_.to_string() + ", "; str += "\ninput_dist_attrs: [\n"; @@ -747,6 +748,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) { if (lhs.impl_idx() != rhs.impl_idx()) { return false; } + if (lhs.execution_stream() != rhs.execution_stream()) { + return false; + } for (auto const& item : lhs.input_dist_attrs()) { if (rhs.input_dist_attrs().count(item.first) != 1) { return false; diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.h b/paddle/fluid/distributed/auto_parallel/dist_attr.h index d4aa306e712..61e61e2e53d 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.h +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.h @@ -46,6 +46,8 @@ using framework::OpDesc; using framework::ProgramDesc; using framework::VarDesc; +constexpr const char* kDefault = "default"; + class TensorDistAttr { public: TensorDistAttr() = default; @@ -205,6 +207,12 @@ class OperatorDistAttr { void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; } + const std::string& execution_stream() const { return execution_stream_; } + + void set_execution_stream(const std::string& execution_stream) { + execution_stream_ = execution_stream; + } + const std::map& annotated() const { return annotated_; } void set_annotated(const std::map& annotated); @@ -262,6 +270,7 @@ class OperatorDistAttr { ProcessMesh process_mesh_; std::string impl_type_; int64_t impl_idx_ = -1; + std::string execution_stream_; std::map annotated_; }; diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index efe10fcd5f3..bf51ebd1d48 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc index 3b2a2aed7f3..ae7d7e42536 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc @@ -151,9 +151,8 @@ void DependencyBuilder::AddDependencyForCoalesceTensorOp() { // 'first_read_fused_out_op' size_t target = first_read_fused_out_op; for (size_t j = first_read_fused_out_op + 1; j < op_num_; ++j) { - if (j == target + 1 && - IsCommunicationOp(instructions_->at(target).OpBase()->Type()) && - IsCommunicationOp(instructions_->at(j).OpBase()->Type())) { + if (j == target + 1 && IsCommunicationOp(instructions_->at(target)) && + IsCommunicationOp(instructions_->at(j))) { VLOG(4) << "Found consecutive communication ops, " << instructions_->at(target).OpBase()->Type() << " -> " << instructions_->at(j).OpBase()->Type(); @@ -174,7 +173,7 @@ void DependencyBuilder::AddDependencyForCoalesceTensorOp() { void DependencyBuilder::AddDependencyForCommunicationOp() { int dependence_op_idx = -1; for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { - if (IsCommunicationOp(instructions_->at(op_idx).OpBase()->Type())) { + if (IsCommunicationOp(instructions_->at(op_idx))) { if (dependence_op_idx != -1) { AddDownstreamOp(dependence_op_idx, op_idx); } diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index ae646ed42db..104217fa80f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -16,6 +16,7 @@ #include +#include "paddle/fluid/distributed/auto_parallel/dist_attr.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" @@ -125,18 +126,60 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, } } -void LogDeviceMemoryStats(const platform::Place& place) { - if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) { - VLOG(0) << "memory_allocated: " - << static_cast(memory::DeviceMemoryStatCurrentValue( - "Allocated", place.device)) / - 1024 / 1024 - << " MB"; - VLOG(0) << "max_memory_allocated: " - << static_cast(memory::DeviceMemoryStatPeakValue( - "Allocated", place.device)) / - 1024 / 1024 - << " MB"; +bool IsCommunicationOp(const Instruction& instr) { + const std::set special_comm_op_set = { + "send", + "recv", + "send_v2", + "recv_v2", + }; + const std::string& op_name = instr.OpBase()->Type(); + const std::string communication_op_prefix = "c_"; + if (op_name.find(communication_op_prefix) != std::string::npos || + special_comm_op_set.count(op_name)) { + return true; + } + return false; +} + +bool IsCpuOp(const Instruction& instr) { + return platform::is_cpu_place(instr.DeviceContext().GetPlace()); +} + +bool IsSupportedHeterPlace(const phi::Place& place) { + return platform::is_gpu_place(place) || platform::is_npu_place(place) || + platform::is_xpu_place(place) || platform::is_ipu_place(place) || + platform::is_custom_place(place); +} + +bool IsMemcpyD2H(const Instruction& instr) { + return instr.OpBase()->Type() == kMemcpyD2H; +} + +bool IsMemcpyH2D(const Instruction& instr) { + return instr.OpBase()->Type() == kMemcpyH2D; +} + +bool IsMemcpyOp(const Instruction& instr) { + return IsMemcpyD2H(instr) || IsMemcpyH2D(instr); +} + +void AddFetch(const std::vector& fetch_names, + framework::BlockDesc* block) { + auto* fetch_holder = block->Var(kFetchVarName); + fetch_holder->SetType(proto::VarType::FETCH_LIST); + fetch_holder->SetPersistable(true); + + int i = 0; + for (auto& fetch_name : fetch_names) { + // append fetch op + auto* op = block->AppendOp(); + op->SetType("fetch_v2"); + op->SetInput("X", {fetch_name}); + op->SetOutput("Out", {kFetchVarName}); + op->SetAttr("col", {static_cast(i)}); + op->CheckAttrs(); + i++; } } @@ -517,6 +560,12 @@ void BuildOpFuncList(const platform::Place& place, op_func_node.input_index = ins_name2id; op_func_node.output_index = outs_name2id; + const OperatorDistAttr* dist_attr = block.Op(i)->DistAttr(); + if (dist_attr && + dist_attr->execution_stream() != distributed::auto_parallel::kDefault) { + op_func_node.execution_stream_ = dist_attr->execution_stream(); + } + SingleStreamGuard single_stream_guard(ops[i]); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); @@ -748,38 +797,19 @@ void BuildOpFuncList(const platform::Place& place, memory::Release(place); } -void AddFetch(const std::vector& fetch_names, - framework::BlockDesc* block) { - auto* fetch_holder = block->Var(kFetchVarName); - fetch_holder->SetType(proto::VarType::FETCH_LIST); - fetch_holder->SetPersistable(true); - - int i = 0; - for (auto& fetch_name : fetch_names) { - // append fetch op - auto* op = block->AppendOp(); - op->SetType("fetch_v2"); - op->SetInput("X", {fetch_name}); - op->SetOutput("Out", {kFetchVarName}); - op->SetAttr("col", {static_cast(i)}); - op->CheckAttrs(); - i++; - } -} - -bool IsCommunicationOp(const std::string& op_name) { - const std::set special_comm_op_set = { - "send", - "recv", - "send_v2", - "recv_v2", - }; - const std::string communication_op_prefix = "c_"; - if (op_name.find(communication_op_prefix) != std::string::npos || - special_comm_op_set.count(op_name)) { - return true; +void LogDeviceMemoryStats(const platform::Place& place) { + if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) { + VLOG(0) << "memory_allocated: " + << static_cast(memory::DeviceMemoryStatCurrentValue( + "Allocated", place.device)) / + 1024 / 1024 + << " MB"; + VLOG(0) << "max_memory_allocated: " + << static_cast(memory::DeviceMemoryStatPeakValue( + "Allocated", place.device)) / + 1024 / 1024 + << " MB"; } - return false; } } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index 52163c64f7e..b842d3acfde 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -65,11 +65,20 @@ class AsyncWorkQueue { std::unique_ptr queue_group_; }; -void LogDeviceMemoryStats(const platform::Place& place); +bool IsCommunicationOp(const Instruction& instr); -void BuildVariableScope(const framework::BlockDesc& block, - VariableScope* var_scope, - bool use_local_scope = true); +bool IsCpuOp(const Instruction& instr); + +bool IsMemcpyD2H(const Instruction& instr); + +bool IsMemcpyH2D(const Instruction& instr); + +bool IsMemcpyOp(const Instruction& instr); + +bool IsSupportedHeterPlace(const phi::Place& place); + +void AddFetch(const std::vector& fetch_names, + framework::BlockDesc* block); void BuildOpFuncList(const platform::Place& place, const framework::BlockDesc& block, @@ -79,10 +88,11 @@ void BuildOpFuncList(const platform::Place& place, const ExecutionConfig& execution_config, bool use_local_scope = true); -void AddFetch(const std::vector& fetch_names, - framework::BlockDesc* block); +void BuildVariableScope(const framework::BlockDesc& block, + VariableScope* var_scope, + bool use_local_scope = true); -bool IsCommunicationOp(const std::string& op_name); +void LogDeviceMemoryStats(const platform::Place& place); } // namespace interpreter } // namespace framework diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 6f2287a8966..6735e891230 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -34,6 +34,12 @@ using OpKernelComputeFunc = std::function; constexpr int kEmptyVarIndex = 0; +// stream types +constexpr const char* kCustomStream = "CustromStream"; +constexpr const char* kDefaultStream = "DefaultStream"; +constexpr const char* kD2HStream = "D2HStream"; +constexpr const char* kH2DStream = "H2DStream"; + class InterpretercoreInferShapeContext : public InferShapeContext { public: InterpretercoreInferShapeContext(const OperatorBase& op, @@ -274,6 +280,7 @@ class RuntimeInferShapeContext; struct OpFuncNode { // TODO(zhiqiu): Better make it unique_ptr std::shared_ptr operator_base_; + std::string execution_stream_{kDefaultStream}; std::map> input_index; std::map> output_index; std::unordered_set no_data_transform_index; @@ -379,25 +386,6 @@ static constexpr char kMemcpyH2D[] = "memcpy_h2d"; static constexpr char kMemcpyD2H[] = "memcpy_d2h"; static constexpr char kFetchVarName[] = "fetch"; -static bool IsMemcpyH2D(const Instruction& instr) { - return instr.OpBase()->Type() == kMemcpyH2D; -} - -static bool IsMemcpyD2H(const Instruction& instr) { - return instr.OpBase()->Type() == kMemcpyD2H; -} - -static bool IsCpuOp(const Instruction& instr) { - return platform::is_cpu_place(instr.DeviceContext().GetPlace()); -} - -// is supported heterogeneous place -static bool IsSupportedHeterPlace(const phi::Place& place) { - return platform::is_gpu_place(place) || platform::is_npu_place(place) || - platform::is_xpu_place(place) || platform::is_ipu_place(place) || - platform::is_custom_place(place); -} - // static_ref_ is the numer of last live ops calculated to statically after // `build` the Instructions. dynamic_ref_ is the runtime version ref which will // be decreased by one dynamiclly after the execution of an op (in last ops diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 09c54a64805..8ee82699b47 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -24,10 +24,6 @@ namespace paddle { namespace framework { -// stream types -constexpr const char* kD2HStream = "D2HStream"; -constexpr const char* kH2DStream = "H2DStream"; - class ContextManager { public: using DeviceContextMap = @@ -94,13 +90,14 @@ std::vector StreamAnalyzer::GetNeedEventVarIds( return false; }; - bool is_comm = interpreter::IsCommunicationOp(cur_instr.OpBase()->Type()) || - interpreter::IsCommunicationOp(next_instr.OpBase()->Type()); + bool is_memcpy = + interpreter::IsMemcpyOp(cur_instr) || interpreter::IsMemcpyOp(next_instr); + std::vector need_event_var_ids; for (auto& item : next_instr.Inputs()) { for (auto var_id : item.second) { if (unique_var_ids.count(var_id) > 0) { - if (!is_comm) { + if (is_memcpy) { if (next_instr.NoDataTransformVars().count(var_id)) { VLOG(4) << "Skip inserting event at variable " << item.first << " of operator " << next_instr.OpBase()->Type() @@ -186,12 +183,22 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( const OpFuncNode& op_func_node) { auto& op = op_func_node.operator_base_; auto& op_type = op->Type(); + const std::string& execution_stream = op_func_node.execution_stream_; ContextManager& ctx_manager = ContextManager::Instance(); // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is // synchronous. if (platform::is_gpu_place(place_) || platform::is_npu_place(place_) || platform::is_custom_place(place_)) { + VLOG(7) << "Parse DeviceContext for " << op_type + << ", execution stream = " << execution_stream; + if (execution_stream != kDefaultStream) { + return ctx_manager + .Get(std::string(kCustomStream) + "-" + execution_stream, place_) + .get() + .get(); + } + if (op_type == interpreter::kMemcpyD2H) { return ctx_manager.Get(std::string(kD2HStream), place_).get().get(); } else if (op_type == interpreter::kMemcpyH2D) { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 321230e8606..dcc47058b64 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -1105,6 +1105,10 @@ void OpDesc::InferVarType(BlockDesc *block) const { } } +const OperatorDistAttr *OpDesc::DistAttr() const { + return dist_attr_ ? dist_attr_.get() : nullptr; +} + OperatorDistAttr *OpDesc::MutableDistAttr() { if (dist_attr_) { return dist_attr_.get(); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 7987a9ded47..6c6f13d7c92 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -196,6 +196,7 @@ class OpDesc { uint64_t Id() const { return id_; } uint64_t OriginalId() const { return original_id_; } void SetOriginalId(uint64_t original_id) { original_id_ = original_id; } + const OperatorDistAttr *DistAttr() const; OperatorDistAttr *MutableDistAttr(); void SetDistAttr(const OperatorDistAttr &dist_attr); diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 1e0bda0c940..089f5da5abc 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -215,6 +215,9 @@ void BindAutoParallel(py::module *m) { .def_property("impl_idx", &OperatorDistAttr::impl_idx, &OperatorDistAttr::set_impl_idx) + .def_property("execution_stream", + &OperatorDistAttr::execution_stream, + &OperatorDistAttr::set_execution_stream) .def_property("annotated", &OperatorDistAttr::annotated, &OperatorDistAttr::set_annotated) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a206e8994e5..d1eaebcdc2e 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -893,7 +893,7 @@ add_subdirectory(asp) add_subdirectory(ir) -add_subdirectory(interpreter) +add_subdirectory(standalone_executor) if(WITH_TESTING) set_property(TEST test_parallel_executor_mnist diff --git a/python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt b/python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt similarity index 100% rename from python/paddle/fluid/tests/unittests/interpreter/CMakeLists.txt rename to python/paddle/fluid/tests/unittests/standalone_executor/CMakeLists.txt diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_controlflow.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_controlflow.py similarity index 100% rename from python/paddle/fluid/tests/unittests/interpreter/test_standalone_controlflow.py rename to python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_controlflow.py diff --git a/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py new file mode 100644 index 00000000000..3915b2459e0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_custom_stream.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.fluid import core +from test_standalone_executor import build_program + +paddle.enable_static() + + +class TestCustomStream(unittest.TestCase): + def setUp(self): + self.steps = 3 + + ### + ### fill_constant(cpu) gaussian_random + ### | | | | + ### | | matmul_v2(s1) fill_constant + ### | | | | | + ### | | elementwise_add(s1) | + ### | | | | + ### | elementwise_sub(cpu) | + ### | | | | + ### | tanh(cpu) elementwise_add(s2) + ### | | | + ### elementwise_sub(s1) tanh(s2) + ### | | + ### elementwise_add(s2) + ### | + ### reduce_mean(s2) + ### + def set_custom_stream(self, prog): + op_index_for_stream1 = [2, 4, 9] + op_index_for_stream2 = [7, 8, 10, 11] + ops = prog.global_block().ops + for op_index in op_index_for_stream1: + ops[op_index].dist_attr.execution_stream = "s1" + for op_index in op_index_for_stream2: + ops[op_index].dist_attr.execution_stream = "s2" + + def run_program(self, apply_custom_stream=False): + paddle.seed(2022) + main_program, startup_program, fetch_list = build_program() + self.assertEqual(len(startup_program.global_block().ops), 0) + + if apply_custom_stream: + self.set_custom_stream(main_program) + + with paddle.static.program_guard(main_program, startup_program): + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + scope = core.Scope() + outs = [] + for i in range(self.steps): + outs.append( + exe.run(main_program, scope=scope, fetch_list=fetch_list) + ) + return outs + + def test_result(self): + if not core.is_compiled_with_cuda(): + return + + baselines = self.run_program() + outs = self.run_program(apply_custom_stream=True) + for bl, out in zip(baselines, outs): + self.assertEqual(bl[0], out[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_executor.py similarity index 100% rename from python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py rename to python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_executor.py diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_multiply_write.py b/python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_multiply_write.py similarity index 100% rename from python/paddle/fluid/tests/unittests/interpreter/test_standalone_multiply_write.py rename to python/paddle/fluid/tests/unittests/standalone_executor/test_standalone_multiply_write.py -- GitLab