未验证 提交 78014059 编写于 作者: Y yaoxuefeng 提交者: GitHub

【paddle.distributed.fleet】add data_generator in distributed.fleet.dataset (#27345)

上级 aac57159
......@@ -31,8 +31,13 @@ __all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv", "InMemoryDataset", "QueueDataset"
"init_parallel_env",
"get_rank",
"get_world_size",
"prepare_context",
"ParallelEnv",
"InMemoryDataset",
"QueueDataset",
]
# collective apis
......
......@@ -18,6 +18,7 @@ from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet
from .base.util_factory import UtilBase
from .dataset import *
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator
#from . import metrics
__all__ = [
......@@ -26,6 +27,8 @@ __all__ = [
"UserDefinedRoleMaker",
"PaddleCloudRoleMaker",
"Fleet",
"MultiSlotDataGenerator",
"MultiSlotStringDataGenerator",
"Role",
]
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .data_generator import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
class DataGenerator(object):
"""
DataGenerator is a general Base class for user to inherit
A user who wants to define his/her own python processing logic
with paddle.distributed.InMemoryDataset/QueueDataset should
inherit this class.
"""
def __init__(self):
self._proto_info = None
self.batch_size_ = 32
def set_batch(self, batch_size):
'''
Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
self.batch_size_ = batch_size
def run_from_memory(self):
'''
This function generator data from memory, it is usually used for
debug and benchmarking
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
yield ("words", [1, 2, 3, 4])
return local_iter
mydata = MyData()
mydata.run_from_memory()
'''
batch_samples = []
line_iter = self.generate_sample(None)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def run_from_stdin(self):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
mydata = MyData()
mydata.run_from_stdin()
'''
batch_samples = []
for line in sys.stdin:
line_iter = self.generate_sample(line)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the datafeed,and
updating proto_info information.
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the datafeed.
'''
raise NotImplementedError(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
def generate_sample(self, line):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple.
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
'''
raise NotImplementedError(
"Please rewrite this function to return a list or tuple: " +
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)")
def generate_batch(self, samples):
'''
This function needs to be overridden by the user to process the
generated samples from generate_sample(self, str) function
It is usually used as batch processing when a user wants to
do preprocessing on a batch of samples, e.g. padding according to
the max length of a sample in the batch
Args:
samples(list tuple): generated sample from generate_sample
Returns:
a python generator, the same format as return value of generate_sample
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
def local_iter():
for sample in samples:
yield sample
return local_iter
# TODO: guru4elephant
# add more generalized DataGenerator that can adapt user-defined slot
# for example, [(name, float_list), (name, str_list), (name, int_list)]
class MultiSlotStringDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info information.
The input line will be in this format:
>>> [(name, [str(feasign), ...]), ...]
>>> or ((name, [str(feasign), ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
For example, if the input is like this:
>>> [("words", ["1926", "08", "17"]), ("label", ["1"])]
>>> or (("words", ["1926", "08", "17"]), ("label", ["1"]))
the output will be:
>>> 3 1234 2345 3456 1 1
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
"Examples: [('words', ['1926', '08', '17']), ('label', ['1'])]")
output = ""
for index, item in enumerate(line):
name, elements = item
if output:
output += " "
out_str = []
out_str.append(str(len(elements)))
out_str.extend(elements)
output += " ".join(out_str)
return output + "\n"
class MultiSlotDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info information.
The input line will be in this format:
>>> [(name, [feasign, ...]), ...]
>>> or ((name, [feasign, ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
the output will be:
>>> 3 1234 2345 3456 1 1
the proto_info will be:
>>> [("words", "uint64"), ("label", "uint64")]
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type"
"Example: [('words', [1926, 08, 17]), ('label', [1])]")
output = ""
if self._proto_info is None:
self._proto_info = []
for item in line:
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
self._proto_info.append((name, "uint64"))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if isinstance(elem, float):
self._proto_info[-1] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float" %
type(elem))
output += " " + str(elem)
else:
if len(line) != len(self._proto_info):
raise ValueError(
"the complete field set of two given line are inconsistent.")
for index, item in enumerate(line):
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%s>."
% (self._proto_info[index][0], name))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if self._proto_info[index][1] != "float":
if isinstance(elem, float):
self._proto_info[index] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float"
% type(elem))
output += " " + str(elem)
return output + "\n"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import paddle.distributed.fleet as fleet
class SyntheticData(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class SyntheticStringData(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield [("words", ["1", "2", "3", "4"]), ("label", ["0"])]
return data_iter
sd = SyntheticData()
sd.run_from_memory()
sd2 = SyntheticStringData()
sd2.run_from_memory()
......@@ -119,7 +119,7 @@ class DatasetBase(object):
def set_filelist(self, filelist):
"""
Set file list in current worker.
Set file list in current worker. The filelist is indicated by a list of file names (string).
Examples:
.. code-block:: python
......@@ -129,7 +129,7 @@ class DatasetBase(object):
dataset.set_filelist(['a.txt', 'b.txt'])
Args:
filelist(list): file list
filelist(list[str]): list of file names of inputs.
"""
self.dataset.set_filelist(filelist)
self.filelist = filelist
......@@ -240,6 +240,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase):
"""
:api_attr: Static Graph
InMemoryDataset, it will load data into memory
and shuffle data before training.
......@@ -265,6 +267,8 @@ class InMemoryDataset(DatasetBase):
def _init_distributed_settings(self, **kwargs):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize distributed-related setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
......@@ -323,6 +327,8 @@ class InMemoryDataset(DatasetBase):
def update_settings(self, **kwargs):
"""
:api_attr: Static Graph
should be called in user's python scripts to update setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs,
......@@ -400,6 +406,8 @@ class InMemoryDataset(DatasetBase):
def init(self, **kwargs):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
......@@ -450,11 +458,16 @@ class InMemoryDataset(DatasetBase):
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
exe.train_from_dataset(fluid.default_main_program(),
dataset)
paddle.enable_static()
place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda() else paddle.CPUPlace()
exe = paddle.static.Executor(place)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
"""
......@@ -639,6 +652,8 @@ class InMemoryDataset(DatasetBase):
def load_into_memory(self):
"""
:api_attr: Static Graph
Load data into memory
Examples:
......@@ -655,6 +670,8 @@ class InMemoryDataset(DatasetBase):
def preload_into_memory(self, thread_num=None):
"""
:api_attr: Static Graph
Load data into memory in async mode
Args:
......@@ -679,6 +696,8 @@ class InMemoryDataset(DatasetBase):
def wait_preload_done(self):
"""
:api_attr: Static Graph
Wait preload_into_memory done
Examples:
......@@ -696,6 +715,8 @@ class InMemoryDataset(DatasetBase):
def local_shuffle(self):
"""
:api_attr: Static Graph
Local shuffle
Examples:
......@@ -712,6 +733,8 @@ class InMemoryDataset(DatasetBase):
def global_shuffle(self, fleet=None, thread_num=12):
"""
:api_attr: Static Graph
Global shuffle.
Global shuffle can be used only in distributed mode. i.e. multiple
processes on single machine or multiple machines training together.
......@@ -771,9 +794,11 @@ class InMemoryDataset(DatasetBase):
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.global_shuffle(fleet)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.train_from_dataset(fluid.default_main_program(), dataset)
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
dataset.release_memory()
"""
......@@ -781,6 +806,8 @@ class InMemoryDataset(DatasetBase):
def get_memory_data_size(self, fleet=None):
"""
:api_attr: Static Graph
Get memory data size, user can call this function to know the num
of ins in all workers after load into memory.
......@@ -817,6 +844,8 @@ class InMemoryDataset(DatasetBase):
def get_shuffle_data_size(self, fleet=None):
"""
:api_attr: Static Graph
Get shuffle data size, user can call this function to know the num
of ins in all workers after local/global shuffle.
......@@ -901,6 +930,8 @@ class InMemoryDataset(DatasetBase):
class QueueDataset(DatasetBase):
"""
:api_attr: Static Graph
QueueDataset, it will process data streamly.
Examples:
......@@ -920,6 +951,8 @@ class QueueDataset(DatasetBase):
def init(self, **kwargs):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance
"""
super(QueueDataset, self).init(**kwargs)
......
......@@ -16,6 +16,7 @@
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
from ..utils import deprecated
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
......@@ -335,6 +336,7 @@ class InMemoryDataset(DatasetBase):
dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset")
"""
@deprecated(since="2.0.0", update_to="paddle.distributed.InMemoryDataset")
def __init__(self):
""" Init. """
super(InMemoryDataset, self).__init__()
......@@ -350,12 +352,18 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_feed_type")
def set_feed_type(self, data_feed_type):
"""
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._prepare_to_run")
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
......@@ -376,16 +384,27 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_channel()
self.dataset.create_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_before_train"
)
def _dynamic_adjust_before_train(self, thread_num):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, False)
self.dataset.dynamic_adjust_readers_num(thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._dynamic_adjust_after_train"
)
def _dynamic_adjust_after_train(self):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(self.thread_num, False)
self.dataset.dynamic_adjust_readers_num(self.thread_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_queue_num")
def set_queue_num(self, queue_num):
"""
Set Dataset output queue num, training threads get data from queues
......@@ -404,6 +423,9 @@ class InMemoryDataset(DatasetBase):
self.is_user_set_queue_num = True
self.queue_num = queue_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_ins_id")
def set_parse_ins_id(self, parse_ins_id):
"""
Set id Dataset need to parse insid
......@@ -421,6 +443,9 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_ins_id = parse_ins_id
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_parse_content")
def set_parse_content(self, parse_content):
"""
Set if Dataset need to parse content
......@@ -455,6 +480,9 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_logkey = parse_logkey
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
def set_merge_by_sid(self, merge_by_sid):
"""
Set if Dataset need to merge sid. If not, one ins means one Pv.
......@@ -544,6 +572,10 @@ class InMemoryDataset(DatasetBase):
"""
self.dataset.postprocess_instance()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_batch_size"
)
def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
"""
Set fleet send batch size, default is 1024
......@@ -561,6 +593,10 @@ class InMemoryDataset(DatasetBase):
"""
self.fleet_send_batch_size = fleet_send_batch_size
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_fleet_send_sleep_seconds"
)
def set_fleet_send_sleep_seconds(self, fleet_send_sleep_seconds=0):
"""
Set fleet send sleep time, default is 0
......@@ -578,6 +614,9 @@ class InMemoryDataset(DatasetBase):
"""
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_lineid")
def set_merge_by_lineid(self, merge_size=2):
"""
Set merge by line id, instances of same line id will be merged after
......@@ -598,16 +637,27 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True
self.parse_ins_id = True
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_generate_unique_feasigns"
)
def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns
self.local_shard_num = shard_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._generate_local_tables_unlock"
)
def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num,
consume_thread_num, shard_num):
self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.load_into_memory")
def load_into_memory(self):
"""
Load data into memory
......@@ -624,6 +674,9 @@ class InMemoryDataset(DatasetBase):
self._prepare_to_run()
self.dataset.load_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.preload_into_memory")
def preload_into_memory(self, thread_num=None):
"""
Load data into memory in async mode
......@@ -648,6 +701,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.create_preload_readers()
self.dataset.preload_into_memory()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.wait_preload_done")
def wait_preload_done(self):
"""
Wait preload_into_memory done
......@@ -665,6 +721,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.wait_preload_done()
self.dataset.destroy_preload_readers()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.local_shuffle")
def local_shuffle(self):
"""
Local shuffle
......@@ -681,6 +740,9 @@ class InMemoryDataset(DatasetBase):
"""
self.dataset.local_shuffle()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.global_shuffle")
def global_shuffle(self, fleet=None, thread_num=12):
"""
Global shuffle.
......@@ -726,6 +788,9 @@ class InMemoryDataset(DatasetBase):
if fleet is not None:
fleet._role_maker.barrier_worker()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.release_memory")
def release_memory(self):
"""
:api_attr: Static Graph
......@@ -774,6 +839,9 @@ class InMemoryDataset(DatasetBase):
"""
return self.dataset.get_pv_data_size()
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_memory_data_size")
def get_memory_data_size(self, fleet=None):
"""
Get memory data size, user can call this function to know the num
......@@ -810,6 +878,9 @@ class InMemoryDataset(DatasetBase):
return global_data_size[0]
return local_data_size[0]
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.get_shuffle_data_size")
def get_shuffle_data_size(self, fleet=None):
"""
Get shuffle data size, user can call this function to know the num
......@@ -869,6 +940,9 @@ class QueueDataset(DatasetBase):
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
@deprecated(
since="2.0.0",
update_to="paddle.distributed.QueueDataset._prepare_to_run")
def _prepare_to_run(self):
"""
Set data_feed_desc/thread num/filelist before run,
......
......@@ -19,7 +19,7 @@ import tarfile
import os
import paddle
import paddle.fluid.incubate.data_generator as data_generator
import paddle.distributed.fleet as fleet
from paddle.fluid.log_helper import get_logger
logger = get_logger(
......@@ -59,7 +59,7 @@ def load_lr_input_record(sent):
return res
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def iter():
fs = line.strip().split('\t')
......
......@@ -22,7 +22,7 @@ import random
import warnings
import paddle
import paddle.fluid.incubate.data_generator as data_generator
import paddle.distributed.fleet as fleet
logging.basicConfig()
logger = logging.getLogger("paddle")
......@@ -84,7 +84,7 @@ class CtrReader(object):
return reader
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
class DatasetCtrReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def get_rand(low=0.0, high=1.0):
return random.random()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import paddle
import re
import collections
import time
import paddle.distributed.fleet as fleet
class MyDataset(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
elements = line.strip().split()[0:]
output = [("show", [int(elements[0])]),
("click", [int(elements[1])]),
("slot1", [int(elements[2])])]
yield output
return data_iter
if __name__ == "__main__":
d = MyDataset()
d.run_from_stdin()
......@@ -21,13 +21,13 @@ import tarfile
import random
import paddle
import paddle.fluid.incubate.data_generator as data_generator
import paddle.distributed.fleet as fleet
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
class DatasetSimnetReader(data_generator.MultiSlotDataGenerator):
class DatasetSimnetReader(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import paddle
import unittest
import paddle.distributed.fleet as fleet
import os
import sys
import platform
class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
class MyMultiSlotStringDataGenerator(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", ["1", "2", "3", "4"]), ("label", ["0"])
return data_iter
class MyMultiSlotDataGenerator_error(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_2(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield "words"
return data_iter
class MyMultiSlotDataGenerator_error_3(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield (1, ["1", "2", "3", "4"]), (2, ["0"])
return data_iter
class MyMultiSlotDataGenerator_error_4(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", "1"), ("label", "0")
return data_iter
class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
yield ("words", []), ("label", [])
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_basic(self):
my_ms_dg = MyMultiSlotStringDataGenerator()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGenerator_2(unittest.TestCase):
def test_MyMultiSlotStringDataGenerator_stdin(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip pipecommand UT on MacOS/Win")
return
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "2 1 2\n"
data += "2 6 2\n"
data += "2 5 2\n"
data += "2 7 2\n"
f.write(data)
tmp = os.popen(
"cat test_queue_dataset_run_a.txt | python my_data_generator.py"
).readlines()
expected_res = [
'1 2 1 1 1 2\n', '1 2 1 6 1 2\n', '1 2 1 5 1 2\n', '1 2 1 7 1 2\n'
]
self.assertEqual(tmp, expected_res)
os.remove("./test_queue_dataset_run_a.txt")
class TestMultiSlotDataGenerator_error(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_2(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_2()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_3(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_3()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_4(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_4()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
def test_MultiSlotDataGenerator_error(self):
with self.assertRaises(ValueError):
my_ms_dg = MyMultiSlotDataGenerator_error_5()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__':
unittest.main()
......@@ -105,11 +105,15 @@ class TestDataset(unittest.TestCase):
dataset.load_into_memory()
dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
exe.train_from_dataset(main_program, dataset)
except ImportError as e:
pass
except Exception as e:
......@@ -181,20 +185,24 @@ class TestDataset(unittest.TestCase):
use_var=slots_vars)
dataset.set_filelist([filename1, filename2])
dataset.load_into_memory()
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.run(startup_program)
if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(),
self.drop_last)
for i in range(self.epoch_num):
for data in data_loader():
exe.run(fluid.default_main_program(), feed=data)
exe.run(main_program, feed=data)
else:
for i in range(self.epoch_num):
try:
exe.train_from_dataset(fluid.default_main_program(),
dataset)
exe.train_from_dataset(main_program, dataset)
except Exception as e:
self.assertTrue(False)
......
......@@ -150,6 +150,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
'paddle.distributed.fleet.metrics',
'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册