未验证 提交 c67c3916 编写于 作者: Y yaoxuefeng 提交者: GitHub

refine fleet dataset class api (#27133)

上级 c296618c
......@@ -21,6 +21,7 @@ from .parallel import get_rank
from .parallel import get_world_size
from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from paddle.distributed.fleet.dataset import *
from . import collective
from .collective import *
......@@ -30,11 +31,8 @@ __all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env",
"get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv", "InMemoryDataset", "QueueDataset"
]
# collective apis
......
......@@ -23,7 +23,6 @@ from .dataset import *
__all__ = [
"DistributedStrategy",
"UtilBase",
"DatasetFactory",
"UserDefinedRoleMaker",
"PaddleCloudRoleMaker",
"Fleet",
......
......@@ -1726,13 +1726,13 @@ class DatasetLoader(DataLoaderBase):
logging.warn('thread_num {} which is set in Dataset is ignored'.
format(dataset.thread_num))
dataset.set_thread(thread_num)
dataset._set_thread(thread_num)
if isinstance(dataset, paddle.distributed.fleet.dataset.
InMemoryDataset) and dataset.queue_num > thread_num:
logging.warn("queue_num {} which is set in Dataset is ignored".
format(dataset.queue_num))
dataset.set_queue_num(thread_num)
dataset._set_queue_num(thread_num)
self._dataset = dataset
use_slots = [
......
......@@ -208,14 +208,16 @@ class TestDistCTR2x2(FleetDistRunnerBase):
filelist = train_file_list
# config dataset
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
dataset = paddle.distributed.QueueDataset()
pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.init(
batch_size=batch_size,
use_var=self.feeds,
pipe_command=pipe_command,
thread_num=thread_num)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
for epoch_id in range(1):
pass_start = time.time()
......
......@@ -114,14 +114,14 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
filelist.append(train_file_path)
# config dataset
dataset = paddle.fleet.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
dataset = paddle.distributed.QueueDataset()
dataset._set_batch_size(batch_size)
dataset._set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset._set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
dataset._set_thread(thread_num)
for epoch_id in range(1):
pass_start = time.time()
......
......@@ -183,14 +183,14 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
print("filelist: {}".format(filelist))
# config dataset
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
dataset = paddle.distributed.QueueDataset()
dataset._set_batch_size(batch_size)
dataset._set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset._set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
dataset._set_thread(thread_num)
for epoch_id in range(1):
pass_start = time.time()
......
......@@ -97,9 +97,11 @@ class DatasetLoaderTestBase(unittest.TestCase):
def check_batch_number(self, place, randomize_batch_num=False):
main_prog, startup_prog, feeds = self.build_network()
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset(
self.dataset_name)
dataset.set_batch_size(BATCH_SIZE)
if self.dataset_name == "QueueDataset":
dataset = paddle.distributed.QueueDataset()
else:
dataset = paddle.distributed.InMemoryDataset()
dataset._set_batch_size(BATCH_SIZE)
if isinstance(place, fluid.CPUPlace):
file_num = 10
......@@ -128,8 +130,8 @@ class DatasetLoaderTestBase(unittest.TestCase):
fake_reader(batch_num=BATCH_NUM + random_delta_batch_size[i]))
dataset.set_filelist(filelist)
dataset.set_use_var(feeds)
dataset.set_pipe_command("cat")
dataset._set_use_var(feeds)
dataset._set_pipe_command("cat")
if self.dataset_name == 'InMemoryDataset':
dataset.load_into_memory()
......
......@@ -163,10 +163,9 @@ class TestCloudRoleMaker2(unittest.TestCase):
data = "1 1 1 1\n"
f.write(data)
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset(
"InMemoryDataset")
dataset = paddle.distributed.InMemoryDataset()
dataset.set_filelist(["test_fleet_gloo_role_maker_1.txt"])
dataset.set_use_var([show, label])
dataset._set_use_var([show, label])
dataset.load_into_memory()
dataset.get_memory_data_size(fleet)
dataset.get_shuffle_data_size(fleet)
......
......@@ -52,18 +52,17 @@ class TestDatasetWithStat(unittest.TestCase):
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset(
"InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset = paddle.distributed.InMemoryDataset()
dataset._set_batch_size(32)
dataset._set_thread(3)
dataset.set_filelist([
"test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"
])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset._set_pipe_command("cat")
dataset._set_use_var(slots_vars)
dataset.load_into_memory()
dataset.set_fea_eval(1, True)
dataset._set_fea_eval(1, True)
dataset.slots_shuffle(["slot1"])
exe = fluid.Executor(fluid.CPUPlace())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册