未验证 提交 2dfcdf21 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] skip compiled program with places > 1 (#37457)

* skip compiled program with places > 1

* fix corner case and add ut
上级 33653195
...@@ -1330,13 +1330,26 @@ class Executor(object): ...@@ -1330,13 +1330,26 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
def _can_use_interpreter_core(program, place):
compiled = isinstance(program, compiler.CompiledProgram)
# NOTE(zhiqiu): only single card compiled program is supported
if compiled:
if program._is_data_parallel and len(
program._get_places(place, program._places)) == 1:
return True
else:
return False
else:
assert isinstance(program, Program)
return True
# NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `, # NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `,
# use StandaloneExecutor to run the program. # use StandaloneExecutor to run the program.
if self._enable_interpreter_core: if self._enable_interpreter_core and _can_use_interpreter_core(
inner_program_ = program._program if isinstance( program, self.place):
inner_program = program._program if isinstance(
program, compiler.CompiledProgram) else program program, compiler.CompiledProgram) else program
assert isinstance(inner_program_, framework.Program) if not inner_program._is_start_up_program_:
if not inner_program_._is_start_up_program_:
if feed is None: if feed is None:
feed = {} feed = {}
elif isinstance(feed, (list, tuple)): elif isinstance(feed, (list, tuple)):
...@@ -1348,7 +1361,7 @@ class Executor(object): ...@@ -1348,7 +1361,7 @@ class Executor(object):
% (type(feed))) % (type(feed)))
feed = self._update_feed(program, feed) feed = self._update_feed(program, feed)
program = self._add_feed_fetch_ops( program = self._add_feed_fetch_ops(
program=inner_program_, program=inner_program,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
feed_var_name=feed_var_name, feed_var_name=feed_var_name,
......
...@@ -195,7 +195,12 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -195,7 +195,12 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
return main_program, startup_program, [c] return main_program, startup_program, [c]
def _run(self, feed, use_str=False, is_double=False, add_wrong_fetch=False): def _run(self,
feed,
use_str=False,
is_double=False,
add_wrong_fetch=False,
use_compiled=False):
paddle.seed(2020) paddle.seed(2020)
main_program, startup_program, fetch_vars = self.build_program( main_program, startup_program, fetch_vars = self.build_program(
...@@ -204,6 +209,11 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -204,6 +209,11 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(startup_program) exe.run(startup_program)
if use_compiled:
main_program = paddle.static.CompiledProgram(
main_program).with_data_parallel(
fetch_vars[0].name, places=[self.place])
if use_str: # test for fetch name if use_str: # test for fetch name
fetch_vars = [x.name for x in fetch_vars] fetch_vars = [x.name for x in fetch_vars]
if add_wrong_fetch: # test for wrong fetch type if add_wrong_fetch: # test for wrong fetch type
...@@ -216,17 +226,19 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -216,17 +226,19 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
return outs return outs
def run_raw_executor(self, feed): def run_raw_executor(self, feed, use_compiled=False):
# run construct program 1 # run construct program 1
out1 = self._run(feed, use_str=False, is_double=False) out1 = self._run(
feed, use_str=False, is_double=False, use_compiled=use_compiled)
# run construct program 2 with same executor # run construct program 2 with same executor
out2 = self._run(feed, use_str=True, is_double=True) out2 = self._run(
feed, use_str=True, is_double=True, use_compiled=use_compiled)
return [out1, out2] return [out1, out2]
def run_new_executor(self, feed): def run_new_executor(self, feed, use_compiled=False):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
out = self.run_raw_executor(feed) out = self.run_raw_executor(feed, use_compiled=use_compiled)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
return out return out
...@@ -247,6 +259,15 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -247,6 +259,15 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
self._run(feed[0], add_wrong_fetch=True) self._run(feed[0], add_wrong_fetch=True)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
def test_compiled_program(self):
data = np.ones([2, 2], dtype="float32")
feed = {"a": data, 'fake_input': data}
res = self.run_new_executor(feed, use_compiled=True)
gt = self.run_raw_executor(feed, use_compiled=True)
for x, y in zip(gt, res):
self.assertTrue(np.array_equal(x, y))
class TestException(unittest.TestCase): class TestException(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册