未验证 提交 8562668e 编写于 作者: K kuizhiqing 提交者: GitHub

fix device id env (#40844)

上级 1d60e819
......@@ -242,7 +242,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
elastic_group.add_argument(
"--force", type=bool, default=False, help="update np force")
return parser.parse_args()
known_args, _ = parser.parse_known_args()
return known_args
def get_cluster_from_args(args, device_mode, devices_per_proc):
......
......@@ -25,12 +25,13 @@ class Context(object):
def __init__(self, enable_plugin=True):
self.args, self.unknown_args = parse_args()
self.envs = fetch_envs()
self.logger = self.get_logger()
self.set_env_in_args()
self.node = Node()
self.status = Status()
self.set_env_in_args()
self.logger = self.get_logger()
# design for event queue, later
self.events = []
......
......@@ -57,7 +57,7 @@ class Device(object):
else:
self._labels = []
def get_selected_flag_key(self):
def get_selected_device_key(self):
if self._dtype == DeviceType.CPU:
return 'FLAGS_selected_cpus'
if self._dtype == DeviceType.GPU:
......@@ -70,19 +70,15 @@ class Device(object):
return 'FLAGS_selected_mlus'
return 'FLAGS_selected_devices'
def get_selected_flag_label(self, idx):
if idx < len(self._labels):
return self._labels[idx]
def get_selected_devices(self, devices=''):
'''
return the device label/id relative to the visible devices
'''
if not devices:
return [str(x) for x in range(0, len(self._labels))]
else:
return '0'
def selected_flags(self, idx=None):
if idx is None:
return {self.get_selected_flag_key(): ','.join(self._labels)}
else:
return {
self.get_selected_flag_key(): self.get_selected_flag_label(idx)
}
devs = [x.strip() for x in devices.split(',')]
return [str(self._labels.index(d)) for d in devs]
@classmethod
def parse_device(self):
......
......@@ -75,6 +75,9 @@ class CollectiveController(Controller):
job_endpoints = [i['endpoints'] for i in peer_list]
self.pod.reset()
selected_dev_key = self.ctx.node.device.get_selected_device_key()
selected_dev_list = self.ctx.node.device.get_selected_devices(
self.ctx.args.devices)
for i in range(self.pod.replicas):
e = {
"PADDLE_MASTER": collective_master,
......@@ -90,9 +93,9 @@ class CollectiveController(Controller):
"PADDLE_RANK_IN_NODE": str(i),
}
if self.pod.replicas == 1:
e.update(self.ctx.node.device.selected_flags())
e.update({selected_dev_key: selected_dev_list})
else:
e.update(self.ctx.node.device.selected_flags(i))
e.update({selected_dev_key: selected_dev_list[i]})
self.add_container(envs=e, log_tag=i)
return True
......
......@@ -210,6 +210,8 @@ class Controller(ControllerBase):
if self.ctx.args.nproc_per_node:
return int(self.ctx.args.nproc_per_node)
elif self.ctx.args.devices:
return len(self.ctx.args.devices.split(','))
else:
return self.ctx.node.device.count
......
......@@ -29,8 +29,9 @@ def process_args(ctx):
#argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus
argdev = ctx.args.devices
if argdev:
ctx.node.device.labels = argdev.split(',')
ctx.logger.debug('Device reset by args {}'.format(argdev))
for d in argdev.split(','):
assert d in ctx.node.device.labels, 'Device not found {}'.format(
argdev)
def collective_compatible(ctx):
......
......@@ -64,7 +64,10 @@ class Collective_Test(unittest.TestCase):
if args:
cmd.extend(args.split(" "))
cmd.extend([pyname])
proc = subprocess.Popen(cmd, env)
env = os.environ.copy()
# virtual devies for testing
env.update({'CUDA_VISIBLE_DEVICES': '0,1,2,3,4,5,6,7'})
proc = subprocess.Popen(cmd, env=env)
return proc
def test_collective_1(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册