提交 73b1f396 编写于 作者: D dongdaxiang

add data_generator into paddle.fluid.incubate.data_generator, add op run log...

add data_generator into paddle.fluid.incubate.data_generator, add op run log in hogwild_device_worker and downpour_device_worker
test=develop
上级 73544e8b
...@@ -236,7 +236,9 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -236,7 +236,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
if (!need_skip) { if (!need_skip) {
timeline.Start(); timeline.Start();
VLOG(3) << "Going to run op " << op_name[run_op_idx];
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
VLOG(3) << "Op " << op_name[run_op_idx] << " Finished";
timeline.Pause(); timeline.Pause();
op_total_time[run_op_idx++] += timeline.ElapsedSec(); op_total_time[run_op_idx++] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
......
...@@ -97,7 +97,9 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -97,7 +97,9 @@ void HogwildWorker::TrainFilesWithProfiler() {
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
for (size_t i = 0; i < ops_.size(); ++i) { for (size_t i = 0; i < ops_.size(); ++i) {
timeline.Start(); timeline.Start();
VLOG(3) << "Going to run op " << op_name[i];
ops_[i]->Run(*thread_scope_, place_); ops_[i]->Run(*thread_scope_, place_);
VLOG(3) << "Op " << op_name[i] << " Finished";
timeline.Pause(); timeline.Pause();
op_total_time[i] += timeline.ElapsedSec(); op_total_time[i] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__all__ = ['MultiSlotDataGenerator']
class DataGenerator(object):
def __init__(self):
self._proto_info = None
self.batch_size_ = 32
def _set_line_limit(self, line_limit):
if not isinstance(line_limit, int):
raise ValueError("line_limit%s must be in int type" %
type(line_limit))
if line_limit < 1:
raise ValueError("line_limit can not less than 1")
self._line_limit = line_limit
def set_batch(self, batch_size):
self.batch_size_ = batch_size
def run_from_memory(self):
'''
This function generator data from memory, it is usually used for
debug and benchmarking
'''
batch_samples = []
line_iter = self.generate_sample(None)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def run_from_stdin(self):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated.
'''
batch_samples = []
for line in sys.stdin:
line_iter = self.generate_sample(line)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_samples.append(user_parsed_line)
if len(batch_samples) == self.batch_size_:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
batch_samples = []
if len(batch_samples) > 0:
batch_iter = self.generate_batch(batch_samples)
for sample in batch_iter():
sys.stdout.write(self._gen_str(sample))
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the datafeed,and
updating proto_info infomation.
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the datafeed.
'''
raise NotImplementedError(
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
def generate_sample(self, line):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple.
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
'''
raise NotImplementedError(
"Please rewrite this function to return a list or tuple: " +
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)")
def generate_batch(self, samples):
def local_iter():
for sample in samples:
yield sample
return local_iter
class MultiSlotDataGenerator(DataGenerator):
def _gen_str(self, line):
'''
Further processing the output of the process() function rewritten by
user, outputting data that can be directly read by the MultiSlotDataFeed,
and updating proto_info infomation.
The input line will be in this format:
>>> [(name, [feasign, ...]), ...]
>>> or ((name, [feasign, ...]), ...)
The output will be in this format:
>>> [ids_num id1 id2 ...] ...
The proto_info will be in this format:
>>> [(name, type), ...]
For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1]))
the output will be:
>>> 3 1234 2345 3456 1 1
the proto_info will be:
>>> [("words", "uint64"), ("label", "uint64")]
Args:
line(str): the output of the process() function rewritten by user.
Returns:
Return a string data that can be read directly by the MultiSlotDataFeed.
'''
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type")
output = ""
if self._proto_info is None:
self._proto_info = []
for item in line:
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
self._proto_info.append((name, "uint64"))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if isinstance(elem, float):
self._proto_info[-1] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float" %
type(elem))
output += " " + str(elem)
else:
if len(line) != len(self._proto_info):
raise ValueError(
"the complete field set of two given line are inconsistent.")
for index, item in enumerate(line):
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%d>."
% (self._proto_info[index][0], name))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if self._proto_info[index][1] != "float":
if isinstance(elem, float):
self._proto_info[index] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float"
% type(elem))
output += " " + str(elem)
return output + "\n"
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from __init__ import *
class SyntheticData(MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(10000):
yield ("words", [1, 2, 3, 4]), ("label", [0])
return data_iter
sd = SyntheticData()
sd.run_from_memory()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册