未验证 提交 c67c3916 编写于 作者: Y yaoxuefeng 提交者: GitHub

refine fleet dataset class api (#27133)

上级 c296618c
...@@ -21,6 +21,7 @@ from .parallel import get_rank ...@@ -21,6 +21,7 @@ from .parallel import get_rank
from .parallel import get_world_size from .parallel import get_world_size
from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from paddle.distributed.fleet.dataset import *
from . import collective from . import collective
from .collective import * from .collective import *
...@@ -30,11 +31,8 @@ __all__ = ["spawn"] ...@@ -30,11 +31,8 @@ __all__ = ["spawn"]
# dygraph parallel apis # dygraph parallel apis
__all__ += [ __all__ += [
"init_parallel_env", "init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"get_rank", "ParallelEnv", "InMemoryDataset", "QueueDataset"
"get_world_size",
"prepare_context",
"ParallelEnv",
] ]
# collective apis # collective apis
......
...@@ -23,7 +23,6 @@ from .dataset import * ...@@ -23,7 +23,6 @@ from .dataset import *
__all__ = [ __all__ = [
"DistributedStrategy", "DistributedStrategy",
"UtilBase", "UtilBase",
"DatasetFactory",
"UserDefinedRoleMaker", "UserDefinedRoleMaker",
"PaddleCloudRoleMaker", "PaddleCloudRoleMaker",
"Fleet", "Fleet",
......
...@@ -1726,13 +1726,13 @@ class DatasetLoader(DataLoaderBase): ...@@ -1726,13 +1726,13 @@ class DatasetLoader(DataLoaderBase):
logging.warn('thread_num {} which is set in Dataset is ignored'. logging.warn('thread_num {} which is set in Dataset is ignored'.
format(dataset.thread_num)) format(dataset.thread_num))
dataset.set_thread(thread_num) dataset._set_thread(thread_num)
if isinstance(dataset, paddle.distributed.fleet.dataset. if isinstance(dataset, paddle.distributed.fleet.dataset.
InMemoryDataset) and dataset.queue_num > thread_num: InMemoryDataset) and dataset.queue_num > thread_num:
logging.warn("queue_num {} which is set in Dataset is ignored". logging.warn("queue_num {} which is set in Dataset is ignored".
format(dataset.queue_num)) format(dataset.queue_num))
dataset.set_queue_num(thread_num) dataset._set_queue_num(thread_num)
self._dataset = dataset self._dataset = dataset
use_slots = [ use_slots = [
......
...@@ -208,14 +208,16 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -208,14 +208,16 @@ class TestDistCTR2x2(FleetDistRunnerBase):
filelist = train_file_list filelist = train_file_list
# config dataset # config dataset
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset() dataset = paddle.distributed.QueueDataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py' pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.init(
batch_size=batch_size,
use_var=self.feeds,
pipe_command=pipe_command,
thread_num=thread_num)
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
for epoch_id in range(1): for epoch_id in range(1):
pass_start = time.time() pass_start = time.time()
......
...@@ -114,14 +114,14 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): ...@@ -114,14 +114,14 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
filelist.append(train_file_path) filelist.append(train_file_path)
# config dataset # config dataset
dataset = paddle.fleet.DatasetFactory().create_dataset() dataset = paddle.distributed.QueueDataset()
dataset.set_batch_size(batch_size) dataset._set_batch_size(batch_size)
dataset.set_use_var(self.feeds) dataset._set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py' pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command) dataset._set_pipe_command(pipe_command)
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.set_thread(thread_num) dataset._set_thread(thread_num)
for epoch_id in range(1): for epoch_id in range(1):
pass_start = time.time() pass_start = time.time()
......
...@@ -183,14 +183,14 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -183,14 +183,14 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
print("filelist: {}".format(filelist)) print("filelist: {}".format(filelist))
# config dataset # config dataset
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset() dataset = paddle.distributed.QueueDataset()
dataset.set_batch_size(batch_size) dataset._set_batch_size(batch_size)
dataset.set_use_var(self.feeds) dataset._set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py' pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command) dataset._set_pipe_command(pipe_command)
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.set_thread(thread_num) dataset._set_thread(thread_num)
for epoch_id in range(1): for epoch_id in range(1):
pass_start = time.time() pass_start = time.time()
......
...@@ -97,9 +97,11 @@ class DatasetLoaderTestBase(unittest.TestCase): ...@@ -97,9 +97,11 @@ class DatasetLoaderTestBase(unittest.TestCase):
def check_batch_number(self, place, randomize_batch_num=False): def check_batch_number(self, place, randomize_batch_num=False):
main_prog, startup_prog, feeds = self.build_network() main_prog, startup_prog, feeds = self.build_network()
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset( if self.dataset_name == "QueueDataset":
self.dataset_name) dataset = paddle.distributed.QueueDataset()
dataset.set_batch_size(BATCH_SIZE) else:
dataset = paddle.distributed.InMemoryDataset()
dataset._set_batch_size(BATCH_SIZE)
if isinstance(place, fluid.CPUPlace): if isinstance(place, fluid.CPUPlace):
file_num = 10 file_num = 10
...@@ -128,8 +130,8 @@ class DatasetLoaderTestBase(unittest.TestCase): ...@@ -128,8 +130,8 @@ class DatasetLoaderTestBase(unittest.TestCase):
fake_reader(batch_num=BATCH_NUM + random_delta_batch_size[i])) fake_reader(batch_num=BATCH_NUM + random_delta_batch_size[i]))
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.set_use_var(feeds) dataset._set_use_var(feeds)
dataset.set_pipe_command("cat") dataset._set_pipe_command("cat")
if self.dataset_name == 'InMemoryDataset': if self.dataset_name == 'InMemoryDataset':
dataset.load_into_memory() dataset.load_into_memory()
......
...@@ -163,10 +163,9 @@ class TestCloudRoleMaker2(unittest.TestCase): ...@@ -163,10 +163,9 @@ class TestCloudRoleMaker2(unittest.TestCase):
data = "1 1 1 1\n" data = "1 1 1 1\n"
f.write(data) f.write(data)
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset( dataset = paddle.distributed.InMemoryDataset()
"InMemoryDataset")
dataset.set_filelist(["test_fleet_gloo_role_maker_1.txt"]) dataset.set_filelist(["test_fleet_gloo_role_maker_1.txt"])
dataset.set_use_var([show, label]) dataset._set_use_var([show, label])
dataset.load_into_memory() dataset.load_into_memory()
dataset.get_memory_data_size(fleet) dataset.get_memory_data_size(fleet)
dataset.get_shuffle_data_size(fleet) dataset.get_shuffle_data_size(fleet)
......
...@@ -52,18 +52,17 @@ class TestDatasetWithStat(unittest.TestCase): ...@@ -52,18 +52,17 @@ class TestDatasetWithStat(unittest.TestCase):
name=slot, shape=[1], dtype="int64", lod_level=1) name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var) slots_vars.append(var)
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset( dataset = paddle.distributed.InMemoryDataset()
"InMemoryDataset") dataset._set_batch_size(32)
dataset.set_batch_size(32) dataset._set_thread(3)
dataset.set_thread(3)
dataset.set_filelist([ dataset.set_filelist([
"test_in_memory_dataset_run_a.txt", "test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt" "test_in_memory_dataset_run_b.txt"
]) ])
dataset.set_pipe_command("cat") dataset._set_pipe_command("cat")
dataset.set_use_var(slots_vars) dataset._set_use_var(slots_vars)
dataset.load_into_memory() dataset.load_into_memory()
dataset.set_fea_eval(1, True) dataset._set_fea_eval(1, True)
dataset.slots_shuffle(["slot1"]) dataset.slots_shuffle(["slot1"])
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册