From fa7ba041fc6621720241563bb66a9de776cf0a0f Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Fri, 2 Jun 2023 13:59:24 +0800 Subject: [PATCH] Run multiple programs in standalone executor (#54268) --- .../new_executor/standalone_executor.cc | 60 +++++- .../new_executor/standalone_executor.h | 12 +- paddle/fluid/pybind/pybind.cc | 3 +- python/paddle/fluid/executor.py | 46 +---- .../new_executor/standalone_executor_test.cc | 5 +- .../test_standalone_executor_multi_program.py | 184 ++++++++++++++++++ 6 files changed, 254 insertions(+), 56 deletions(-) create mode 100644 test/standalone_executor/test_standalone_executor_multi_program.py diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 7cd57447f28..6745fec7ca3 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -19,8 +19,8 @@ namespace paddle { namespace framework { StandaloneExecutor::StandaloneExecutor(const platform::Place& place, - const ProgramDesc& prog) - : place_(place), prog_(prog) {} + const std::vector& programs) + : place_(place), programs_(programs) {} paddle::framework::FetchList StandaloneExecutor::Run( Scope* scope, @@ -28,18 +28,60 @@ paddle::framework::FetchList StandaloneExecutor::Run( const std::vector& fetch_names) { platform::RecordEvent record_event( "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); - auto core = GetInterpreterCore(scope, prog_, feed_names, fetch_names); - VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; - return core->Run(feed_names); + // TODO(Ruibiao): Unified single and multiple program run + if (programs_.size() == 1) { // run single program + VLOG(6) << "Run single program"; + auto core = GetInterpreterCore(scope, + programs_.at(0), + feed_names, + fetch_names, + 0, + interpreter::ExecutionConfig()); + VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; + return core->Run(feed_names); + } else { // run multiple programs + VLOG(6) << "Run multiple program, programs_.size() " << programs_.size(); + FetchList merged_fetch_list; + for (size_t program_idx = 0; program_idx < programs_.size(); + ++program_idx) { + const ProgramDesc& program = programs_[program_idx]; + + interpreter::ExecutionConfig execution_config; + execution_config.create_local_scope = false; + // TODO(Ruibiao): hack skip gc for all vars, improve it later + std::set skip_gc_vars; + for (VarDesc* var : program.Block(0).AllVars()) { + execution_config.skip_gc_vars.insert(var->Name()); + } + + // TODO(Ruibiao): ONLY support feeds data in the first program for now + const std::vector& real_feed_names = + (program_idx == 0 ? feed_names : std::vector()); + auto core = GetInterpreterCore(scope, + program, + real_feed_names, + fetch_names, + program_idx, + execution_config); + const FetchList& fetch_list = core->Run(real_feed_names); + std::move(fetch_list.begin(), + fetch_list.end(), + std::back_inserter(merged_fetch_list)); + } + return merged_fetch_list; + } } std::shared_ptr StandaloneExecutor::GetInterpreterCore( Scope* scope, - const ProgramDesc& prog, + const ProgramDesc& program, const std::vector& feed_names, - const std::vector& fetch_names) { + const std::vector& fetch_names, + size_t program_idx, + interpreter::ExecutionConfig execution_config) { std::ostringstream oss; + oss << "prog_idx:" << program_idx << ","; oss << "feed:"; for (auto& feedname : feed_names) { oss << feedname << ","; @@ -55,8 +97,8 @@ std::shared_ptr StandaloneExecutor::GetInterpreterCore( if (iter == interpretercores_.end()) { VLOG(3) << "create interpreter_core for " << oss.str() << " on place " << place_; - std::shared_ptr core = - std::make_shared(place_, prog.Block(0), scope); + std::shared_ptr core = std::make_shared( + place_, program.Block(0), scope, execution_config); interpretercores_.emplace(oss.str(), core); return core; } else { diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index dc89cb07f55..76de149e952 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -31,7 +31,8 @@ class InterpreterCore; class StandaloneExecutor { public: - StandaloneExecutor(const platform::Place& place, const ProgramDesc& prog); + StandaloneExecutor(const platform::Place& place, + const std::vector& programs); ~StandaloneExecutor() {} @@ -47,10 +48,13 @@ class StandaloneExecutor { Scope* scope, const ProgramDesc& prog, const std::vector& feed_names, - const std::vector& fetch_names); + const std::vector& fetch_names, + size_t program_idx, + interpreter::ExecutionConfig execution_config); - platform::Place place_; - const ProgramDesc& prog_; + const platform::Place place_; + const std::vector programs_; + std::vector microbatch_scopes_; std::unordered_map> interpretercores_; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 31bf9d1fc06..0f1d9fc6782 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1842,7 +1842,8 @@ All parameter, weight, gradient are variables in Paddle. }); py::class_(m, "StandaloneExecutor") - .def(py::init()) + .def( + py::init &>()) .def("run", [](StandaloneExecutor &self, Scope *scope, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 5eec58e3d8f..6efc93f22fd 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -632,10 +632,10 @@ handler = FetchHandlerExample(var_dict=var_dict) class _StandaloneExecutor: - def __init__(self, place, main_program, scope): + def __init__(self, place, programs, scope): self._place = core.Place() self._place.set_place(place) - self._main_program = main_program + self._programs = programs self._scope = scope self._new_exe = self._create_new_executor() @@ -660,46 +660,12 @@ class _StandaloneExecutor: return tensors def _create_new_executor(self): - new_exe = core.StandaloneExecutor(self._place, self._main_program.desc) + new_exe = core.StandaloneExecutor( + self._place, [program.desc for program in self._programs] + ) return new_exe - def _update_feed(self, feed): - """ - Update the feed dict, remove the feed item which is pruned in program. - - Notes: This is a very low level API. Users should not use this API - directly. - - Args: - feed(list|dict): feed dict or list. - - Returns: - feed:(list|dict) updated feed. - """ - if feed is None: - feed = {} - elif isinstance(feed, (list, tuple)): - assert len(feed) == 1, "Not compiled with data parallel" - feed = feed[0] - - if not isinstance(feed, dict): - raise TypeError( - "feed requires dict as its Parameter. But you passed in %s" - % (type(feed)) - ) - - global_block = self._main_program.global_block() - for feed_name in list(feed.keys()): - if not global_block.has_var(feed_name): - feed.pop(feed_name) - warnings.warn( - "The variable %s is not found in program. It is not declared or is pruned." - % feed_name - ) - - return feed - def _check_fetch(self, fetch_list): if fetch_list is None: fetch_list = [] @@ -886,7 +852,7 @@ class _ExecutorCache: ) new_program = program.clone() - new_exe = _StandaloneExecutor(place, new_program, scope) + new_exe = _StandaloneExecutor(place, [new_program], scope) return new_program, new_exe diff --git a/test/cpp/new_executor/standalone_executor_test.cc b/test/cpp/new_executor/standalone_executor_test.cc index 771f1d0ad59..a9026ba2468 100644 --- a/test/cpp/new_executor/standalone_executor_test.cc +++ b/test/cpp/new_executor/standalone_executor_test.cc @@ -146,9 +146,10 @@ TEST(StandaloneExecutor, run) { ProgramDesc main_prog = GetLmMainProgram(); Scope scope; - StandaloneExecutor startup_exec(place, startup_prog); + StandaloneExecutor startup_exec(place, + std::vector{startup_prog}); startup_exec.Run(&scope, {}, {}); - StandaloneExecutor exec(place, main_prog); + StandaloneExecutor exec(place, std::vector{main_prog}); exec.Run(&scope, {}, {}); auto start = std::chrono::steady_clock::now(); diff --git a/test/standalone_executor/test_standalone_executor_multi_program.py b/test/standalone_executor/test_standalone_executor_multi_program.py new file mode 100644 index 00000000000..5eeb64f5fac --- /dev/null +++ b/test/standalone_executor/test_standalone_executor_multi_program.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +import paddle +from paddle.distributed.passes.pass_utils import split_program +from paddle.fluid import core +from paddle.fluid.executor import ( + _add_feed_fetch_ops, + _as_lodtensor, + _StandaloneExecutor, + check_feed_shape_type, +) +from paddle.nn import TransformerEncoderLayer + +paddle.enable_static() + + +class TestMulitProgramRun(unittest.TestCase): + def setUp(self): + self.place_desc = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + self.place = core.Place() + self.place.set_place(self.place_desc) + + def build_program(self): + batch_size = 2 + src_len = 4 + d_model = 128 + n_head = 2 + + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + enc_input = paddle.static.data( + name="enc_input", + shape=[batch_size, src_len, d_model], + dtype="float32", + ) + attn_mask = paddle.static.data( + name="attn_mask", + shape=[batch_size, n_head, src_len, src_len], + dtype="float32", + ) + encoder_layer = TransformerEncoderLayer( + d_model, n_head, dim_feedforward=512 + ) + attn_mask = paddle.nn.layer.transformer._convert_attention_mask( + attn_mask, enc_input.dtype + ) + + enc_output = encoder_layer(enc_input, attn_mask) + + split_op_indics = [len(main_program.block(0).ops)] + + enc_output = encoder_layer(enc_output, attn_mask) + + np.random.seed(2022) + feed = { + enc_input.name: np.random.rand( + batch_size, src_len, d_model + ).astype(np.float32), + attn_mask.name: np.random.rand( + batch_size, n_head, src_len, src_len + ).astype(np.float32), + } + fetch_list = [enc_output.name] + + return ( + startup_program, + main_program, + split_op_indics, + feed, + fetch_list, + ) + + def feed_data(self, program, feed, feed_var_name, scope): + # feed var to framework + global_block = program.global_block() + for op in global_block.ops: + if op.desc.type() == 'feed': + feed_target_name = op.desc.output('Out')[0] + cur_feed = feed[feed_target_name] + var = global_block.var(feed_target_name) + if var.dtype != core.VarDesc.VarType.STRINGS: + if not isinstance(cur_feed, core.LoDTensor): + cur_feed = _as_lodtensor( + cur_feed, self.place_desc, var.dtype + ) + check_feed_shape_type(var, cur_feed) + idx = op.desc.attr('col') + core.set_feed_variable(scope, cur_feed, feed_var_name, idx) + else: + break + + def run_program( + self, + startup_program, + main_program, + feed, + fetch_list, + scope, + run_step, + split_op_indics=None, + ): + paddle.seed(2022) + + startup_exe = _StandaloneExecutor(self.place, [startup_program], scope) + startup_exe.run(scope, [], []) + + programs = [main_program] + if split_op_indics is not None: + programs, _, _ = split_program(main_program, split_op_indics) + # hack add feed ops in the first program and fetch ops in the last program + programs[0] = _add_feed_fetch_ops( + programs[0], feed, [], "feed", "fetch" + ) + programs[-1] = _add_feed_fetch_ops( + programs[-1], [], fetch_list, "feed", "fetch" + ) + else: + programs[0] = _add_feed_fetch_ops( + programs[0], feed, fetch_list, "feed", "fetch" + ) + + self.feed_data(programs[0], feed, "feed", scope) + + main_exe = _StandaloneExecutor(self.place, programs, scope) + + res = [] + for i in range(run_step): + res += main_exe.run(scope, list(feed.keys()), fetch_list) + return res + + def test_multi_program_run(self): + ( + startup_program, + main_program, + split_op_indics, + feed, + fetch_list, + ) = self.build_program() + run_step = 3 + res = self.run_program( + startup_program, + main_program, + feed, + fetch_list, + paddle.static.Scope(), + run_step, + ) + splited_res = self.run_program( + startup_program, + main_program, + feed, + fetch_list, + paddle.static.Scope(), + run_step, + split_op_indics, + ) + np.testing.assert_array_equal(res, splited_res) + + +if __name__ == "__main__": + unittest.main() -- GitLab