Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
78014059
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
78014059
编写于
9月 28, 2020
作者:
Y
yaoxuefeng
提交者:
GitHub
9月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【paddle.distributed.fleet】add data_generator in distributed.fleet.dataset (#27345)
上级
aac57159
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
782 addition
and
25 deletion
+782
-25
python/paddle/distributed/__init__.py
python/paddle/distributed/__init__.py
+7
-2
python/paddle/distributed/fleet/__init__.py
python/paddle/distributed/fleet/__init__.py
+3
-0
python/paddle/distributed/fleet/data_generator/__init__.py
python/paddle/distributed/fleet/data_generator/__init__.py
+14
-0
python/paddle/distributed/fleet/data_generator/data_generator.py
...paddle/distributed/fleet/data_generator/data_generator.py
+366
-0
python/paddle/distributed/fleet/data_generator/test_data_generator.py
...e/distributed/fleet/data_generator/test_data_generator.py
+39
-0
python/paddle/distributed/fleet/dataset/dataset.py
python/paddle/distributed/fleet/dataset/dataset.py
+43
-10
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+74
-0
python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py
...n/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py
+2
-2
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
+2
-2
python/paddle/fluid/tests/unittests/my_data_generator.py
python/paddle/fluid/tests/unittests/my_data_generator.py
+38
-0
python/paddle/fluid/tests/unittests/simnet_dataset_reader.py
python/paddle/fluid/tests/unittests/simnet_dataset_reader.py
+2
-2
python/paddle/fluid/tests/unittests/test_data_generator.py
python/paddle/fluid/tests/unittests/test_data_generator.py
+176
-0
python/paddle/fluid/tests/unittests/test_dataset.py
python/paddle/fluid/tests/unittests/test_dataset.py
+15
-7
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
python/paddle/distributed/__init__.py
浏览文件 @
78014059
...
...
@@ -31,8 +31,13 @@ __all__ = ["spawn"]
# dygraph parallel apis
__all__
+=
[
"init_parallel_env"
,
"get_rank"
,
"get_world_size"
,
"prepare_context"
,
"ParallelEnv"
,
"InMemoryDataset"
,
"QueueDataset"
"init_parallel_env"
,
"get_rank"
,
"get_world_size"
,
"prepare_context"
,
"ParallelEnv"
,
"InMemoryDataset"
,
"QueueDataset"
,
]
# collective apis
...
...
python/paddle/distributed/fleet/__init__.py
浏览文件 @
78014059
...
...
@@ -18,6 +18,7 @@ from .base.distributed_strategy import DistributedStrategy
from
.base.fleet_base
import
Fleet
from
.base.util_factory
import
UtilBase
from
.dataset
import
*
from
.data_generator
import
MultiSlotDataGenerator
,
MultiSlotStringDataGenerator
#from . import metrics
__all__
=
[
...
...
@@ -26,6 +27,8 @@ __all__ = [
"UserDefinedRoleMaker"
,
"PaddleCloudRoleMaker"
,
"Fleet"
,
"MultiSlotDataGenerator"
,
"MultiSlotStringDataGenerator"
,
"Role"
,
]
...
...
python/paddle/distributed/fleet/data_generator/__init__.py
0 → 100644
浏览文件 @
78014059
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from
.data_generator
import
*
python/paddle/distributed/fleet/data_generator/data_generator.py
0 → 100644
浏览文件 @
78014059
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
class
DataGenerator
(
object
):
"""
DataGenerator is a general Base class for user to inherit
A user who wants to define his/her own python processing logic
with paddle.distributed.InMemoryDataset/QueueDataset should
inherit this class.
"""
def
__init__
(
self
):
self
.
_proto_info
=
None
self
.
batch_size_
=
32
def
set_batch
(
self
,
batch_size
):
'''
Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
self
.
batch_size_
=
batch_size
def
run_from_memory
(
self
):
'''
This function generator data from memory, it is usually used for
debug and benchmarking
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
yield ("words", [1, 2, 3, 4])
return local_iter
mydata = MyData()
mydata.run_from_memory()
'''
batch_samples
=
[]
line_iter
=
self
.
generate_sample
(
None
)
for
user_parsed_line
in
line_iter
():
if
user_parsed_line
==
None
:
continue
batch_samples
.
append
(
user_parsed_line
)
if
len
(
batch_samples
)
==
self
.
batch_size_
:
batch_iter
=
self
.
generate_batch
(
batch_samples
)
for
sample
in
batch_iter
():
sys
.
stdout
.
write
(
self
.
_gen_str
(
sample
))
batch_samples
=
[]
if
len
(
batch_samples
)
>
0
:
batch_iter
=
self
.
generate_batch
(
batch_samples
)
for
sample
in
batch_iter
():
sys
.
stdout
.
write
(
self
.
_gen_str
(
sample
))
def
run_from_stdin
(
self
):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
mydata = MyData()
mydata.run_from_stdin()
'''
batch_samples
=
[]
for
line
in
sys
.
stdin
:
line_iter
=
self
.
generate_sample
(
line
)
for
user_parsed_line
in
line_iter
():
if
user_parsed_line
==
None
:
continue
batch_samples
.
append
(
user_parsed_line
)
if
len
(
batch_samples
)
==
self
.
batch_size_
:
batch_iter
=
self
.
generate_batch
(
batch_samples
)
for
sample
in
batch_iter
():
sys
.
stdout
.
write
(
self
.
_gen_str
(
sample
))
batch_samples
=
[]
if
len
(
batch_samples
)
>
0
:
batch_iter
=
self
.
generate_batch
(
batch_samples
)
for
sample
in
batch_iter
():
sys
.
stdout
.
write
(
self
.
_gen_str
(
sample
))
def
_gen_str
(
self
,
line
):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the datafeed,and
updating proto_info information.
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the datafeed.
'''
raise
NotImplementedError
(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator"
)
def
generate_sample
(
self
,
line
):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple.
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", [int_words])
return local_iter
'''
raise
NotImplementedError
(
"Please rewrite this function to return a list or tuple: "
+
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)"
)
def
generate_batch
(
self
,
samples
):
'''
This function needs to be overridden by the user to process the
generated samples from generate_sample(self, str) function
It is usually used as batch processing when a user wants to
do preprocessing on a batch of samples, e.g. padding according to
the max length of a sample in the batch
Args:
samples(list tuple): generated sample from generate_sample
Returns:
a python generator, the same format as return value of generate_sample
Example:
.. code-block:: python
import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator):
def generate_sample(self, line):
def local_iter():
int_words = [int(x) for x in line.split()]
yield ("words", int_words)
return local_iter
def generate_batch(self, samples):
def local_iter():
for s in samples:
yield ("words", s[1].extend([s[1][0]]))
mydata = MyData()
mydata.set_batch(128)
'''
def
local_iter
():
for
sample
in
samples
:
yield
sample
return
local_iter
# TODO: guru4elephant
# add more generalized DataGenerator that can adapt user-defined slot
# for example, [(name, float_list), (name, str_list), (name, int_list)]
class
MultiSlotStringDataGenerator
(
DataGenerator
):
def
_gen_str
(
self
,
line
):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info information.
The input line will be in this format:
>>> [(name, [str(feasign), ...]), ...]
>>> or ((name, [str(feasign), ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
For example, if the input is like this:
>>> [("words", ["1926", "08", "17"]), ("label", ["1"])]
>>> or (("words", ["1926", "08", "17"]), ("label", ["1"]))
the output will be:
>>> 3 1234 2345 3456 1 1
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if
not
isinstance
(
line
,
list
)
and
not
isinstance
(
line
,
tuple
):
raise
ValueError
(
"the output of process() must be in list or tuple type"
"Examples: [('words', ['1926', '08', '17']), ('label', ['1'])]"
)
output
=
""
for
index
,
item
in
enumerate
(
line
):
name
,
elements
=
item
if
output
:
output
+=
" "
out_str
=
[]
out_str
.
append
(
str
(
len
(
elements
)))
out_str
.
extend
(
elements
)
output
+=
" "
.
join
(
out_str
)
return
output
+
"
\n
"
class
MultiSlotDataGenerator
(
DataGenerator
):
def
_gen_str
(
self
,
line
):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info information.
The input line will be in this format:
>>> [(name, [feasign, ...]), ...]
>>> or ((name, [feasign, ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
the output will be:
>>> 3 1234 2345 3456 1 1
the proto_info will be:
>>> [("words", "uint64"), ("label", "uint64")]
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if
not
isinstance
(
line
,
list
)
and
not
isinstance
(
line
,
tuple
):
raise
ValueError
(
"the output of process() must be in list or tuple type"
"Example: [('words', [1926, 08, 17]), ('label', [1])]"
)
output
=
""
if
self
.
_proto_info
is
None
:
self
.
_proto_info
=
[]
for
item
in
line
:
name
,
elements
=
item
if
not
isinstance
(
name
,
str
):
raise
ValueError
(
"name%s must be in str type"
%
type
(
name
))
if
not
isinstance
(
elements
,
list
):
raise
ValueError
(
"elements%s must be in list type"
%
type
(
elements
))
if
not
elements
:
raise
ValueError
(
"the elements of each field can not be empty, you need padding it in process()."
)
self
.
_proto_info
.
append
((
name
,
"uint64"
))
if
output
:
output
+=
" "
output
+=
str
(
len
(
elements
))
for
elem
in
elements
:
if
isinstance
(
elem
,
float
):
self
.
_proto_info
[
-
1
]
=
(
name
,
"float"
)
elif
not
isinstance
(
elem
,
int
)
and
not
isinstance
(
elem
,
long
):
raise
ValueError
(
"the type of element%s must be in int or float"
%
type
(
elem
))
output
+=
" "
+
str
(
elem
)
else
:
if
len
(
line
)
!=
len
(
self
.
_proto_info
):
raise
ValueError
(
"the complete field set of two given line are inconsistent."
)
for
index
,
item
in
enumerate
(
line
):
name
,
elements
=
item
if
not
isinstance
(
name
,
str
):
raise
ValueError
(
"name%s must be in str type"
%
type
(
name
))
if
not
isinstance
(
elements
,
list
):
raise
ValueError
(
"elements%s must be in list type"
%
type
(
elements
))
if
not
elements
:
raise
ValueError
(
"the elements of each field can not be empty, you need padding it in process()."
)
if
name
!=
self
.
_proto_info
[
index
][
0
]:
raise
ValueError
(
"the field name of two given line are not match: require<%s>, get<%s>."
%
(
self
.
_proto_info
[
index
][
0
],
name
))
if
output
:
output
+=
" "
output
+=
str
(
len
(
elements
))
for
elem
in
elements
:
if
self
.
_proto_info
[
index
][
1
]
!=
"float"
:
if
isinstance
(
elem
,
float
):
self
.
_proto_info
[
index
]
=
(
name
,
"float"
)
elif
not
isinstance
(
elem
,
int
)
and
not
isinstance
(
elem
,
long
):
raise
ValueError
(
"the type of element%s must be in int or float"
%
type
(
elem
))
output
+=
" "
+
str
(
elem
)
return
output
+
"
\n
"
python/paddle/distributed/fleet/data_generator/test_data_generator.py
0 → 100644
浏览文件 @
78014059
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import
paddle
import
paddle.distributed.fleet
as
fleet
class
SyntheticData
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
10000
):
yield
(
"words"
,
[
1
,
2
,
3
,
4
]),
(
"label"
,
[
0
])
return
data_iter
class
SyntheticStringData
(
fleet
.
MultiSlotStringDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
10000
):
yield
[(
"words"
,
[
"1"
,
"2"
,
"3"
,
"4"
]),
(
"label"
,
[
"0"
])]
return
data_iter
sd
=
SyntheticData
()
sd
.
run_from_memory
()
sd2
=
SyntheticStringData
()
sd2
.
run_from_memory
()
python/paddle/distributed/fleet/dataset/dataset.py
浏览文件 @
78014059
...
...
@@ -119,7 +119,7 @@ class DatasetBase(object):
def
set_filelist
(
self
,
filelist
):
"""
Set file list in current worker.
Set file list in current worker.
The filelist is indicated by a list of file names (string).
Examples:
.. code-block:: python
...
...
@@ -129,7 +129,7 @@ class DatasetBase(object):
dataset.set_filelist(['a.txt', 'b.txt'])
Args:
filelist(list
): file list
filelist(list
[str]): list of file names of inputs.
"""
self
.
dataset
.
set_filelist
(
filelist
)
self
.
filelist
=
filelist
...
...
@@ -240,6 +240,8 @@ class DatasetBase(object):
class
InMemoryDataset
(
DatasetBase
):
"""
:api_attr: Static Graph
InMemoryDataset, it will load data into memory
and shuffle data before training.
...
...
@@ -265,6 +267,8 @@ class InMemoryDataset(DatasetBase):
def
_init_distributed_settings
(
self
,
**
kwargs
):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize distributed-related setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
...
...
@@ -323,6 +327,8 @@ class InMemoryDataset(DatasetBase):
def
update_settings
(
self
,
**
kwargs
):
"""
:api_attr: Static Graph
should be called in user's python scripts to update setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs,
...
...
@@ -400,6 +406,8 @@ class InMemoryDataset(DatasetBase):
def
init
(
self
,
**
kwargs
):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance
Args:
kwargs: Keyword arguments. Currently, we support following keys in **kwargs:
...
...
@@ -450,11 +458,16 @@ class InMemoryDataset(DatasetBase):
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
exe.train_from_dataset(fluid.default_main_program(),
dataset)
paddle.enable_static()
place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda() else paddle.CPUPlace()
exe = paddle.static.Executor(place)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
"""
...
...
@@ -639,6 +652,8 @@ class InMemoryDataset(DatasetBase):
def
load_into_memory
(
self
):
"""
:api_attr: Static Graph
Load data into memory
Examples:
...
...
@@ -655,6 +670,8 @@ class InMemoryDataset(DatasetBase):
def
preload_into_memory
(
self
,
thread_num
=
None
):
"""
:api_attr: Static Graph
Load data into memory in async mode
Args:
...
...
@@ -679,6 +696,8 @@ class InMemoryDataset(DatasetBase):
def
wait_preload_done
(
self
):
"""
:api_attr: Static Graph
Wait preload_into_memory done
Examples:
...
...
@@ -696,6 +715,8 @@ class InMemoryDataset(DatasetBase):
def
local_shuffle
(
self
):
"""
:api_attr: Static Graph
Local shuffle
Examples:
...
...
@@ -712,6 +733,8 @@ class InMemoryDataset(DatasetBase):
def
global_shuffle
(
self
,
fleet
=
None
,
thread_num
=
12
):
"""
:api_attr: Static Graph
Global shuffle.
Global shuffle can be used only in distributed mode. i.e. multiple
processes on single machine or multiple machines training together.
...
...
@@ -771,9 +794,11 @@ class InMemoryDataset(DatasetBase):
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.global_shuffle(fleet)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.train_from_dataset(fluid.default_main_program(), dataset)
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
exe.run(startup_program)
exe.train_from_dataset(main_program, dataset)
dataset.release_memory()
"""
...
...
@@ -781,6 +806,8 @@ class InMemoryDataset(DatasetBase):
def
get_memory_data_size
(
self
,
fleet
=
None
):
"""
:api_attr: Static Graph
Get memory data size, user can call this function to know the num
of ins in all workers after load into memory.
...
...
@@ -817,6 +844,8 @@ class InMemoryDataset(DatasetBase):
def
get_shuffle_data_size
(
self
,
fleet
=
None
):
"""
:api_attr: Static Graph
Get shuffle data size, user can call this function to know the num
of ins in all workers after local/global shuffle.
...
...
@@ -901,6 +930,8 @@ class InMemoryDataset(DatasetBase):
class
QueueDataset
(
DatasetBase
):
"""
:api_attr: Static Graph
QueueDataset, it will process data streamly.
Examples:
...
...
@@ -920,6 +951,8 @@ class QueueDataset(DatasetBase):
def
init
(
self
,
**
kwargs
):
"""
:api_attr: Static Graph
should be called only once in user's python scripts to initialize setings of dataset instance
"""
super
(
QueueDataset
,
self
).
init
(
**
kwargs
)
...
...
python/paddle/fluid/dataset.py
浏览文件 @
78014059
...
...
@@ -16,6 +16,7 @@
from
paddle.fluid.proto
import
data_feed_pb2
from
google.protobuf
import
text_format
from
.
import
core
from
..utils
import
deprecated
__all__
=
[
'DatasetFactory'
,
'InMemoryDataset'
,
'QueueDataset'
]
...
...
@@ -335,6 +336,7 @@ class InMemoryDataset(DatasetBase):
dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset")
"""
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset"
)
def
__init__
(
self
):
""" Init. """
super
(
InMemoryDataset
,
self
).
__init__
()
...
...
@@ -350,12 +352,18 @@ class InMemoryDataset(DatasetBase):
self
.
merge_by_lineid
=
False
self
.
fleet_send_sleep_seconds
=
None
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_feed_type"
)
def
set_feed_type
(
self
,
data_feed_type
):
"""
Set data_feed_desc
"""
self
.
proto_desc
.
name
=
data_feed_type
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._prepare_to_run"
)
def
_prepare_to_run
(
self
):
"""
Set data_feed_desc before load or shuffle,
...
...
@@ -376,16 +384,27 @@ class InMemoryDataset(DatasetBase):
self
.
dataset
.
create_channel
()
self
.
dataset
.
create_readers
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._dynamic_adjust_before_train"
)
def
_dynamic_adjust_before_train
(
self
,
thread_num
):
if
not
self
.
is_user_set_queue_num
:
self
.
dataset
.
dynamic_adjust_channel_num
(
thread_num
,
False
)
self
.
dataset
.
dynamic_adjust_readers_num
(
thread_num
)
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._dynamic_adjust_after_train"
)
def
_dynamic_adjust_after_train
(
self
):
if
not
self
.
is_user_set_queue_num
:
self
.
dataset
.
dynamic_adjust_channel_num
(
self
.
thread_num
,
False
)
self
.
dataset
.
dynamic_adjust_readers_num
(
self
.
thread_num
)
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_queue_num"
)
def
set_queue_num
(
self
,
queue_num
):
"""
Set Dataset output queue num, training threads get data from queues
...
...
@@ -404,6 +423,9 @@ class InMemoryDataset(DatasetBase):
self
.
is_user_set_queue_num
=
True
self
.
queue_num
=
queue_num
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_parse_ins_id"
)
def
set_parse_ins_id
(
self
,
parse_ins_id
):
"""
Set id Dataset need to parse insid
...
...
@@ -421,6 +443,9 @@ class InMemoryDataset(DatasetBase):
"""
self
.
parse_ins_id
=
parse_ins_id
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_parse_content"
)
def
set_parse_content
(
self
,
parse_content
):
"""
Set if Dataset need to parse content
...
...
@@ -455,6 +480,9 @@ class InMemoryDataset(DatasetBase):
"""
self
.
parse_logkey
=
parse_logkey
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_merge_by_sid"
)
def
set_merge_by_sid
(
self
,
merge_by_sid
):
"""
Set if Dataset need to merge sid. If not, one ins means one Pv.
...
...
@@ -544,6 +572,10 @@ class InMemoryDataset(DatasetBase):
"""
self
.
dataset
.
postprocess_instance
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_fleet_send_batch_size"
)
def
set_fleet_send_batch_size
(
self
,
fleet_send_batch_size
=
1024
):
"""
Set fleet send batch size, default is 1024
...
...
@@ -561,6 +593,10 @@ class InMemoryDataset(DatasetBase):
"""
self
.
fleet_send_batch_size
=
fleet_send_batch_size
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_fleet_send_sleep_seconds"
)
def
set_fleet_send_sleep_seconds
(
self
,
fleet_send_sleep_seconds
=
0
):
"""
Set fleet send sleep time, default is 0
...
...
@@ -578,6 +614,9 @@ class InMemoryDataset(DatasetBase):
"""
self
.
fleet_send_sleep_seconds
=
fleet_send_sleep_seconds
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_merge_by_lineid"
)
def
set_merge_by_lineid
(
self
,
merge_size
=
2
):
"""
Set merge by line id, instances of same line id will be merged after
...
...
@@ -598,16 +637,27 @@ class InMemoryDataset(DatasetBase):
self
.
merge_by_lineid
=
True
self
.
parse_ins_id
=
True
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._set_generate_unique_feasigns"
)
def
set_generate_unique_feasigns
(
self
,
generate_uni_feasigns
,
shard_num
):
self
.
dataset
.
set_generate_unique_feasigns
(
generate_uni_feasigns
)
self
.
gen_uni_feasigns
=
generate_uni_feasigns
self
.
local_shard_num
=
shard_num
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset._generate_local_tables_unlock"
)
def
generate_local_tables_unlock
(
self
,
table_id
,
fea_dim
,
read_thread_num
,
consume_thread_num
,
shard_num
):
self
.
dataset
.
generate_local_tables_unlock
(
table_id
,
fea_dim
,
read_thread_num
,
consume_thread_num
,
shard_num
)
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.load_into_memory"
)
def
load_into_memory
(
self
):
"""
Load data into memory
...
...
@@ -624,6 +674,9 @@ class InMemoryDataset(DatasetBase):
self
.
_prepare_to_run
()
self
.
dataset
.
load_into_memory
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.preload_into_memory"
)
def
preload_into_memory
(
self
,
thread_num
=
None
):
"""
Load data into memory in async mode
...
...
@@ -648,6 +701,9 @@ class InMemoryDataset(DatasetBase):
self
.
dataset
.
create_preload_readers
()
self
.
dataset
.
preload_into_memory
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.wait_preload_done"
)
def
wait_preload_done
(
self
):
"""
Wait preload_into_memory done
...
...
@@ -665,6 +721,9 @@ class InMemoryDataset(DatasetBase):
self
.
dataset
.
wait_preload_done
()
self
.
dataset
.
destroy_preload_readers
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.local_shuffle"
)
def
local_shuffle
(
self
):
"""
Local shuffle
...
...
@@ -681,6 +740,9 @@ class InMemoryDataset(DatasetBase):
"""
self
.
dataset
.
local_shuffle
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.global_shuffle"
)
def
global_shuffle
(
self
,
fleet
=
None
,
thread_num
=
12
):
"""
Global shuffle.
...
...
@@ -726,6 +788,9 @@ class InMemoryDataset(DatasetBase):
if
fleet
is
not
None
:
fleet
.
_role_maker
.
barrier_worker
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.release_memory"
)
def
release_memory
(
self
):
"""
:api_attr: Static Graph
...
...
@@ -774,6 +839,9 @@ class InMemoryDataset(DatasetBase):
"""
return
self
.
dataset
.
get_pv_data_size
()
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.get_memory_data_size"
)
def
get_memory_data_size
(
self
,
fleet
=
None
):
"""
Get memory data size, user can call this function to know the num
...
...
@@ -810,6 +878,9 @@ class InMemoryDataset(DatasetBase):
return
global_data_size
[
0
]
return
local_data_size
[
0
]
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.InMemoryDataset.get_shuffle_data_size"
)
def
get_shuffle_data_size
(
self
,
fleet
=
None
):
"""
Get shuffle data size, user can call this function to know the num
...
...
@@ -869,6 +940,9 @@ class QueueDataset(DatasetBase):
super
(
QueueDataset
,
self
).
__init__
()
self
.
proto_desc
.
name
=
"MultiSlotDataFeed"
@
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.distributed.QueueDataset._prepare_to_run"
)
def
_prepare_to_run
(
self
):
"""
Set data_feed_desc/thread num/filelist before run,
...
...
python/paddle/fluid/incubate/fleet/tests/ctr_dataset_reader.py
浏览文件 @
78014059
...
...
@@ -19,7 +19,7 @@ import tarfile
import
os
import
paddle
import
paddle.
fluid.incubate.data_generator
as
data_generator
import
paddle.
distributed.fleet
as
fleet
from
paddle.fluid.log_helper
import
get_logger
logger
=
get_logger
(
...
...
@@ -59,7 +59,7 @@ def load_lr_input_record(sent):
return
res
class
DatasetCtrReader
(
data_generator
.
MultiSlotDataGenerator
):
class
DatasetCtrReader
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
iter
():
fs
=
line
.
strip
().
split
(
'
\t
'
)
...
...
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
浏览文件 @
78014059
...
...
@@ -22,7 +22,7 @@ import random
import
warnings
import
paddle
import
paddle.
fluid.incubate.data_generator
as
data_generator
import
paddle.
distributed.fleet
as
fleet
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
"paddle"
)
...
...
@@ -84,7 +84,7 @@ class CtrReader(object):
return
reader
class
DatasetCtrReader
(
data_generator
.
MultiSlotDataGenerator
):
class
DatasetCtrReader
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
get_rand
(
low
=
0.0
,
high
=
1.0
):
return
random
.
random
()
...
...
python/paddle/fluid/tests/unittests/my_data_generator.py
0 → 100644
浏览文件 @
78014059
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
paddle
import
re
import
collections
import
time
import
paddle.distributed.fleet
as
fleet
class
MyDataset
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
elements
=
line
.
strip
().
split
()[
0
:]
output
=
[(
"show"
,
[
int
(
elements
[
0
])]),
(
"click"
,
[
int
(
elements
[
1
])]),
(
"slot1"
,
[
int
(
elements
[
2
])])]
yield
output
return
data_iter
if
__name__
==
"__main__"
:
d
=
MyDataset
()
d
.
run_from_stdin
()
python/paddle/fluid/tests/unittests/simnet_dataset_reader.py
浏览文件 @
78014059
...
...
@@ -21,13 +21,13 @@ import tarfile
import
random
import
paddle
import
paddle.
fluid.incubate.data_generator
as
data_generator
import
paddle.
distributed.fleet
as
fleet
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
"paddle"
)
logger
.
setLevel
(
logging
.
INFO
)
class
DatasetSimnetReader
(
data_generator
.
MultiSlotDataGenerator
):
class
DatasetSimnetReader
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
pass
python/paddle/fluid/tests/unittests/test_data_generator.py
0 → 100644
浏览文件 @
78014059
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import
paddle
import
unittest
import
paddle.distributed.fleet
as
fleet
import
os
import
sys
import
platform
class
MyMultiSlotDataGenerator
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
[
1
,
2
,
3
,
4
]),
(
"label"
,
[
0
])
return
data_iter
class
MyMultiSlotStringDataGenerator
(
fleet
.
MultiSlotStringDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
[
"1"
,
"2"
,
"3"
,
"4"
]),
(
"label"
,
[
"0"
])
return
data_iter
class
MyMultiSlotDataGenerator_error
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
"words"
return
data_iter
class
MyMultiSlotDataGenerator_error_2
(
fleet
.
MultiSlotStringDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
"words"
return
data_iter
class
MyMultiSlotDataGenerator_error_3
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
1
,
[
"1"
,
"2"
,
"3"
,
"4"
]),
(
2
,
[
"0"
])
return
data_iter
class
MyMultiSlotDataGenerator_error_4
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
"1"
),
(
"label"
,
"0"
)
return
data_iter
class
MyMultiSlotDataGenerator_error_5
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
[]),
(
"label"
,
[])
return
data_iter
class
TestMultiSlotDataGenerator
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_basic
(
self
):
my_ms_dg
=
MyMultiSlotDataGenerator
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotStringDataGenerator
(
unittest
.
TestCase
):
def
test_MyMultiSlotStringDataGenerator_basic
(
self
):
my_ms_dg
=
MyMultiSlotStringDataGenerator
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotStringDataGenerator_2
(
unittest
.
TestCase
):
def
test_MyMultiSlotStringDataGenerator_stdin
(
self
):
plats
=
platform
.
platform
()
if
'Linux'
not
in
plats
:
print
(
"skip pipecommand UT on MacOS/Win"
)
return
with
open
(
"test_queue_dataset_run_a.txt"
,
"w"
)
as
f
:
data
=
"2 1 2
\n
"
data
+=
"2 6 2
\n
"
data
+=
"2 5 2
\n
"
data
+=
"2 7 2
\n
"
f
.
write
(
data
)
tmp
=
os
.
popen
(
"cat test_queue_dataset_run_a.txt | python my_data_generator.py"
).
readlines
()
expected_res
=
[
'1 2 1 1 1 2
\n
'
,
'1 2 1 6 1 2
\n
'
,
'1 2 1 5 1 2
\n
'
,
'1 2 1 7 1 2
\n
'
]
self
.
assertEqual
(
tmp
,
expected_res
)
os
.
remove
(
"./test_queue_dataset_run_a.txt"
)
class
TestMultiSlotDataGenerator_error
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_2
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_2
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_3
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_3
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_4
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_4
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_5
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_5
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_dataset.py
浏览文件 @
78014059
...
...
@@ -105,11 +105,15 @@ class TestDataset(unittest.TestCase):
dataset
.
load_into_memory
()
dataset
.
local_shuffle
()
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
())
paddle
.
enable_static
()
exe
=
paddle
.
static
.
Executor
(
paddle
.
CPUPlace
())
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
exe
.
run
(
startup_program
)
for
i
in
range
(
2
):
try
:
exe
.
train_from_dataset
(
fluid
.
default_main_program
()
,
dataset
)
exe
.
train_from_dataset
(
main_program
,
dataset
)
except
ImportError
as
e
:
pass
except
Exception
as
e
:
...
...
@@ -181,20 +185,24 @@ class TestDataset(unittest.TestCase):
use_var
=
slots_vars
)
dataset
.
set_filelist
([
filename1
,
filename2
])
dataset
.
load_into_memory
()
paddle
.
enable_static
()
exe
=
paddle
.
static
.
Executor
(
paddle
.
CPUPlace
())
startup_program
=
paddle
.
static
.
Program
()
main_program
=
paddle
.
static
.
Program
()
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
()
)
exe
.
run
(
startup_program
)
if
self
.
use_data_loader
:
data_loader
=
fluid
.
io
.
DataLoader
.
from_dataset
(
dataset
,
fluid
.
cpu_places
(),
self
.
drop_last
)
for
i
in
range
(
self
.
epoch_num
):
for
data
in
data_loader
():
exe
.
run
(
fluid
.
default_main_program
()
,
feed
=
data
)
exe
.
run
(
main_program
,
feed
=
data
)
else
:
for
i
in
range
(
self
.
epoch_num
):
try
:
exe
.
train_from_dataset
(
fluid
.
default_main_program
(),
dataset
)
exe
.
train_from_dataset
(
main_program
,
dataset
)
except
Exception
as
e
:
self
.
assertTrue
(
False
)
...
...
python/setup.py.in
浏览文件 @
78014059
...
...
@@ -150,6 +150,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
'paddle.distributed.fleet.metrics',
'paddle.distributed.fleet.proto',
'paddle.distributed.fleet.utils',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录