未验证 提交 a67d3bb7 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

[Auto Parallel] Add auto parallel tuner options in launch (#52053)

* add auto parallel tuner options in launch

* add ut for launch in auto_parallel tuner

fix code format

* fix ci-converage
上级 205a4d9a
......@@ -1076,6 +1076,10 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._non_distributed = True
self._worker_endpoints = self._worker_endpoints.split(",")
self._trainers_num = len(self._worker_endpoints)
auto_tuner = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG", None)
if auto_tuner is not None:
trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
self._trainers_num = int(trainers_num)
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})
self._local_rank = os.getenv("PADDLE_RANK_IN_NODE")
self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS")
......
......@@ -37,6 +37,7 @@ env_args_mapping = {
'PADDLE_WITH_GLOO': 'with_gloo',
'PADDLE_START_PORT': 'start_port',
'PADDLE_IPS': 'ips',
"PADDLE_AUTO_PARALLEL_CONFIG": 'auto_parallel_config',
}
......@@ -128,6 +129,13 @@ def parse_args():
"--start_port", type=int, default=6070, help="fix port start with"
)
base_group.add_argument(
"--auto_parallel_config",
type=str,
default=None,
help="auto parallel config file absolute path, the file should be json format",
)
base_group.add_argument(
"training_script",
type=str,
......
......@@ -13,12 +13,17 @@
# limitations under the License.
import json
import os
from ..context.device import DeviceType
from .controller import ControleMode, Controller
class CollectiveController(Controller):
def __init__(self, ctx):
self._tuner_run_mode = None # 'tuner_only', 'run_only', 'tuner_and_run'
super().__init__(ctx)
@classmethod
def enable(cls, ctx):
# collective is the default mode
......@@ -30,6 +35,9 @@ class CollectiveController(Controller):
return False
def build_pod(self):
skip_run = self._build_pod_with_tuner()
if skip_run:
return
if (
self.ctx.args.master is None
and self.ctx.args.start_port
......@@ -39,6 +47,46 @@ class CollectiveController(Controller):
else:
return self._build_pod_with_master()
def _build_pod_with_tuner(self):
auto_parallel_config = self.ctx.args.auto_parallel_config
if auto_parallel_config is not None:
if not os.path.exists(auto_parallel_config):
self.ctx.logger.warning("auto_parallel_conf not exists!")
if not auto_parallel_config.endswith(".json"):
self.ctx.logger.warning(
"auto_parallel_config should be a json format file!"
)
with open(auto_parallel_config, 'r') as robj:
auto_parallel_data = json.loads(robj.read())
self._tuner_run_mode = auto_parallel_data.get(
"tuner_run_mode", 'tuner_and_run'
)
self.ctx.logger.info(f"tuner_run_mode is: {self._tuner_run_mode}")
endpoint = f"127.0.0.1:{self.ctx.node.get_free_port()}"
pod_replicas = self.pod_replicas()
if self._tuner_run_mode in ['tuner_only', 'tuner_and_run']:
e = {
"PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
"PADDLE_TRAINERS_NUM": "1",
"PADDLE_TRAINER_ENDPOINTS": endpoint,
"PADDLE_TRAINER_ID": "0",
"PADDLE_CURRENT_ENDPOINT": endpoint,
"FLAGS_selected_gpus": "0",
"PADDLE_AUTO_PARALLEL_STAGE": "tuner",
"PADDLE_GLOBAL_SIZE": "{}".format(
pod_replicas * int(self.ctx.args.nnodes)
),
"PADDLE_LOCAL_SIZE": f"{pod_replicas}",
}
log_file = "tuner.log"
self.add_container(envs=e, log_file=log_file, is_init=True)
if self._tuner_run_mode == 'tuner_only':
return True
return False
def _build_pod_with_args(self):
self.pod.replicas = self.pod_replicas()
......@@ -78,6 +126,13 @@ class CollectiveController(Controller):
"PADDLE_TRAINERS_NUM": f"{len(job_endpoints)}",
"PADDLE_RANK_IN_NODE": str(i),
}
if self._tuner_run_mode is not None:
e.update(
{
"PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
"PADDLE_AUTO_PARALLEL_STAGE": "run",
}
)
if len(selected_dev_list) > 0:
if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
e.update(self.ctx.node.device.get_custom_device_envs())
......@@ -144,7 +199,7 @@ class CollectiveController(Controller):
job_endpoints = [i['endpoints'] for i in peer_list]
self.pod.reset()
# self.pod.reset()
selected_dev_key = self.ctx.node.device.get_selected_device_key()
selected_dev_list = self.ctx.node.device.get_selected_devices(
self.ctx.args.devices
......@@ -164,6 +219,13 @@ class CollectiveController(Controller):
"PADDLE_TRAINERS_NUM": f"{global_size}",
"PADDLE_RANK_IN_NODE": str(i),
}
if self._tuner_run_mode is not None:
e.update(
{
"PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
"PADDLE_AUTO_PARALLEL_STAGE": "run",
}
)
if len(selected_dev_list) > 0:
if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
e.update(self.ctx.node.device.get_custom_device_envs())
......
......@@ -55,10 +55,15 @@ class ControllerBase:
def deploy_pod(self):
assert len(self.pod.containers) > 0, "No container in the pod"
assert (
len(self.pod.containers) + len(self.pod.init_containers) > 0
), "No container in the pod"
self.ctx.logger.info(f"Run {self.pod}")
self.ctx.logger.debug(self.pod.containers[0])
if len(self.pod.init_containers) > 0:
self.ctx.logger.debug(self.pod.init_containers[0])
if len(self.pod.containers) > 0:
self.ctx.logger.debug(self.pod.containers[0])
self.ctx.status.run()
self.pod.deploy()
......
......@@ -109,7 +109,6 @@ class Pod(PodSepc):
for i in self._init_containers:
i.start()
i.wait(self._init_timeout)
for c in self._containers:
c.start()
......@@ -173,7 +172,10 @@ class Pod(PodSepc):
def logs(self, idx=None):
if idx is None:
self._containers[0].logs()
if len(self._containers) > 0:
self._containers[0].logs()
if len(self._init_containers) > 0:
self._init_containers[0].logs()
else:
self._containers[idx].logs()
......@@ -196,11 +198,11 @@ class Pod(PodSepc):
'''
end = time.time() + timeout
while timeout < 0 or time.time() < end:
for c in self._containers:
for c in self._init_containers + self._containers:
if c.status in any_list:
return c.status
s = [c.status for c in self._containers]
s = [c.status for c in self._init_containers + self._containers]
if len(set(s)) == 1 and s[0] in all_list:
return s[0]
......
......@@ -25,11 +25,12 @@ pyname = 'train.py'
colpyfile = '''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_MASTER" in env
if "PADDLE_AUTO_PARALLEL_CONFIG" not in env:
assert "PADDLE_MASTER" in env
assert "PADDLE_GLOBAL_RANK" in env
assert "PADDLE_LOCAL_RANK" in env
assert "PADDLE_GLOBAL_SIZE" in env
assert "PADDLE_LOCAL_SIZE" in env
assert "PADDLE_GLOBAL_RANK" in env
assert "PADDLE_LOCAL_RANK" in env
'''
pspyfile = '''# train.py for unitest
......@@ -114,6 +115,26 @@ class Collective_Test(unittest.TestCase):
self.assertTrue(len(c2) == 3)
log_dir.cleanup()
def test_collective_4(self):
log_dir = tempfile.TemporaryDirectory()
config_dir = tempfile.TemporaryDirectory()
config_path = os.path.join(config_dir.name, 'auto_parallel_config.json')
with open(config_path, 'w') as wobj:
wobj.write(
'{\"tuner_save_path\":\"parallel_strategy.pkl\",\"tuner_load_path\":\"parallel_strategy.pkl\",\"tuner_run_mode\":\"tuner_and_run\"}'
)
port = random.randrange(6000, 8000)
args = "--job_id test4 --devices 0,1 --log_dir {} --auto_parallel_config {}"
p1 = self.pdrun(args.format(log_dir.name + "/1", config_path))
p1.wait()
self.assertTrue(p1.poll() == 0)
c1 = get_files(log_dir.name + "/1", 'test4')
print(c1)
self.assertTrue(len(c1) == 4)
log_dir.cleanup()
config_dir.cleanup()
class PS_Test(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册