提交 49ac67bc 编写于 作者: S sandyhouse

update

上级 75bee264
...@@ -1454,7 +1454,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1454,7 +1454,7 @@ All parameter, weight, gradient are variables in Paddle.
"number on your machine is %d", "number on your machine is %d",
dev_id, platform::GetCUDADeviceCount(), dev_id, platform::GetCUDADeviceCount(),
platform::GetCUDADeviceCount()); platform::GetCUDADeviceCount());
std::exit(-1); // std::exit(-1);
} }
} }
......
...@@ -73,9 +73,30 @@ class CollectiveHelper(object): ...@@ -73,9 +73,30 @@ class CollectiveHelper(object):
other_endpoints.remove(current_endpoint) other_endpoints.remove(current_endpoint)
block = program.global_block() block = program.global_block()
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
if rank == 0 and wait_port: if not wait_port:
wait_server_ready(other_endpoints) temp_var = block.create_var(
nccl_id_var = block.create_var( name=unique_name.generate('temp_var'),
dtype=core.VarDesc.VarType.INT32,
persistable=False,
stop_gradient=True)
block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [temp_var]},
attrs={
'shape': [1],
'dtype': temp_var.dtype,
'value': 1,
'force_cpu': False,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_allreduce_sum',
inputs={'X': [temp_var]},
outputs={'Out': [temp_var]},
attrs={'ring_id': 3,
OP_ROLE_KEY: OpRole.Forward})
comm_id_var = block.create_var(
name=unique_name.generate('nccl_id'), name=unique_name.generate('nccl_id'),
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
...@@ -100,9 +121,7 @@ class CollectiveHelper(object): ...@@ -100,9 +121,7 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
endpoint_to_index_map = { endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
e: idx for idx, e in enumerate(endpoints)
}
block.append_op( block.append_op(
type='c_comm_init_hcom', type='c_comm_init_hcom',
inputs={}, inputs={},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册