未验证 提交 2f56b6da 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix recompute (#53718)

上级 793f3b93
......@@ -125,12 +125,14 @@ void ProcessGroupCustom::BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT
if (rank_ == 0) {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i);
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
store_->set(key, ccl_ids[i]);
}
} else {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(i);
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
ccl_ids[i] = store_->get(key);
}
}
......
......@@ -110,7 +110,10 @@ class _HPRecomputeFunction(PyLayer):
cur_device = paddle.get_device()
assert (
'gpu:' in paddle.get_device() or 'xpu:' in paddle.get_device()
'gpu:' in paddle.get_device()
or 'xpu:' in paddle.get_device()
or cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type()
), "Recompute with RNG is not support current device: {}.".format(
cur_device
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册