未验证 提交 48669aa8 编写于 作者: X xujiaqi01 提交者: GitHub

fix several sparse table issuses (#20686)

* no longer need to define all embedding layers (no one less) of all slots in each program. make trainer_param repeated in ps.proto.
* add find_distributed_lookup_table_grads instead of hard code GRAD
* support embedding stop gradient. push sparse has error before fix this.* 
* fix fill sparse, skip slots which do not have embedding. each slot's embedding in a sparse table should be used in all training programs before fix this.
* fix pull sparse, skip slots which do not have embedding.
* fix collect feasign label info, skip slots which do not have embedding.
* support when there are multi sparse tables in one or multi training programs, each program can pull/push its own related sparse tables instead of all sparse tables.
* test=develop
上级 fa67e6e8
...@@ -211,6 +211,8 @@ class DownpourWorker : public HogwildWorker { ...@@ -211,6 +211,8 @@ class DownpourWorker : public HogwildWorker {
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_; std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_; std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_; std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
// actually pushed feasign of each table
std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;
// feasign // feasign
std::map<uint64_t, std::vector<uint64_t>> features_; std::map<uint64_t, std::vector<uint64_t>> features_;
......
...@@ -44,6 +44,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -44,6 +44,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
sparse_grad_names_[table_id][j] = table.sparse_grad_name(j); sparse_grad_names_[table_id][j] = table.sparse_grad_name(j);
} }
label_var_name_[table_id] = table.label_var_name(); label_var_name_[table_id] = table.label_var_name();
sparse_push_keys_[table_id] = std::vector<uint64_t>();
} }
for (int i = 0; i < param_.dense_table_size(); ++i) { for (int i = 0; i < param_.dense_table_size(); ++i) {
...@@ -191,6 +192,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) { ...@@ -191,6 +192,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>(); LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " CHECK(tensor != nullptr) << "tensor of var "
<< sparse_key_names_[table_id][i] << " is null"; << sparse_key_names_[table_id][i] << " is null";
// skip slots which do not have embedding
Variable* emb_var =
thread_scope_->FindVar(sparse_value_names_[table_id][i]);
if (emb_var == nullptr) {
continue;
}
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
size_t fea_idx = 0; size_t fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1 // tensor->lod()[0].size() == batch_size + 1
...@@ -237,6 +246,9 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -237,6 +246,9 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel(); int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(emb_slot_name); Variable* var_emb = thread_scope_->FindVar(emb_slot_name);
if (var_emb == nullptr) {
continue;
}
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>(); LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
float* ptr = tensor_emb->mutable_data<float>({len, table.emb_dim()}, float* ptr = tensor_emb->mutable_data<float>({len, table.emb_dim()},
platform::CPUPlace()); platform::CPUPlace());
...@@ -422,9 +434,9 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -422,9 +434,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
} }
} }
timeline.Start(); timeline.Start();
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid, fleet_ptr_->PullSparseVarsSync(
sparse_key_names_[tid], &features_[tid], *thread_scope_, tid, sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim()); &feature_values_[tid], table.fea_dim(), sparse_value_names_[tid]);
timeline.Pause(); timeline.Pause();
pull_sparse_time += timeline.ElapsedSec(); pull_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -504,7 +516,7 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -504,7 +516,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_); dump_slot_, &sparse_push_keys_[tid]);
timeline.Pause(); timeline.Pause();
push_sparse_time += timeline.ElapsedSec(); push_sparse_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
...@@ -646,9 +658,9 @@ void DownpourWorker::TrainFiles() { ...@@ -646,9 +658,9 @@ void DownpourWorker::TrainFiles() {
break; break;
} }
} }
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid, fleet_ptr_->PullSparseVarsSync(
sparse_key_names_[tid], &features_[tid], *thread_scope_, tid, sparse_key_names_[tid], &features_[tid],
&feature_values_[tid], table.fea_dim()); &feature_values_[tid], table.fea_dim(), sparse_value_names_[tid]);
CollectLabelInfo(i); CollectLabelInfo(i);
FillSparseValue(i); FillSparseValue(i);
auto nid_iter = std::find(sparse_value_names_[tid].begin(), auto nid_iter = std::find(sparse_value_names_[tid].begin(),
...@@ -707,7 +719,7 @@ void DownpourWorker::TrainFiles() { ...@@ -707,7 +719,7 @@ void DownpourWorker::TrainFiles() {
*thread_scope_, tid, features_[tid], feature_labels_[tid], *thread_scope_, tid, features_[tid], feature_labels_[tid],
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(), sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
&feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_, &feature_grads_[tid], &push_sparse_status_, cur_batch, use_cvm_,
dump_slot_); dump_slot_, &sparse_push_keys_[tid]);
} }
} }
......
...@@ -159,14 +159,16 @@ void FleetWrapper::CreateClient2ClientConnection() { ...@@ -159,14 +159,16 @@ void FleetWrapper::CreateClient2ClientConnection() {
void FleetWrapper::PullSparseVarsSync( void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id, const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys, const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) { std::vector<std::vector<float>>* fea_values, int fea_value_dim,
const std::vector<std::string>& var_emb_names) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::vector<::std::future<int32_t>> pull_sparse_status; std::vector<::std::future<int32_t>> pull_sparse_status;
pull_sparse_status.resize(0); pull_sparse_status.resize(0);
fea_keys->clear(); fea_keys->clear();
fea_keys->resize(0); fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM); fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) { for (size_t var_index = 0; var_index < var_names.size(); ++var_index) {
const std::string& name = var_names[var_index];
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
continue; continue;
...@@ -175,6 +177,14 @@ void FleetWrapper::PullSparseVarsSync( ...@@ -175,6 +177,14 @@ void FleetWrapper::PullSparseVarsSync(
CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel(); int len = tensor->numel();
// skip slots which do not have embedding
const std::string& emb_name = var_emb_names[var_index];
Variable* emb_var = scope.FindVar(emb_name);
if (emb_var == nullptr) {
continue;
}
for (auto i = 0u; i < len; ++i) { for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) { if (ids[i] == 0u) {
continue; continue;
...@@ -314,7 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -314,7 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot) { const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int offset = 2; int offset = 2;
int slot_offset = 0; int slot_offset = 0;
...@@ -332,12 +343,15 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -332,12 +343,15 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
} }
CHECK_GE(grad_dim, 0); CHECK_GE(grad_dim, 0);
sparse_push_keys->clear();
sparse_push_keys->reserve(fea_keys.size() + 1);
push_values->resize(fea_keys.size() + 1); push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) { for (auto& t : *push_values) {
t.resize(emb_dim + offset + slot_offset); t.resize(emb_dim + offset + slot_offset);
} }
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) { for (size_t i = 0;
i < sparse_key_names.size() && i < sparse_grad_names.size(); ++i) {
Variable* var = scope.FindVar(sparse_key_names[i]); Variable* var = scope.FindVar(sparse_key_names[i]);
if (var == nullptr) { if (var == nullptr) {
continue; continue;
...@@ -376,6 +390,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -376,6 +390,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
g += emb_dim; g += emb_dim;
continue; continue;
} }
sparse_push_keys->push_back(ids[id_idx]);
CHECK(fea_idx < (*push_values).size()); CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size()); CHECK(fea_idx < fea_labels.size());
...@@ -396,17 +411,43 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -396,17 +411,43 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
fea_idx++; fea_idx++;
} }
} }
CHECK(fea_idx == fea_keys.size()) << "fea_idx: " << fea_idx // slots whose embedding has been stop gradient or
<< "features size: " << fea_keys.size(); // not involved in forward-backward
uint64_t no_grad_fea_num = 0u;
for (size_t i = sparse_grad_names.size(); i < sparse_key_names.size(); ++i) {
Variable* var = scope.FindVar(sparse_key_names[i]);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == nullptr) {
LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null";
exit(-1);
}
int len = tensor->numel();
int64_t* ids = tensor->data<int64_t>();
for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) {
continue;
}
++no_grad_fea_num;
}
}
CHECK(fea_idx + no_grad_fea_num == fea_keys.size())
<< "fea_idx: " << fea_idx << " no_grad_fea_num: " << no_grad_fea_num
<< " features size: " << fea_keys.size();
CHECK(fea_idx == sparse_push_keys->size());
if (fea_idx == 0) {
return;
}
std::vector<float*> push_g_vec; std::vector<float*> push_g_vec;
for (auto i = 0u; i < fea_keys.size(); ++i) { for (auto i = 0u; i < sparse_push_keys->size(); ++i) {
push_g_vec.push_back((*push_values)[i].data()); push_g_vec.push_back((*push_values)[i].data());
} }
auto status = pslib_ptr_->_worker_ptr->push_sparse( auto status = pslib_ptr_->_worker_ptr->push_sparse(
table_id, fea_keys.data(), (const float**)push_g_vec.data(), table_id, sparse_push_keys->data(), (const float**)push_g_vec.data(),
fea_keys.size()); sparse_push_keys->size());
push_sparse_status->push_back(std::move(status)); push_sparse_status->push_back(std::move(status));
#endif #endif
} }
......
...@@ -77,7 +77,8 @@ class FleetWrapper { ...@@ -77,7 +77,8 @@ class FleetWrapper {
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, std::vector<std::vector<float>>* fea_values,
int fea_dim); int fea_dim,
const std::vector<std::string>& var_emb_names);
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id, void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names); const std::vector<std::string>& var_names);
...@@ -115,7 +116,8 @@ class FleetWrapper { ...@@ -115,7 +116,8 @@ class FleetWrapper {
const std::vector<std::string>& sparse_grad_names, const int emb_dim, const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values, std::vector<std::vector<float>>* push_values,
std::vector<::std::future<int32_t>>* push_sparse_status, std::vector<::std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot); const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Defination of device workers."""
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section'] __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section']
...@@ -23,9 +24,7 @@ class DeviceWorker(object): ...@@ -23,9 +24,7 @@ class DeviceWorker(object):
""" """
def __init__(self): def __init__(self):
""" """Init."""
Init.
"""
self._program = None self._program = None
self._infer = None self._infer = None
...@@ -75,9 +74,7 @@ class Hogwild(DeviceWorker): ...@@ -75,9 +74,7 @@ class Hogwild(DeviceWorker):
""" """
def __init__(self): def __init__(self):
""" """Init."""
Init.
"""
super(Hogwild, self).__init__() super(Hogwild, self).__init__()
def _gen_worker_desc(self, trainer_desc): def _gen_worker_desc(self, trainer_desc):
...@@ -140,23 +137,29 @@ class DownpourSGD(DeviceWorker): ...@@ -140,23 +137,29 @@ class DownpourSGD(DeviceWorker):
trainer_desc.device_worker_name = "DownpourWorker" trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num pull_thread.device_num = trainer_desc.thread_num
for i in self._fleet_desc.trainer_param.dense_table: if opt_info.get("program_id_to_worker") is None:
raise ValueError("opt_info must have program_id_to_worker")
prog_id_to_worker = opt_info["program_id_to_worker"]
if prog_id_to_worker.get(program_id) is None:
raise ValueError("%s not found in program_id_to_worker" %
program_id)
worker = opt_info["program_id_to_worker"][program_id]
for i in worker.get_desc().dense_table:
if i.table_id in dense_table_set: if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add() dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \ dense_table.table_id = \
i.table_id i.table_id
sparse_len = len(self._fleet_desc.trainer_param.sparse_table) sparse_len = len(worker.get_desc().sparse_table)
for i in range(sparse_len): for i in range(sparse_len):
sparse_table = downpour.sparse_table.add() sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \ sparse_table.table_id = worker.get_desc().sparse_table[i].table_id
self._fleet_desc.trainer_param.sparse_table[i].table_id sparse_table.sparse_key_name.extend(worker.get_desc().sparse_table[
sparse_table.sparse_key_name.extend( i].slot_key)
self._fleet_desc.trainer_param.sparse_table[i].slot_key) sparse_table.sparse_value_name.extend(worker.get_desc()
sparse_table.sparse_value_name.extend( .sparse_table[i].slot_value)
self._fleet_desc.trainer_param.sparse_table[i].slot_value) sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[
sparse_table.sparse_grad_name.extend( i].slot_gradient)
self._fleet_desc.trainer_param.sparse_table[i].slot_gradient)
if opt_info["use_cvm"]: if opt_info["use_cvm"]:
sparse_table.emb_dim = \ sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
...@@ -173,28 +176,24 @@ class DownpourSGD(DeviceWorker): ...@@ -173,28 +176,24 @@ class DownpourSGD(DeviceWorker):
for i in opt_info["stat_var_names"]: for i in opt_info["stat_var_names"]:
downpour.stat_var_names.extend([i]) downpour.stat_var_names.extend([i])
for i in self._fleet_desc.trainer_param.dense_table: for i in worker.get_desc().dense_table:
if i.table_id in dense_table_set: if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add() dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend( dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name) i.dense_gradient_variable_name)
downpour.skip_ops.extend(self._fleet_desc.trainer_param.skip_op) downpour.skip_ops.extend(worker.get_desc().skip_op)
if self._infer: if self._infer:
downpour.push_dense = False downpour.push_dense = False
downpour.push_sparse = False downpour.push_sparse = False
class Section(DeviceWorker): class Section(DeviceWorker):
""" """SectionWorker."""
SectionWorker
"""
def __init__(self): def __init__(self):
""" """Init."""
Init.
"""
super(Section, self).__init__() super(Section, self).__init__()
def _gen_worker_desc(self, trainer_desc): def _gen_worker_desc(self, trainer_desc):
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
"""Defination of PSLib."""
import os import os
import sys import sys
...@@ -25,6 +26,8 @@ from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker ...@@ -25,6 +26,8 @@ from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
class PSLib(Fleet): class PSLib(Fleet):
"""PSLib class."""
def __init__(self): def __init__(self):
super(PSLib, self).__init__(Mode.PSLIB) super(PSLib, self).__init__(Mode.PSLIB)
self._opt_info = None self._opt_info = None
...@@ -89,7 +92,10 @@ class PSLib(Fleet): ...@@ -89,7 +92,10 @@ class PSLib(Fleet):
# barrier for init model # barrier for init model
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table tables = []
for tp in self._dist_desc.trainer_param:
for i in tp.dense_table:
tables.append(i)
for prog, scope in zip(self._main_programs, self._scopes): for prog, scope in zip(self._main_programs, self._scopes):
prog_id = str(id(prog)) prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id] prog_conf = self._opt_info['program_configs'][prog_id]
...@@ -244,7 +250,9 @@ class PSLib(Fleet): ...@@ -244,7 +250,9 @@ class PSLib(Fleet):
3 means save batch model. 3 means save batch model.
Example: Example:
>>> fleet.save_persistables(dirname="/you/path/to/model", mode = 0) .. code-block:: python
fleet.save_persistables(dirname="/you/path/to/model", mode = 0)
""" """
mode = kwargs.get("mode", 0) mode = kwargs.get("mode", 0)
...@@ -260,15 +268,20 @@ class PSLib(Fleet): ...@@ -260,15 +268,20 @@ class PSLib(Fleet):
when using fleet, it will save sparse cache table when using fleet, it will save sparse cache table
Args: Args:
executor(Executor): fluid executor
dirname(str): save path. It can be hdfs/afs path or local path dirname(str): save path. It can be hdfs/afs path or local path
main_program(Program): fluid program, default None main_program(Program): fluid program, default None
kwargs: use define property, current support following kwargs: use define property, current support following
mode(int): define for feature extension in the future, mode(int): define for feature extension in the future,
currently no use, will pass a default value 0 currently no use, will pass a default value 0
Returns:
feasign_num(int): cache feasign num
Example: Example:
.. code-block:: python .. code-block:: python
>>> fleet.save_cache_model(None, dirname="/you/path/to/model", mode = 0)
fleet.save_cache_model(None, dirname="/you/path/to/model", mode = 0)
""" """
mode = kwargs.get("mode", 0) mode = kwargs.get("mode", 0)
...@@ -304,8 +317,12 @@ class PSLib(Fleet): ...@@ -304,8 +317,12 @@ class PSLib(Fleet):
""" """
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
for i in self._opt_info["fleet_desc"].trainer_param.sparse_table: tables = []
self._fleet_ptr.shrink_sparse_table(i.table_id) for tp in self._opt_info["fleet_desc"].trainer_param:
for i in tp.sparse_table:
tables.append(i.table_id)
for i in list(set(tables)):
self._fleet_ptr.shrink_sparse_table(i)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def shrink_dense_table(self, decay, emb_dim=11, scope=None, table_id=None): def shrink_dense_table(self, decay, emb_dim=11, scope=None, table_id=None):
...@@ -330,19 +347,20 @@ class PSLib(Fleet): ...@@ -330,19 +347,20 @@ class PSLib(Fleet):
scope = fluid.global_scope() scope = fluid.global_scope()
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
for i in self._opt_info["fleet_desc"].trainer_param.dense_table: for tp in self._opt_info["fleet_desc"].trainer_param:
if table_id is not None and table_id != i.table_id: for i in tp.dense_table:
continue if table_id is not None and table_id != i.table_id:
var_list = [var for var in i.dense_variable_name] continue
skip = False var_list = [var for var in i.dense_variable_name]
for var in var_list: skip = False
if scope.find_var(var) is None: for var in var_list:
skip = True if scope.find_var(var) is None:
break skip = True
if skip: break
continue if skip:
self._fleet_ptr.shrink_dense_table(i.table_id, scope, var_list, continue
decay, emb_dim) self._fleet_ptr.shrink_dense_table(i.table_id, scope,
var_list, decay, emb_dim)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def clear_model(self): def clear_model(self):
...@@ -476,20 +494,21 @@ class PSLib(Fleet): ...@@ -476,20 +494,21 @@ class PSLib(Fleet):
if ret != 0: if ret != 0:
raise RuntimeError("download model proto file failed") raise RuntimeError("download model proto file failed")
model_proto_file = dest model_proto_file = dest
for i in self._opt_info["fleet_desc"].trainer_param.dense_table: for tp in self._opt_info["fleet_desc"].trainer_param:
if table_id is not None and table_id != i.table_id: for i in tp.dense_table:
continue if table_id is not None and table_id != i.table_id:
table_var_names = [var for var in i.dense_variable_name] continue
skip = False table_var_names = [var for var in i.dense_variable_name]
for var in table_var_names: skip = False
if scope.find_var(var) is None: for var in table_var_names:
skip = True if scope.find_var(var) is None:
break skip = True
if skip: break
continue if skip:
self._fleet_ptr.load_from_paddle_model( continue
scope, table_id, var_names, model_path, model_proto_file, self._fleet_ptr.load_from_paddle_model(
table_var_names, load_combine) scope, table_id, var_names, model_path,
model_proto_file, table_var_names, load_combine)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def _set_opt_info(self, opt_info): def _set_opt_info(self, opt_info):
......
...@@ -10,13 +10,15 @@ ...@@ -10,13 +10,15 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
"""Defination of Server and Worker."""
from . import ps_pb2 as pslib from . import ps_pb2 as pslib
class Server(object): class Server(object):
""" """
A Server basic class. A Server basic class
it's a base class, does not have implementation
""" """
def __init__(self): def __init__(self):
...@@ -26,6 +28,7 @@ class Server(object): ...@@ -26,6 +28,7 @@ class Server(object):
class Worker(object): class Worker(object):
""" """
A Worker basic class. A Worker basic class.
it's a base class, does not have implementation
""" """
def __init__(self): def __init__(self):
...@@ -169,7 +172,10 @@ class DownpourServer(Server): ...@@ -169,7 +172,10 @@ class DownpourServer(Server):
""" """
Args: Args:
table_id(int): id of sparse params table table_id(int): id of sparse params table
strategy(dict): the dense config dict. param_var(list): param vars
grad_var(list): param grad vars
strategy(dict): the dense config dict
sparse_table_names(list): sparse table names
Returns: Returns:
return None return None
""" """
...@@ -230,7 +236,11 @@ class DownpourServer(Server): ...@@ -230,7 +236,11 @@ class DownpourServer(Server):
""" """
Args: Args:
table_id(int): id of datanorm table table_id(int): id of datanorm table
strategy(dict): the datanorm config dict. learning_rate(float): the learning rate used to update parameters
param_var(list): param vars
grad_var(list): param grad vars
strategy(dict): the datanorm config dict
sparse_table_names(list): sparse table names
Returns: Returns:
return None return None
""" """
...@@ -296,43 +306,60 @@ class DownpourWorker(Worker): ...@@ -296,43 +306,60 @@ class DownpourWorker(Worker):
self.window = window self.window = window
self._worker = pslib.DownpourTrainerParameter() self._worker = pslib.DownpourTrainerParameter()
def add_sparse_table(self, table_id, slot_key_vars, slot_value_vars): def add_sparse_table(self,
table_id,
slot_key_vars,
slot_value_vars,
slot_value_grads=None):
""" """
Args: Args:
table_id(int): id of sparse params table table_id(int): id of sparse params table
slot_key_vars(string): slot key id slot_key_vars(list): slot key id
slot_value_var(string): slot key value after embedding slot_value_vars(list): slot key value after embedding
slot_value_grads(list): grad of all params, default is None
Returns: Returns:
return None return None
""" """
if slot_value_grads is None:
slot_value_grad_names = \
[var.name + "@GRAD" for var in slot_value_vars]
else:
value_to_key = {}
for i in range(len(slot_key_vars)):
value_to_key[slot_value_vars[i].name] = slot_key_vars[i]
slot_value_grad_names = []
all_grad_names = [var.name for var in slot_value_grads]
for var in slot_value_vars:
if var.name + "@GRAD" in all_grad_names:
slot_value_grad_names.append(var.name + "@GRAD")
sorted_slot_value_vars = [i for i in slot_value_vars if \
i.name + "@GRAD" in slot_value_grad_names]
sorted_slot_value_vars += [i for i in slot_value_vars if \
i.name + "@GRAD" not in slot_value_grad_names]
sorted_slot_key_vars = \
[value_to_key[v.name] for v in sorted_slot_value_vars]
target_table = None
for table in self._worker.sparse_table: for table in self._worker.sparse_table:
if table.table_id == table_id: if table.table_id == table_id:
if [var.name for var in slot_key_vars keys = self._worker.sparse_table[table_id].slot_key
] == self._worker.sparse_table[table_id].slot_key: key_names = [var.name for var in sorted_slot_key_vars]
if [var.name for var in slot_value_vars for key_name in key_names:
] == self._worker.sparse_table[table_id].slot_value: if key_name not in keys:
if [ raise ValueError("sparse table %s slot_key error" %
var.name + "@GRAD" for var in slot_value_vars
] == self._worker.sparse_table[table_id].slot_gradient:
return
else:
raise ValueError(
"sparse table %s slot_gradient error" %
table_id)
else:
raise ValueError("sparse table %s slot_value error" %
table_id) table_id)
else: target_table = table
raise ValueError("sparse table %s slot_key error" % break
table_id)
table = target_table
if table is not None:
self._worker.sparse_table.remove(table)
table = self._worker.sparse_table.add() table = self._worker.sparse_table.add()
table.table_id = table_id table.table_id = table_id
table.slot_key.extend([var.name for var in slot_key_vars]) table.slot_key.extend([var.name for var in sorted_slot_key_vars])
table.slot_value.extend([var.name for var in slot_value_vars]) table.slot_value.extend([var.name for var in sorted_slot_value_vars])
table.slot_gradient.extend( table.slot_gradient.extend(slot_value_grad_names)
[var.name + "@GRAD" for var in slot_value_vars])
def add_dense_table(self, table_id, learning_rate, param_vars, grad_vars, def add_dense_table(self, table_id, learning_rate, param_vars, grad_vars,
dense_start_table_id, sparse_table_names): dense_start_table_id, sparse_table_names):
...@@ -341,8 +368,10 @@ class DownpourWorker(Worker): ...@@ -341,8 +368,10 @@ class DownpourWorker(Worker):
table_id(int): id of sparse params table table_id(int): id of sparse params table
learning_rate(float): the learning rate used to update parameters. \ learning_rate(float): the learning rate used to update parameters. \
Can be a float value Can be a float value
param_var(list): all dense param. it is a list. param_vars(list): all dense param. it is a list.
grad_var(list): all dense grad parm it is a list. grad_vars(list): all dense grad parm it is a list.
dense_start_table_id(int): dense table start index
sparse_table_names(list): sparse table names
Returns: Returns:
return None return None
""" """
...@@ -365,21 +394,19 @@ class DownpourWorker(Worker): ...@@ -365,21 +394,19 @@ class DownpourWorker(Worker):
for table in self._worker.dense_table: for table in self._worker.dense_table:
if table.table_id == table_id: if table.table_id == table_id:
desc_dense_param_name = list(self._worker.dense_table[ desc_dense_param_name = list(table.dense_variable_name)
table_id - dense_start_table_id].dense_variable_name)
desc_dense_param_name.sort() desc_dense_param_name.sort()
if dense_param_name == desc_dense_param_name: if dense_param_name == desc_dense_param_name:
desc_dense_grad_name = list(self._worker.dense_table[ desc_dense_grad_name = list(
table_id - dense_start_table_id] table.dense_gradient_variable_name)
.dense_gradient_variable_name)
desc_dense_grad_name.sort() desc_dense_grad_name.sort()
if dense_grad_name == desc_dense_grad_name: if dense_grad_name == desc_dense_grad_name:
return return
else: else:
raise ValueError( raise ValueError(
"dense table %s dense_gradient_variable_name error" "dense table %s dense_gradient_variable_name "
% table_id) "error" % table_id)
else: else:
raise ValueError( raise ValueError(
"dense table %s dense_variable_name error" % table_id) "dense table %s dense_variable_name error" % table_id)
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Optimizer Factory."""
__all__ = ["DistributedAdam"] __all__ = ["DistributedAdam"]
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -18,11 +19,17 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table ...@@ -18,11 +19,17 @@ from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
from google.protobuf import text_format from google.protobuf import text_format
from collections import OrderedDict
from .node import DownpourWorker, DownpourServer from .node import DownpourWorker, DownpourServer
from . import ps_pb2 as pslib from . import ps_pb2 as pslib
class DistributedOptimizerImplBase(object): class DistributedOptimizerImplBase(object):
"""
DistributedOptimizerImplBase
base class of optimizers
"""
def __init__(self, optimizer): def __init__(self, optimizer):
self._optimizer = optimizer self._optimizer = optimizer
self._learning_rate = optimizer._learning_rate self._learning_rate = optimizer._learning_rate
...@@ -33,10 +40,23 @@ class DistributedOptimizerImplBase(object): ...@@ -33,10 +40,23 @@ class DistributedOptimizerImplBase(object):
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
"""
Args:
losses(Variable): loss variable defined by user
startup_program(Program): startup program that defined by user
parameter_list(str list): parameter names defined by users
no_grad_set(set): a set of variables that is defined by users
so that these variables do not need gradient computation
"""
pass pass
class DistributedAdam(DistributedOptimizerImplBase): class DistributedAdam(DistributedOptimizerImplBase):
"""
DistributedAdam
adam optimizer in distributed training
"""
def __init__(self, optimizer): def __init__(self, optimizer):
# todo(guru4elephant): add more optimizers here as argument # todo(guru4elephant): add more optimizers here as argument
# todo(guru4elephant): make learning_rate as a variable # todo(guru4elephant): make learning_rate as a variable
...@@ -53,10 +73,10 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -53,10 +73,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
Find input variable of distribute lookup table in program. Find input variable of distribute lookup table in program.
We could support multi-distribute table now. We could support multi-distribute table now.
Args: Args:
program(Program): given program, locate distributed lookup table program(Program): given program, locate distributed lookup table
table_name(str): given table names that is found beforehand table_name(str): given table names that is found beforehand
Returns: Returns:
inputs inputs
""" """
local_vars = program.current_block().vars local_vars = program.current_block().vars
inputs_dict = dict() inputs_dict = dict()
...@@ -75,10 +95,10 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -75,10 +95,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
Find output variable of distribute lookup table in program. Find output variable of distribute lookup table in program.
We could support multi-distribute table now. We could support multi-distribute table now.
Args: Args:
program(Program): given program, locate distributed lookup table programs(Program): given program, locate distributed lookup table
table_name(str): given table name that is found beforehand table_name(str): given table name that is found beforehand
Returns: Returns:
outputs outputs
""" """
local_vars = program.current_block().vars local_vars = program.current_block().vars
outputs_dict = dict() outputs_dict = dict()
...@@ -92,6 +112,19 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -92,6 +112,19 @@ class DistributedAdam(DistributedOptimizerImplBase):
[local_vars[name] for name in op.output("Out")]) [local_vars[name] for name in op.output("Out")])
return outputs_dict return outputs_dict
def _find_distributed_lookup_table_grads(self, program, table_names):
local_vars = program.current_block().vars
grads_dict = dict()
for table_name in table_names:
grads_dict[table_name] = []
for op in program.global_block().ops:
if op.type == "lookup_table_grad" and op.input("W")[
0] in table_names:
grads_dict[op.input("W")[0]].extend(
[local_vars[name] for name in op.input("Out@GRAD")])
return grads_dict
def _find_multi_distributed_lookup_table(self, losses): def _find_multi_distributed_lookup_table(self, losses):
""" """
find multi-sparse-table find multi-sparse-table
...@@ -125,17 +158,57 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -125,17 +158,57 @@ class DistributedAdam(DistributedOptimizerImplBase):
Returns: Returns:
[optimize_ops, grads_and_weights] [optimize_ops, grads_and_weights]
""" """
# sparse table names of each program
prog_id_to_sparse_table = OrderedDict()
# inputs_dict and outputs_dict of sparse tables of each program
prog_id_to_inputs_dict = OrderedDict()
prog_id_to_outputs_dict = OrderedDict()
# related to PSParameter
ps_param = pslib.PSParameter()
# related to ServerParameter
server = DownpourServer()
# program to worker (related to DownpourTrainerParameter)
prog_id_to_worker = OrderedDict()
# param_grads of each program
prog_id_to_param_grads = OrderedDict()
# sparse_grads of each program
prog_id_to_sparse_grads = OrderedDict()
sparse_table_names = self._find_multi_distributed_lookup_table(losses) sparse_table_to_index = OrderedDict()
inputs_dict = self._find_distributed_lookup_table_inputs( sparse_table_index = 0
losses[0].block.program, sparse_table_names) for loss in losses:
sparse_table = self._find_multi_distributed_lookup_table([loss])
prog_id = str(id(loss.block.program))
prog_id_to_sparse_table[prog_id] = sparse_table
outputs_dict = self._find_distributed_lookup_table_outputs( # get sparse_table_to_index
losses[0].block.program, sparse_table_names) for tn in sparse_table:
if sparse_table_to_index.get(tn) is None:
sparse_table_to_index[tn] = sparse_table_index
sparse_table_index += 1
# get inputs_dict
inputs_dict = self._find_distributed_lookup_table_inputs(
loss.block.program, sparse_table)
prog_id_to_inputs_dict[prog_id] = inputs_dict
# get outputs_dict
outputs_dict = self._find_distributed_lookup_table_outputs(
loss.block.program, sparse_table)
prog_id_to_outputs_dict[prog_id] = outputs_dict
prog_id_to_worker[prog_id] = DownpourWorker(self._window)
# param_grads of program
params_grads = sorted(
fluid.backward.append_backward(loss, parameter_list,
no_grad_set),
key=lambda x: x[0].name)
prog_id_to_param_grads[prog_id] = params_grads
grads_dict = self._find_distributed_lookup_table_grads(
loss.block.program, sparse_table)
prog_id_to_sparse_grads[prog_id] = grads_dict
ps_param = pslib.PSParameter()
server = DownpourServer()
worker = DownpourWorker(self._window)
# if user specify a fleet_desc.prototxt file, then load the file # if user specify a fleet_desc.prototxt file, then load the file
# instead of creating default fleet_desc.prototxt. # instead of creating default fleet_desc.prototxt.
# user can specify server_param or trainer_param or fs_client_param. # user can specify server_param or trainer_param or fs_client_param.
...@@ -144,37 +217,60 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -144,37 +217,60 @@ class DistributedAdam(DistributedOptimizerImplBase):
with open(fleet_desc_file) as f: with open(fleet_desc_file) as f:
text_format.Merge(f.read(), ps_param) text_format.Merge(f.read(), ps_param)
server.get_desc().CopyFrom(ps_param.server_param) server.get_desc().CopyFrom(ps_param.server_param)
worker.get_desc().CopyFrom(ps_param.trainer_param) if len(ps_param.trainer_param) == 1:
for k in prog_id_to_worker:
prog_id_to_worker[k].get_desc().CopyFrom(
ps_param.trainer_param[0])
else:
if len(ps_param.trainer_param) != len(prog_id_to_worker):
raise ValueError(
"trainer param size != program size, %s vs %s" %
(len(ps_param.trainer_param), len(prog_id_to_worker)))
idx = 0
# prog_id_to_worker is OrderedDict
for k in prog_id_to_worker:
prog_id_to_worker[k].get_desc().CopyFrom(
ps_param.trainer_param[idx])
idx += 1
sparse_table_index = 0 # ServerParameter add all sparse tables
for tn in sparse_table_names: for tn in sparse_table_to_index:
sparse_table_index = sparse_table_to_index[tn]
if strategy.get(tn) is not None: if strategy.get(tn) is not None:
server.add_sparse_table(sparse_table_index, strategy[tn]) server.add_sparse_table(sparse_table_index, strategy[tn])
else: else:
server.add_sparse_table(sparse_table_index, None) server.add_sparse_table(sparse_table_index, None)
worker.add_sparse_table(sparse_table_index, inputs_dict[tn],
outputs_dict[tn])
sparse_table_index += 1
dense_start_table_id = sparse_table_index # each DownpourTrainerParameter add its own sparse tables
dense_table_index = sparse_table_index for loss in losses:
program_configs = {} prog_id = str(id(loss.block.program))
param_grads_list = [] worker = prog_id_to_worker[prog_id]
inputs_dict = prog_id_to_inputs_dict[prog_id]
outputs_dict = prog_id_to_outputs_dict[prog_id]
for tn in prog_id_to_sparse_table[prog_id]:
sparse_table_index = sparse_table_to_index[tn]
grads_dict = prog_id_to_sparse_grads[prog_id]
worker.add_sparse_table(sparse_table_index, inputs_dict[tn],
outputs_dict[tn], grads_dict[tn])
dense_start_table_id = len(sparse_table_to_index)
dense_table_index = len(sparse_table_to_index)
program_configs = {}
# ServerParameter add all dense tables
# each DownpourTrainerParameter add its own dense tables
for loss_index in range(len(losses)): for loss_index in range(len(losses)):
program_id = str(id(losses[loss_index].block.program)) program_id = str(id(losses[loss_index].block.program))
worker = prog_id_to_worker[program_id]
sparse_table_names = prog_id_to_sparse_table[program_id]
sparse_table_index = \
[sparse_table_to_index[i] for i in sparse_table_names]
program_configs[program_id] = { program_configs[program_id] = {
"pull_sparse": "pull_sparse": [t_index for t_index in sparse_table_index],
[t_index for t_index in range(sparse_table_index)], "push_sparse": [t_index for t_index in sparse_table_index]
"push_sparse":
[t_index for t_index in range(sparse_table_index)]
} }
params_grads = sorted( params_grads = prog_id_to_param_grads[program_id]
fluid.backward.append_backward(losses[loss_index],
parameter_list, no_grad_set),
key=lambda x: x[0].name)
param_grads_list.append(params_grads)
params = [] params = []
grads = [] grads = []
data_norm_params = [] data_norm_params = []
...@@ -230,15 +326,22 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -230,15 +326,22 @@ class DistributedAdam(DistributedOptimizerImplBase):
program_configs[program_id]["push_dense"].extend( program_configs[program_id]["push_dense"].extend(
[dense_table_index]) [dense_table_index])
dense_table_index += 1 dense_table_index += 1
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
if len(worker.get_desc().skip_op) == 0:
worker.get_desc().skip_op.extend(worker_skipped_ops)
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc()) # prog_id_to_worker is OrderedDict
# Todo(guru4elephant): figure out how to support more sparse parameters if len(ps_param.trainer_param) == 0:
# currently only support lookup_table for k in prog_id_to_worker:
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] tp = ps_param.trainer_param.add()
if len(ps_param.trainer_param.skip_op) == 0: tp.CopyFrom(prog_id_to_worker[k].get_desc())
ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
opt_info = {} opt_info = {}
opt_info["program_id_to_worker"] = prog_id_to_worker
opt_info["program_configs"] = program_configs opt_info["program_configs"] = program_configs
opt_info["trainer"] = "DistMultiTrainer" opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD" opt_info["device_worker"] = "DownpourSGD"
...@@ -263,4 +366,8 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -263,4 +366,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
for loss in losses: for loss in losses:
loss.block.program._fleet_opt = opt_info loss.block.program._fleet_opt = opt_info
return None, param_grads_list[0], opt_info param_grads_list = []
for loss in losses:
prog_id = str(id(loss.block.program))
param_grads_list.append(prog_id_to_param_grads[prog_id])
return None, param_grads_list, opt_info
...@@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle', package='paddle',
syntax='proto2', syntax='proto2',
serialized_pb=_b( serialized_pb=_b(
'\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x01(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xfc\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\x96\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\x9c\x03\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01' '\n\x08ps.proto\x12\x06paddle\"\x9e\x02\n\x0bPSParameter\x12\x14\n\x0cworker_class\x18\x01 \x01(\t\x12\x14\n\x0cserver_class\x18\x02 \x01(\t\x12\x16\n\x0einstance_class\x18\x03 \x01(\t\x12-\n\x0cworker_param\x18\x65 \x01(\x0b\x32\x17.paddle.WorkerParameter\x12-\n\x0cserver_param\x18\x66 \x01(\x0b\x32\x17.paddle.ServerParameter\x12\x38\n\rtrainer_param\x18\xad\x02 \x03(\x0b\x32 .paddle.DownpourTrainerParameter\x12\x33\n\x0f\x66s_client_param\x18\xf5\x03 \x01(\x0b\x32\x19.paddle.FsClientParameter\"Q\n\x0fWorkerParameter\x12>\n\x15\x64ownpour_worker_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourWorkerParameter\"Q\n\x0fServerParameter\x12>\n\x15\x64ownpour_server_param\x18\x01 \x01(\x0b\x32\x1f.paddle.DownpourServerParameter\"O\n\x17\x44ownpourWorkerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\"\xfd\x01\n\x18\x44ownpourTrainerParameter\x12\x30\n\x0b\x64\x65nse_table\x18\x01 \x03(\x0b\x32\x1b.paddle.DenseTableParameter\x12\x32\n\x0csparse_table\x18\x02 \x03(\x0b\x32\x1c.paddle.SparseTableParameter\x12\x1d\n\x15push_sparse_per_batch\x18\x03 \x01(\x05\x12\x1c\n\x14push_dense_per_batch\x18\x04 \x01(\x05\x12\x0f\n\x07skip_op\x18\x05 \x03(\t\x12-\n\x0eprogram_config\x18\x06 \x03(\x0b\x32\x15.paddle.ProgramConfig\"\x99\x01\n\rProgramConfig\x12\x12\n\nprogram_id\x18\x01 \x02(\t\x12\x1c\n\x14push_sparse_table_id\x18\x02 \x03(\x05\x12\x1b\n\x13push_dense_table_id\x18\x03 \x03(\x05\x12\x1c\n\x14pull_sparse_table_id\x18\x04 \x03(\x05\x12\x1b\n\x13pull_dense_table_id\x18\x05 \x03(\x05\"{\n\x13\x44\x65nseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x1b\n\x13\x64\x65nse_variable_name\x18\x02 \x03(\t\x12$\n\x1c\x64\x65nse_gradient_variable_name\x18\x03 \x03(\t\x12\x0f\n\x07\x66\x65\x61_dim\x18\x04 \x01(\x05\"z\n\x14SparseTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x05\x12\x13\n\x0b\x66\x65\x61ture_dim\x18\x02 \x01(\x05\x12\x10\n\x08slot_key\x18\x03 \x03(\t\x12\x12\n\nslot_value\x18\x04 \x03(\t\x12\x15\n\rslot_gradient\x18\x05 \x03(\t\"\x86\x01\n\x17\x44ownpourServerParameter\x12\x34\n\x14\x64ownpour_table_param\x18\x01 \x03(\x0b\x32\x16.paddle.TableParameter\x12\x35\n\rservice_param\x18\x02 \x01(\x0b\x32\x1e.paddle.ServerServiceParameter\"\xd7\x01\n\x16ServerServiceParameter\x12*\n\x0cserver_class\x18\x01 \x01(\t:\x14\x44ownpourBrpcPsServer\x12*\n\x0c\x63lient_class\x18\x02 \x01(\t:\x14\x44ownpourBrpcPsClient\x12(\n\rservice_class\x18\x03 \x01(\t:\x11\x44ownpourPsService\x12\x1c\n\x11start_server_port\x18\x04 \x01(\r:\x01\x30\x12\x1d\n\x11server_thread_num\x18\x05 \x01(\r:\x02\x31\x32\"\xc0\x02\n\x0eTableParameter\x12\x10\n\x08table_id\x18\x01 \x01(\x04\x12\x13\n\x0btable_class\x18\x02 \x01(\t\x12\x17\n\tshard_num\x18\x03 \x01(\x04:\x04\x31\x30\x30\x30\x12\x30\n\x08\x61\x63\x63\x65ssor\x18\x04 \x01(\x0b\x32\x1e.paddle.TableAccessorParameter\x12\x1f\n\x04type\x18\x05 \x01(\x0e\x32\x11.paddle.TableType\x12\x1f\n\x10\x63ompress_in_save\x18\x06 \x01(\x08:\x05\x66\x61lse\x12\'\n\x19\x65nable_sparse_table_cache\x18\x07 \x01(\x08:\x04true\x12(\n\x17sparse_table_cache_rate\x18\x08 \x01(\x01:\x07\x30.00055\x12\'\n\x1bsparse_table_cache_file_num\x18\t \x01(\r:\x02\x31\x36\"\xfc\x02\n\x16TableAccessorParameter\x12\x16\n\x0e\x61\x63\x63\x65ssor_class\x18\x01 \x01(\t\x12\x38\n\x10sparse_sgd_param\x18\x02 \x01(\x0b\x32\x1e.paddle.SparseSGDRuleParameter\x12\x36\n\x0f\x64\x65nse_sgd_param\x18\x03 \x01(\x0b\x32\x1d.paddle.DenseSGDRuleParameter\x12\x13\n\x07\x66\x65\x61_dim\x18\x04 \x01(\r:\x02\x31\x31\x12\x15\n\nembedx_dim\x18\x05 \x01(\r:\x01\x38\x12\x1c\n\x10\x65mbedx_threshold\x18\x06 \x01(\r:\x02\x31\x30\x12G\n\x17\x64ownpour_accessor_param\x18\x07 \x01(\x0b\x32&.paddle.DownpourTableAccessorParameter\x12\x45\n\x19table_accessor_save_param\x18\x08 \x03(\x0b\x32\".paddle.TableAccessorSaveParameter\"\x96\x02\n\x1e\x44ownpourTableAccessorParameter\x12\x19\n\x0cnonclk_coeff\x18\x01 \x01(\x02:\x03\x30.1\x12\x16\n\x0b\x63lick_coeff\x18\x02 \x01(\x02:\x01\x31\x12\x1b\n\x0e\x62\x61se_threshold\x18\x03 \x01(\x02:\x03\x31.5\x12\x1d\n\x0f\x64\x65lta_threshold\x18\x04 \x01(\x02:\x04\x30.25\x12\x1b\n\x0f\x64\x65lta_keep_days\x18\x05 \x01(\x02:\x02\x31\x36\x12#\n\x15show_click_decay_rate\x18\x06 \x01(\x02:\x04\x30.98\x12\x1d\n\x10\x64\x65lete_threshold\x18\x07 \x01(\x02:\x03\x30.8\x12$\n\x18\x64\x65lete_after_unseen_days\x18\x08 \x01(\x02:\x02\x33\x30\"S\n\x1aTableAccessorSaveParameter\x12\r\n\x05param\x18\x01 \x01(\r\x12\x11\n\tconverter\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65\x63onverter\x18\x03 \x01(\t\"e\n\x10PsRequestMessage\x12\x0e\n\x06\x63md_id\x18\x01 \x02(\r\x12\x10\n\x08table_id\x18\x02 \x01(\r\x12\x0e\n\x06params\x18\x03 \x03(\x0c\x12\x11\n\tclient_id\x18\x04 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\x0c\"\x85\x01\n\x16SparseSGDRuleParameter\x12\x1b\n\rlearning_rate\x18\x01 \x01(\x01:\x04\x30.05\x12\x18\n\rinitial_g2sum\x18\x02 \x01(\x01:\x01\x33\x12\x1d\n\rinitial_range\x18\x03 \x01(\x01:\x06\x30.0001\x12\x15\n\rweight_bounds\x18\x04 \x03(\x02\"\xe1\x01\n\x15\x44\x65nseSGDRuleParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x04\x61\x64\x61m\x18\x02 \x01(\x0b\x32\x18.paddle.AdamSGDParameter\x12(\n\x05naive\x18\x03 \x01(\x0b\x32\x19.paddle.NaiveSGDParameter\x12,\n\x07summary\x18\x04 \x01(\x0b\x32\x1b.paddle.SummarySGDParameter\x12:\n\x0emoving_average\x18\x05 \x01(\x0b\x32\".paddle.MovingAverageRuleParameter\"\xac\x01\n\x10\x41\x64\x61mSGDParameter\x12\x1c\n\rlearning_rate\x18\x01 \x01(\x01:\x05\x35\x65-06\x12 \n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01:\x08\x30.999993\x12\x1e\n\x0e\x61\x64\x61_decay_rate\x18\x03 \x01(\x01:\x06\x30.9999\x12\x1a\n\x0b\x61\x64\x61_epsilon\x18\x04 \x01(\x01:\x05\x31\x65-08\x12\x1c\n\x0emom_decay_rate\x18\x05 \x01(\x01:\x04\x30.99\"J\n\x11NaiveSGDParameter\x12\x1d\n\rlearning_rate\x18\x01 \x01(\x01:\x06\x30.0002\x12\x16\n\x0e\x61vg_decay_rate\x18\x02 \x01(\x01\";\n\x13SummarySGDParameter\x12$\n\x12summary_decay_rate\x18\x01 \x01(\x01:\x08\x30.999999\".\n\x1aMovingAverageRuleParameter\x12\x10\n\x08momentum\x18\x01 \x01(\x01\"I\n\x11PsResponseMessage\x12\x13\n\x08\x65rr_code\x18\x01 \x02(\x05:\x01\x30\x12\x11\n\x07\x65rr_msg\x18\x02 \x02(\t:\x00\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xd5\x01\n\x11\x46sClientParameter\x12:\n\x07\x66s_type\x18\x01 \x01(\x0e\x32#.paddle.FsClientParameter.FsApiType:\x04HDFS\x12\x0b\n\x03uri\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0e\n\x06passwd\x18\x04 \x01(\t\x12\x13\n\x0b\x62uffer_size\x18\x05 \x01(\x05\x12\x12\n\nhadoop_bin\x18\x33 \x01(\t\x12\x10\n\x08\x61\x66s_conf\x18\x65 \x01(\t\"\x1e\n\tFsApiType\x12\x08\n\x04HDFS\x10\x00\x12\x07\n\x03\x41\x46S\x10\x01*4\n\tTableType\x12\x13\n\x0fPS_SPARSE_TABLE\x10\x00\x12\x12\n\x0ePS_DENSE_TABLE\x10\x01*\x9c\x03\n\x07PsCmdID\x12\x17\n\x13PS_PULL_DENSE_TABLE\x10\x00\x12\x17\n\x13PS_PUSH_DENSE_TABLE\x10\x01\x12\x18\n\x14PS_PULL_SPARSE_TABLE\x10\x02\x12\x18\n\x14PS_PUSH_SPARSE_TABLE\x10\x03\x12\x13\n\x0fPS_SHRINK_TABLE\x10\x04\x12\x15\n\x11PS_SAVE_ONE_TABLE\x10\x05\x12\x15\n\x11PS_SAVE_ALL_TABLE\x10\x06\x12\x15\n\x11PS_LOAD_ONE_TABLE\x10\x07\x12\x15\n\x11PS_LOAD_ALL_TABLE\x10\x08\x12\x16\n\x12PS_CLEAR_ONE_TABLE\x10\t\x12\x16\n\x12PS_CLEAR_ALL_TABLE\x10\n\x12\x17\n\x13PS_PUSH_DENSE_PARAM\x10\x0b\x12\x12\n\x0ePS_STOP_SERVER\x10\x0c\x12\x1b\n\x17PS_SAVE_ONE_CACHE_TABLE\x10\r\x12\x1a\n\x16PS_GET_CACHE_THRESHOLD\x10\x0e\x12\x14\n\x10PS_CACHE_SHUFFLE\x10\x0f\x12\x0e\n\nPS_S2S_MSG\x10\x65\x32K\n\tPsService\x12>\n\x07service\x12\x18.paddle.PsRequestMessage\x1a\x19.paddle.PsResponseMessageB\x03\x80\x01\x01'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -290,9 +290,9 @@ _PSPARAMETER = _descriptor.Descriptor( ...@@ -290,9 +290,9 @@ _PSPARAMETER = _descriptor.Descriptor(
number=301, number=301,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=1, label=3,
has_default_value=False, has_default_value=False,
default_value=None, default_value=[],
message_type=None, message_type=None,
enum_type=None, enum_type=None,
containing_type=None, containing_type=None,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Testcases for Downpour."""
from __future__ import print_function from __future__ import print_function
...@@ -25,15 +26,19 @@ import sys ...@@ -25,15 +26,19 @@ import sys
from op_test import OpTest from op_test import OpTest
from paddle.fluid.trainer_desc import DistMultiTrainer from paddle.fluid.trainer_desc import DistMultiTrainer
from paddle.fluid.device_worker import DownpourSGD from paddle.fluid.device_worker import DownpourSGD
from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker
from google.protobuf import text_format from google.protobuf import text_format
import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib
class TestListenAndServOp(OpTest): class TestListenAndServOp(OpTest):
"""TestListenAndServOp."""
def setUp(self): def setUp(self):
pass pass
def test_device_work_use_cvm(self): def test_device_work_use_cvm(self):
"""test device work use_cvm."""
if sys.platform == 'win32' or sys.platform == 'sys.platform': if sys.platform == 'win32' or sys.platform == 'sys.platform':
pass pass
else: else:
...@@ -77,6 +82,9 @@ class TestListenAndServOp(OpTest): ...@@ -77,6 +82,9 @@ class TestListenAndServOp(OpTest):
opt_info["scale_datanorm"] = -1 opt_info["scale_datanorm"] = -1
opt_info["dump_slot"] = False opt_info["dump_slot"] = False
opt_info["stat_var_names"] = [] opt_info["stat_var_names"] = []
worker = DownpourWorker(None)
worker.get_desc().CopyFrom(ps_param.trainer_param[0])
opt_info["program_id_to_worker"] = {program_id: worker}
main_program._fleet_opt = opt_info main_program._fleet_opt = opt_info
trainer = DistMultiTrainer() trainer = DistMultiTrainer()
...@@ -90,6 +98,7 @@ class TestListenAndServOp(OpTest): ...@@ -90,6 +98,7 @@ class TestListenAndServOp(OpTest):
os.system(cmd) os.system(cmd)
def test_device_work(self): def test_device_work(self):
"""test devicve worker."""
if sys.platform == 'win32' or sys.platform == 'sys.platform': if sys.platform == 'win32' or sys.platform == 'sys.platform':
pass pass
else: else:
...@@ -133,6 +142,9 @@ class TestListenAndServOp(OpTest): ...@@ -133,6 +142,9 @@ class TestListenAndServOp(OpTest):
opt_info["scale_datanorm"] = -1 opt_info["scale_datanorm"] = -1
opt_info["dump_slot"] = False opt_info["dump_slot"] = False
opt_info["stat_var_names"] = [] opt_info["stat_var_names"] = []
worker = DownpourWorker(None)
worker.get_desc().CopyFrom(ps_param.trainer_param[0])
opt_info["program_id_to_worker"] = {program_id: worker}
main_program._fleet_opt = opt_info main_program._fleet_opt = opt_info
trainer = DistMultiTrainer() trainer = DistMultiTrainer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册