提交 c28bbdf8 编写于 作者: D dongdaxiang

add dataset_generator.py

dataset_generator.py is a framework for generating data with python
the generated data with a fixed format will be feeded into c++ reader
test=develop
上级 be757096
......@@ -60,6 +60,7 @@ class DataFeed {
// Otherwise, Init() function will init finish_set_filelist_ flag.
virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0;
// The trainer calls the Next() function, and the DataFeed will load a new
// batch to the feed_vec. The return value of this function is the batch
// size of the current batch.
......
......@@ -28,4 +28,5 @@ message DataFeedDesc {
optional int32 batch_size = 2 [ default = 32 ];
optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4;
optional int32 thread_num = 5;
}
......@@ -284,6 +284,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
}
fprintf(stderr, "IO percent: %f\n", read_time / total_time);
}
}
timeline.Start();
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__all__ = ['MultiSlotDataset']
class DatasetGenerator(object):
def __init__(self):
self._proto_info = None
self._hadoop_host = None
self._batch_size = 32
self._hadoop_ugi = None
self._hadoop_path = None
def _set_proto_filename(self, proto_filename):
if not isinstance(proto_filename, str):
raise ValueError("proto_filename%s must be in str type" %
type(proto_filename))
if not proto_filename:
raise ValueError("proto_filename can not be empty")
self._proto_filename = proto_filename
def generate_sample(self, line):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17])], ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
'''
raise NotImplementedError(
"please rewrite this function to return a list" +
"[(name, [int, int ...]), ...]")
def set_batch(self, batch):
self.batch = batch
def generate_batch(self, samples):
'''
This function can be overridden by the user to process batch
data, a user can define how to generate batch with this function
Args:
samples(list of results from generate_samples)
Returns:
Returns the processed batch by the user
[[(name, [int, ...]), ...],
[(name, [int, ...]), ...],
[(name, [int, ...])]]
Default:
Do nothing about current batch
'''
def batch_iter():
for sample in samples:
yield sample
return batch_iter
def _gen_str(self, line):
raise NotImplementedError(
"Please inherit this class and implement _gen_str")
def _upload_proto_file(self):
if self.proto_output_path == None:
raise ValueError("If you are running data generation on hadoop, "
"please set proto output path first")
if self._hadoop_host == None or self._hadoop_ugi == None or \
self._hadoop_path == None:
raise ValueError(
"If you are running data generation on hadoop, "
"please set hadoop_host, hadoop_path, hadoop_ugi first")
cmd = "$HADOOP_HOME/bin/hadoop fs" \
+ " -Dhadoop.job.ugi=" + self.hadoop_ugi \
+ " -Dfs.default.name=" + self.hadoop_host \
+ " -put " + self._proto_filename + " " + self._proto_output_path
os.system(cmd)
def set_hadoop_config(self,
hadoop_host=None,
hadoop_ugi=None,
proto_path=None):
'''
This function set hadoop configuration for map-reduce based data
generation.
Args:
hadoop_host(str): The host name of the hadoop. It should be
in this format: "hdfs://${HOST}:${PORT}".
hadoop_ugi(str): The ugi of the hadoop. It should be in this
format: "${USERNAME},${PASSWORD}".
proto_path(str): The hadoop path you want to upload the
protofile to.
'''
self.hadoop_host = hadoop_host
self.hadoop_ugi = hadoop_ugi
self.proto_output_path = proto_path
def run_from_memory(self, is_local=True, proto_filename='data_feed.proto'):
'''
This function generates data from memory, user needs to
define how to generate samples by define generate_sample
and generate_batch
'''
self._set_proto_filename(proto_filename)
batch_data = []
line_iter = self.generate_sample(None)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_data.append(user_parsed_line)
if len(batch_data) == self._batch_size:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
batch_data = []
if len(batch_data) > 0:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
if self.proto_info is not None:
with open(self._proto_filename, "w") as f:
f.write(self._get_proto_desc(self._proto_info))
if is_local == False:
self._upload_proto_file()
def run_from_stdin(self, is_local=True, proto_filename='data_feed.proto'):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated. If local is set to False, the protofile will be
uploaded to hadoop.
Args:
is_local(bool): Whether user wants to run this function from local
proto_filename(str): The name of protofile. The default value
is "data_feed.proto". It is not
recommended to modify it.
'''
self._set_proto_filename(proto_filename)
batch_data = []
for line in sys.stdin:
line_iter = self.generate_sample(line)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_data.append(user_parsed_line)
if len(batch_data) == self._batch_size:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
batch_data = []
if len(batch_data) > 0:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
if self._proto_info is not None:
with open(self._proto_filename, "w") as f:
f.write(self._get_proto_desc(self._proto_info))
if is_local == False:
self._upload_proto_file()
class MultiSlotDataset(DatasetGenerator):
def _get_proto_desc(self, proto_info):
proto_str = "name: \"MultiSlotDataFeed\"\n" \
+ "batch_size: 32\nmulti_slot_desc {\n"
for elem in proto_info:
proto_str += " slots {\n" \
+ " name: \"%s\"\n" % elem[0]\
+ " type: \"%s\"\n" % elem[1]\
+ " is_dense: false\n" \
+ " is_used: false\n" \
+ " }\n"
proto_str += "}"
return proto_str
def generate_batch(self, samples):
super(MultiSlotDataset, self).generate_batch(samples)
def batch_iter():
for sample in samples:
yield sample
return batch_iter
def _gen_str(self, line):
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type")
output = ""
if self._proto_info is None:
self._proto_info = []
for item in line:
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
self._proto_info.append((name, "uint64"))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if isinstance(elem, float):
self._proto_info[-1] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float" %
type(elem))
output += " " + str(elem)
else:
if len(line) != len(self._proto_info):
raise ValueError(
"the complete field set of two given line are inconsistent.")
for index, item in enumerate(line):
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%d>."
% (self._proto_info[index][0], name))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if self._proto_info[index][1] != "float":
if isinstance(elem, float):
self._proto_info[index] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float"
% type(elem))
output += " " + str(elem)
return output + "\n"
......@@ -139,6 +139,10 @@ class DataFeedDesc(object):
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[
name]].is_used = True
def global_shuffle(self):
self.data.global_shuffle()
pass
def desc(self):
"""
Returns a protobuf message for this DataFeedDesc
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
__all__ = ['DatasetFactory']
class DatasetFactory(object):
def __init__(self):
pass
def create_dataset(self, datafeed_class):
datafeed_class = datafeed_class.capitalize()
try:
dataset = globals()[datafeed_class]()
except:
raise ValueError("datafeed class %s does not exist" %
datafeed_class)
class DatasetBase(object):
def __init__(self):
# define class name here
# to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
def set_pipe_command(self, pipe_command):
"""
Set pipe command of current dataset
A pipe command is a UNIX pipeline command that can be used only
"""
self.proto_desc.pipe_command = pipe_command
def set_batch_size(self, batch_size):
"""
Set batch size. Will be effective during training
Example:
>>> data_feed = fluid.DataFeedDesc('data.proto')
>>> data_feed.set_batch_size(128)
Args:
batch_size: batch size
"""
self.proto_desc.batch_size = batch_size
def set_use_var(self, var_list):
multi_slot = self.proto_desc.multi_slot_desc()
for var in var_list:
slot_var = multi_slot.add()
slot_var.is_used = True
slot_var.name = var.name
if var.lod_level == 0:
slot_var.is_dense = True
if var.dtype == core.VarType.FP32:
slot_var.type = "float32"
elif var.dtype == core.VarType.INT64:
slot_var.type = "uint64"
else:
raise ValueError(
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
def desc(self):
"""
Returns a protobuf message for this DataFeedDesc
Example:
>>> data_feed = fluid.DataFeedDesc('data.proto')
>>> print(data_feed.desc())
Returns:
A string message
"""
return text_format.MessageToString(self.proto_desc)
class InMemoryDataset(DatasetBase):
def __init__(self):
super(InMemoryDataset.__init__())
self.proto_desc.name = "InMemoryDataFeed"
def local_shuffle(self):
pass
def global_shuffle(self):
pass
class QueueDataset(DatasetBase):
def __init__(self):
super(QueueDataset.__init__())
self.proto_desc.name = "MultiSlotDataFeed"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册