未验证 提交 6eed9f49 编写于 作者: L Leo Chen 提交者: GitHub

Refine StandaloneExecutor (#44076)

* not run startup program in constructor of StandaloneExecutor

* clear interface of standalone executor

* clean debug code
上级 81fd2fff
...@@ -583,7 +583,8 @@ int StatisticsEngine::StatNormalizationTime( ...@@ -583,7 +583,8 @@ int StatisticsEngine::StatNormalizationTime(
if (total - normalization_sum != 0) { if (total - normalization_sum != 0) {
LOG(WARNING) << "total: " << total LOG(WARNING) << "total: " << total
<< "is greater than normalization_sum:" << normalization_sum; << "is greater than normalization_sum:" << normalization_sum;
return -1; // TODO(dev): figure out why total != normalization_sum and fix it
// return -1;
} }
return 0; return 0;
} }
......
...@@ -19,34 +19,8 @@ ...@@ -19,34 +19,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
StandaloneExecutor::StandaloneExecutor(const platform::Place& place, StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
const ProgramDesc& startup_prog, const ProgramDesc& prog)
const ProgramDesc& main_prog, : place_(place), prog_(prog) {}
Scope* scope)
: place_(place),
startup_prog_(startup_prog),
main_prog_(main_prog),
scope_(scope) {
// NOTE(zhiqiu): for startup_program, run once ?
if (startup_prog.Block(0).AllOps().size() > 0) {
auto core = GetInterpreterCore(scope, startup_prog, {}, {}, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
core->Run({});
}
}
paddle::framework::FetchList StandaloneExecutor::Run(
Scope* scope,
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names) {
platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core =
GetInterpreterCore(scope, main_prog_, feed_names, fetch_names, true);
return core->Run(feed_names, feed_tensors);
}
paddle::framework::FetchList StandaloneExecutor::Run( paddle::framework::FetchList StandaloneExecutor::Run(
Scope* scope, Scope* scope,
...@@ -55,8 +29,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -55,8 +29,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
platform::RecordEvent record_event( platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core = auto core = GetInterpreterCore(scope, prog_, feed_names, fetch_names, false);
GetInterpreterCore(scope, main_prog_, feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run(feed_names); return core->Run(feed_names);
} }
...@@ -65,7 +38,7 @@ framework::interpreter::CostInfo StandaloneExecutor::DryRun( ...@@ -65,7 +38,7 @@ framework::interpreter::CostInfo StandaloneExecutor::DryRun(
Scope* scope, Scope* scope,
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
auto core = GetInterpreterCore(scope, main_prog_, feed_names, {}, true); auto core = GetInterpreterCore(scope, prog_, feed_names, {}, true);
return core->DryRun(feed_names, feed_tensors); return core->DryRun(feed_names, feed_tensors);
} }
......
...@@ -31,19 +31,10 @@ class InterpreterCore; ...@@ -31,19 +31,10 @@ class InterpreterCore;
class StandaloneExecutor { class StandaloneExecutor {
public: public:
StandaloneExecutor(const platform::Place& place, StandaloneExecutor(const platform::Place& place, const ProgramDesc& prog);
const ProgramDesc& startup_prog,
const ProgramDesc& main_prog,
Scope* scope);
~StandaloneExecutor() {} ~StandaloneExecutor() {}
paddle::framework::FetchList Run(
Scope* scope,
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names);
// NOTE(zhiqiu): feed_names are only used for caching interpretercore. // NOTE(zhiqiu): feed_names are only used for caching interpretercore.
// fetch_names are used for caching interpretercore and inserting fetch ops, // fetch_names are used for caching interpretercore and inserting fetch ops,
// the latter can be moved to python side. // the latter can be moved to python side.
...@@ -65,9 +56,7 @@ class StandaloneExecutor { ...@@ -65,9 +56,7 @@ class StandaloneExecutor {
bool add_fetch_op); bool add_fetch_op);
platform::Place place_; platform::Place place_;
const ProgramDesc& startup_prog_; const ProgramDesc& prog_;
const ProgramDesc& main_prog_;
Scope* scope_; // not owned
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>> std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_; interpretercores_;
......
...@@ -3057,54 +3057,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3057,54 +3057,7 @@ All parameter, weight, gradient are variables in Paddle.
}); });
py::class_<framework::StandaloneExecutor>(m, "StandaloneExecutor") py::class_<framework::StandaloneExecutor>(m, "StandaloneExecutor")
.def(py::init<const platform::Place &, .def(py::init<const platform::Place &, const ProgramDesc &>())
const ProgramDesc &,
const ProgramDesc &,
Scope *>())
.def("run",
[](StandaloneExecutor &self,
Scope *scope,
const std::unordered_map<std::string, py::array> &input_dict,
std::vector<std::string> fetch_names) {
std::vector<framework::LoDTensor> feed_tensors;
std::vector<std::string> feed_names;
for (auto &item : input_dict) {
framework::LoDTensor t;
SetTensorFromPyArray<platform::CPUPlace>(
&t, item.second, platform::CPUPlace(), false);
feed_names.push_back(item.first);
feed_tensors.push_back(t);
}
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(scope, feed_names, feed_tensors, fetch_names);
}
return py::cast(std::move(ret));
})
.def("run",
[](StandaloneExecutor &self,
Scope *scope,
const std::unordered_map<std::string, framework::LoDTensor>
&input_dict,
std::vector<std::string> fetch_names) {
std::vector<framework::LoDTensor> feed_tensors;
std::vector<std::string> feed_names;
for (auto &item : input_dict) {
feed_names.push_back(item.first);
feed_tensors.push_back(item.second);
}
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(scope, feed_names, feed_tensors, fetch_names);
}
return py::cast(std::move(ret));
})
.def("run", .def("run",
[](StandaloneExecutor &self, [](StandaloneExecutor &self,
Scope *scope, Scope *scope,
......
...@@ -25,6 +25,7 @@ import six ...@@ -25,6 +25,7 @@ import six
from .data_feeder import convert_dtype from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, Operator from .framework import Program, default_main_program, Variable, Operator
from .framework import convert_np_dtype_to_dtype_ from .framework import convert_np_dtype_to_dtype_
from . import core from . import core
from . import unique_name from . import unique_name
from . import compiler from . import compiler
...@@ -397,15 +398,12 @@ def _is_enable_standalone_executor(): ...@@ -397,15 +398,12 @@ def _is_enable_standalone_executor():
Whether to use experimental executor `StandaloneExecutor`. Whether to use experimental executor `StandaloneExecutor`.
""" """
flag = False flag = False
from ..distributed.fleet import fleet from ..distributed.fleet import fleet
if fleet._role_maker is not None: # use standalone_executor by default if not distributed
warnings.warn("do not use standalone executor in fleet by default") if fleet._role_maker is None and framework._enable_standalone_executor_ is None:
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None) framework._enable_standalone_executor_ = 1
else:
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', '1')
if env_val in [1, '1', True, 'True', 'true']: if framework._enable_standalone_executor_ in [1, '1', True, 'True', 'true']:
flag = True flag = True
return flag return flag
...@@ -569,10 +567,7 @@ class _StandaloneExecutor(object): ...@@ -569,10 +567,7 @@ class _StandaloneExecutor(object):
return tensors return tensors
def _create_new_executor(self): def _create_new_executor(self):
# NOTE: It's a trick to set empty start_up program. new_exe = core.StandaloneExecutor(self._place, self._main_program.desc)
startup_program = Program()
new_exe = core.StandaloneExecutor(self._place, startup_program.desc,
self._main_program.desc, self._scope)
return new_exe return new_exe
......
...@@ -84,6 +84,8 @@ _already_patch_eager_tensor = False ...@@ -84,6 +84,8 @@ _already_patch_eager_tensor = False
_already_patch_varbase = False _already_patch_varbase = False
_current_cuda_graph_mode = None _current_cuda_graph_mode = None
_global_flags_ = core.globals() _global_flags_ = core.globals()
_enable_standalone_executor_ = (os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR',
None))
# Some explanation of our execution system 2022.03 # Some explanation of our execution system 2022.03
# For now we have 3 kinds of execution system, since we refactored dygraph mode to # For now we have 3 kinds of execution system, since we refactored dygraph mode to
...@@ -259,6 +261,17 @@ ipu_index_attr_name = 'ipu_index' ...@@ -259,6 +261,17 @@ ipu_index_attr_name = 'ipu_index'
ipu_stage_attr_name = 'ipu_stage' ipu_stage_attr_name = 'ipu_stage'
@signature_safe_contextmanager
def _enable_standalone_executor(enable=True):
global _enable_standalone_executor_
original_ = _enable_standalone_executor_
_enable_standalone_executor_ = enable
try:
yield
finally:
_enable_standalone_executor_ = original_
@signature_safe_contextmanager @signature_safe_contextmanager
def ipu_shard_guard(index=-1, stage=-1): def ipu_shard_guard(index=-1, stage=-1):
""" """
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import sys import sys
import unittest import unittest
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core, framework
from paddle.fluid.core import StandaloneExecutor from paddle.fluid.core import StandaloneExecutor
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
...@@ -81,17 +81,13 @@ class TestCompatibility(unittest.TestCase): ...@@ -81,17 +81,13 @@ class TestCompatibility(unittest.TestCase):
return ret return ret
def run_raw_executor(self, feed): def run_raw_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0' with framework._enable_standalone_executor(False):
out = self._run(feed) out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
print("GT:", out)
return out return out
def run_new_executor(self, feed): def run_new_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' with framework._enable_standalone_executor(True):
out = self._run(feed) out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
print("New:", out)
return out return out
def test_with_feed(self): def test_with_feed(self):
......
...@@ -20,7 +20,7 @@ import shutil ...@@ -20,7 +20,7 @@ import shutil
import unittest import unittest
import paddle import paddle
import json import json
from paddle.fluid import core from paddle.fluid import core, framework
from paddle.fluid.core import StandaloneExecutor from paddle.fluid.core import StandaloneExecutor
from paddle.profiler import profiler from paddle.profiler import profiler
...@@ -29,7 +29,7 @@ import numpy as np ...@@ -29,7 +29,7 @@ import numpy as np
paddle.enable_static() paddle.enable_static()
class LinearTestCase(unittest.TestCase): class TestDryRun(unittest.TestCase):
def setUp(self): def setUp(self):
place = paddle.CUDAPlace( place = paddle.CUDAPlace(
...@@ -48,29 +48,13 @@ class LinearTestCase(unittest.TestCase): ...@@ -48,29 +48,13 @@ class LinearTestCase(unittest.TestCase):
return startup_program, main_program, c return startup_program, main_program, c
def test_interp_base(self):
startup_program, main_program, c = self.build_program()
scope = core.Scope()
standaloneexecutor = StandaloneExecutor(self.place,
startup_program.desc,
main_program.desc, scope)
out = standaloneexecutor.run(
scope, {"a": np.ones([2, 2], dtype="float32") * 2}, [c.name])
for i in range(10):
out = standaloneexecutor.run(
scope, {"a": np.ones([2, 2], dtype="float32") * i}, [c.name])
for i in range(10):
out = standaloneexecutor.run(
scope, {"a": np.ones([2, 2], dtype="float32") * i},
['a', c.name])
def test_dry_run(self): def test_dry_run(self):
scope = core.Scope() scope = core.Scope()
startup_program, main_program, c = self.build_program() startup_program, main_program, c = self.build_program()
standaloneexecutor = StandaloneExecutor(self.place, exe = paddle.static.Executor(self.place)
startup_program.desc, exe.run(startup_program, scope=scope)
main_program.desc, scope)
standaloneexecutor = StandaloneExecutor(self.place, main_program.desc)
# test for cost_info # test for cost_info
cost_info = standaloneexecutor.dry_run( cost_info = standaloneexecutor.dry_run(
scope, {"a": np.ones([2, 2], dtype="float32")}) scope, {"a": np.ones([2, 2], dtype="float32")})
...@@ -124,100 +108,49 @@ class ExecutorStatisticsTestCase(unittest.TestCase): ...@@ -124,100 +108,49 @@ class ExecutorStatisticsTestCase(unittest.TestCase):
self.iter_n = 3 self.iter_n = 3
self.place = paddle.CUDAPlace( self.place = paddle.CUDAPlace(
0) if core.is_compiled_with_cuda() else paddle.CPUPlace() 0) if core.is_compiled_with_cuda() else paddle.CPUPlace()
self.perf_path = './perfstat'
def test_standalone_executor_statistics(self):
if os.getenv("FLAGS_static_executor_perfstat_filepath") is None:
return
paddle.seed(2020)
main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
p = core.Place()
p.set_place(self.place)
scope = core.Scope()
executor = StandaloneExecutor(p, startup_program.desc,
main_program.desc, scope)
helper_profiler = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2))
helper_profiler.start()
for i in range(self.iter_n):
executor.run(scope, {}, fetch_list)
helper_profiler.step()
helper_profiler.stop()
perfstat_filepath = os.environ[
'FLAGS_static_executor_perfstat_filepath']
self.assertTrue(os.path.exists(perfstat_filepath))
with open(perfstat_filepath, 'r') as load_f:
stat_res = json.load(load_f)
self.assertTrue(len(stat_res) > 0)
os.remove(perfstat_filepath)
shutil.rmtree('./profiler_log')
def test_parallel_executor_statistics(self): def test_parallel_executor_statistics(self):
if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: self.run_with_statistics(executor='ParallelExecutor')
return
paddle.seed(2020)
main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
main_program = paddle.fluid.compiler.CompiledProgram(main_program)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0'
executor = paddle.static.Executor(self.place)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
executor.run(startup_program)
helper_profiler = profiler.Profiler( def test_executor_statistics(self):
targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) self.run_with_statistics(executor='Executor')
helper_profiler.start()
for i in range(self.iter_n):
executor.run(main_program, fetch_list=fetch_list)
helper_profiler.step()
helper_profiler.stop()
perfstat_filepath = os.environ[
'FLAGS_static_executor_perfstat_filepath']
self.assertTrue(os.path.exists(perfstat_filepath))
with open(perfstat_filepath, 'r') as load_f:
stat_res = json.load(load_f)
self.assertTrue(len(stat_res) > 0)
os.remove(perfstat_filepath) def test_standalone_executor_statistics(self):
shutil.rmtree('./profiler_log') self.run_with_statistics(executor='StandaloneExecutor')
def test_executor_statistics(self): def run_with_statistics(self, executor=None):
if os.getenv("FLAGS_static_executor_perfstat_filepath") is None: if os.getenv("FLAGS_static_executor_perfstat_filepath") is None:
return return
paddle.seed(2020) paddle.seed(2020)
# note: startup program is empty
main_program, startup_program, fetch_list = build_program() main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0' enable = True
executor = paddle.static.Executor(self.place) if executor == 'ParallelExecutor':
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' main_program = paddle.fluid.compiler.CompiledProgram(main_program)
executor.run(startup_program) enable = False
elif executor == 'Executor':
enable = False
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
with framework._enable_standalone_executor(enable):
exe = paddle.static.Executor(self.place)
helper_profiler = profiler.Profiler( helper_profiler = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2)) targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2))
helper_profiler.start() helper_profiler.start()
for i in range(self.iter_n): for i in range(self.iter_n):
executor.run(main_program, fetch_list=fetch_list) exe.run(main_program, fetch_list=fetch_list)
helper_profiler.step() helper_profiler.step()
helper_profiler.stop() helper_profiler.stop()
perfstat_filepath = os.environ[ self.assertTrue(os.path.exists(self.perf_path))
'FLAGS_static_executor_perfstat_filepath'] with open(self.perf_path, 'r') as load_f:
self.assertTrue(os.path.exists(perfstat_filepath))
with open(perfstat_filepath, 'r') as load_f:
stat_res = json.load(load_f) stat_res = json.load(load_f)
self.assertTrue(len(stat_res) > 0) self.assertTrue(len(stat_res) > 0)
os.remove(perfstat_filepath) os.remove(self.perf_path)
shutil.rmtree('./profiler_log') shutil.rmtree('./profiler_log')
...@@ -229,59 +162,24 @@ class MultiStreamModelTestCase(unittest.TestCase): ...@@ -229,59 +162,24 @@ class MultiStreamModelTestCase(unittest.TestCase):
0) if core.is_compiled_with_cuda() else paddle.CPUPlace() 0) if core.is_compiled_with_cuda() else paddle.CPUPlace()
def test_result(self): def test_result(self):
ground_truths = self.run_raw_executor() ground_truths = self.run_test(False)
res = self.run_new_executor() res = self.run_test(True)
for gt, out in zip(ground_truths, res): for gt, out in zip(ground_truths, res):
self.assertEqual(gt[0], out[0]) self.assertEqual(gt[0], out[0])
def run_raw_executor(self): def run_test(self, use_new_executor=True):
paddle.seed(2020) paddle.seed(2020)
main_program, startup_program, fetch_list = build_program() main_program, startup_program, fetch_list = build_program()
exe = paddle.static.Executor(self.place) with framework._enable_standalone_executor(use_new_executor):
exe.run(startup_program)
outs = []
for i in range(self.iter_n):
outs.append(exe.run(main_program, fetch_list=fetch_list))
return outs
def run_new_executor(self):
paddle.seed(2020)
main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
p = core.Place()
p.set_place(self.place)
scope = core.Scope() scope = core.Scope()
inter_core = StandaloneExecutor(p, startup_program.desc,
main_program.desc, scope)
outs = []
for i in range(self.iter_n):
outs.append(
np.array(
inter_core.run(scope, {}, fetch_list)._move_to_list()[0]))
return outs
class SwitchExecutorInterfaceTestCase(MultiStreamModelTestCase):
def run_new_executor(self):
paddle.seed(2020)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
main_program, startup_program, fetch_list = build_program()
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(startup_program)
outs = [] outs = []
for i in range(self.iter_n): for i in range(self.iter_n):
outs.append(exe.run(main_program, fetch_list=fetch_list)) outs.append(
exe.run(main_program, scope=scope, fetch_list=fetch_list))
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] print(outs)
return outs return outs
...@@ -337,6 +235,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -337,6 +235,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
return outs return outs
def run_raw_executor(self, feed, use_compiled=False): def run_raw_executor(self, feed, use_compiled=False):
with framework._enable_standalone_executor(False):
# run construct program 1 # run construct program 1
out1 = self._run(feed, out1 = self._run(feed,
use_str=False, use_str=False,
...@@ -351,9 +250,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -351,9 +250,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
return [out1, out2] return [out1, out2]
def run_new_executor(self, feed, use_compiled=False): def run_new_executor(self, feed, use_compiled=False):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' with framework._enable_standalone_executor():
out = self.run_raw_executor(feed, use_compiled=use_compiled) out = self.run_raw_executor(feed, use_compiled=use_compiled)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
return out return out
def test_with_feed(self): def test_with_feed(self):
...@@ -369,9 +267,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -369,9 +267,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
feed = [{'a': np.ones([2, 2], dtype="float32")}] feed = [{'a': np.ones([2, 2], dtype="float32")}]
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' with framework._enable_standalone_executor():
self._run(feed[0], add_wrong_fetch=True) self._run(feed[0], add_wrong_fetch=True)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
def test_compiled_program(self): def test_compiled_program(self):
data = np.ones([2, 2], dtype="float32") data = np.ones([2, 2], dtype="float32")
...@@ -386,9 +283,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -386,9 +283,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
data = np.ones([2, 2], dtype="float32") data = np.ones([2, 2], dtype="float32")
feed = {"a": data} feed = {"a": data}
os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM'] = '1'
res = self.run_new_executor(feed, use_compiled=True) res = self.run_new_executor(feed, use_compiled=True)
del os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM']
gt = self.run_raw_executor(feed, use_compiled=True) gt = self.run_raw_executor(feed, use_compiled=True)
for x, y in zip(gt, res): for x, y in zip(gt, res):
self.assertTrue(np.array_equal(x, y)) self.assertTrue(np.array_equal(x, y))
...@@ -401,9 +296,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -401,9 +296,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
for i in range(10): for i in range(10):
print(i, flush=1) print(i, flush=1)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' with framework._enable_standalone_executor():
out = exe.run(program, feed=None) out = exe.run(program, feed=None)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
class TestException(unittest.TestCase): class TestException(unittest.TestCase):
...@@ -437,14 +331,12 @@ class TestException(unittest.TestCase): ...@@ -437,14 +331,12 @@ class TestException(unittest.TestCase):
for feed in feeds: for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars) out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(main_program)
self.fetch_vars = fetch_vars self.fetch_vars = fetch_vars
return out return out
def run_new_executor(self, feed): def run_new_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' with framework._enable_standalone_executor():
out = self._run(feed) out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
return out return out
def test_exception(self): def test_exception(self):
...@@ -492,15 +384,13 @@ class TestInplaceApiWithDataTransform(unittest.TestCase): ...@@ -492,15 +384,13 @@ class TestInplaceApiWithDataTransform(unittest.TestCase):
with paddle.fluid.device_guard("cpu"): with paddle.fluid.device_guard("cpu"):
x = paddle.increment(x) x = paddle.increment(x)
exe = paddle.static.Executor(paddle.CUDAPlace(0)) exe = paddle.static.Executor(paddle.CUDAPlace(0))
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' with framework._enable_standalone_executor():
for i in range(10): for i in range(10):
a, = exe.run(paddle.static.default_main_program(), a, = exe.run(paddle.static.default_main_program(),
fetch_list=[x]) fetch_list=[x])
self.assertEqual(a[0], 1) self.assertEqual(a[0], 1)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -246,16 +246,6 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp): ...@@ -246,16 +246,6 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp):
self.groups = 3 self.groups = 3
# TODO(chenweihang): To solve the coverage problem, add this unittest,
# remove this unittest after new executor set to default executor
class TestConv2dMKLDNNByNewExecutor(TestConv2DMKLDNNOp):
def test_check_output_by_new_executor(self):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
self.test_check_output()
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
if __name__ == '__main__': if __name__ == '__main__':
from paddle import enable_static from paddle import enable_static
enable_static() enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册