未验证 提交 5c6ae96b 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] Fix device occupancy (#56556)

上级 8b17207c
......@@ -946,9 +946,7 @@ static void RegisterOperatorKernel(
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
RegisterOperatorKernelWithPlace(name,
op_kernel_func,
proto::VarType::RAW,
......
......@@ -206,9 +206,8 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
++dev_id) {
for (auto& dev_id :
phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
InitNaiveBestFitCustomDeviceAllocator(
platform::CustomPlace(dev_type, dev_id));
}
......@@ -272,9 +271,8 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
++dev_id) {
for (auto& dev_id :
phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
InitAutoGrowthCustomDeviceAllocator(
platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk);
}
......@@ -1212,9 +1210,7 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
platform::CustomPlace p(dev_type, dev_id);
system_allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
}
......@@ -1248,9 +1244,7 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
places.emplace_back(platform::CustomPlace(dev_type, dev_id));
}
}
......
......@@ -27,9 +27,10 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto selected_devices =
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType());
pool_.reserve(selected_devices.size());
for (auto& dev_idx : selected_devices) {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
......@@ -82,12 +83,11 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert({place.GetDeviceType(),
std::vector<CustomDeviceStreamResourcePool*>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
for (auto& dev_id :
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) {
pool[place.GetDeviceType()].emplace_back(
new CustomDeviceStreamResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), i)));
paddle::platform::CustomPlace(place.GetDeviceType(), dev_id)));
}
}
PADDLE_ENFORCE_LT(
......@@ -125,9 +125,10 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto selected_devices =
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType());
pool_.reserve(selected_devices.size());
for (auto& dev_idx : selected_devices) {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
......@@ -180,12 +181,11 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
for (auto& dev_id :
phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) {
pool[place.GetDeviceType()].emplace_back(
new CustomDeviceEventResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), i)));
paddle::platform::CustomPlace(place.GetDeviceType(), dev_id)));
}
}
PADDLE_ENFORCE_LT(
......
......@@ -116,9 +116,8 @@ void SynchronizeAllDevice() {
auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto &dev_type : dev_types) {
int pre_device_id = phi::DeviceManager::GetDevice(dev_type);
auto dev_cnt = phi::DeviceManager::GetDeviceCount(dev_type);
for (size_t i = 0; i < dev_cnt; i++) {
auto place = paddle::platform::CustomPlace(dev_type, i);
for (auto &dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
auto place = paddle::platform::CustomPlace(dev_type, dev_id);
phi::DeviceManager::SetDevice(place);
phi::DeviceManager::SynchronizeDevice(place);
}
......
......@@ -465,21 +465,24 @@ std::vector<size_t> DeviceManager::GetDeviceList(
std::vector<size_t> DeviceManager::GetSelectedDeviceList(
const std::string& device_type) {
std::vector<size_t> devices;
static std::unordered_map<std::string, std::vector<size_t>> device_list_map;
if (device_list_map.find(device_type) == device_list_map.end()) {
std::vector<size_t>& device_list = device_list_map[device_type];
std::string FLAGS = "FLAGS_selected_" + device_type + "s";
auto FLAGS_selected_devices = getenv(FLAGS.c_str());
if (FLAGS_selected_devices) {
auto devices_str = paddle::string::Split(FLAGS_selected_devices, ',');
for (auto id : devices_str) {
devices.push_back(atoi(id.c_str()));
device_list.push_back(atoi(id.c_str()));
}
} else {
int count = DeviceManager::GetDeviceCount(device_type);
for (int i = 0; i < count; ++i) {
devices.push_back(i);
device_list.push_back(i);
}
}
return devices;
}
return device_list_map[device_type];
}
void DeviceManager::CCLDestroyComm(const std::string& device_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册