未验证 提交 a67d3bb7 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

[Auto Parallel] Add auto parallel tuner options in launch (#52053)

* add auto parallel tuner options in launch

* add ut for launch in auto_parallel tuner

fix code format

* fix ci-converage
上级 205a4d9a
...@@ -1076,6 +1076,10 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -1076,6 +1076,10 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._non_distributed = True self._non_distributed = True
self._worker_endpoints = self._worker_endpoints.split(",") self._worker_endpoints = self._worker_endpoints.split(",")
self._trainers_num = len(self._worker_endpoints) self._trainers_num = len(self._worker_endpoints)
auto_tuner = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG", None)
if auto_tuner is not None:
trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
self._trainers_num = int(trainers_num)
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints}) self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})
self._local_rank = os.getenv("PADDLE_RANK_IN_NODE") self._local_rank = os.getenv("PADDLE_RANK_IN_NODE")
self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS") self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS")
......
...@@ -37,6 +37,7 @@ env_args_mapping = { ...@@ -37,6 +37,7 @@ env_args_mapping = {
'PADDLE_WITH_GLOO': 'with_gloo', 'PADDLE_WITH_GLOO': 'with_gloo',
'PADDLE_START_PORT': 'start_port', 'PADDLE_START_PORT': 'start_port',
'PADDLE_IPS': 'ips', 'PADDLE_IPS': 'ips',
"PADDLE_AUTO_PARALLEL_CONFIG": 'auto_parallel_config',
} }
...@@ -128,6 +129,13 @@ def parse_args(): ...@@ -128,6 +129,13 @@ def parse_args():
"--start_port", type=int, default=6070, help="fix port start with" "--start_port", type=int, default=6070, help="fix port start with"
) )
base_group.add_argument(
"--auto_parallel_config",
type=str,
default=None,
help="auto parallel config file absolute path, the file should be json format",
)
base_group.add_argument( base_group.add_argument(
"training_script", "training_script",
type=str, type=str,
......
...@@ -13,12 +13,17 @@ ...@@ -13,12 +13,17 @@
# limitations under the License. # limitations under the License.
import json import json
import os
from ..context.device import DeviceType from ..context.device import DeviceType
from .controller import ControleMode, Controller from .controller import ControleMode, Controller
class CollectiveController(Controller): class CollectiveController(Controller):
def __init__(self, ctx):
self._tuner_run_mode = None # 'tuner_only', 'run_only', 'tuner_and_run'
super().__init__(ctx)
@classmethod @classmethod
def enable(cls, ctx): def enable(cls, ctx):
# collective is the default mode # collective is the default mode
...@@ -30,6 +35,9 @@ class CollectiveController(Controller): ...@@ -30,6 +35,9 @@ class CollectiveController(Controller):
return False return False
def build_pod(self): def build_pod(self):
skip_run = self._build_pod_with_tuner()
if skip_run:
return
if ( if (
self.ctx.args.master is None self.ctx.args.master is None
and self.ctx.args.start_port and self.ctx.args.start_port
...@@ -39,6 +47,46 @@ class CollectiveController(Controller): ...@@ -39,6 +47,46 @@ class CollectiveController(Controller):
else: else:
return self._build_pod_with_master() return self._build_pod_with_master()
def _build_pod_with_tuner(self):
auto_parallel_config = self.ctx.args.auto_parallel_config
if auto_parallel_config is not None:
if not os.path.exists(auto_parallel_config):
self.ctx.logger.warning("auto_parallel_conf not exists!")
if not auto_parallel_config.endswith(".json"):
self.ctx.logger.warning(
"auto_parallel_config should be a json format file!"
)
with open(auto_parallel_config, 'r') as robj:
auto_parallel_data = json.loads(robj.read())
self._tuner_run_mode = auto_parallel_data.get(
"tuner_run_mode", 'tuner_and_run'
)
self.ctx.logger.info(f"tuner_run_mode is: {self._tuner_run_mode}")
endpoint = f"127.0.0.1:{self.ctx.node.get_free_port()}"
pod_replicas = self.pod_replicas()
if self._tuner_run_mode in ['tuner_only', 'tuner_and_run']:
e = {
"PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
"PADDLE_TRAINERS_NUM": "1",
"PADDLE_TRAINER_ENDPOINTS": endpoint,
"PADDLE_TRAINER_ID": "0",
"PADDLE_CURRENT_ENDPOINT": endpoint,
"FLAGS_selected_gpus": "0",
"PADDLE_AUTO_PARALLEL_STAGE": "tuner",
"PADDLE_GLOBAL_SIZE": "{}".format(
pod_replicas * int(self.ctx.args.nnodes)
),
"PADDLE_LOCAL_SIZE": f"{pod_replicas}",
}
log_file = "tuner.log"
self.add_container(envs=e, log_file=log_file, is_init=True)
if self._tuner_run_mode == 'tuner_only':
return True
return False
def _build_pod_with_args(self): def _build_pod_with_args(self):
self.pod.replicas = self.pod_replicas() self.pod.replicas = self.pod_replicas()
...@@ -78,6 +126,13 @@ class CollectiveController(Controller): ...@@ -78,6 +126,13 @@ class CollectiveController(Controller):
"PADDLE_TRAINERS_NUM": f"{len(job_endpoints)}", "PADDLE_TRAINERS_NUM": f"{len(job_endpoints)}",
"PADDLE_RANK_IN_NODE": str(i), "PADDLE_RANK_IN_NODE": str(i),
} }
if self._tuner_run_mode is not None:
e.update(
{
"PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
"PADDLE_AUTO_PARALLEL_STAGE": "run",
}
)
if len(selected_dev_list) > 0: if len(selected_dev_list) > 0:
if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE: if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
e.update(self.ctx.node.device.get_custom_device_envs()) e.update(self.ctx.node.device.get_custom_device_envs())
...@@ -144,7 +199,7 @@ class CollectiveController(Controller): ...@@ -144,7 +199,7 @@ class CollectiveController(Controller):
job_endpoints = [i['endpoints'] for i in peer_list] job_endpoints = [i['endpoints'] for i in peer_list]
self.pod.reset() # self.pod.reset()
selected_dev_key = self.ctx.node.device.get_selected_device_key() selected_dev_key = self.ctx.node.device.get_selected_device_key()
selected_dev_list = self.ctx.node.device.get_selected_devices( selected_dev_list = self.ctx.node.device.get_selected_devices(
self.ctx.args.devices self.ctx.args.devices
...@@ -164,6 +219,13 @@ class CollectiveController(Controller): ...@@ -164,6 +219,13 @@ class CollectiveController(Controller):
"PADDLE_TRAINERS_NUM": f"{global_size}", "PADDLE_TRAINERS_NUM": f"{global_size}",
"PADDLE_RANK_IN_NODE": str(i), "PADDLE_RANK_IN_NODE": str(i),
} }
if self._tuner_run_mode is not None:
e.update(
{
"PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
"PADDLE_AUTO_PARALLEL_STAGE": "run",
}
)
if len(selected_dev_list) > 0: if len(selected_dev_list) > 0:
if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE: if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
e.update(self.ctx.node.device.get_custom_device_envs()) e.update(self.ctx.node.device.get_custom_device_envs())
......
...@@ -55,10 +55,15 @@ class ControllerBase: ...@@ -55,10 +55,15 @@ class ControllerBase:
def deploy_pod(self): def deploy_pod(self):
assert len(self.pod.containers) > 0, "No container in the pod" assert (
len(self.pod.containers) + len(self.pod.init_containers) > 0
), "No container in the pod"
self.ctx.logger.info(f"Run {self.pod}") self.ctx.logger.info(f"Run {self.pod}")
self.ctx.logger.debug(self.pod.containers[0]) if len(self.pod.init_containers) > 0:
self.ctx.logger.debug(self.pod.init_containers[0])
if len(self.pod.containers) > 0:
self.ctx.logger.debug(self.pod.containers[0])
self.ctx.status.run() self.ctx.status.run()
self.pod.deploy() self.pod.deploy()
......
...@@ -109,7 +109,6 @@ class Pod(PodSepc): ...@@ -109,7 +109,6 @@ class Pod(PodSepc):
for i in self._init_containers: for i in self._init_containers:
i.start() i.start()
i.wait(self._init_timeout) i.wait(self._init_timeout)
for c in self._containers: for c in self._containers:
c.start() c.start()
...@@ -173,7 +172,10 @@ class Pod(PodSepc): ...@@ -173,7 +172,10 @@ class Pod(PodSepc):
def logs(self, idx=None): def logs(self, idx=None):
if idx is None: if idx is None:
self._containers[0].logs() if len(self._containers) > 0:
self._containers[0].logs()
if len(self._init_containers) > 0:
self._init_containers[0].logs()
else: else:
self._containers[idx].logs() self._containers[idx].logs()
...@@ -196,11 +198,11 @@ class Pod(PodSepc): ...@@ -196,11 +198,11 @@ class Pod(PodSepc):
''' '''
end = time.time() + timeout end = time.time() + timeout
while timeout < 0 or time.time() < end: while timeout < 0 or time.time() < end:
for c in self._containers: for c in self._init_containers + self._containers:
if c.status in any_list: if c.status in any_list:
return c.status return c.status
s = [c.status for c in self._containers] s = [c.status for c in self._init_containers + self._containers]
if len(set(s)) == 1 and s[0] in all_list: if len(set(s)) == 1 and s[0] in all_list:
return s[0] return s[0]
......
...@@ -25,11 +25,12 @@ pyname = 'train.py' ...@@ -25,11 +25,12 @@ pyname = 'train.py'
colpyfile = '''# train.py for unitest colpyfile = '''# train.py for unitest
import os import os
env = os.environ.copy() env = os.environ.copy()
assert "PADDLE_MASTER" in env if "PADDLE_AUTO_PARALLEL_CONFIG" not in env:
assert "PADDLE_MASTER" in env
assert "PADDLE_GLOBAL_RANK" in env
assert "PADDLE_LOCAL_RANK" in env
assert "PADDLE_GLOBAL_SIZE" in env assert "PADDLE_GLOBAL_SIZE" in env
assert "PADDLE_LOCAL_SIZE" in env assert "PADDLE_LOCAL_SIZE" in env
assert "PADDLE_GLOBAL_RANK" in env
assert "PADDLE_LOCAL_RANK" in env
''' '''
pspyfile = '''# train.py for unitest pspyfile = '''# train.py for unitest
...@@ -114,6 +115,26 @@ class Collective_Test(unittest.TestCase): ...@@ -114,6 +115,26 @@ class Collective_Test(unittest.TestCase):
self.assertTrue(len(c2) == 3) self.assertTrue(len(c2) == 3)
log_dir.cleanup() log_dir.cleanup()
def test_collective_4(self):
log_dir = tempfile.TemporaryDirectory()
config_dir = tempfile.TemporaryDirectory()
config_path = os.path.join(config_dir.name, 'auto_parallel_config.json')
with open(config_path, 'w') as wobj:
wobj.write(
'{\"tuner_save_path\":\"parallel_strategy.pkl\",\"tuner_load_path\":\"parallel_strategy.pkl\",\"tuner_run_mode\":\"tuner_and_run\"}'
)
port = random.randrange(6000, 8000)
args = "--job_id test4 --devices 0,1 --log_dir {} --auto_parallel_config {}"
p1 = self.pdrun(args.format(log_dir.name + "/1", config_path))
p1.wait()
self.assertTrue(p1.poll() == 0)
c1 = get_files(log_dir.name + "/1", 'test4')
print(c1)
self.assertTrue(len(c1) == 4)
log_dir.cleanup()
config_dir.cleanup()
class PS_Test(unittest.TestCase): class PS_Test(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册