未验证 提交 a3f3172c 编写于 作者: K kuizhiqing 提交者: GitHub

launch no python script (#44849)

上级 58d8ead2
...@@ -57,8 +57,6 @@ class Context(object): ...@@ -57,8 +57,6 @@ class Context(object):
return True return True
legacy_env_list = [ legacy_env_list = [
'DISTRIBUTED_TRAINER_ENDPOINTS',
'PADDLE_ELASTIC_JOB_ID',
'FLAGS_START_PORT', 'FLAGS_START_PORT',
] ]
......
...@@ -170,7 +170,11 @@ class Controller(ControllerBase): ...@@ -170,7 +170,11 @@ class Controller(ControllerBase):
raise NotImplementedError raise NotImplementedError
def _get_entrypoint(self): def _get_entrypoint(self):
entrypoint = [sys.executable, "-u", self.ctx.args.training_script] if self.ctx.args.training_script.endswith('.py'):
entrypoint = [sys.executable, "-u", self.ctx.args.training_script]
else:
entrypoint = [self.ctx.args.training_script]
entrypoint.extend(self.ctx.args.training_script_args) entrypoint.extend(self.ctx.args.training_script_args)
return entrypoint return entrypoint
......
...@@ -32,8 +32,10 @@ def process_args(ctx): ...@@ -32,8 +32,10 @@ def process_args(ctx):
argdev = ctx.args.devices argdev = ctx.args.devices
if argdev: if argdev:
for d in argdev.split(','): for d in argdev.split(','):
assert d in ctx.node.device.labels, 'Device not found {}'.format( if d not in ctx.node.device.labels:
argdev) ctx.logger.error(
f'Device not found {d} from {argdev} for setting {ctx.node.device.labels}'
)
def collective_compatible(ctx): def collective_compatible(ctx):
...@@ -44,7 +46,7 @@ def collective_compatible(ctx): ...@@ -44,7 +46,7 @@ def collective_compatible(ctx):
ctx.args.nnodes = len(hosts) ctx.args.nnodes = len(hosts)
ctx.logger.info( ctx.logger.info(
'args reset by env PADDLE_TRAINER_ENDPOINTS\n{}'.format(eps)) 'args reset by env PADDLE_TRAINER_ENDPOINTS\n{}'.format(eps))
'''
if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs: if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs:
eps = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',') eps = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',')
hosts = set([h.split(':')[0] for h in eps]) hosts = set([h.split(':')[0] for h in eps])
...@@ -52,7 +54,6 @@ def collective_compatible(ctx): ...@@ -52,7 +54,6 @@ def collective_compatible(ctx):
ctx.args.nnodes = len(hosts) ctx.args.nnodes = len(hosts)
ctx.logger.info( ctx.logger.info(
'args reset by env DISTRIBUTED_TRAINER_ENDPOINTS\n{}'.format(eps)) 'args reset by env DISTRIBUTED_TRAINER_ENDPOINTS\n{}'.format(eps))
'''
def rewrite_host_ip(ctx): def rewrite_host_ip(ctx):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册