未验证 提交 2f56b6da 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix recompute (#53718)

上级 793f3b93
...@@ -125,12 +125,14 @@ void ProcessGroupCustom::BroadcastUniqueCustomID( ...@@ -125,12 +125,14 @@ void ProcessGroupCustom::BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT
if (rank_ == 0) { if (rank_ == 0) {
for (size_t i = 0; i < ccl_ids.size(); i++) { for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i); auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
store_->set(key, ccl_ids[i]); store_->set(key, ccl_ids[i]);
} }
} else { } else {
for (size_t i = 0; i < ccl_ids.size(); i++) { for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i); auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
ccl_ids[i] = store_->get(key); ccl_ids[i] = store_->get(key);
} }
} }
......
...@@ -110,7 +110,10 @@ class _HPRecomputeFunction(PyLayer): ...@@ -110,7 +110,10 @@ class _HPRecomputeFunction(PyLayer):
cur_device = paddle.get_device() cur_device = paddle.get_device()
assert ( assert (
'gpu:' in paddle.get_device() or 'xpu:' in paddle.get_device() 'gpu:' in paddle.get_device()
or 'xpu:' in paddle.get_device()
or cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
), "Recompute with RNG is not support current device: {}.".format( ), "Recompute with RNG is not support current device: {}.".format(
cur_device cur_device
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册