未验证 提交 5c6ae96b 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] Fix device occupancy (#56556)

上级 8b17207c
...@@ -946,9 +946,7 @@ static void RegisterOperatorKernel( ...@@ -946,9 +946,7 @@ static void RegisterOperatorKernel(
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) { for (const auto& dev_type : device_types) {
for (size_t dev_id = 0; for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
RegisterOperatorKernelWithPlace(name, RegisterOperatorKernelWithPlace(name,
op_kernel_func, op_kernel_func,
proto::VarType::RAW, proto::VarType::RAW,
......
...@@ -206,9 +206,8 @@ class AllocatorFacadePrivate { ...@@ -206,9 +206,8 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) { for (const auto& dev_type : device_types) {
for (size_t dev_id = 0; for (auto& dev_id :
dev_id < phi::DeviceManager::GetDeviceCount(dev_type); phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
++dev_id) {
InitNaiveBestFitCustomDeviceAllocator( InitNaiveBestFitCustomDeviceAllocator(
platform::CustomPlace(dev_type, dev_id)); platform::CustomPlace(dev_type, dev_id));
} }
...@@ -272,9 +271,8 @@ class AllocatorFacadePrivate { ...@@ -272,9 +271,8 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) { for (const auto& dev_type : device_types) {
for (size_t dev_id = 0; for (auto& dev_id :
dev_id < phi::DeviceManager::GetDeviceCount(dev_type); phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
++dev_id) {
InitAutoGrowthCustomDeviceAllocator( InitAutoGrowthCustomDeviceAllocator(
platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk); platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk);
} }
...@@ -1212,9 +1210,7 @@ class AllocatorFacadePrivate { ...@@ -1212,9 +1210,7 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) { for (const auto& dev_type : device_types) {
for (size_t dev_id = 0; for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
platform::CustomPlace p(dev_type, dev_id); platform::CustomPlace p(dev_type, dev_id);
system_allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p); system_allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
} }
...@@ -1248,9 +1244,7 @@ class AllocatorFacadePrivate { ...@@ -1248,9 +1244,7 @@ class AllocatorFacadePrivate {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) { for (const auto& dev_type : device_types) {
for (size_t dev_id = 0; for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
places.emplace_back(platform::CustomPlace(dev_type, dev_id)); places.emplace_back(platform::CustomPlace(dev_type, dev_id));
} }
} }
......
...@@ -27,9 +27,10 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool( ...@@ -27,9 +27,10 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place)); "Required device shall be CustomPlace, but received %d. ", place));
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); auto selected_devices =
pool_.reserve(dev_cnt); phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType());
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { pool_.reserve(selected_devices.size());
for (auto& dev_idx : selected_devices) {
auto creator = [place, dev_idx, this] { auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_); phi::DeviceManager::SetDevice(place_);
...@@ -82,12 +83,11 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( ...@@ -82,12 +83,11 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) { if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert({place.GetDeviceType(), pool.insert({place.GetDeviceType(),
std::vector<CustomDeviceStreamResourcePool*>()}); std::vector<CustomDeviceStreamResourcePool*>()});
for (size_t i = 0; for (auto& dev_id :
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) {
++i) {
pool[place.GetDeviceType()].emplace_back( pool[place.GetDeviceType()].emplace_back(
new CustomDeviceStreamResourcePool( new CustomDeviceStreamResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), i))); paddle::platform::CustomPlace(place.GetDeviceType(), dev_id)));
} }
} }
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
...@@ -125,9 +125,10 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool( ...@@ -125,9 +125,10 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place)); "Required device shall be CustomPlace, but received %d. ", place));
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); auto selected_devices =
pool_.reserve(dev_cnt); phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType());
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { pool_.reserve(selected_devices.size());
for (auto& dev_idx : selected_devices) {
auto creator = [place, dev_idx, this] { auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_); phi::DeviceManager::SetDevice(place_);
...@@ -180,12 +181,11 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance( ...@@ -180,12 +181,11 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
if (pool.find(place.GetDeviceType()) == pool.end()) { if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert( pool.insert(
{place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()}); {place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()});
for (size_t i = 0; for (auto& dev_id :
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); phi::DeviceManager::GetSelectedDeviceList(place.GetDeviceType())) {
++i) {
pool[place.GetDeviceType()].emplace_back( pool[place.GetDeviceType()].emplace_back(
new CustomDeviceEventResourcePool( new CustomDeviceEventResourcePool(
paddle::platform::CustomPlace(place.GetDeviceType(), i))); paddle::platform::CustomPlace(place.GetDeviceType(), dev_id)));
} }
} }
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
......
...@@ -116,9 +116,8 @@ void SynchronizeAllDevice() { ...@@ -116,9 +116,8 @@ void SynchronizeAllDevice() {
auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto &dev_type : dev_types) { for (const auto &dev_type : dev_types) {
int pre_device_id = phi::DeviceManager::GetDevice(dev_type); int pre_device_id = phi::DeviceManager::GetDevice(dev_type);
auto dev_cnt = phi::DeviceManager::GetDeviceCount(dev_type); for (auto &dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) {
for (size_t i = 0; i < dev_cnt; i++) { auto place = paddle::platform::CustomPlace(dev_type, dev_id);
auto place = paddle::platform::CustomPlace(dev_type, i);
phi::DeviceManager::SetDevice(place); phi::DeviceManager::SetDevice(place);
phi::DeviceManager::SynchronizeDevice(place); phi::DeviceManager::SynchronizeDevice(place);
} }
......
...@@ -465,21 +465,24 @@ std::vector<size_t> DeviceManager::GetDeviceList( ...@@ -465,21 +465,24 @@ std::vector<size_t> DeviceManager::GetDeviceList(
std::vector<size_t> DeviceManager::GetSelectedDeviceList( std::vector<size_t> DeviceManager::GetSelectedDeviceList(
const std::string& device_type) { const std::string& device_type) {
std::vector<size_t> devices; static std::unordered_map<std::string, std::vector<size_t>> device_list_map;
if (device_list_map.find(device_type) == device_list_map.end()) {
std::vector<size_t>& device_list = device_list_map[device_type];
std::string FLAGS = "FLAGS_selected_" + device_type + "s"; std::string FLAGS = "FLAGS_selected_" + device_type + "s";
auto FLAGS_selected_devices = getenv(FLAGS.c_str()); auto FLAGS_selected_devices = getenv(FLAGS.c_str());
if (FLAGS_selected_devices) { if (FLAGS_selected_devices) {
auto devices_str = paddle::string::Split(FLAGS_selected_devices, ','); auto devices_str = paddle::string::Split(FLAGS_selected_devices, ',');
for (auto id : devices_str) { for (auto id : devices_str) {
devices.push_back(atoi(id.c_str())); device_list.push_back(atoi(id.c_str()));
} }
} else { } else {
int count = DeviceManager::GetDeviceCount(device_type); int count = DeviceManager::GetDeviceCount(device_type);
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
devices.push_back(i); device_list.push_back(i);
} }
} }
return devices; }
return device_list_map[device_type];
} }
void DeviceManager::CCLDestroyComm(const std::string& device_type, void DeviceManager::CCLDestroyComm(const std::string& device_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册