未验证 提交 a3f3172c 编写于 作者: K kuizhiqing 提交者: GitHub

launch no python script (#44849)

上级 58d8ead2
......@@ -57,8 +57,6 @@ class Context(object):
return True
legacy_env_list = [
'DISTRIBUTED_TRAINER_ENDPOINTS',
'PADDLE_ELASTIC_JOB_ID',
'FLAGS_START_PORT',
]
......
......@@ -170,7 +170,11 @@ class Controller(ControllerBase):
raise NotImplementedError
def _get_entrypoint(self):
if self.ctx.args.training_script.endswith('.py'):
entrypoint = [sys.executable, "-u", self.ctx.args.training_script]
else:
entrypoint = [self.ctx.args.training_script]
entrypoint.extend(self.ctx.args.training_script_args)
return entrypoint
......
......@@ -32,8 +32,10 @@ def process_args(ctx):
argdev = ctx.args.devices
if argdev:
for d in argdev.split(','):
assert d in ctx.node.device.labels, 'Device not found {}'.format(
argdev)
if d not in ctx.node.device.labels:
ctx.logger.error(
f'Device not found {d} from {argdev} for setting {ctx.node.device.labels}'
)
def collective_compatible(ctx):
......@@ -44,7 +46,7 @@ def collective_compatible(ctx):
ctx.args.nnodes = len(hosts)
ctx.logger.info(
'args reset by env PADDLE_TRAINER_ENDPOINTS\n{}'.format(eps))
'''
if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs:
eps = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',')
hosts = set([h.split(':')[0] for h in eps])
......@@ -52,7 +54,6 @@ def collective_compatible(ctx):
ctx.args.nnodes = len(hosts)
ctx.logger.info(
'args reset by env DISTRIBUTED_TRAINER_ENDPOINTS\n{}'.format(eps))
'''
def rewrite_host_ip(ctx):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册