未验证 提交 b8793f70 编写于 作者: L LiYuRio 提交者: GitHub

[Fleet Executor] Add feed, fetch and check correctness (#37824)

上级 70dea138
......@@ -21,12 +21,11 @@ message RankInfo {
}
message FleetExecutorDesc {
optional string strategy = 1 [ default = "Origin" ];
optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor
repeated RankInfo cluster_info = 3;
optional int32 dp_degree = 4 [ default = 1 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 pp_degree = 6 [ default = 1 ];
optional int64 num_micro_batches = 7 [ default = 1 ];
optional int64 num_slots = 8 [ default = 1 ];
optional int64 cur_rank = 1 [ default = 0 ]; // Rank id of current processor
repeated RankInfo cluster_info = 2;
optional int32 dp_degree = 3 [ default = 1 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional int32 pp_degree = 5 [ default = 1 ];
optional int64 num_micro_batches = 6 [ default = 1 ];
optional int64 num_slots = 7 [ default = 1 ];
}
......@@ -100,12 +100,7 @@ std::vector<OpRole> RuntimeGraph::functionality_order = {
RuntimeGraph::RuntimeGraph(const ProgramDesc& program,
const FleetExecutorDesc& exe_desc)
: exe_desc_(exe_desc) {
if (exe_desc.strategy() == "1F1B") {
SplitProgramBasedFunctionality(program);
AssignTaskToIntercepter();
FakeDependence();
FakeRuntimeInfo();
} else if (exe_desc.strategy() == "Origin") {
if (exe_desc.pp_degree() == 1) {
int64_t cur_rank = exe_desc_.cur_rank();
int64_t max_run_times = exe_desc_.num_micro_batches();
int64_t max_slot_nums = exe_desc_.num_slots();
......@@ -117,8 +112,10 @@ RuntimeGraph::RuntimeGraph(const ProgramDesc& program,
intercepter_id_to_rank_.insert({task_id, cur_rank});
intercepter_id_to_node_.insert({task_id, task_nodes_[0].get()});
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Strategy %s is None of 1F1B or Origin.", exe_desc.strategy()));
SplitProgramBasedFunctionality(program);
AssignTaskToIntercepter();
FakeDependence();
FakeRuntimeInfo();
}
}
......
......@@ -682,8 +682,6 @@ class Executor(object):
self._enable_interpreter_core = _is_enable_standalone_executor()
self._executor_cache = _ExecutorCache(self.place)
self._fleet_executor_cache = None
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
......@@ -1274,9 +1272,7 @@ class Executor(object):
if isinstance(program, Program) and program._pipeline_opt:
if "fleet_opt" in program._pipeline_opt:
return self._run_using_fleet_executor(
program,
fetch_list=fetch_list,
use_program_cache=use_program_cache)
program=program, feed=feed, fetch_list=fetch_list)
if "startup_program" in program._pipeline_opt:
program = program._pipeline_opt["startup_program"]
else:
......@@ -1950,64 +1946,72 @@ class Executor(object):
return ctx
def _prepare_fleet_executor(self, program=None, scope=None, fleet_opt=None):
from ..distributed.fleet.proto import fleet_executor_desc_pb2
from google.protobuf import text_format
assert program, "Program for fleet executor should not be None"
assert fleet_opt, "Configurations for fleet executor should not be None"
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "")
trainer_endpoints = trainer_endpoints_str.split(',')
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
fleet_exe_desc.cur_rank = os.getenv("PADDLE_TRAINER_ID", 0)
nrank = len(trainer_endpoints)
for rank, endpoint in enumerate(trainer_endpoints):
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
fleet_exe_desc.cluster_info.append(rank_info)
if "dist_strategy" in fleet_opt:
fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"]
fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"]
fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"]
if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place()
place.set_place(self.place)
fleet_exe.init(program.desc, scope, place)
return fleet_exe
def _run_using_fleet_executor(self,
program=None,
dataset=None,
scope=None,
thread=0,
is_infer=False,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100,
fetch_handler=None,
use_program_cache=False):
if self._fleet_executor_cache is None:
from ..distributed.fleet.proto import fleet_executor_desc_pb2
from google.protobuf import text_format
cur_rank = os.getenv("PADDLE_TRAINER_ID")
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS")
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
nrank = 1
if cur_rank and trainer_endpoints_str:
fleet_exe_desc.cur_rank = int(cur_rank)
trainer_endpoints = trainer_endpoints_str.split(',')
for rank, endpoint in enumerate(trainer_endpoints):
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
fleet_exe_desc.cluster_info.append(rank_info)
nrank = len(trainer_endpoints)
else:
fleet_exe_desc.cur_rank = 0
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = 0
rank_info.ip_port = ''
fleet_exe_desc.cluster_info.append(rank_info)
logging.warning(
"Fleet Executor will run on single device only.")
feed=None,
feed_var_name="feed",
fetch_var_name="fetch",
fetch_list=None):
cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key)
cached_program = self._get_program_cache(cache_key)
if cached_scope is None:
cached_scope = global_scope()
self._add_scope_cache(cache_key, cached_scope)
if cached_program is None:
real_feed = [] if feed is None else feed
real_program = program
if "section_program" in program._pipeline_opt:
real_program = program._pipeline_opt["section_program"]
cached_program = self._add_feed_fetch_ops(
program=real_program,
feed=real_feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program)
if cached_ctx is None:
fleet_opt = program._pipeline_opt["fleet_opt"]
if "dist_strategy" in fleet_opt:
fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"][
"dp_degree"]
fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"][
"mp_degree"]
fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"][
"pp_degree"]
if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt[
"num_micro_batches"]
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place()
place.set_place(self.place)
if scope is None:
scope = global_scope()
fleet_exe.init(program._pipeline_opt["section_program"].desc, scope,
place)
self._fleet_executor_cache = fleet_exe
self._fleet_executor_cache.run()
cached_ctx = self._prepare_fleet_executor(
program=cached_program, scope=cached_scope, fleet_opt=fleet_opt)
self._add_ctx_cache(cache_key, cached_ctx)
if feed:
self._feed_data(cached_program, feed, feed_var_name, cached_scope)
cached_ctx.run()
if fetch_list:
arr = cached_scope.find_var(fetch_var_name).get_fetch_list()
tensors = arr._move_to_list()
return as_numpy(tensors)
return None
def _run_pipeline(self,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
......@@ -20,20 +21,53 @@ paddle.enable_static()
class TestFleetExecutor(unittest.TestCase):
def run_fleet_executor(self, place):
def fake_fleet_opt(self):
# TODO: Fake for coverage will be removed in the future
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1
}
strategy.pipeline_configs = {"accumulate_steps": 1}
fleet_opt = {
"dist_strategy": strategy.sharding_configs,
"num_micro_batches": strategy.pipeline_configs["accumulate_steps"]
}
return fleet_opt
def run_fleet_executor(self, place, x_data, y_data):
exe = paddle.static.Executor(place)
empty_program = paddle.static.Program()
with fluid.program_guard(empty_program, empty_program):
x = fluid.layers.data(name='x', shape=[1], dtype=paddle.float32)
x = fluid.layers.data(
name='x', shape=x_data.shape, dtype=x_data.dtype)
y = fluid.layers.data(
name='y', shape=y_data.shape, dtype=y_data.dtype)
z = x + y
a = 2 * x + 3 * y
# TODO: section_program will be removed in the future
empty_program._pipeline_opt = {
"fleet_opt": {},
"fleet_opt": self.fake_fleet_opt(),
"section_program": empty_program
}
exe.run(empty_program, feed={'x': [1]})
res = exe.run(empty_program,
feed={'x': x_data,
'y': y_data},
fetch_list=[z.name, a.name])
return res
def test_executor_on_single_device(self):
if fluid.is_compiled_with_cuda():
self.run_fleet_executor(fluid.CUDAPlace(0))
shape = (10000, 3462)
x_data = np.random.rand(*shape)
y_data = np.random.rand(*shape)
z_data = x_data + y_data
a_data = 2 * x_data + 3 * y_data
res = self.run_fleet_executor(fluid.CUDAPlace(0), x_data, y_data)
self.assertTrue(np.allclose(res[0], z_data))
self.assertTrue(np.allclose(res[1], a_data))
if __name__ == "__main__":
......
......@@ -49,7 +49,8 @@ class TestFleetExecutor(unittest.TestCase):
"num_micro_batches": strategy.pipeline_configs["accumulate_steps"]
}
if fluid.is_compiled_with_cuda():
self.run_fleet_executor(fluid.CUDAPlace(0), fleet_opt)
# TODO: Distribute test case is not supported for executor can not stop
pass
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册