提交 b95b80bc 编写于 作者: D dongdaxiang

add doc string for executor and update API.spec

test=develop
上级 d52586a9
......@@ -16,6 +16,8 @@ paddle.fluid.cuda_pinned_places (ArgSpec(args=['device_count'], varargs=None, ke
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f5369953dd0c443961cf79f7a00e1a03'))
paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'f482e93b38b4018796969a2e1dde479d'))
paddle.fluid.Executor.infer_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'fetch_list', 'scope', 'thread', 'opt_info'], varargs=None, keywords=None, defaults=(None, None, None, None, 0, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.train_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.global_scope (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'e148d3ab1ed8edf3e928212a375959c0'))
paddle.fluid.scope_guard (ArgSpec(args=['scope'], varargs=None, keywords=None, defaults=None), ('document', 'b94d1f6bcc29c4fb58fc0058561250c2'))
paddle.fluid.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......@@ -43,7 +45,7 @@ paddle.fluid.AsyncExecutor.get_instance (ArgSpec(args=['self'], varargs=None, ke
paddle.fluid.AsyncExecutor.init_model (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '504f39be2007404a17e5cabea1256c7d'))
paddle.fluid.AsyncExecutor.init_server (ArgSpec(args=['self', 'dist_desc'], varargs=None, keywords=None, defaults=None), ('document', 'c403ab46c5d3ef25c0f7e94ae75dcb68'))
paddle.fluid.AsyncExecutor.init_worker (ArgSpec(args=['self', 'dist_desc', 'startup_program'], varargs=None, keywords=None, defaults=None), ('document', 'dcf08f4bf2f3282acf11391f5d39c536'))
paddle.fluid.AsyncExecutor.run (ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'mode', 'debug'], varargs=None, keywords=None, defaults=('', False)), ('document', '848fc53484e8326f6325feea87fe955c'))
paddle.fluid.AsyncExecutor.run (ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'debug'], varargs=None, keywords=None, defaults=(False,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.AsyncExecutor.save_model (ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None), ('document', 'c8ac0dfcb3b187aba25d03af7fea56b2'))
paddle.fluid.AsyncExecutor.stop (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '5f23d043607bb5d55e466ec3f578e093'))
paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......@@ -495,7 +497,7 @@ paddle.fluid.LoDTensor.__init__ 1. __init__(self: paddle.fluid.core.LoDTensor, a
paddle.fluid.LoDTensor.has_valid_recursive_sequence_lengths has_valid_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> bool
paddle.fluid.LoDTensor.lod lod(self: paddle.fluid.core.LoDTensor) -> List[List[int]]
paddle.fluid.LoDTensor.recursive_sequence_lengths recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor) -> List[List[int]]
paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CPUPlace) -> None 9. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPlace) -> None 10. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPlace) -> None 11. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPlace) -> None 12. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPlace) -> None 13. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPlace) -> None 14. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPlace) -> None 15. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPlace) -> None 16. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPlace) -> None 17. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CUDAPinnedPlace) -> None 18. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CUDAPinnedPlace) -> None 19. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CUDAPinnedPlace) -> None 20. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CUDAPinnedPlace) -> None 21. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CUDAPinnedPlace) -> None 22. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CUDAPinnedPlace) -> None 23. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CUDAPinnedPlace) -> None 24. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CUDAPinnedPlace) -> None
paddle.fluid.LoDTensor.set 1. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float32], arg1: paddle::platform::CPUPlace) -> None 2. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int32], arg1: paddle::platform::CPUPlace) -> None 3. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[float64], arg1: paddle::platform::CPUPlace) -> None 4. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int64], arg1: paddle::platform::CPUPlace) -> None 5. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[bool], arg1: paddle::platform::CPUPlace) -> None 6. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint16], arg1: paddle::platform::CPUPlace) -> None 7. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[uint8], arg1: paddle::platform::CPUPlace) -> None 8. set(self: paddle.fluid.core.Tensor, arg0: numpy.ndarray[int8], arg1: paddle::platform::CPUPlace) -> None
paddle.fluid.LoDTensor.set_lod set_lod(self: paddle.fluid.core.LoDTensor, lod: List[List[int]]) -> None
paddle.fluid.LoDTensor.set_recursive_sequence_lengths set_recursive_sequence_lengths(self: paddle.fluid.core.LoDTensor, recursive_sequence_lengths: List[List[int]]) -> None
paddle.fluid.LoDTensor.shape shape(self: paddle.fluid.core.Tensor) -> List[int]
......
......@@ -164,6 +164,8 @@ class DownpourWorker : public HogwildWorker {
void CollectLabelInfo(size_t table_id);
private:
bool need_to_push_dense_;
bool need_to_push_sparse_;
DownpourWorkerParameter param_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;
......
......@@ -58,6 +58,9 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_[i] = param_.skip_ops(i);
}
need_to_push_sparse_ = param_.push_sparse();
need_to_push_dense_ = param_.push_dense();
fleet_ptr_ = FleetWrapper::GetInstance();
fetch_config_ = desc.fetch_config();
}
......@@ -239,8 +242,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
for (size_t i = 0; i < param_.program_config(0).push_sparse_table_id_size();
++i) {
if (need_to_push_sparse_) {
for (size_t i = 0;
i < param_.program_config(0).push_sparse_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_sparse_table_id(i));
TableParameter table;
......@@ -259,10 +263,12 @@ void DownpourWorker::TrainFilesWithProfiler() {
push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
}
}
if (need_to_push_dense_) {
timeline.Start();
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
fleet_ptr_->PushDenseVarsAsync(
......@@ -288,7 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
if (tmp_push_dense_wait_times == -1) {
push_dense_status_.resize(0);
}
}
if (need_to_push_sparse_) {
if (push_sparse_status_.size() >= push_sparse_wait_times) {
for (auto& t : push_sparse_status_) {
t.wait();
......@@ -299,17 +307,17 @@ void DownpourWorker::TrainFilesWithProfiler() {
if (tmp_push_sparse_wait_times == -1) {
push_sparse_status_.resize(0);
}
VLOG(3) << "going to increase thread version";
VLOG(3) << "going to increase thread version";
VLOG(3) << "push dense table id size: "
<< param_.program_config(0).push_dense_table_id_size();
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
++i) {
for (size_t i = 0;
i < param_.program_config(0).push_dense_table_id_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(
param_.program_config(0).push_dense_table_id(i));
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
PrintFetchVars();
thread_scope_->DropKids();
......
......@@ -46,6 +46,8 @@ message DownpourWorkerParameter {
repeated TableParameter dense_table = 2;
repeated string skip_ops = 3;
repeated ProgramConfig program_config = 4;
bool push_sparse = 5 [ default = true ];
bool push_dense = 6 [ default = true ];
}
message FetchConfig {
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__all__ = ['MultiSlotDataset']
class DatasetGenerator(object):
def __init__(self):
self._proto_info = None
self._hadoop_host = None
self._batch_size = 32
self._hadoop_ugi = None
self._hadoop_path = None
def _set_proto_filename(self, proto_filename):
if not isinstance(proto_filename, str):
raise ValueError("proto_filename%s must be in str type" %
type(proto_filename))
if not proto_filename:
raise ValueError("proto_filename can not be empty")
self._proto_filename = proto_filename
def generate_sample(self, line):
'''
This function needs to be overridden by the user to process the
original data row into a list or tuple
Args:
line(str): the original data row
Returns:
Returns the data processed by the user.
The data format is list or tuple:
[(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...)
For example:
[("words", [1926, 08, 17])], ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1]))
Note:
The type of feasigns must be in int or float. Once the float
element appears in the feasign, the type of that slot will be
processed into a float.
'''
raise NotImplementedError(
"please rewrite this function to return a list" +
"[(name, [int, int ...]), ...]")
def set_batch(self, batch):
self.batch = batch
def generate_batch(self, samples):
'''
This function can be overridden by the user to process batch
data, a user can define how to generate batch with this function
Args:
samples(list of results from generate_samples)
Returns:
Returns the processed batch by the user
[[(name, [int, ...]), ...],
[(name, [int, ...]), ...],
[(name, [int, ...])]]
Default:
Do nothing about current batch
'''
def batch_iter():
for sample in samples:
yield sample
return batch_iter
def _gen_str(self, line):
raise NotImplementedError(
"Please inherit this class and implement _gen_str")
def _upload_proto_file(self):
if self.proto_output_path == None:
raise ValueError("If you are running data generation on hadoop, "
"please set proto output path first")
if self._hadoop_host == None or self._hadoop_ugi == None or \
self._hadoop_path == None:
raise ValueError(
"If you are running data generation on hadoop, "
"please set hadoop_host, hadoop_path, hadoop_ugi first")
cmd = "$HADOOP_HOME/bin/hadoop fs" \
+ " -Dhadoop.job.ugi=" + self.hadoop_ugi \
+ " -Dfs.default.name=" + self.hadoop_host \
+ " -put " + self._proto_filename + " " + self._proto_output_path
os.system(cmd)
def set_hadoop_config(self,
hadoop_host=None,
hadoop_ugi=None,
proto_path=None):
'''
This function set hadoop configuration for map-reduce based data
generation.
Args:
hadoop_host(str): The host name of the hadoop. It should be
in this format: "hdfs://${HOST}:${PORT}".
hadoop_ugi(str): The ugi of the hadoop. It should be in this
format: "${USERNAME},${PASSWORD}".
proto_path(str): The hadoop path you want to upload the
protofile to.
'''
self.hadoop_host = hadoop_host
self.hadoop_ugi = hadoop_ugi
self.proto_output_path = proto_path
def run_from_memory(self, is_local=True, proto_filename='data_feed.proto'):
'''
This function generates data from memory, user needs to
define how to generate samples by define generate_sample
and generate_batch
'''
self._set_proto_filename(proto_filename)
batch_data = []
line_iter = self.generate_sample(None)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_data.append(user_parsed_line)
if len(batch_data) == self._batch_size:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
batch_data = []
if len(batch_data) > 0:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
if self.proto_info is not None:
with open(self._proto_filename, "w") as f:
f.write(self._get_proto_desc(self._proto_info))
if is_local == False:
self._upload_proto_file()
def run_from_stdin(self, is_local=True, proto_filename='data_feed.proto'):
'''
This function reads the data row from stdin, parses it with the
process function, and further parses the return value of the
process function with the _gen_str function. The parsed data will
be wrote to stdout and the corresponding protofile will be
generated. If local is set to False, the protofile will be
uploaded to hadoop.
Args:
is_local(bool): Whether user wants to run this function from local
proto_filename(str): The name of protofile. The default value
is "data_feed.proto". It is not
recommended to modify it.
'''
self._set_proto_filename(proto_filename)
batch_data = []
for line in sys.stdin:
line_iter = self.generate_sample(line)
for user_parsed_line in line_iter():
if user_parsed_line == None:
continue
batch_data.append(user_parsed_line)
if len(batch_data) == self._batch_size:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
batch_data = []
if len(batch_data) > 0:
batched_iter = self.generate_batch(batch_data)
for batched_line in batched_iter():
sys.stdout.write(self._gen_str(batched_line))
if self._proto_info is not None:
with open(self._proto_filename, "w") as f:
f.write(self._get_proto_desc(self._proto_info))
if is_local == False:
self._upload_proto_file()
class MultiSlotDataset(DatasetGenerator):
def _get_proto_desc(self, proto_info):
proto_str = "name: \"MultiSlotDataFeed\"\n" \
+ "batch_size: 32\nmulti_slot_desc {\n"
for elem in proto_info:
proto_str += " slots {\n" \
+ " name: \"%s\"\n" % elem[0]\
+ " type: \"%s\"\n" % elem[1]\
+ " is_dense: false\n" \
+ " is_used: false\n" \
+ " }\n"
proto_str += "}"
return proto_str
def generate_batch(self, samples):
super(MultiSlotDataset, self).generate_batch(samples)
def batch_iter():
for sample in samples:
yield sample
return batch_iter
def _gen_str(self, line):
if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError(
"the output of process() must be in list or tuple type")
output = ""
if self._proto_info is None:
self._proto_info = []
for item in line:
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
self._proto_info.append((name, "uint64"))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if isinstance(elem, float):
self._proto_info[-1] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float" %
type(elem))
output += " " + str(elem)
else:
if len(line) != len(self._proto_info):
raise ValueError(
"the complete field set of two given line are inconsistent.")
for index, item in enumerate(line):
name, elements = item
if not isinstance(name, str):
raise ValueError("name%s must be in str type" % type(name))
if not isinstance(elements, list):
raise ValueError("elements%s must be in list type" %
type(elements))
if not elements:
raise ValueError(
"the elements of each field can not be empty, you need padding it in process()."
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%d>."
% (self._proto_info[index][0], name))
if output:
output += " "
output += str(len(elements))
for elem in elements:
if self._proto_info[index][1] != "float":
if isinstance(elem, float):
self._proto_info[index] = (name, "float")
elif not isinstance(elem, int) and not isinstance(elem,
long):
raise ValueError(
"the type of element%s must be in int or float"
% type(elem))
output += " " + str(elem)
return output + "\n"
......@@ -46,10 +46,13 @@ from . import regularizer
from . import average
from . import metrics
from . import transpiler
from . import incubate
from . import distribute_lookup_table
from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder
from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope, _Scope
from .incubate import fleet
from .incubate import data_generator
from .transpiler import DistributeTranspiler, \
memory_optimize, release_memory, DistributeTranspilerConfig
from .lod_tensor import create_lod_tensor, create_random_int_lodtensor
......
......@@ -25,6 +25,10 @@ class DeviceWorker(object):
Init.
"""
self.program_ = None
self.infer_ = None
def set_infer(self, infer=False):
self.infer_ = infer
def set_fleet_desc(self, fleet_desc):
"""
......@@ -125,8 +129,7 @@ class DownpourSGD(DeviceWorker):
for i in self.fleet_desc_.trainer_param.dense_table:
if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(
i.dense_variable_name)
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \
i.table_id
sparse_table = downpour.sparse_table.add()
......@@ -149,11 +152,13 @@ class DownpourSGD(DeviceWorker):
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(
i.dense_variable_name)
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)
if self.infer_:
downpour.push_dense = False
downpour.push_sparse = False
class DeviceWorkerFactory(object):
......
......@@ -612,16 +612,14 @@ class Executor(object):
def _run_inference(self, exe, feed):
return exe.run(feed)
def infer_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
pass
def _dump_debug_info(self, program=None, trainer=None):
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(trainer._desc())
if program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
def train_from_dataset(self,
def _prepare_trainer(self,
program=None,
dataset=None,
scope=None,
......@@ -648,23 +646,148 @@ class Executor(object):
if thread <= 0:
if dataset.thread_num <= 0:
raise RuntimeError(
"You should set thread num first, either in Dataset or in Executor.train_from_dataset"
)
"You should set thread num first, either in Dataset"
"or in Executor.train_from_dataset")
else:
trainer.set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
trainer.set_debug(debug)
trainer.set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return trainer
def infer_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
"""
The document of infer_from_dataset is almost the same as
train_from_dataset, except that in distributed training,
push gradients will be disabled in infer_from_dataset.
infer_from_dataset() can be used for evaluation in multi-thread
very easily.
Args:
program(Program|CompiledProgram): the program that needs to be run,
if not provided, then default_main_program (not compiled) will be used.
dataset(paddle.fluid.Dataset): dataset created outside this function,
a user should provide a well-defined dataset before calling this function.
Please check the document of Dataset if needed.
scope(Scope): the scope used to run this program, you can switch it to different scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run train_from_dataset
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
fetch_info(String List): print information for each variable
print_period(int): the number of mini-batches for each print
Example:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x = fluid.layers.data(name="x", type="int64")
y = fluid.layers.data(name="y", type="int64")
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
filelist = ["dataA.txt", "dataB.txt"]
dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program())
exe.infer_from_dataset(program=fluid.default_main_program(),
dataset=dataset)
"""
trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
thread=thread,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer.gen_trainer_desc()
trainer.set_infer(True)
dataset._prepare_to_run()
if debug:
#with open("train_desc.prototxt", "w") as fout:
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(trainer._desc())
if program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
def train_from_dataset(self,
program=None,
dataset=None,
scope=None,
thread=0,
debug=False,
fetch_list=None,
fetch_info=None,
print_period=100):
"""
Train from a pre-defined Dataset. Dataset is defined in paddle.fluid.dataset.
Given a program, either a program or compiled program, train_from_dataset will
consume all data samples in dataset. Input scope can be given by users. By default,
scope is global_scope(). The total number of thread run in training is `thread`.
Thread number used in training will be minimum value of threadnum in Dataset and
the value of thread in this interface. Debug can be set so that executor will display
Run-Time for all operators and the throughputs of current training task.
Note: train_from_dataset will destroy all resources created within executor for each run.
Args:
program(Program|CompiledProgram): the program that needs to be run,
if not provided, then default_main_program (not compiled) will be used.
dataset(paddle.fluid.Dataset): dataset created outside this function,
a user should provide a well-defined dataset before calling this function.
Please check the document of Dataset if needed.
scope(Scope): the scope used to run this program, you can switch it to different scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run train_from_dataset
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
fetch_info(String List): print information for each variable
print_period(int): the number of mini-batches for each print
Example:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
x = fluid.layers.data(name="x", type="int64")
y = fluid.layers.data(name="y", type="int64")
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
dataset.set_thread(2)
filelist = ["dataA.txt", "dataB.txt"]
dataset.set_filelist(filelist)
exe.run(fluid.default_startup_program())
exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset)
"""
trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
thread=thread,
debug=debug,
fetch_list=fetch_list,
fetch_info=fetch_info,
print_period=print_period)
trainer.gen_trainer_desc()
dataset._prepare_to_run()
if debug:
self._dump_debug_info(program=program, trainer=trainer)
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
......@@ -35,6 +35,7 @@ class TrainerDesc(object):
self.fleet_desc_ = None
self.device_worker_ = None
self.program_ = None
self.infer_ = False
def set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for i, v in enumerate(fetch_vars):
......@@ -52,6 +53,9 @@ class TrainerDesc(object):
def set_device_worker(self, device_worker):
self.device_worker_ = device_worker
def set_infer(self, infer):
self.infer_ = infer
def set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc
......@@ -77,6 +81,7 @@ class MultiTrainer(TrainerDesc):
def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.set_infer(self.infer_)
self.device_worker_.gen_worker_desc(self.proto_desc)
......@@ -94,5 +99,6 @@ class DistMultiTrainer(TrainerDesc):
self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
print("None program")
self.device_worker_.set_infer(self.infer_)
self.device_worker_.set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册