未验证 提交 a8ee07c8 编写于 作者: N niuliling123 提交者: GitHub

[cherry-pick] Add AutoTune to reader.py for DataLoader (#42004)

Add AutoTune to reader.py for DataLoader
上级 4ef0a0b7
...@@ -18,11 +18,13 @@ import six ...@@ -18,11 +18,13 @@ import six
import numpy as np import numpy as np
import threading import threading
import paddle import paddle
import time
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, _non_static_mode, cpu_places, _current_expected_place, _in_eager_without_dygraph_check from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, _non_static_mode, cpu_places, _current_expected_place, _in_eager_without_dygraph_check
from .executor import global_scope from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider from .data_feeder import DataFeeder, BatchedTensorProvider
from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler
from .dataloader import BatchSampler, Dataset, IterableDataset from .dataloader import BatchSampler, Dataset, IterableDataset, Subset
from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, _DatasetKind, default_collate_fn from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, _DatasetKind, default_collate_fn
from .dataloader.batch_sampler import _InfiniteIterableSampler from .dataloader.batch_sampler import _InfiniteIterableSampler
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
...@@ -36,10 +38,8 @@ import warnings ...@@ -36,10 +38,8 @@ import warnings
import os import os
import multiprocessing import multiprocessing
import signal import signal
# NOTE: queue has a different name in python2 and python3 # NOTE: queue has a different name in python2 and python3
import queue import queue
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process # NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
QUEUE_GET_TIMEOUT = 60 QUEUE_GET_TIMEOUT = 60
...@@ -49,6 +49,16 @@ data_loader_unique_name_generator = UniqueNameGenerator() ...@@ -49,6 +49,16 @@ data_loader_unique_name_generator = UniqueNameGenerator()
KEEP_DATA_LOADER_ORDER = True KEEP_DATA_LOADER_ORDER = True
USE_PINNED_MEMORY = None USE_PINNED_MEMORY = None
# AutoTune Flags
USE_AUTOTUNE = False
TUNING_STEPS = 500
def set_autotune_config(use_autotune, tuning_steps=500):
global USE_AUTOTUNE
USE_AUTOTUNE = use_autotune
global TUNING_STEPS
TUNING_STEPS = tuning_steps
def keep_data_loader_order(*args): def keep_data_loader_order(*args):
...@@ -143,6 +153,122 @@ class DataLoaderBase(object): ...@@ -143,6 +153,122 @@ class DataLoaderBase(object):
return arr return arr
class AuToTune(object):
def __init__(self, loader):
self.loader = loader
self.max_num_worker = multiprocessing.cpu_count() / 2
def __call__(self):
# use default loader
if (not USE_AUTOTUNE) or (not self.need_autotune()):
return self.loader.num_workers
# get autotune loader
auto_tune_loader = self.get_autotune_loader()
if auto_tune_loader is None:
return self.loader.num_workers
# pick the best num_workers
auto_tune_start = time.time()
logging.debug("========= DataLoader Auto Tune =========")
logging.debug("User config for DataLoader: " + str(
self.loader.num_workers))
best_num_workers = 0
min_cost = float("inf")
logging.debug("Tuning Range for num_workers: 0 ~ " + str(
self.max_num_worker))
num_workers = 0
while num_workers < self.max_num_worker:
auto_tune_loader.num_workers = num_workers
avg_cost = self.evaluate_reader_cost(auto_tune_loader)
if min_cost * 0.75 > avg_cost:
min_cost = avg_cost
best_num_workers = num_workers
else:
update_num = self.is_best(auto_tune_loader, best_num_workers,
min_cost, self.max_num_worker)
if update_num == best_num_workers:
break
else:
best_num_workers = update_num
logging.debug("num_workers: " + str(num_workers) + " avg_cost: " +
str(avg_cost))
num_workers += 2
logging.info("auto_tune dataLoader best_num_workers: " + str(
best_num_workers))
logging.debug("AutoTuning Cost for DataLoader: " + str(time.time(
) - auto_tune_start) + ' seconds')
# tune the default loader's num_workers
return best_num_workers
def need_autotune(self):
if (sys.platform == 'darwin' or sys.platform == 'win32'):
return False
else:
return True
def get_sub_dataset(self, dataset, batch_size):
num_samples = min(batch_size * TUNING_STEPS, len(dataset))
sub_dataset = Subset(dataset, indices=list(range(num_samples)))
return sub_dataset
def get_autotune_loader(self):
loader = self.loader
batch_size = self.loader.batch_sampler.batch_size
if isinstance(self.loader.batch_sampler,
paddle.io.DistributedBatchSampler):
dataset = self.loader.batch_sampler.dataset
sub_dataset = self.get_sub_dataset(dataset, batch_size)
loader.batch_sampler = paddle.io.DistributedBatchSampler(
dataset=sub_dataset,
batch_size=batch_size,
num_replicas=self.loader.batch_sampler.nranks,
rank=self.loader.batch_sampler.local_rank,
shuffle=self.loader.batch_sampler.shuffle,
drop_last=self.loader.batch_sampler.drop_last)
elif isinstance(self.loader.batch_sampler, paddle.io.BatchSampler):
dataset = self.loader.batch_sampler.sampler.data_source
sub_dataset = self.get_sub_dataset(dataset, batch_size)
loader.batch_sampler = paddle.io.BatchSampler(
dataset=sub_dataset,
batch_size=batch_size,
drop_last=self.loader.batch_sampler.drop_last)
else:
loader = None
return loader
def evaluate_reader_cost(self, reader):
costs = []
avg_cost = 0
start = time.time()
for i, data in enumerate(reader):
costs.append(time.time() - start)
start = time.time()
if len(costs) > 2:
avg_cost = sum(costs[2:]) / len(costs[2:])
else:
avg_cost = sum(costs[0:]) / len(costs[0:])
return avg_cost
def is_best(self, reader, best_workers, best_time, num_work_boundary):
step = 0
num_workers = best_workers + 1
boundary = 1
while num_workers < num_work_boundary and step < 5:
self.loader.num_workers = num_workers
time = self.evaluate_reader_cost(reader)
logging.debug("for back num_workers: " + str(num_workers) +
" avg_cost: " + str(time))
step += 1
if (time < best_time * 0.70 * boundary):
return num_workers
else:
num_workers += 1
boundary *= 0.80
return best_workers
class DataLoader(object): class DataLoader(object):
""" """
DataLoader prodives an iterator which iterates given dataset DataLoader prodives an iterator which iterates given dataset
...@@ -409,6 +535,7 @@ class DataLoader(object): ...@@ -409,6 +535,7 @@ class DataLoader(object):
self._persistent_workers = persistent_workers self._persistent_workers = persistent_workers
self._iterator = None self._iterator = None
self.num_workers = AuToTune(self).__call__()
def __len__(self): def __len__(self):
if self.dataset_kind == _DatasetKind.ITER: if self.dataset_kind == _DatasetKind.ITER:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.nn as nn
from paddle.io import Dataset, DataLoader, BatchSampler, SequenceSampler
from paddle.fluid.reader import set_autotune_config
import sys
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([10]).astype('float32')
label = np.random.randint(0, 10 - 1, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class SimpleNet(nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, image):
return self.fc(image)
class TestAutoTune(unittest.TestCase):
def setUp(self):
self.batch_size = 1
self.dataset = RandomDataset(10)
def test_dataloader_use_autotune(self):
set_autotune_config(True, 1)
loader = DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=0)
def test_dataloader_disable_autotune(self):
set_autotune_config(False)
loader = DataLoader(
self.dataset, batch_size=self.batch_size, num_workers=2)
if (sys.platform == 'darwin' or sys.platform == 'win32'):
self.assertEqual(loader.num_workers, 0)
else:
self.assertEqual(loader.num_workers, 2)
def test_distributer_batch_sampler_autotune(self):
set_autotune_config(True, 1)
batch_sampler = paddle.io.DistributedBatchSampler(
self.dataset, batch_size=self.batch_size)
loader = DataLoader(
self.dataset, batch_sampler=batch_sampler, num_workers=2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册