提交 b7940c29 编写于 作者: X xjqbest 提交者: dongdaxiang

fix bug of gen_worker_desc and set_filelist, add some doc

上级 68d7bf3d
......@@ -221,6 +221,11 @@ void DatasetImpl<T>::DestroyReaders() {
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "readers size: " << readers_.size();
// if memory_data_ is not empty, which means it's not InMemory mode,
// so the next epoch should read all data again
if (memory_data_.size() != 0) {
file_idx_ = 0;
}
}
template <typename T>
......
......@@ -295,8 +295,6 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
int offset = 2;
uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) {
LOG(WARNING) << "sparse key names[" << i << "]: " << sparse_key_names[i];
LOG(WARNING) << "sparse grad names[" << i << "]: " << sparse_grad_names[i];
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
......@@ -313,7 +311,6 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
exit(-1);
}
int len = tensor->numel();
LOG(WARNING) << " tensor len: " << len;
int64_t* ids = tensor->data<int64_t>();
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
......@@ -325,16 +322,12 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
g += emb_dim;
continue;
}
LOG(WARNING) << "going to memcpy";
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
LOG(WARNING) << "show";
(*push_values)[fea_idx][0] = 1.0f;
LOG(WARNING) << "click";
(*push_values)[fea_idx][1] = static_cast<float>(fea_labels[fea_idx]);
LOG(WARNING) << "offset";
g += emb_dim;
fea_idx++;
}
......
......@@ -19,10 +19,25 @@ __all__ = ['DatasetFactory']
class DatasetFactory(object):
"""
DatasetFactory is a factory which create dataset by its name,
you can create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset".
Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
"""
def __init__(self):
"""
Init
"""
pass
def create_dataset(self, datafeed_class="QueueDataset"):
"""
Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset".
"""
try:
dataset = globals()[datafeed_class]()
return dataset
......@@ -32,7 +47,13 @@ class DatasetFactory(object):
class DatasetBase(object):
"""
Base dataset class
"""
def __init__(self):
"""
Init
"""
# define class name here
# to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc()
......@@ -45,6 +66,12 @@ class DatasetBase(object):
Set pipe command of current dataset
A pipe command is a UNIX pipeline command that can be used only
Example:
>>> dataset.set_pipe_command("python my_script.py")
Args:
pipe_command: pipe command
"""
self.proto_desc.pipe_command = pipe_command
......@@ -53,8 +80,7 @@ class DatasetBase(object):
Set batch size. Will be effective during training
Example:
>>> data_feed = fluid.DataFeedDesc('data.proto')
>>> data_feed.set_batch_size(128)
>>> dataset.set_batch_size(128)
Args:
batch_size: batch size
......@@ -63,13 +89,40 @@ class DatasetBase(object):
self.proto_desc.batch_size = batch_size
def set_thread(self, thread_num):
"""
Set thread num, it is the num of readers.
Example:
>>> dataset.set_thread(12)
Args:
thread_num: thread num
"""
self.dataset.set_thread_num(thread_num)
self.thread_num = thread_num
def set_filelist(self, filelist):
"""
Set file list in current worker.
Example:
>>> dataset.set_filelist(['a.txt', 'b.txt'])
Args:
filelist: file list
"""
self.dataset.set_filelist(filelist)
def set_use_var(self, var_list):
"""
Set Variables which you will use.
Example:
>>> dataset.set_use_var([data, label])
Args:
var_list: variable list
"""
multi_slot = self.proto_desc.multi_slot_desc
for var in var_list:
slot_var = multi_slot.slots.add()
......@@ -87,9 +140,23 @@ class DatasetBase(object):
)
def set_hdfs_config(self, fs_name, fs_ugi):
"""
Set hdfs config: fs name ad ugi
Example:
>>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
Args:
fs_name: fs name
fs_ugi: fs ugi
"""
self.dataset.set_hdfs_config(fs_name, fs_ugi)
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
user no need to call this function.
"""
self.dataset.set_data_feed_desc(self.desc())
def desc(self):
......@@ -97,8 +164,7 @@ class DatasetBase(object):
Returns a protobuf message for this DataFeedDesc
Example:
>>> data_feed = fluid.DataFeedDesc('data.proto')
>>> print(data_feed.desc())
>>> print(dataset.desc())
Returns:
A string message
......@@ -107,18 +173,50 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase):
"""
InMemoryDataset, it will load data into memory
and shuffle data before training
Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
"""
def __init__(self):
"""
Init
"""
super(InMemoryDataset, self).__init__()
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self):
"""
Load data into memory
Example:
>>> dataset.load_into_memory()
"""
self._prepare_to_run()
self.dataset.load_into_memory()
def local_shuffle(self):
"""
Local shuffle
Example:
>>> dataset.local_shuffle()
"""
self.dataset.local_shuffle()
def global_shuffle(self, fleet=None):
"""
Global shuffle.
If you run distributed, you should pass fleet instead of None.
Example:
>>> dataset.global_shuffle(fleet)
Args:
fleet: fleet singleton. Default None.
"""
trainer_num = 1
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
......@@ -130,12 +228,27 @@ class InMemoryDataset(DatasetBase):
class QueueDataset(DatasetBase):
"""
QueueDataset, it will process data streamly.
Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("QueueDataset")
"""
def __init__(self):
"""
Init
"""
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
def local_shuffle(self):
"""
Local shuffle
"""
pass
def global_shuffle(self, fleet=None):
"""
Global shuffle
"""
pass
......@@ -43,31 +43,6 @@ class DownpourSGD(DeviceWorker):
super(DownpourSGD, self).__init__()
def gen_worker_desc(self, trainer_desc):
trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(
self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
dense_table.table_id = \
self.fleet_desc_.trainer_param.dense_table[0].table_id
downpour = trainer_desc.downpour_param
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
self.fleet_desc_.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = \
self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click"
dense_table_set = set()
program_id = str(id(self.program_))
if self.program_ == None:
......@@ -75,6 +50,7 @@ class DownpourSGD(DeviceWorker):
sys.exit(-1)
opt_info = self.program_._fleet_opt
program_configs = opt_info["program_configs"]
downpour = trainer_desc.downpour_param
for pid in program_configs:
if pid == program_id:
......@@ -92,6 +68,32 @@ class DownpourSGD(DeviceWorker):
dense_table_set.add(i)
break
trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
for i in self.fleet_desc_.trainer_param.dense_table:
if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(
i.dense_variable_name)
dense_table.table_id = \
i.table_id
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
self.fleet_desc_.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = \
self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click"
for i in self.fleet_desc_.trainer_param.dense_table:
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
......
......@@ -658,7 +658,8 @@ class Executor(object):
trainer.gen_trainer_desc()
dataset._prepare_to_run()
if debug:
with open("train_desc.prototxt", "w") as fout:
#with open("train_desc.prototxt", "w") as fout:
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(trainer._desc())
if program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout:
......
......@@ -146,7 +146,7 @@ class Fleet(object):
self.role_maker_.barrier_all()
self.role_maker_.barrier_worker()
if self.role_maker_.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table._values
tables = self._dist_desc.trainer_param.dense_table
for prog in programs:
prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id]
......@@ -156,8 +156,7 @@ class Fleet(object):
continue
for table_id in prog_conf[key]:
prog_tables[int(table_id)] = 0
for i in range(0, len(tables)):
table = tables[i]
for table in tables:
if int(table.table_id) not in prog_tables:
continue
var_name_list = []
......@@ -185,6 +184,12 @@ class Fleet(object):
"""
return self.role_maker_.server_num()
def get_worker_index(self):
"""
return the mpi rank of current worker
"""
return self.role_maker_.worker_index();
def is_worker(self):
"""
return whether current node is a worker
......@@ -306,3 +311,4 @@ init_pserver_model = fleet_instance.init_pserver_model
save_pserver_model = fleet_instance.save_pserver_model
worker_num = fleet_instance.get_worker_num
server_num = fleet_instance.get_server_num
worker_index = fleet_instance.get_worker_index
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册