未验证 提交 6eed9f49 编写于 作者: L Leo Chen 提交者: GitHub

Refine StandaloneExecutor (#44076)

* not run startup program in constructor of StandaloneExecutor

* clear interface of standalone executor

* clean debug code
上级 81fd2fff
......@@ -583,7 +583,8 @@ int StatisticsEngine::StatNormalizationTime(
if (total - normalization_sum != 0) {
LOG(WARNING) << "total: " << total
<< "is greater than normalization_sum:" << normalization_sum;
return -1;
// TODO(dev): figure out why total != normalization_sum and fix it
// return -1;
}
return 0;
}
......
......@@ -19,34 +19,8 @@
namespace paddle {
namespace framework {
StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
const ProgramDesc& startup_prog,
const ProgramDesc& main_prog,
Scope* scope)
: place_(place),
startup_prog_(startup_prog),
main_prog_(main_prog),
scope_(scope) {
// NOTE(zhiqiu): for startup_program, run once ?
if (startup_prog.Block(0).AllOps().size() > 0) {
auto core = GetInterpreterCore(scope, startup_prog, {}, {}, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
core->Run({});
}
}
paddle::framework::FetchList StandaloneExecutor::Run(
Scope* scope,
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names) {
platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core =
GetInterpreterCore(scope, main_prog_, feed_names, fetch_names, true);
return core->Run(feed_names, feed_tensors);
}
const ProgramDesc& prog)
: place_(place), prog_(prog) {}
paddle::framework::FetchList StandaloneExecutor::Run(
Scope* scope,
......@@ -55,8 +29,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core =
GetInterpreterCore(scope, main_prog_, feed_names, fetch_names, false);
auto core = GetInterpreterCore(scope, prog_, feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run(feed_names);
}
......@@ -65,7 +38,7 @@ framework::interpreter::CostInfo StandaloneExecutor::DryRun(
Scope* scope,
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
auto core = GetInterpreterCore(scope, main_prog_, feed_names, {}, true);
auto core = GetInterpreterCore(scope, prog_, feed_names, {}, true);
return core->DryRun(feed_names, feed_tensors);
}
......
......@@ -31,19 +31,10 @@ class InterpreterCore;
class StandaloneExecutor {
public:
StandaloneExecutor(const platform::Place& place,
const ProgramDesc& startup_prog,
const ProgramDesc& main_prog,
Scope* scope);
StandaloneExecutor(const platform::Place& place, const ProgramDesc& prog);
~StandaloneExecutor() {}
paddle::framework::FetchList Run(
Scope* scope,
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names);
// NOTE(zhiqiu): feed_names are only used for caching interpretercore.
// fetch_names are used for caching interpretercore and inserting fetch ops,
// the latter can be moved to python side.
......@@ -65,9 +56,7 @@ class StandaloneExecutor {
bool add_fetch_op);
platform::Place place_;
const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_;
Scope* scope_; // not owned
const ProgramDesc& prog_;
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_;
......
......@@ -3057,54 +3057,7 @@ All parameter, weight, gradient are variables in Paddle.
});
py::class_<framework::StandaloneExecutor>(m, "StandaloneExecutor")
.def(py::init<const platform::Place &,
const ProgramDesc &,
const ProgramDesc &,
Scope *>())
.def("run",
[](StandaloneExecutor &self,
Scope *scope,
const std::unordered_map<std::string, py::array> &input_dict,
std::vector<std::string> fetch_names) {
std::vector<framework::LoDTensor> feed_tensors;
std::vector<std::string> feed_names;
for (auto &item : input_dict) {
framework::LoDTensor t;
SetTensorFromPyArray<platform::CPUPlace>(
&t, item.second, platform::CPUPlace(), false);
feed_names.push_back(item.first);
feed_tensors.push_back(t);
}
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(scope, feed_names, feed_tensors, fetch_names);
}
return py::cast(std::move(ret));
})
.def("run",
[](StandaloneExecutor &self,
Scope *scope,
const std::unordered_map<std::string, framework::LoDTensor>
&input_dict,
std::vector<std::string> fetch_names) {
std::vector<framework::LoDTensor> feed_tensors;
std::vector<std::string> feed_names;
for (auto &item : input_dict) {
feed_names.push_back(item.first);
feed_tensors.push_back(item.second);
}
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(scope, feed_names, feed_tensors, fetch_names);
}
return py::cast(std::move(ret));
})
.def(py::init<const platform::Place &, const ProgramDesc &>())
.def("run",
[](StandaloneExecutor &self,
Scope *scope,
......
......@@ -25,6 +25,7 @@ import six
from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, Operator
from .framework import convert_np_dtype_to_dtype_
from . import core
from . import unique_name
from . import compiler
......@@ -397,15 +398,12 @@ def _is_enable_standalone_executor():
Whether to use experimental executor `StandaloneExecutor`.
"""
flag = False
from ..distributed.fleet import fleet
if fleet._role_maker is not None:
warnings.warn("do not use standalone executor in fleet by default")
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None)
else:
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', '1')
# use standalone_executor by default if not distributed
if fleet._role_maker is None and framework._enable_standalone_executor_ is None:
framework._enable_standalone_executor_ = 1
if env_val in [1, '1', True, 'True', 'true']:
if framework._enable_standalone_executor_ in [1, '1', True, 'True', 'true']:
flag = True
return flag
......@@ -569,10 +567,7 @@ class _StandaloneExecutor(object):
return tensors
def _create_new_executor(self):
# NOTE: It's a trick to set empty start_up program.
startup_program = Program()
new_exe = core.StandaloneExecutor(self._place, startup_program.desc,
self._main_program.desc, self._scope)
new_exe = core.StandaloneExecutor(self._place, self._main_program.desc)
return new_exe
......
......@@ -84,6 +84,8 @@ _already_patch_eager_tensor = False
_already_patch_varbase = False
_current_cuda_graph_mode = None
_global_flags_ = core.globals()
_enable_standalone_executor_ = (os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR',
None))
# Some explanation of our execution system 2022.03
# For now we have 3 kinds of execution system, since we refactored dygraph mode to
......@@ -259,6 +261,17 @@ ipu_index_attr_name = 'ipu_index'
ipu_stage_attr_name = 'ipu_stage'
@signature_safe_contextmanager
def _enable_standalone_executor(enable=True):
global _enable_standalone_executor_
original_ = _enable_standalone_executor_
_enable_standalone_executor_ = enable
try:
yield
finally:
_enable_standalone_executor_ = original_
@signature_safe_contextmanager
def ipu_shard_guard(index=-1, stage=-1):
"""
......
......@@ -16,7 +16,7 @@ import os
import sys
import unittest
import paddle
from paddle.fluid import core
from paddle.fluid import core, framework
from paddle.fluid.core import StandaloneExecutor
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
......@@ -81,17 +81,13 @@ class TestCompatibility(unittest.TestCase):
return ret
def run_raw_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0'
out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
print("GT:", out)
with framework._enable_standalone_executor(False):
out = self._run(feed)
return out
def run_new_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
print("New:", out)
with framework._enable_standalone_executor(True):
out = self._run(feed)
return out
def test_with_feed(self):
......
......@@ -20,7 +20,7 @@ import shutil
import unittest
import paddle
import json
from paddle.fluid import core
from paddle.fluid import core, framework
from paddle.fluid.core import StandaloneExecutor
from paddle.profiler import profiler
......@@ -29,7 +29,7 @@ import numpy as np
paddle.enable_static()
class LinearTestCase(unittest.TestCase):
class TestDryRun(unittest.TestCase):
def setUp(self):
place = paddle.CUDAPlace(
......@@ -48,29 +48,13 @@ class LinearTestCase(unittest.TestCase):
return startup_program, main_program, c
def test_interp_base(self):
startup_program, main_program, c = self.build_program()
scope = core.Scope()
standaloneexecutor = StandaloneExecutor(self.place,
startup_program.desc,
main_program.desc, scope)
out = standaloneexecutor.run(
scope, {"a": np.ones([2, 2], dtype="float32") * 2}, [c.name])
for i in range(10):
out = standaloneexecutor.run(
scope, {"a": np.ones([2, 2], dtype="float32") * i}, [c.name])
for i in range(10):
out = standaloneexecutor.run(
scope, {"a": np.ones([2, 2], dtype="float32") * i},
['a', c.name])
def test_dry_run(self):
scope = core.Scope()
startup_program, main_program, c = self.build_program()
standaloneexecutor = StandaloneExecutor(self.place,
startup_program.desc,
main_program.desc, scope)
exe = paddle.static.Executor(self.place)
exe.run(startup_program, scope=scope)
standaloneexecutor = StandaloneExecutor(self.place, main_program.desc)
# test for cost_info
cost_info = standaloneexecutor.dry_run(
scope, {"a": np.ones([2, 2], dtype="float32")})
......@@ -124,100 +108,49 @@ class ExecutorStatisticsTestCase(unittest.TestCase):
self.iter_n = 3
self.place = paddle.CUDAPlace(
0) if core.is_compiled_with_cuda() else paddle.CPUPlace()
def test_standalone_executor_statistics(self):
if os.getenv("FLAGS_static_executor_perfstat_filepath") is None:
return
paddle.seed(2020)
main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
p = core.Place()
p.set_place(self.place)
scope = core.Scope()
executor = StandaloneExecutor(p, startup_program.desc,
main_program.desc, scope)
helper_profiler = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2))
helper_profiler.start()
for i in range(self.iter_n):
executor.run(scope, {}, fetch_list)
helper_profiler.step()
helper_profiler.stop()
perfstat_filepath = os.environ[
'FLAGS_static_executor_perfstat_filepath']
self.assertTrue(os.path.exists(perfstat_filepath))
with open(perfstat_filepath, 'r') as load_f:
stat_res = json.load(load_f)
self.assertTrue(len(stat_res) > 0)
os.remove(perfstat_filepath)
shutil.rmtree('./profiler_log')
self.perf_path = './perfstat'
def test_parallel_executor_statistics(self):
if os.getenv("FLAGS_static_executor_perfstat_filepath") is None:
return
self.run_with_statistics(executor='ParallelExecutor')
paddle.seed(2020)
main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
main_program = paddle.fluid.compiler.CompiledProgram(main_program)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0'
executor = paddle.static.Executor(self.place)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
executor.run(startup_program)
helper_profiler = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2))
helper_profiler.start()
for i in range(self.iter_n):
executor.run(main_program, fetch_list=fetch_list)
helper_profiler.step()
helper_profiler.stop()
perfstat_filepath = os.environ[
'FLAGS_static_executor_perfstat_filepath']
self.assertTrue(os.path.exists(perfstat_filepath))
with open(perfstat_filepath, 'r') as load_f:
stat_res = json.load(load_f)
self.assertTrue(len(stat_res) > 0)
def test_executor_statistics(self):
self.run_with_statistics(executor='Executor')
os.remove(perfstat_filepath)
shutil.rmtree('./profiler_log')
def test_standalone_executor_statistics(self):
self.run_with_statistics(executor='StandaloneExecutor')
def test_executor_statistics(self):
def run_with_statistics(self, executor=None):
if os.getenv("FLAGS_static_executor_perfstat_filepath") is None:
return
paddle.seed(2020)
# note: startup program is empty
main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '0'
executor = paddle.static.Executor(self.place)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
executor.run(startup_program)
helper_profiler = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2))
helper_profiler.start()
for i in range(self.iter_n):
executor.run(main_program, fetch_list=fetch_list)
helper_profiler.step()
helper_profiler.stop()
perfstat_filepath = os.environ[
'FLAGS_static_executor_perfstat_filepath']
self.assertTrue(os.path.exists(perfstat_filepath))
with open(perfstat_filepath, 'r') as load_f:
enable = True
if executor == 'ParallelExecutor':
main_program = paddle.fluid.compiler.CompiledProgram(main_program)
enable = False
elif executor == 'Executor':
enable = False
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
with framework._enable_standalone_executor(enable):
exe = paddle.static.Executor(self.place)
helper_profiler = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU], scheduler=(1, 2))
helper_profiler.start()
for i in range(self.iter_n):
exe.run(main_program, fetch_list=fetch_list)
helper_profiler.step()
helper_profiler.stop()
self.assertTrue(os.path.exists(self.perf_path))
with open(self.perf_path, 'r') as load_f:
stat_res = json.load(load_f)
self.assertTrue(len(stat_res) > 0)
os.remove(perfstat_filepath)
os.remove(self.perf_path)
shutil.rmtree('./profiler_log')
......@@ -229,59 +162,24 @@ class MultiStreamModelTestCase(unittest.TestCase):
0) if core.is_compiled_with_cuda() else paddle.CPUPlace()
def test_result(self):
ground_truths = self.run_raw_executor()
res = self.run_new_executor()
ground_truths = self.run_test(False)
res = self.run_test(True)
for gt, out in zip(ground_truths, res):
self.assertEqual(gt[0], out[0])
def run_raw_executor(self):
def run_test(self, use_new_executor=True):
paddle.seed(2020)
main_program, startup_program, fetch_list = build_program()
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
outs = []
for i in range(self.iter_n):
outs.append(exe.run(main_program, fetch_list=fetch_list))
return outs
def run_new_executor(self):
paddle.seed(2020)
main_program, startup_program, fetch_list = build_program()
fetch_list = [x.name for x in fetch_list]
p = core.Place()
p.set_place(self.place)
scope = core.Scope()
inter_core = StandaloneExecutor(p, startup_program.desc,
main_program.desc, scope)
outs = []
for i in range(self.iter_n):
outs.append(
np.array(
inter_core.run(scope, {}, fetch_list)._move_to_list()[0]))
return outs
class SwitchExecutorInterfaceTestCase(MultiStreamModelTestCase):
def run_new_executor(self):
paddle.seed(2020)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
main_program, startup_program, fetch_list = build_program()
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
outs = []
for i in range(self.iter_n):
outs.append(exe.run(main_program, fetch_list=fetch_list))
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
with framework._enable_standalone_executor(use_new_executor):
scope = core.Scope()
exe = paddle.static.Executor(self.place)
outs = []
for i in range(self.iter_n):
outs.append(
exe.run(main_program, scope=scope, fetch_list=fetch_list))
print(outs)
return outs
......@@ -337,23 +235,23 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
return outs
def run_raw_executor(self, feed, use_compiled=False):
# run construct program 1
out1 = self._run(feed,
use_str=False,
is_double=False,
use_compiled=use_compiled)
# run construct program 2 with same executor
out2 = self._run(feed,
use_str=True,
is_double=True,
use_compiled=use_compiled)
return [out1, out2]
with framework._enable_standalone_executor(False):
# run construct program 1
out1 = self._run(feed,
use_str=False,
is_double=False,
use_compiled=use_compiled)
# run construct program 2 with same executor
out2 = self._run(feed,
use_str=True,
is_double=True,
use_compiled=use_compiled)
return [out1, out2]
def run_new_executor(self, feed, use_compiled=False):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
out = self.run_raw_executor(feed, use_compiled=use_compiled)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
with framework._enable_standalone_executor():
out = self.run_raw_executor(feed, use_compiled=use_compiled)
return out
def test_with_feed(self):
......@@ -369,9 +267,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
feed = [{'a': np.ones([2, 2], dtype="float32")}]
with self.assertRaises(TypeError):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
self._run(feed[0], add_wrong_fetch=True)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
with framework._enable_standalone_executor():
self._run(feed[0], add_wrong_fetch=True)
def test_compiled_program(self):
data = np.ones([2, 2], dtype="float32")
......@@ -386,9 +283,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
data = np.ones([2, 2], dtype="float32")
feed = {"a": data}
os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM'] = '1'
res = self.run_new_executor(feed, use_compiled=True)
del os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM']
gt = self.run_raw_executor(feed, use_compiled=True)
for x, y in zip(gt, res):
self.assertTrue(np.array_equal(x, y))
......@@ -401,9 +296,8 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
for i in range(10):
print(i, flush=1)
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
out = exe.run(program, feed=None)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
with framework._enable_standalone_executor():
out = exe.run(program, feed=None)
class TestException(unittest.TestCase):
......@@ -437,14 +331,12 @@ class TestException(unittest.TestCase):
for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(main_program)
self.fetch_vars = fetch_vars
return out
def run_new_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
with framework._enable_standalone_executor():
out = self._run(feed)
return out
def test_exception(self):
......@@ -492,14 +384,12 @@ class TestInplaceApiWithDataTransform(unittest.TestCase):
with paddle.fluid.device_guard("cpu"):
x = paddle.increment(x)
exe = paddle.static.Executor(paddle.CUDAPlace(0))
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
for i in range(10):
a, = exe.run(paddle.static.default_main_program(),
fetch_list=[x])
self.assertEqual(a[0], 1)
with framework._enable_standalone_executor():
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
for i in range(10):
a, = exe.run(paddle.static.default_main_program(),
fetch_list=[x])
self.assertEqual(a[0], 1)
if __name__ == "__main__":
......
......@@ -246,16 +246,6 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp):
self.groups = 3
# TODO(chenweihang): To solve the coverage problem, add this unittest,
# remove this unittest after new executor set to default executor
class TestConv2dMKLDNNByNewExecutor(TestConv2DMKLDNNOp):
def test_check_output_by_new_executor(self):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
self.test_check_output()
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
if __name__ == '__main__':
from paddle import enable_static
enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册