From 4bc085304f9372230c85a2393314daac2a9c9515 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 14 Sep 2021 12:51:17 +0800 Subject: [PATCH] Intergrate StandaloneExecutor in Static.Executor Interface with FLAGS_USE_STANDALONE_EXECUTOR (#35628) * Intergrate StandaloneExecutor in Static.Executor Interface with FLAGS_USE_STANDALONE_EXECUTOR * Enhance unittest and clean code in StandaloneExecutor * polish unittest --- .../new_executor/interpretercore_util.cc | 7 + .../new_executor/standalone_executor.cc | 13 +- python/paddle/fluid/executor.py | 140 ++++++++++++++- python/paddle/fluid/framework.py | 5 + .../interpreter/test_standalone_executor.py | 159 ++++++++++++++---- 5 files changed, 283 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index dafac3b904..56cac06436 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -117,6 +117,13 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, info.var_ref_count_ = 0; info.vardesc_ = var; var_scope->vec_meta_info_.push_back(info); + } else { + auto var_id = var_scope->name2id[var->Name()]; + if (nullptr == var_scope->vec_meta_info_[var_id].vardesc_) { + VLOG(3) << "update var:" << var->Name() << " desc from nullptr into " + << var; + var_scope->vec_meta_info_[var_id].vardesc_ = var; + } } } } diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 1152cc0cd1..a7579d5461 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -35,14 +35,13 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, auto v = outer_scope_->Var(name); if (global_scope_.name2id.find(name) == global_scope_.name2id.end()) { global_scope_.name2id[name] = global_scope_.var_list.size(); - } - - global_scope_.var_list.push_back(v); + global_scope_.var_list.push_back(v); - VariableMetaInfo info; - info.var_ref_count_ = 0; - info.vardesc_ = nullptr; - global_scope_.vec_meta_info_.push_back(info); + VariableMetaInfo info; + info.var_ref_count_ = 0; + info.vardesc_ = nullptr; + global_scope_.vec_meta_info_.push_back(info); + } } } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8ca0344962..74b6ec3480 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -136,9 +136,9 @@ def as_numpy(tensor, copy=False): numpy.ndarray """ if isinstance(tensor, core.LoDTensorArray): - return [as_numpy(t) for t in tensor] + return [as_numpy(t, copy) for t in tensor] if isinstance(tensor, list): - return [as_numpy(t) for t in tensor] + return [as_numpy(t, copy) for t in tensor] assert isinstance(tensor, core.LoDTensor) lod = tensor.lod() if len(lod) > 0: @@ -383,6 +383,17 @@ def _to_name_str(var): return _to_str(var) +def _is_enable_standalone_executor(): + """ + Whether to use experimental executor `StandaloneExecutor`. + """ + flag = False + env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None) + if env_val in [1, '1', True, 'True', 'true']: + flag = True + return flag + + def _get_strong_program_cache_key(program, feed, fetch_list): return str(id(program)) + _get_program_cache_key(feed, fetch_list) @@ -472,6 +483,121 @@ handler = FetchHandlerExample(var_dict=var_dict) """) +class _StandaloneExecutor(object): + def __init__(self, place, main_program): + self._place = core.Place() + self._place.set_place(place) + self._main_program = main_program + self._new_exe = self._create_new_executor() + + def run(self, feed, fetch_list, return_numpy=True): + """ + Args: + feed(list|dict): This parameter represents the input Tensors of the model. + If it is single card training, the feed is dict type, and if it is multi-card + training, the parameter feed can be dict or list of Tensors. If the + parameter type is dict, the data in the feed will be split and sent to + multiple devices (CPU/GPU), that is to say, the input data will be evenly + sent to different devices, so you should make sure the number of samples of + the current mini-batch must be greater than the number of places; + if the parameter type is list, those data are copied directly to each device, + so the length of this list should be equal to the number of places. + The default is None. + fetch_list(list): This parameter represents the Tensors that need to be returned + after the model runs. The default is None. + return_numpy(bool): This parameter indicates whether convert the fetched Tensors + (the Tensor specified in the fetch list) to numpy.ndarray. if it is False, + the type of the return value is a list of :code:`LoDTensor`. The default is True. + """ + feed = self._update_feed(feed) + fetch_list = self._check_fetch(fetch_list) + + tensors = self._new_exe.run(feed, fetch_list)._move_to_list() + if return_numpy: + return as_numpy(tensors, copy=True) + else: + return tensors + + def _create_new_executor(self): + # NOTE: It's a trick to set empty start_up program. + startup_program = Program() + outer_scope = global_scope() + new_exe = core.StandaloneExecutor(self._place, startup_program.desc, + self._main_program.desc, outer_scope) + + return new_exe + + def _update_feed(self, feed): + """ + Update the feed dict, remove the feed item which is pruned in program. + + Notes: This is a very low level API. Users should not use this API + directly. + + Args: + feed(list|dict): feed dict or list. + + Returns: + feed:(list|dict) updated feed. + """ + global_block = self._main_program.global_block() + if feed is None: + feed = {} + elif isinstance(feed, dict): + for feed_name in list(feed.keys()): + if not global_block.has_var(feed_name): + feed.pop(feed_name) + warnings.warn( + "The variable %s is not found in program. It is not declared or is pruned." + % feed_name) + else: + raise TypeError("Only support feed with `dict`, but received {}". + format(type(feed).__name__)) + + return feed + + def _check_fetch(self, fetch_list): + if fetch_list is None: + fetch_list = [] + + res = [] + for fetch_var in fetch_list: + if isinstance(fetch_var, Variable): + fetch_var = fetch_var.name + elif not isinstance(fetch_var, str): + raise TypeError( + "Required fetch_var shall be str|Variable, but received {}". + format(type(fetch_var).__name__)) + + res.append(fetch_var) + return res + + +class _ExecutorCache(object): + def __init__(self, place): + # {Program : _StandaloneExecutor} + self._place = place + self._cached_executors = {} + + def run(self, program, feed, fetch_list, return_numpy=True): + new_exe = self._get_exe_from_cache(program) + return new_exe.run(feed, fetch_list, return_numpy) + + def _get_exe_from_cache(self, program): + """ + Return cached _StandaloneExecutor instance. If not found, create associated + _StandaloneExecutor instance with given program and cache it. + """ + assert isinstance( + program, Program), "Required type(Program), but received {}".format( + type(program).__name__) + if program not in self._cached_executors: + new_exe = _StandaloneExecutor(self._place, program) + self._cached_executors[program] = new_exe + + return self._cached_executors[program] + + class Executor(object): """ :api_attr: Static Graph @@ -568,6 +694,10 @@ class Executor(object): self._auto_checkpoint_name = unique_name.generate( "__auto_checkpoint_executor__") + # NOTE: Whether to use experimental executor `StandaloneExecutor`. + self._enable_interpreter_core = _is_enable_standalone_executor() + self._executor_cache = _ExecutorCache(self.place) + def _get_scope_cache(self, program_cache_key): return self.scope_caches.get(program_cache_key, None) @@ -1155,6 +1285,12 @@ class Executor(object): if scope is None: scope = global_scope() + # NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `, + # use StandaloneExecutor to run the program. + if self._enable_interpreter_core and not program._is_start_up_program_: + return self._executor_cache.run(program, feed, fetch_list, + return_numpy) + # use_prune can be overrided by putting optimize_ops in fetch_list _origin_fetch_list = fetch_list _origin_program = program diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 04666470ea..4089e4f615 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -4381,6 +4381,8 @@ class Program(object): # compiled program, i.e. Graph self._graph = None + # to tag whether is startup_program + self._is_start_up_program_ = False def _find_var_class_kwargs(self, new_desc): # NOTE: not all variables support shape/dtype/lod_level methods. @@ -5994,6 +5996,7 @@ class ParamBase(core.VarBase): # program is a global instance. _main_program_ = Program() _startup_program_ = Program() +_startup_program_._is_start_up_program_ = True def default_startup_program(): @@ -6142,6 +6145,8 @@ def program_guard(main_program, startup_program=None): if startup_program is not None: check_type(startup_program, 'startup_program', Program, 'paddle.static.program_guard') + # Tag the program __is_start_up as True + startup_program._is_start_up_program_ = True startup_program = switch_startup_program(startup_program) try: yield diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index 1f971ae1b2..da335a88e3 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import sys import unittest import paddle @@ -92,52 +93,53 @@ class LinearTestCase(unittest.TestCase): self.assertGreaterEqual(cost_info.device_total_memory_bytes(), 0) -class MultiStreamModelTestCase(unittest.TestCase): - def setUp(self): - self.iter_n = 2 - self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( - ) else paddle.CPUPlace() +def build_program(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() - def build_program(self): - main_program = paddle.static.Program() - startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + with paddle.static.device_guard('cpu'): + data = paddle.ones([4, 64], dtype='float32', name='data') - with paddle.static.program_guard(main_program, startup_program): - with paddle.static.device_guard('cpu'): - data = paddle.ones([4, 64], dtype='float32', name='data') + # data -> [memcpy_h2d] -> data' -> [matmul] -> out ->[add] -> add_out + with paddle.static.device_guard('gpu'): + weight = paddle.randn([64, 64], name='weight') # gpu + matmul_out = paddle.matmul(data, weight, name='matmul_out') # gpus + bias = paddle.ones([4, 64], dtype='float32', name='bias') + add_out = paddle.add(matmul_out, bias, name='add_out') + + # add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [tanh] -> tanh_out + with paddle.static.device_guard('cpu'): + sub_out = paddle.subtract(add_out, data, name='sub_out') + tanh_out = paddle.tanh(sub_out, name='tanh_out') - # data -> [memcpy_h2d] -> data' -> [matmul] -> out ->[add] -> add_out - with paddle.static.device_guard('gpu'): - weight = paddle.randn([64, 64], name='weight') # gpu - matmul_out = paddle.matmul( - data, weight, name='matmul_out') # gpus - bias = paddle.ones([4, 64], dtype='float32', name='bias') - add_out = paddle.add(matmul_out, bias, name='add_out') + with paddle.static.device_guard('gpu'): + bias_1 = paddle.add(bias, sub_out, name='bias_1') + out_before = paddle.tanh(bias_1, name='out_before') + out_last = paddle.subtract(tanh_out, data, name='out_last') - # add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [tanh] -> tanh_out - with paddle.static.device_guard('cpu'): - sub_out = paddle.subtract(add_out, data, name='sub_out') - tanh_out = paddle.tanh(sub_out, name='tanh_out') + out = paddle.add(out_before, out_last, name='out') + mean = paddle.mean(out, name='mean_out') - with paddle.static.device_guard('gpu'): - bias_1 = paddle.add(bias, sub_out, name='bias_1') - out_before = paddle.tanh(bias_1, name='out_before') - out_last = paddle.subtract(tanh_out, data, name='out_last') + return main_program, startup_program, [mean] - out = paddle.add(out_before, out_last, name='out') - mean = paddle.mean(out, name='mean_out') - return main_program, startup_program, [mean] +class MultiStreamModelTestCase(unittest.TestCase): + def setUp(self): + self.iter_n = 2 + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() - def test_multi_stream(self): + def test_result(self): ground_truths = self.run_raw_executor() res = self.run_new_executor() + for gt, out in zip(ground_truths, res): self.assertEqual(gt[0], out[0]) def run_raw_executor(self): paddle.seed(2020) - main_program, startup_program, fetch_list = self.build_program() + main_program, startup_program, fetch_list = build_program() exe = paddle.static.Executor(self.place) exe.run(startup_program) @@ -145,11 +147,12 @@ class MultiStreamModelTestCase(unittest.TestCase): outs = [] for i in range(self.iter_n): outs.append(exe.run(main_program, fetch_list=fetch_list)) + return outs def run_new_executor(self): paddle.seed(2020) - main_program, startup_program, fetch_list = self.build_program() + main_program, startup_program, fetch_list = build_program() fetch_list = [x.name for x in fetch_list] p = core.Place() @@ -163,5 +166,97 @@ class MultiStreamModelTestCase(unittest.TestCase): return outs +class SwitchExecutorInterfaceTestCase(MultiStreamModelTestCase): + def run_new_executor(self): + paddle.seed(2020) + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' + main_program, startup_program, fetch_list = build_program() + exe = paddle.static.Executor(self.place) + exe.run(startup_program) + + outs = [] + for i in range(self.iter_n): + outs.append(exe.run(main_program, fetch_list=fetch_list)) + + del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + + return outs + + +class SwitchExecutorInterfaceWithFeed(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + self.iter_run = 2 + + def build_program(self, is_double=False): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + a = paddle.static.data(name="a", shape=[2, 2], dtype='float32') + b = paddle.ones([2, 2]) * 2 + t = paddle.static.nn.fc(a, 2) + c = t + b + if is_double: + c = c + c + + return main_program, startup_program, [c] + + def _run(self, feed, use_str=False, is_double=False, add_wrong_fetch=False): + paddle.seed(2020) + + main_program, startup_program, fetch_vars = self.build_program( + is_double) + + exe = paddle.static.Executor(self.place) + exe.run(startup_program) + + if use_str: # test for fetch name + fetch_vars = [x.name for x in fetch_vars] + if add_wrong_fetch: # test for wrong fetch type + fetch_vars.append(1123) + outs = [] + for i in range(self.iter_run): + out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)[0] + + outs.append(out) + + return outs + + def run_raw_executor(self, feed): + # run construct program 1 + out1 = self._run(feed, use_str=False, is_double=False) + # run construct program 2 with same executor + out2 = self._run(feed, use_str=True, is_double=True) + + return [out1, out2] + + def run_new_executor(self, feed): + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' + out = self.run_raw_executor(feed) + del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + return out + + def test_with_feed(self): + data = np.ones([2, 2], dtype="float32") + feed = {"a": data, 'fake_input': data} + + res = self.run_new_executor(feed) + gt = self.run_raw_executor(feed) + for x, y in zip(gt, res): + self.assertTrue(np.array_equal(x, y)) + + def test_with_error(self): + feed = [{'a': np.ones([2, 2], dtype="float32")}] + + with self.assertRaises(TypeError): + res = self.run_new_executor(feed) + + with self.assertRaises(TypeError): + os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' + self._run(feed[0], add_wrong_fetch=True) + del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + + if __name__ == "__main__": unittest.main() -- GitLab