未验证 提交 78014059 编写于 作者: Y yaoxuefeng 提交者: GitHub

【paddle.distributed.fleet】add data_generator in distributed.fleet.dataset (#27345)

上级 aac57159
...@@ -31,8 +31,13 @@ __all__ = ["spawn"] ...@@ -31,8 +31,13 @@ __all__ = ["spawn"]
# dygraph parallel apis # dygraph parallel apis
__all__ += [ __all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context", "init_parallel_env",
"ParallelEnv", "InMemoryDataset", "QueueDataset" "get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
"InMemoryDataset",
"QueueDataset",
] ]
# collective apis # collective apis
......
...@@ -18,6 +18,7 @@ from .base.distributed_strategy import DistributedStrategy ...@@ -18,6 +18,7 @@ from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet from .base.fleet_base import Fleet
from .base.util_factory import UtilBase from .base.util_factory import UtilBase
from .dataset import * from .dataset import *
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator
#from . import metrics #from . import metrics
__all__ = [ __all__ = [
...@@ -26,6 +27,8 @@ __all__ = [ ...@@ -26,6 +27,8 @@ __all__ = [
"UserDefinedRoleMaker", "UserDefinedRoleMaker",
"PaddleCloudRoleMaker", "PaddleCloudRoleMaker",
"Fleet", "Fleet",
"MultiSlotDataGenerator",
"MultiSlotStringDataGenerator",
"Role", "Role",
] ]
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .data_generator import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
class DataGenerator(object):
"""
DataGenerator is a general Base class for user to inherit
A user who wants to define his/her own python processing logic
with paddle.distributed.InMemoryDataset/QueueDataset should
inherit this class.
"""
def __init__(self):
self._proto_info = None
self.batch_size_ = 32
def set_batch(self, batch_size):
'''
Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
self.batch_size_ = batch_size
def run_from_memory(self):
'''
This function generator data from memory, it is usually used for
debug and benchmarking
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
yield ("words", [1, 2, 3, 4])
return local_iter
mydata = MyData()
mydata.run_from_memory()
'''
batch_samples = []
line_iter = self.generate_sample(None)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def run_from_stdin(self):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
mydata = MyData()
mydata.run_from_stdin()
'''
batch_samples = []
for line in sys.stdin:
line_iter = self.generate_sample(line)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the datafeed,and
updating proto_info information.
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the datafeed.
'''
raise NotImplementedError(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
def generate_sample(self, line):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple.
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
'''
raise NotImplementedError(
"Please rewrite this function to return a list or tuple: " +
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)")
def generate_batch(self, samples):
'''
This function needs to be overridden by the user to process the
generated samples from generate_sample(self, str) function
It is usually used as batch processing when a user wants to
do preprocessing on a batch of samples, e.g. padding according to
the max length of a sample in the batch
Args:
samples(list tuple): generated sample from generate_sample
Returns:
a python generator, the same format as return value of generate_sample
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
def local_iter():
for sample in samples:
yield sample
return local_iter
# TODO: guru4elephant
# add more generalized DataGenerator that can adapt user-defined slot
# for example, [(name, float_list), (name, str_list), (name, int_list)]
class MultiSlotStringDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info information.
The input line will be in this format:
>>> [(name, [str(feasign), ...]), ...]
>>> or ((name, [str(feasign), ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
For example, if the input is like this:
>>> [("words", ["1926", "08", "17"]), ("label", ["1"])]
>>> or (("words", ["1926", "08", "17"]), ("label", ["1"]))
the output will be:
>>> 3 1234 2345 3456 1 1
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
"Examples: [('words', ['1926', '08', '17']), ('label', ['1'])]")
output = ""
for index, item in enumerate(line):
name, elements = item
if output:
output += " "
out_str = []
out_str.append(str(len(elements)))
out_str.extend(elements)
output += " ".join(out_str)
return output + "\n"
class MultiSlotDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info information.
The input line will be in this format:
>>> [(name, [feasign, ...]), ...]
>>> or ((name, [feasign, ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
the output will be:
>>> 3 1234 2345 3456 1 1
the proto_info will be:
>>> [("words", "uint64"), ("label", "uint64")]
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
"Example: [('words', [1926, 08, 17]), ('label', [1])]")
output = ""
if self._proto_info is None:
self._proto_info = []
for item in line:
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
self._proto_info.append((name, "uint64"))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if isinstance(elem, float):
self._proto_info[-1] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float" %
type(elem))
output += " " + str(elem)
else:
if len(line) != len(self._proto_info):
raise ValueError(
"the complete field set of two given line are inconsistent.")
for index, item in enumerate(line):
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%s>."
% (self._proto_info[index][0], name))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if self._proto_info[index][1] != "float":
if isinstance(elem, float):
self._proto_info[index] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float"
% type(elem))
output += " " + str(elem)
return output + "\n"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import paddle.distributed.fleet as fleet
class SyntheticData(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class SyntheticStringData(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])]
return data_iter
sd = SyntheticData()
sd.run_from_memory()
sd2 = SyntheticStringData()
sd2.run_from_memory()
...@@ -119,7 +119,7 @@ class DatasetBase(object): ...@@ -119,7 +119,7 @@ class DatasetBase(object):
def set_filelist(self, filelist): def set_filelist(self, filelist):
""" """
Set file list in current worker. Set file list in current worker. The filelist is indicated by a list of file names (string).
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -129,7 +129,7 @@ class DatasetBase(object): ...@@ -129,7 +129,7 @@ class DatasetBase(object):
dataset.set_filelist(['a.txt', 'b.txt']) dataset.set_filelist(['a.txt', 'b.txt'])
Args: Args:
filelist(list): file list filelist(list[str]): list of file names of inputs.
""" """
self.dataset.set_filelist(filelist) self.dataset.set_filelist(filelist)
self.filelist = filelist self.filelist = filelist
...@@ -240,6 +240,8 @@ class DatasetBase(object): ...@@ -240,6 +240,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
""" """
:api_attr: Static Graph
InMemoryDataset, it will load data into memory InMemoryDataset, it will load data into memory
and shuffle data before training. and shuffle data before training.
...@@ -265,6 +267,8 @@ class InMemoryDataset(DatasetBase): ...@@ -265,6 +267,8 @@ class InMemoryDataset(DatasetBase):
def _init_distributed_settings(self, **kwargs): def _init_distributed_settings(self, **kwargs):
""" """
:api_attr: Static Graph
should be called only once in user's python scripts to initialize distributed-related setings of dataset instance should be called only once in user's python scripts to initialize distributed-related setings of dataset instance
Args: Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs: kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
...@@ -323,6 +327,8 @@ class InMemoryDataset(DatasetBase): ...@@ -323,6 +327,8 @@ class InMemoryDataset(DatasetBase):
def update_settings(self, **kwargs): def update_settings(self, **kwargs):
""" """
:api_attr: Static Graph
should be called in user's python scripts to update setings of dataset instance should be called in user's python scripts to update setings of dataset instance
Args: Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs, kwargs: Keyword arguments. Currently, we support following keys in **kwargs,
...@@ -400,6 +406,8 @@ class InMemoryDataset(DatasetBase): ...@@ -400,6 +406,8 @@ class InMemoryDataset(DatasetBase):
def init(self, **kwargs): def init(self, **kwargs):
""" """
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance should be called only once in user's python scripts to initialize setings of dataset instance
Args: Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs: kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
...@@ -450,11 +458,16 @@ class InMemoryDataset(DatasetBase): ...@@ -450,11 +458,16 @@ class InMemoryDataset(DatasetBase):
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"]) ["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.load_into_memory() dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( paddle.enable_static()
) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program()) place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda() else paddle.CPUPlace()
exe.train_from_dataset(fluid.default_main_program(), exe = paddle.static.Executor(place)
dataset) startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt") os.remove("./test_queue_dataset_run_b.txt")
""" """
...@@ -639,6 +652,8 @@ class InMemoryDataset(DatasetBase): ...@@ -639,6 +652,8 @@ class InMemoryDataset(DatasetBase):
def load_into_memory(self): def load_into_memory(self):
""" """
:api_attr: Static Graph
Load data into memory Load data into memory
Examples: Examples:
...@@ -655,6 +670,8 @@ class InMemoryDataset(DatasetBase): ...@@ -655,6 +670,8 @@ class InMemoryDataset(DatasetBase):
def preload_into_memory(self, thread_num=None): def preload_into_memory(self, thread_num=None):
""" """
:api_attr: Static Graph
Load data into memory in async mode Load data into memory in async mode
Args: Args:
...@@ -679,6 +696,8 @@ class InMemoryDataset(DatasetBase): ...@@ -679,6 +696,8 @@ class InMemoryDataset(DatasetBase):
def wait_preload_done(self): def wait_preload_done(self):
""" """
:api_attr: Static Graph
Wait preload_into_memory done Wait preload_into_memory done
Examples: Examples:
...@@ -696,6 +715,8 @@ class InMemoryDataset(DatasetBase): ...@@ -696,6 +715,8 @@ class InMemoryDataset(DatasetBase):
def local_shuffle(self): def local_shuffle(self):
""" """
:api_attr: Static Graph
Local shuffle Local shuffle
Examples: Examples:
...@@ -712,6 +733,8 @@ class InMemoryDataset(DatasetBase): ...@@ -712,6 +733,8 @@ class InMemoryDataset(DatasetBase):
def global_shuffle(self, fleet=None, thread_num=12): def global_shuffle(self, fleet=None, thread_num=12):
""" """
:api_attr: Static Graph
Global shuffle. Global shuffle.
Global shuffle can be used only in distributed mode. i.e. multiple Global shuffle can be used only in distributed mode. i.e. multiple
processes on single machine or multiple machines training together. processes on single machine or multiple machines training together.
...@@ -771,9 +794,11 @@ class InMemoryDataset(DatasetBase): ...@@ -771,9 +794,11 @@ class InMemoryDataset(DatasetBase):
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.load_into_memory() dataset.load_into_memory()
dataset.global_shuffle(fleet) dataset.global_shuffle(fleet)
exe = fluid.Executor(fluid.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(fluid.default_startup_program()) startup_program = paddle.static.Program()
exe.train_from_dataset(fluid.default_main_program(), dataset) main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
dataset.release_memory() dataset.release_memory()
""" """
...@@ -781,6 +806,8 @@ class InMemoryDataset(DatasetBase): ...@@ -781,6 +806,8 @@ class InMemoryDataset(DatasetBase):
def get_memory_data_size(self, fleet=None): def get_memory_data_size(self, fleet=None):
""" """
:api_attr: Static Graph
Get memory data size, user can call this function to know the num Get memory data size, user can call this function to know the num
of ins in all workers after load into memory. of ins in all workers after load into memory.
...@@ -817,6 +844,8 @@ class InMemoryDataset(DatasetBase): ...@@ -817,6 +844,8 @@ class InMemoryDataset(DatasetBase):
def get_shuffle_data_size(self, fleet=None): def get_shuffle_data_size(self, fleet=None):
""" """
:api_attr: Static Graph
Get shuffle data size, user can call this function to know the num Get shuffle data size, user can call this function to know the num
of ins in all workers after local/global shuffle. of ins in all workers after local/global shuffle.
...@@ -901,6 +930,8 @@ class InMemoryDataset(DatasetBase): ...@@ -901,6 +930,8 @@ class InMemoryDataset(DatasetBase):
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
""" """
:api_attr: Static Graph
QueueDataset, it will process data streamly. QueueDataset, it will process data streamly.
Examples: Examples:
...@@ -920,6 +951,8 @@ class QueueDataset(DatasetBase): ...@@ -920,6 +951,8 @@ class QueueDataset(DatasetBase):
def init(self, **kwargs): def init(self, **kwargs):
""" """
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance should be called only once in user's python scripts to initialize setings of dataset instance
""" """
super(QueueDataset, self).init(**kwargs) super(QueueDataset, self).init(**kwargs)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import core from . import core
from ..utils import deprecated
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset'] __all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
...@@ -335,6 +336,7 @@ class InMemoryDataset(DatasetBase): ...@@ -335,6 +336,7 @@ class InMemoryDataset(DatasetBase):
dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset") dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset")
""" """
@deprecated(since="2.0.0", update_to="paddle.distributed.InMemoryDataset")
def __init__(self): def __init__(self):
""" Init. """ """ Init. """
super(InMemoryDataset, self).__init__() super(InMemoryDataset, self).__init__()
...@@ -350,12 +352,18 @@ class InMemoryDataset(DatasetBase): ...@@ -350,12 +352,18 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = False self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None self.fleet_send_sleep_seconds = None
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_feed_type")
def set_feed_type(self, data_feed_type): def set_feed_type(self, data_feed_type):
""" """
Set data_feed_desc Set data_feed_desc
""" """
self.proto_desc.name = data_feed_type self.proto_desc.name = data_feed_type
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._prepare_to_run")
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
Set data_feed_desc before load or shuffle, Set data_feed_desc before load or shuffle,
...@@ -376,16 +384,27 @@ class InMemoryDataset(DatasetBase): ...@@ -376,16 +384,27 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_channel() self.dataset.create_channel()
self.dataset.create_readers() self.dataset.create_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_before_train"
)
def _dynamic_adjust_before_train(self, thread_num): def _dynamic_adjust_before_train(self, thread_num):
if not self.is_user_set_queue_num: if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, False) self.dataset.dynamic_adjust_channel_num(thread_num, False)
self.dataset.dynamic_adjust_readers_num(thread_num) self.dataset.dynamic_adjust_readers_num(thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_after_train"
)
def _dynamic_adjust_after_train(self): def _dynamic_adjust_after_train(self):
if not self.is_user_set_queue_num: if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(self.thread_num, False) self.dataset.dynamic_adjust_channel_num(self.thread_num, False)
self.dataset.dynamic_adjust_readers_num(self.thread_num) self.dataset.dynamic_adjust_readers_num(self.thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_queue_num")
def set_queue_num(self, queue_num): def set_queue_num(self, queue_num):
""" """
Set Dataset output queue num, training threads get data from queues Set Dataset output queue num, training threads get data from queues
...@@ -404,6 +423,9 @@ class InMemoryDataset(DatasetBase): ...@@ -404,6 +423,9 @@ class InMemoryDataset(DatasetBase):
self.is_user_set_queue_num = True self.is_user_set_queue_num = True
self.queue_num = queue_num self.queue_num = queue_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_ins_id")
def set_parse_ins_id(self, parse_ins_id): def set_parse_ins_id(self, parse_ins_id):
""" """
Set id Dataset need to parse insid Set id Dataset need to parse insid
...@@ -421,6 +443,9 @@ class InMemoryDataset(DatasetBase): ...@@ -421,6 +443,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.parse_ins_id = parse_ins_id self.parse_ins_id = parse_ins_id
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_content")
def set_parse_content(self, parse_content): def set_parse_content(self, parse_content):
""" """
Set if Dataset need to parse content Set if Dataset need to parse content
...@@ -455,6 +480,9 @@ class InMemoryDataset(DatasetBase): ...@@ -455,6 +480,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.parse_logkey = parse_logkey self.parse_logkey = parse_logkey
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
def set_merge_by_sid(self, merge_by_sid): def set_merge_by_sid(self, merge_by_sid):
""" """
Set if Dataset need to merge sid. If not, one ins means one Pv. Set if Dataset need to merge sid. If not, one ins means one Pv.
...@@ -544,6 +572,10 @@ class InMemoryDataset(DatasetBase): ...@@ -544,6 +572,10 @@ class InMemoryDataset(DatasetBase):
""" """
self.dataset.postprocess_instance() self.dataset.postprocess_instance()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_batch_size"
)
def set_fleet_send_batch_size(self, fleet_send_batch_size=1024): def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
""" """
Set fleet send batch size, default is 1024 Set fleet send batch size, default is 1024
...@@ -561,6 +593,10 @@ class InMemoryDataset(DatasetBase): ...@@ -561,6 +593,10 @@ class InMemoryDataset(DatasetBase):
""" """
self.fleet_send_batch_size = fleet_send_batch_size self.fleet_send_batch_size = fleet_send_batch_size
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_sleep_seconds"
)
def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0): def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
""" """
Set fleet send sleep time, default is 0 Set fleet send sleep time, default is 0
...@@ -578,6 +614,9 @@ class InMemoryDataset(DatasetBase): ...@@ -578,6 +614,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_lineid")
def set_merge_by_lineid(self, merge_size=2): def set_merge_by_lineid(self, merge_size=2):
""" """
Set merge by line id, instances of same line id will be merged after Set merge by line id, instances of same line id will be merged after
...@@ -598,16 +637,27 @@ class InMemoryDataset(DatasetBase): ...@@ -598,16 +637,27 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True self.merge_by_lineid = True
self.parse_ins_id = True self.parse_ins_id = True
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_generate_unique_feasigns"
)
def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num): def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns) self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns self.gen_uni_feasigns = generate_uni_feasigns
self.local_shard_num = shard_num self.local_shard_num = shard_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._generate_local_tables_unlock"
)
def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num, def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num,
consume_thread_num, shard_num): consume_thread_num, shard_num):
self.dataset.generate_local_tables_unlock( self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num) table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.load_into_memory")
def load_into_memory(self): def load_into_memory(self):
""" """
Load data into memory Load data into memory
...@@ -624,6 +674,9 @@ class InMemoryDataset(DatasetBase): ...@@ -624,6 +674,9 @@ class InMemoryDataset(DatasetBase):
self._prepare_to_run() self._prepare_to_run()
self.dataset.load_into_memory() self.dataset.load_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.preload_into_memory")
def preload_into_memory(self, thread_num=None): def preload_into_memory(self, thread_num=None):
""" """
Load data into memory in async mode Load data into memory in async mode
...@@ -648,6 +701,9 @@ class InMemoryDataset(DatasetBase): ...@@ -648,6 +701,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_preload_readers() self.dataset.create_preload_readers()
self.dataset.preload_into_memory() self.dataset.preload_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.wait_preload_done")
def wait_preload_done(self): def wait_preload_done(self):
""" """
Wait preload_into_memory done Wait preload_into_memory done
...@@ -665,6 +721,9 @@ class InMemoryDataset(DatasetBase): ...@@ -665,6 +721,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.wait_preload_done() self.dataset.wait_preload_done()
self.dataset.destroy_preload_readers() self.dataset.destroy_preload_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.local_shuffle")
def local_shuffle(self): def local_shuffle(self):
""" """
Local shuffle Local shuffle
...@@ -681,6 +740,9 @@ class InMemoryDataset(DatasetBase): ...@@ -681,6 +740,9 @@ class InMemoryDataset(DatasetBase):
""" """
self.dataset.local_shuffle() self.dataset.local_shuffle()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.global_shuffle")
def global_shuffle(self, fleet=None, thread_num=12): def global_shuffle(self, fleet=None, thread_num=12):
""" """
Global shuffle. Global shuffle.
...@@ -726,6 +788,9 @@ class InMemoryDataset(DatasetBase): ...@@ -726,6 +788,9 @@ class InMemoryDataset(DatasetBase):
if fleet is not None: if fleet is not None:
fleet._role_maker.barrier_worker() fleet._role_maker.barrier_worker()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.release_memory")
def release_memory(self): def release_memory(self):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -774,6 +839,9 @@ class InMemoryDataset(DatasetBase): ...@@ -774,6 +839,9 @@ class InMemoryDataset(DatasetBase):
""" """
return self.dataset.get_pv_data_size() return self.dataset.get_pv_data_size()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_memory_data_size")
def get_memory_data_size(self, fleet=None): def get_memory_data_size(self, fleet=None):
""" """
Get memory data size, user can call this function to know the num Get memory data size, user can call this function to know the num
...@@ -810,6 +878,9 @@ class InMemoryDataset(DatasetBase): ...@@ -810,6 +878,9 @@ class InMemoryDataset(DatasetBase):
return global_data_size[0] return global_data_size[0]
return local_data_size[0] return local_data_size[0]
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_shuffle_data_size")
def get_shuffle_data_size(self, fleet=None): def get_shuffle_data_size(self, fleet=None):
""" """
Get shuffle data size, user can call this function to know the num Get shuffle data size, user can call this function to know the num
...@@ -869,6 +940,9 @@ class QueueDataset(DatasetBase): ...@@ -869,6 +940,9 @@ class QueueDataset(DatasetBase):
super(QueueDataset, self).__init__() super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
@deprecated(
since="2.0.0",
update_to="paddle.distributed.QueueDataset._prepare_to_run")
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
Set data_feed_desc/thread num/filelist before run, Set data_feed_desc/thread num/filelist before run,
......
...@@ -19,7 +19,7 @@ import tarfile ...@@ -19,7 +19,7 @@ import tarfile
import os import os
import paddle import paddle
import paddle.fluid.incubate.data_generator as data_generator import paddle.distributed.fleet as fleet
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
logger = get_logger( logger = get_logger(
...@@ -59,7 +59,7 @@ def load_lr_input_record(sent): ...@@ -59,7 +59,7 @@ def load_lr_input_record(sent):
return res return res
class DatasetCtrReader(data_generator.MultiSlotDataGenerator): class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def iter(): def iter():
fs = line.strip().split('\t') fs = line.strip().split('\t')
......
...@@ -22,7 +22,7 @@ import random ...@@ -22,7 +22,7 @@ import random
import warnings import warnings
import paddle import paddle
import paddle.fluid.incubate.data_generator as data_generator import paddle.distributed.fleet as fleet
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("paddle") logger = logging.getLogger("paddle")
...@@ -84,7 +84,7 @@ class CtrReader(object): ...@@ -84,7 +84,7 @@ class CtrReader(object):
return reader return reader
class DatasetCtrReader(data_generator.MultiSlotDataGenerator): class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
def get_rand(low=0.0, high=1.0): def get_rand(low=0.0, high=1.0):
return random.random() return random.random()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import paddle
import re
import collections
import time
import paddle.distributed.fleet as fleet
class MyDataset(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
elements = line.strip().split()[0:]
output = [("show", [int(elements[0])]),
("click", [int(elements[1])]),
("slot1", [int(elements[2])])]
yield output
return data_iter
if __name__ == "__main__":
d = MyDataset()
d.run_from_stdin()
...@@ -21,13 +21,13 @@ import tarfile ...@@ -21,13 +21,13 @@ import tarfile
import random import random
import paddle import paddle
import paddle.fluid.incubate.data_generator as data_generator import paddle.distributed.fleet as fleet
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("paddle") logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class DatasetSimnetReader(data_generator.MultiSlotDataGenerator): class DatasetSimnetReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line): def generate_sample(self, line):
pass pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import unittest
import paddle.distributed.fleet as fleet
import os
import sys
import platform
class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class MyMultiSlotStringDataGenerator(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", ["1", "2", "3", "4"]), ("label", ["0"])
return data_iter
class MyMultiSlotDataGenerator_error(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_2(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_3(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield (1, ["1", "2", "3", "4"]), (2, ["0"])
return data_iter
class MyMultiSlotDataGenerator_error_4(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", "1"), ("label", "0")
return data_iter
class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", []), ("label", [])
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_basic(self):
my_ms_dg = MyMultiSlotStringDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator_2(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_stdin(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip pipecommand UT on MacOS/Win")
return
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "2 1 2\n"
data += "2 6 2\n"
data += "2 5 2\n"
data += "2 7 2\n"
f.write(data)
tmp = os.popen(
"cat test_queue_dataset_run_a.txt | python my_data_generator.py"
).readlines()
expected_res = [
'1 2 1 1 1 2\n', '1 2 1 6 1 2\n', '1 2 1 5 1 2\n', '1 2 1 7 1 2\n'
]
self.assertEqual(tmp, expected_res)
os.remove("./test_queue_dataset_run_a.txt")
class TestMultiSlotDataGenerator_error(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_2(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_2()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_3(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_3()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_4(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_4()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_5()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__':
unittest.main()
...@@ -105,11 +105,15 @@ class TestDataset(unittest.TestCase): ...@@ -105,11 +105,15 @@ class TestDataset(unittest.TestCase):
dataset.load_into_memory() dataset.load_into_memory()
dataset.local_shuffle() dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace()) paddle.enable_static()
exe.run(fluid.default_startup_program())
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
for i in range(2): for i in range(2):
try: try:
exe.train_from_dataset(fluid.default_main_program(), dataset) exe.train_from_dataset(main_program, dataset)
except ImportError as e: except ImportError as e:
pass pass
except Exception as e: except Exception as e:
...@@ -181,20 +185,24 @@ class TestDataset(unittest.TestCase): ...@@ -181,20 +185,24 @@ class TestDataset(unittest.TestCase):
use_var=slots_vars) use_var=slots_vars)
dataset.set_filelist([filename1, filename2]) dataset.set_filelist([filename1, filename2])
dataset.load_into_memory() dataset.load_into_memory()
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(startup_program)
if self.use_data_loader: if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset, data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(), fluid.cpu_places(),
self.drop_last) self.drop_last)
for i in range(self.epoch_num): for i in range(self.epoch_num):
for data in data_loader(): for data in data_loader():
exe.run(fluid.default_main_program(), feed=data) exe.run(main_program, feed=data)
else: else:
for i in range(self.epoch_num): for i in range(self.epoch_num):
try: try:
exe.train_from_dataset(fluid.default_main_program(), exe.train_from_dataset(main_program, dataset)
dataset)
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
......
...@@ -150,6 +150,7 @@ packages=['paddle', ...@@ -150,6 +150,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.runtime', 'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset', 'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
'paddle.distributed.fleet.metrics', 'paddle.distributed.fleet.metrics',
'paddle.distributed.fleet.proto', 'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils', 'paddle.distributed.fleet.utils',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册