未验证 提交 8562668e 编写于 作者: K kuizhiqing 提交者: GitHub

fix device id env (#40844)

上级 1d60e819
...@@ -242,7 +242,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -242,7 +242,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
elastic_group.add_argument( elastic_group.add_argument(
"--force", type=bool, default=False, help="update np force") "--force", type=bool, default=False, help="update np force")
return parser.parse_args() known_args, _ = parser.parse_known_args()
return known_args
def get_cluster_from_args(args, device_mode, devices_per_proc): def get_cluster_from_args(args, device_mode, devices_per_proc):
......
...@@ -25,12 +25,13 @@ class Context(object): ...@@ -25,12 +25,13 @@ class Context(object):
def __init__(self, enable_plugin=True): def __init__(self, enable_plugin=True):
self.args, self.unknown_args = parse_args() self.args, self.unknown_args = parse_args()
self.envs = fetch_envs() self.envs = fetch_envs()
self.logger = self.get_logger()
self.set_env_in_args()
self.node = Node() self.node = Node()
self.status = Status() self.status = Status()
self.set_env_in_args() self.logger = self.get_logger()
# design for event queue, later # design for event queue, later
self.events = [] self.events = []
......
...@@ -57,7 +57,7 @@ class Device(object): ...@@ -57,7 +57,7 @@ class Device(object):
else: else:
self._labels = [] self._labels = []
def get_selected_flag_key(self): def get_selected_device_key(self):
if self._dtype == DeviceType.CPU: if self._dtype == DeviceType.CPU:
return 'FLAGS_selected_cpus' return 'FLAGS_selected_cpus'
if self._dtype == DeviceType.GPU: if self._dtype == DeviceType.GPU:
...@@ -70,19 +70,15 @@ class Device(object): ...@@ -70,19 +70,15 @@ class Device(object):
return 'FLAGS_selected_mlus' return 'FLAGS_selected_mlus'
return 'FLAGS_selected_devices' return 'FLAGS_selected_devices'
def get_selected_flag_label(self, idx): def get_selected_devices(self, devices=''):
if idx < len(self._labels): '''
return self._labels[idx] return the device label/id relative to the visible devices
'''
if not devices:
return [str(x) for x in range(0, len(self._labels))]
else: else:
return '0' devs = [x.strip() for x in devices.split(',')]
return [str(self._labels.index(d)) for d in devs]
def selected_flags(self, idx=None):
if idx is None:
return {self.get_selected_flag_key(): ','.join(self._labels)}
else:
return {
self.get_selected_flag_key(): self.get_selected_flag_label(idx)
}
@classmethod @classmethod
def parse_device(self): def parse_device(self):
......
...@@ -75,6 +75,9 @@ class CollectiveController(Controller): ...@@ -75,6 +75,9 @@ class CollectiveController(Controller):
job_endpoints = [i['endpoints'] for i in peer_list] job_endpoints = [i['endpoints'] for i in peer_list]
self.pod.reset() self.pod.reset()
selected_dev_key = self.ctx.node.device.get_selected_device_key()
selected_dev_list = self.ctx.node.device.get_selected_devices(
self.ctx.args.devices)
for i in range(self.pod.replicas): for i in range(self.pod.replicas):
e = { e = {
"PADDLE_MASTER": collective_master, "PADDLE_MASTER": collective_master,
...@@ -90,9 +93,9 @@ class CollectiveController(Controller): ...@@ -90,9 +93,9 @@ class CollectiveController(Controller):
"PADDLE_RANK_IN_NODE": str(i), "PADDLE_RANK_IN_NODE": str(i),
} }
if self.pod.replicas == 1: if self.pod.replicas == 1:
e.update(self.ctx.node.device.selected_flags()) e.update({selected_dev_key: selected_dev_list})
else: else:
e.update(self.ctx.node.device.selected_flags(i)) e.update({selected_dev_key: selected_dev_list[i]})
self.add_container(envs=e, log_tag=i) self.add_container(envs=e, log_tag=i)
return True return True
......
...@@ -210,6 +210,8 @@ class Controller(ControllerBase): ...@@ -210,6 +210,8 @@ class Controller(ControllerBase):
if self.ctx.args.nproc_per_node: if self.ctx.args.nproc_per_node:
return int(self.ctx.args.nproc_per_node) return int(self.ctx.args.nproc_per_node)
elif self.ctx.args.devices:
return len(self.ctx.args.devices.split(','))
else: else:
return self.ctx.node.device.count return self.ctx.node.device.count
......
...@@ -29,8 +29,9 @@ def process_args(ctx): ...@@ -29,8 +29,9 @@ def process_args(ctx):
#argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus #argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus
argdev = ctx.args.devices argdev = ctx.args.devices
if argdev: if argdev:
ctx.node.device.labels = argdev.split(',') for d in argdev.split(','):
ctx.logger.debug('Device reset by args {}'.format(argdev)) assert d in ctx.node.device.labels, 'Device not found {}'.format(
argdev)
def collective_compatible(ctx): def collective_compatible(ctx):
......
...@@ -64,7 +64,10 @@ class Collective_Test(unittest.TestCase): ...@@ -64,7 +64,10 @@ class Collective_Test(unittest.TestCase):
if args: if args:
cmd.extend(args.split(" ")) cmd.extend(args.split(" "))
cmd.extend([pyname]) cmd.extend([pyname])
proc = subprocess.Popen(cmd, env) env = os.environ.copy()
# virtual devies for testing
env.update({'CUDA_VISIBLE_DEVICES': '0,1,2,3,4,5,6,7'})
proc = subprocess.Popen(cmd, env=env)
return proc return proc
def test_collective_1(self): def test_collective_1(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册