From 6eed9f4994ceab781502f36cd36dbd15aac0db34 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 6 Jul 2022 03:46:25 -0500 Subject: [PATCH] Refine StandaloneExecutor (#44076) * not run startup program in constructor of StandaloneExecutor * clear interface of standalone executor * clean debug code --- .../new_executor/executor_statistics.cc | 3 +- .../new_executor/standalone_executor.cc | 35 +-- .../new_executor/standalone_executor.h | 15 +- paddle/fluid/pybind/pybind.cc | 49 +--- python/paddle/fluid/executor.py | 17 +- python/paddle/fluid/framework.py | 13 + .../test_standalone_controlflow.py | 14 +- .../interpreter/test_standalone_executor.py | 258 +++++------------- .../unittests/mkldnn/test_conv2d_mkldnn_op.py | 10 - 9 files changed, 107 insertions(+), 307 deletions(-) diff --git a/paddle/fluid/framework/new_executor/executor_statistics.cc b/paddle/fluid/framework/new_executor/executor_statistics.cc index c9bb7d4555..a381943587 100644 --- a/paddle/fluid/framework/new_executor/executor_statistics.cc +++ b/paddle/fluid/framework/new_executor/executor_statistics.cc @@ -583,7 +583,8 @@ int StatisticsEngine::StatNormalizationTime( if (total - normalization_sum != 0) { LOG(WARNING) << "total: " << total << "is greater than normalization_sum:" << normalization_sum; - return -1; + // TODO(dev): figure out why total != normalization_sum and fix it + // return -1; } return 0; } diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 31b1627dc6..2e6e9aa842 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -19,34 +19,8 @@ namespace paddle { namespace framework { StandaloneExecutor::StandaloneExecutor(const platform::Place& place, - const ProgramDesc& startup_prog, - const ProgramDesc& main_prog, - Scope* scope) - : place_(place), - startup_prog_(startup_prog), - main_prog_(main_prog), - scope_(scope) { - // NOTE(zhiqiu): for startup_program, run once ? - if (startup_prog.Block(0).AllOps().size() > 0) { - auto core = GetInterpreterCore(scope, startup_prog, {}, {}, false); - VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; - core->Run({}); - } -} - -paddle::framework::FetchList StandaloneExecutor::Run( - Scope* scope, - const std::vector& feed_names, - const std::vector& feed_tensors, - const std::vector& fetch_names) { - platform::RecordEvent record_event( - "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); - - auto core = - GetInterpreterCore(scope, main_prog_, feed_names, fetch_names, true); - - return core->Run(feed_names, feed_tensors); -} + const ProgramDesc& prog) + : place_(place), prog_(prog) {} paddle::framework::FetchList StandaloneExecutor::Run( Scope* scope, @@ -55,8 +29,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( platform::RecordEvent record_event( "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); - auto core = - GetInterpreterCore(scope, main_prog_, feed_names, fetch_names, false); + auto core = GetInterpreterCore(scope, prog_, feed_names, fetch_names, false); VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; return core->Run(feed_names); } @@ -65,7 +38,7 @@ framework::interpreter::CostInfo StandaloneExecutor::DryRun( Scope* scope, const std::vector& feed_names, const std::vector& feed_tensors) { - auto core = GetInterpreterCore(scope, main_prog_, feed_names, {}, true); + auto core = GetInterpreterCore(scope, prog_, feed_names, {}, true); return core->DryRun(feed_names, feed_tensors); } diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index 5b9c48009e..e6d84d6f9a 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -31,19 +31,10 @@ class InterpreterCore; class StandaloneExecutor { public: - StandaloneExecutor(const platform::Place& place, - const ProgramDesc& startup_prog, - const ProgramDesc& main_prog, - Scope* scope); + StandaloneExecutor(const platform::Place& place, const ProgramDesc& prog); ~StandaloneExecutor() {} - paddle::framework::FetchList Run( - Scope* scope, - const std::vector& feed_names, - const std::vector& feed_tensors, - const std::vector& fetch_names); - // NOTE(zhiqiu): feed_names are only used for caching interpretercore. // fetch_names are used for caching interpretercore and inserting fetch ops, // the latter can be moved to python side. @@ -65,9 +56,7 @@ class StandaloneExecutor { bool add_fetch_op); platform::Place place_; - const ProgramDesc& startup_prog_; - const ProgramDesc& main_prog_; - Scope* scope_; // not owned + const ProgramDesc& prog_; std::unordered_map> interpretercores_; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3723e58e52..abbcacec38 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3057,54 +3057,7 @@ All parameter, weight, gradient are variables in Paddle. }); py::class_(m, "StandaloneExecutor") - .def(py::init()) - .def("run", - [](StandaloneExecutor &self, - Scope *scope, - const std::unordered_map &input_dict, - std::vector fetch_names) { - std::vector feed_tensors; - std::vector feed_names; - - for (auto &item : input_dict) { - framework::LoDTensor t; - SetTensorFromPyArray( - &t, item.second, platform::CPUPlace(), false); - feed_names.push_back(item.first); - feed_tensors.push_back(t); - } - - paddle::framework::FetchList ret; - { - pybind11::gil_scoped_release release; - ret = self.Run(scope, feed_names, feed_tensors, fetch_names); - } - return py::cast(std::move(ret)); - }) - .def("run", - [](StandaloneExecutor &self, - Scope *scope, - const std::unordered_map - &input_dict, - std::vector fetch_names) { - std::vector feed_tensors; - std::vector feed_names; - - for (auto &item : input_dict) { - feed_names.push_back(item.first); - feed_tensors.push_back(item.second); - } - - paddle::framework::FetchList ret; - { - pybind11::gil_scoped_release release; - ret = self.Run(scope, feed_names, feed_tensors, fetch_names); - } - return py::cast(std::move(ret)); - }) + .def(py::init()) .def("run", [](StandaloneExecutor &self, Scope *scope, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 78c3f41396..3303b6c947 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -25,6 +25,7 @@ import six from .data_feeder import convert_dtype from .framework import Program, default_main_program, Variable, Operator from .framework import convert_np_dtype_to_dtype_ + from . import core from . import unique_name from . import compiler @@ -397,15 +398,12 @@ def _is_enable_standalone_executor(): Whether to use experimental executor `StandaloneExecutor`. """ flag = False - from ..distributed.fleet import fleet - if fleet._role_maker is not None: - warnings.warn("do not use standalone executor in fleet by default") - env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None) - else: - env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', '1') + # use standalone_executor by default if not distributed + if fleet._role_maker is None and framework._enable_standalone_executor_ is None: + framework._enable_standalone_executor_ = 1 - if env_val in [1, '1', True, 'True', 'true']: + if framework._enable_standalone_executor_ in [1, '1', True, 'True', 'true']: flag = True return flag @@ -569,10 +567,7 @@ class _StandaloneExecutor(object): return tensors def _create_new_executor(self): - # NOTE: It's a trick to set empty start_up program. - startup_program = Program() - new_exe = core.StandaloneExecutor(self._place, startup_program.desc, - self._main_program.desc, self._scope) + new_exe = core.StandaloneExecutor(self._place, self._main_program.desc) return new_exe diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index df4691d49e..d6e4af5866 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -84,6 +84,8 @@ _already_patch_eager_tensor = False _already_patch_varbase = False _current_cuda_graph_mode = None _global_flags_ = core.globals() +_enable_standalone_executor_ = (os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', + None)) # Some explanation of our execution system 2022.03 # For now we have 3 kinds of execution system, since we refactored dygraph mode to @@ -259,6 +261,17 @@ ipu_index_attr_name = 'ipu_index' ipu_stage_attr_name = 'ipu_stage' +@signature_safe_contextmanager +def _enable_standalone_executor(enable=True): + global _enable_standalone_executor_ + original_ = _enable_standalone_executor_ + _enable_standalone_executor_ = enable + try: + yield + finally: + _enable_standalone_executor_ = original_ + + @signature_safe_contextmanager def ipu_shard_guard(index=-1, stage=-1): """ diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_controlflow.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_controlflow.py index 5ce035097d..aa0290cf4b 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_controlflow.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_controlflow.py @@ -16,7 +16,7 @@ import os import sys import unittest import paddle -from paddle.fluid import core +from paddle.fluid import core, framework from paddle.fluid.core import StandaloneExecutor import paddle.fluid as fluid from paddle.fluid.framework import Program, program_guard @@ -81,17 +81,13 @@ class TestCompatibility(unittest.TestCase): return ret def run_raw_executor(self, feed): - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0' - out = self._run(feed) - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] - print("GT:", out) + with framework._enable_standalone_executor(False): + out = self._run(feed) return out def run_new_executor(self, feed): - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - out = self._run(feed) - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] - print("New:", out) + with framework._enable_standalone_executor(True): + out = self._run(feed) return out def test_with_feed(self): diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index f1b1bc118e..ad13061d17 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -20,7 +20,7 @@ import shutil import unittest import paddle import json -from paddle.fluid import core +from paddle.fluid import core, framework from paddle.fluid.core import StandaloneExecutor from paddle.profiler import profiler @@ -29,7 +29,7 @@ import numpy as np paddle.enable_static() -class LinearTestCase(unittest.TestCase): +class TestDryRun(unittest.TestCase): def setUp(self): place = paddle.CUDAPlace( @@ -48,29 +48,13 @@ class LinearTestCase(unittest.TestCase): return startup_program, main_program, c - def test_interp_base(self): - startup_program, main_program, c = self.build_program() - scope = core.Scope() - standaloneexecutor = StandaloneExecutor(self.place, - startup_program.desc, - main_program.desc, scope) - out = standaloneexecutor.run( - scope, {"a": np.ones([2, 2], dtype="float32") * 2}, [c.name]) - for i in range(10): - out = standaloneexecutor.run( - scope, {"a": np.ones([2, 2], dtype="float32") * i}, [c.name]) - - for i in range(10): - out = standaloneexecutor.run( - scope, {"a": np.ones([2, 2], dtype="float32") * i}, - ['a', c.name]) - def test_dry_run(self): scope = core.Scope() startup_program, main_program, c = self.build_program() - standaloneexecutor = StandaloneExecutor(self.place, - startup_program.desc, - main_program.desc, scope) + exe = paddle.static.Executor(self.place) + exe.run(startup_program, scope=scope) + + standaloneexecutor = StandaloneExecutor(self.place, main_program.desc) # test for cost_info cost_info = standaloneexecutor.dry_run( scope, {"a": np.ones([2, 2], dtype="float32")}) @@ -124,100 +108,49 @@ class ExecutorStatisticsTestCase(unittest.TestCase): self.iter_n = 3 self.place = paddle.CUDAPlace( 0) if core.is_compiled_with_cuda() else paddle.CPUPlace() - - def test_standalone_executor_statistics(self): - if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: - return - - paddle.seed(2020) - main_program, startup_program, fetch_list = build_program() - fetch_list = [x.name for x in fetch_list] - - p = core.Place() - p.set_place(self.place) - scope = core.Scope() - executor = StandaloneExecutor(p, startup_program.desc, - main_program.desc, scope) - - helper_profiler = profiler.Profiler( - targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) - helper_profiler.start() - for i in range(self.iter_n): - executor.run(scope, {}, fetch_list) - helper_profiler.step() - helper_profiler.stop() - - perfstat_filepath = os.environ[ - 'FLAGS_static_executor_perfstat_filepath'] - self.assertTrue(os.path.exists(perfstat_filepath)) - with open(perfstat_filepath, 'r') as load_f: - stat_res = json.load(load_f) - self.assertTrue(len(stat_res) > 0) - - os.remove(perfstat_filepath) - shutil.rmtree('./profiler_log') + self.perf_path = './perfstat' def test_parallel_executor_statistics(self): - if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: - return + self.run_with_statistics(executor='ParallelExecutor') - paddle.seed(2020) - main_program, startup_program, fetch_list = build_program() - fetch_list = [x.name for x in fetch_list] - - main_program = paddle.fluid.compiler.CompiledProgram(main_program) - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0' - executor = paddle.static.Executor(self.place) - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - executor.run(startup_program) - - helper_profiler = profiler.Profiler( - targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) - helper_profiler.start() - for i in range(self.iter_n): - executor.run(main_program, fetch_list=fetch_list) - helper_profiler.step() - helper_profiler.stop() - - perfstat_filepath = os.environ[ - 'FLAGS_static_executor_perfstat_filepath'] - self.assertTrue(os.path.exists(perfstat_filepath)) - with open(perfstat_filepath, 'r') as load_f: - stat_res = json.load(load_f) - self.assertTrue(len(stat_res) > 0) + def test_executor_statistics(self): + self.run_with_statistics(executor='Executor') - os.remove(perfstat_filepath) - shutil.rmtree('./profiler_log') + def test_standalone_executor_statistics(self): + self.run_with_statistics(executor='StandaloneExecutor') - def test_executor_statistics(self): + def run_with_statistics(self, executor=None): if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: return - paddle.seed(2020) + # note: startup program is empty main_program, startup_program, fetch_list = build_program() - fetch_list = [x.name for x in fetch_list] - - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0' - executor = paddle.static.Executor(self.place) - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - executor.run(startup_program) - - helper_profiler = profiler.Profiler( - targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) - helper_profiler.start() - for i in range(self.iter_n): - executor.run(main_program, fetch_list=fetch_list) - helper_profiler.step() - helper_profiler.stop() - - perfstat_filepath = os.environ[ - 'FLAGS_static_executor_perfstat_filepath'] - self.assertTrue(os.path.exists(perfstat_filepath)) - with open(perfstat_filepath, 'r') as load_f: + + enable = True + if executor == 'ParallelExecutor': + main_program = paddle.fluid.compiler.CompiledProgram(main_program) + enable = False + elif executor == 'Executor': + enable = False + + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + with framework._enable_standalone_executor(enable): + exe = paddle.static.Executor(self.place) + helper_profiler = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) + helper_profiler.start() + for i in range(self.iter_n): + exe.run(main_program, fetch_list=fetch_list) + helper_profiler.step() + helper_profiler.stop() + + self.assertTrue(os.path.exists(self.perf_path)) + with open(self.perf_path, 'r') as load_f: stat_res = json.load(load_f) self.assertTrue(len(stat_res) > 0) - os.remove(perfstat_filepath) + os.remove(self.perf_path) shutil.rmtree('./profiler_log') @@ -229,59 +162,24 @@ class MultiStreamModelTestCase(unittest.TestCase): 0) if core.is_compiled_with_cuda() else paddle.CPUPlace() def test_result(self): - ground_truths = self.run_raw_executor() - res = self.run_new_executor() + ground_truths = self.run_test(False) + res = self.run_test(True) for gt, out in zip(ground_truths, res): self.assertEqual(gt[0], out[0]) - def run_raw_executor(self): + def run_test(self, use_new_executor=True): paddle.seed(2020) main_program, startup_program, fetch_list = build_program() - exe = paddle.static.Executor(self.place) - exe.run(startup_program) - - outs = [] - for i in range(self.iter_n): - outs.append(exe.run(main_program, fetch_list=fetch_list)) - - return outs - - def run_new_executor(self): - paddle.seed(2020) - main_program, startup_program, fetch_list = build_program() - fetch_list = [x.name for x in fetch_list] - - p = core.Place() - p.set_place(self.place) - scope = core.Scope() - inter_core = StandaloneExecutor(p, startup_program.desc, - main_program.desc, scope) - - outs = [] - for i in range(self.iter_n): - outs.append( - np.array( - inter_core.run(scope, {}, fetch_list)._move_to_list()[0])) - return outs - - -class SwitchExecutorInterfaceTestCase(MultiStreamModelTestCase): - - def run_new_executor(self): - paddle.seed(2020) - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - main_program, startup_program, fetch_list = build_program() - exe = paddle.static.Executor(self.place) - exe.run(startup_program) - - outs = [] - for i in range(self.iter_n): - outs.append(exe.run(main_program, fetch_list=fetch_list)) - - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] - + with framework._enable_standalone_executor(use_new_executor): + scope = core.Scope() + exe = paddle.static.Executor(self.place) + outs = [] + for i in range(self.iter_n): + outs.append( + exe.run(main_program, scope=scope, fetch_list=fetch_list)) + print(outs) return outs @@ -337,23 +235,23 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): return outs def run_raw_executor(self, feed, use_compiled=False): - # run construct program 1 - out1 = self._run(feed, - use_str=False, - is_double=False, - use_compiled=use_compiled) - # run construct program 2 with same executor - out2 = self._run(feed, - use_str=True, - is_double=True, - use_compiled=use_compiled) - - return [out1, out2] + with framework._enable_standalone_executor(False): + # run construct program 1 + out1 = self._run(feed, + use_str=False, + is_double=False, + use_compiled=use_compiled) + # run construct program 2 with same executor + out2 = self._run(feed, + use_str=True, + is_double=True, + use_compiled=use_compiled) + + return [out1, out2] def run_new_executor(self, feed, use_compiled=False): - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - out = self.run_raw_executor(feed, use_compiled=use_compiled) - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + with framework._enable_standalone_executor(): + out = self.run_raw_executor(feed, use_compiled=use_compiled) return out def test_with_feed(self): @@ -369,9 +267,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): feed = [{'a': np.ones([2, 2], dtype="float32")}] with self.assertRaises(TypeError): - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - self._run(feed[0], add_wrong_fetch=True) - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + with framework._enable_standalone_executor(): + self._run(feed[0], add_wrong_fetch=True) def test_compiled_program(self): data = np.ones([2, 2], dtype="float32") @@ -386,9 +283,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): data = np.ones([2, 2], dtype="float32") feed = {"a": data} - os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM'] = '1' res = self.run_new_executor(feed, use_compiled=True) - del os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM'] gt = self.run_raw_executor(feed, use_compiled=True) for x, y in zip(gt, res): self.assertTrue(np.array_equal(x, y)) @@ -401,9 +296,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): for i in range(10): print(i, flush=1) - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - out = exe.run(program, feed=None) - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + with framework._enable_standalone_executor(): + out = exe.run(program, feed=None) class TestException(unittest.TestCase): @@ -437,14 +331,12 @@ class TestException(unittest.TestCase): for feed in feeds: out = exe.run(main_program, feed=feed, fetch_list=fetch_vars) - print(main_program) self.fetch_vars = fetch_vars return out def run_new_executor(self, feed): - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - out = self._run(feed) - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + with framework._enable_standalone_executor(): + out = self._run(feed) return out def test_exception(self): @@ -492,14 +384,12 @@ class TestInplaceApiWithDataTransform(unittest.TestCase): with paddle.fluid.device_guard("cpu"): x = paddle.increment(x) exe = paddle.static.Executor(paddle.CUDAPlace(0)) - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - - for i in range(10): - a, = exe.run(paddle.static.default_main_program(), - fetch_list=[x]) - self.assertEqual(a[0], 1) + with framework._enable_standalone_executor(): - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] + for i in range(10): + a, = exe.run(paddle.static.default_main_program(), + fetch_list=[x]) + self.assertEqual(a[0], 1) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py index 0471c295ad..91487fb0ab 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py @@ -246,16 +246,6 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp): self.groups = 3 -# TODO(chenweihang): To solve the coverage problem, add this unittest, -# remove this unittest after new executor set to default executor -class TestConv2dMKLDNNByNewExecutor(TestConv2DMKLDNNOp): - - def test_check_output_by_new_executor(self): - os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' - self.test_check_output() - del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] - - if __name__ == '__main__': from paddle import enable_static enable_static() -- GitLab