未验证 提交 57843b2c 编写于 作者: T tangwei12 提交者: GitHub

2 0 ps core 1 (#29883)

* add ps table (#29463)

* add ps table

Change-Id: I468a04bd071d21ff52654926fcf4d5f3da19e178

* add service (#29560)

* add service, remove ut on mac

* fix heter_profiler & add heter stop method

* fix code style
上级 98d2d072
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(distributed)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(imperative) add_subdirectory(imperative)
add_subdirectory(operators) add_subdirectory(operators)
......
if(NOT WITH_DISTRIBUTE)
return()
endif()
proto_library(ps_framework_proto SRCS ps.proto)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-error=unused-value -Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
add_subdirectory(table)
add_subdirectory(test)
# open it until CI support brpc
return()
add_subdirectory(service)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(fleet
SRCS fleet.cc
DEPS framework_proto ps_framework_proto ps_service variable_helper scope op_registry fs shell ${RPC_DEPS})
target_link_libraries(fleet z)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <iostream>
#include <map>
#include <string>
#include <vector>
namespace paddle {
namespace distributed {
class Any {
public:
Any() : content_(NULL) {}
template <typename ValueType>
Any(const ValueType &value) : content_(new Holder<ValueType>(value)) {}
Any(const Any &other)
: content_(other.content_ ? other.content_->clone() : NULL) {}
~Any() { delete content_; }
template <typename ValueType>
ValueType *any_cast() {
return content_ ? &static_cast<Holder<ValueType> *>(content_)->held_ : NULL;
}
private:
class PlaceHolder {
public:
virtual ~PlaceHolder() {}
virtual PlaceHolder *clone() const = 0;
};
template <typename ValueType>
class Holder : public PlaceHolder {
public:
explicit Holder(const ValueType &value) : held_(value) {}
virtual PlaceHolder *clone() const { return new Holder(held_); }
ValueType held_;
};
PlaceHolder *content_;
};
class ObjectFactory {
public:
ObjectFactory() {}
virtual ~ObjectFactory() {}
virtual Any NewInstance() { return Any(); }
private:
};
typedef std::map<std::string, ObjectFactory *> FactoryMap;
typedef std::map<std::string, FactoryMap> BaseClassMap;
#ifdef __cplusplus
extern "C" {
#endif
inline BaseClassMap &global_factory_map() {
static BaseClassMap *base_class = new BaseClassMap();
return *base_class;
}
#ifdef __cplusplus
}
#endif
inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
// typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap;
#define REGISTER_REGISTERER(base_class) \
class base_class##Registerer { \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
if (global_factory_map_cpp().find(#base_class) == \
global_factory_map_cpp().end()) { \
LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" \
<< #base_class; \
return NULL; \
} \
FactoryMap &map = global_factory_map_cpp()[#base_class]; \
FactoryMap::iterator iter = map.find(name); \
if (iter == map.end()) { \
LOG(ERROR) << "Can't Find Class For Create with:" << name; \
return NULL; \
} \
Any object = iter->second->NewInstance(); \
return *(object.any_cast<base_class *>()); \
} \
};
#define REGISTER_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \
public: \
Any NewInstance() { return Any(new name()); } \
}; \
void register_factory_##name() { \
FactoryMap &map = global_factory_map_cpp()[#clazz]; \
if (map.find(#name) == map.end()) { \
map[#name] = new ObjectFactory##name(); \
} \
} \
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace distributed {
template <typename T>
inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T>
GetBlas() {
auto cpu_ctx = paddle::platform::CPUDeviceContext();
return paddle::operators::math::GetBlas<paddle::platform::CPUDeviceContext,
T>(cpu_ctx);
}
template <typename T>
inline void SQRT(int n, const T* x, T* z) {
for (int i = 0; i < n; ++i) {
z[i] = sqrt(x[i]);
}
}
template <typename T>
inline void ADD(int n, const T* x, const T y, T* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y;
}
}
static bool StartWith(const std::string& str, const std::string& substr) {
return str.find(substr) == 0;
}
static bool EndWith(const std::string& str, const std::string& substr) {
return str.rfind(substr) == (str.length() - substr.length());
}
inline std::vector<int> bucket(const int v_size, const int b_size) {
int remainder = v_size % b_size;
int bucket = v_size / b_size;
std::vector<int> ret_vec(b_size, bucket);
for (int i = 0; i < remainder; ++i) {
ret_vec[i] = ret_vec[i] + 1;
}
int cur_bucket = 0;
for (int& j : ret_vec) {
int tmp = j;
j = cur_bucket;
cur_bucket += tmp;
}
ret_vec.push_back(cur_bucket);
return ret_vec;
}
template <typename T>
std::string to_string(const std::vector<T>& vec) {
std::stringstream ss;
for (const auto& c : vec) {
ss << c << " ";
}
return ss.str();
}
}
}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
namespace paddle {
namespace distributed {
struct CommContext {
CommContext() = default;
CommContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1)
: var_name(name),
splited_varnames(names),
epmap(emap),
height_sections(sections),
origin_varnames(origin_names),
trainer_id(id),
merge_add(merge_add_),
is_sparse(is_sparse_),
is_distributed(is_distributed_),
table_id(table_id_) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
table_id = ctx.table_id;
}
std::string print() const {
std::stringstream ss;
ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
ss << " table_id: " << table_id;
for (size_t i = 0; i < splited_varnames.size(); i++) {
ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
<< " section: " << height_sections[i] << " ";
}
ss << "origin varnames: ";
for (size_t i = 0; i < origin_varnames.size(); i++) {
ss << origin_varnames[i] << " ";
}
ss << " aggregation->add: " << merge_add;
ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n";
return ss.str();
}
std::string var_name;
std::vector<std::string> splited_varnames;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
int table_id;
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/fleet.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace distributed {
using framework::LoDTensor;
using framework::ProgramDesc;
using framework::VarDesc;
using framework::Variable;
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL;
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms,
int max_retry) {
client2client_request_timeout_ms_ = request_timeout_ms;
client2client_connect_timeout_ms_ = connect_timeout_ms;
client2client_max_retry_ = max_retry;
}
void FleetWrapper::LoadSparseOnServer(const std::string& path,
const std::string& meta,
uint32_t table_id) {
VLOG(3) << "load sparse table " << table_id << " with " << path << " meta "
<< meta;
pserver_ptr_->_server_ptr->table(table_id)->load(path, meta);
}
void FleetWrapper::InitServer(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list,
int index) {
if (!is_initialized_) {
VLOG(3) << "Going to init server";
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
new paddle::distributed::PSCore());
pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(),
index);
is_initialized_ = true;
} else {
VLOG(3) << "Server can be initialized only once";
}
}
// void FleetWrapper::InitWorker(
// const std::string& dist_desc, const std::vector<uint64_t>&
// host_sign_list, Scope* scope, const RpcCtxMap& send_ctx, const
// std::unordered_map<uint64_t, std::vector<std::string>>&
// dense_varnames,
// const std::map<std::string, std::string>& envs, int node_num, int index)
// {
// if (!is_initialized_) {
// VLOG(3) << "Going to init worker";
// Communicator::InitInstance<AsyncCommunicator>(
// send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs);
// pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
// new paddle::distributed::PSCore());
// pserver_ptr_->init_worker(dist_desc, _regions,
// const_cast<uint64_t*>(host_sign_list.data()),
// node_num, index);
// is_initialized_ = true;
// } else {
// VLOG(3) << "Worker can be initialized only once";
// }
// }
void FleetWrapper::InitWorker(
const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, Scope* scope,
const RpcCtxMap& send_ctx,
const std::unordered_map<uint64_t, std::vector<std::string>>&
dense_varnames,
const std::map<std::string, std::string>& envs, int node_num, int index) {
if (!is_initialized_) {
VLOG(3) << "Going to init worker";
Communicator::InitInstance<AsyncCommunicator>(
send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs);
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
new paddle::distributed::PSCore());
pserver_ptr_->init_worker(dist_desc, _regions, &host_sign_list, node_num,
index);
is_initialized_ = true;
} else {
VLOG(3) << "Worker can be initialized only once";
}
}
void FleetWrapper::StopServer() {
VLOG(3) << "Going to stop server";
auto* communicator = Communicator::GetInstance();
auto status = communicator->_worker_ptr->stop_server();
status.wait();
}
void FleetWrapper::FinalizeWorker() {
VLOG(3) << "Going to finalize worker";
pserver_ptr_->finalize_worker();
}
void FleetWrapper::BarrierWithTable(uint32_t barrier_type) {
VLOG(3) << "Going to Barrier worker";
auto* communicator = Communicator::GetInstance();
communicator->BarrierWithTable(barrier_type);
}
uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
VLOG(3) << "Going to run server with ip " << ip << " port " << port;
auto ret = pserver_ptr_->run_server(ip, port);
return ret;
}
std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
VLOG(3) << "Going to get client info";
return pserver_ptr_->get_client_info();
return std::vector<uint64_t>();
}
void FleetWrapper::CreateClient2ClientConnection() {
VLOG(3) << "Going to create client2client connection";
pserver_ptr_->create_client2client_connection(
client2client_request_timeout_ms_, client2client_connect_timeout_ms_,
client2client_max_retry_);
}
std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
return pserver_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
}
void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim,
const std::vector<std::string>& var_emb_names) {
std::vector<std::future<int32_t>> pull_sparse_status;
pull_sparse_status.resize(0);
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (size_t var_index = 0; var_index < var_names.size(); ++var_index) {
const std::string& name = var_names[var_index];
Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
// skip slots which do not have embedding
const std::string& emb_name = var_emb_names[var_index];
Variable* emb_var = scope.FindVar(emb_name);
if (emb_var == nullptr) {
continue;
}
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
auto status = pserver_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_sparse_status.push_back(std::move(status));
for (auto& t : pull_sparse_status) {
t.wait();
auto status = t.get();
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
}
void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const LoDTensor*>* inputs,
std::vector<LoDTensor*>* outputs) {
std::vector<uint64_t> fea_keys;
std::vector<float*> pull_result_ptr;
fea_keys.reserve(MAX_FEASIGN_NUM / 100);
pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100);
std::vector<float> init_value(fea_dim, 0);
framework::LoDTensor* output = nullptr;
float* output_data = nullptr;
size_t output_index = -1;
size_t output_len = 0;
for (size_t index = 0; index < inputs->size(); ++index) {
const framework::LoDTensor* tensor = inputs->at(index);
const int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
if (!output || output_len == size_t(output->numel())) {
++output_index;
CHECK(output_index < outputs->size()); // NOLINT
output = outputs->at(output_index);
output->set_lod(tensor->lod());
output_data = output->mutable_data<float>(place);
output_len = 0;
CHECK(output->numel() % fea_dim == 0); // NOLINT
CHECK(output_data != nullptr); // NOLINT
}
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
memcpy(output_data + output_len, init_value.data(),
sizeof(float) * fea_dim);
continue;
}
fea_keys.push_back(real_id);
pull_result_ptr.push_back(output_data + output_len);
}
}
auto* communicator = Communicator::GetInstance();
auto status = communicator->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size());
status.wait();
auto ret = status.get();
if (ret != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]";
sleep(sleep_seconds_before_fail_exit_);
}
}
void FleetWrapper::PullDenseVarsAsync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* pull_dense_status, bool in_cpu) {
auto& regions = _regions[tid];
regions.clear();
regions.resize(var_names.size());
for (auto i = 0u; i < var_names.size(); ++i) {
std::string varname = var_names[i];
if (!in_cpu) {
varname = var_names[i] + "pin";
}
Variable* var = scope.FindVar(varname);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::distributed::Region reg(w, tensor->numel());
regions[i] = std::move(reg);
}
auto status = pserver_ptr_->_worker_ptr->pull_dense(regions.data(),
regions.size(), tid);
pull_dense_status->push_back(std::move(status));
}
void FleetWrapper::PullDenseVarsSync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names) {
auto& regions = _regions[tid];
regions.clear();
regions.reserve(var_names.size());
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>();
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto* communicator = Communicator::GetInstance();
auto status = communicator->_worker_ptr->pull_dense(regions.data(),
regions.size(), tid);
status.wait();
}
void FleetWrapper::PushDenseParamSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {
auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto* communicator = Communicator::GetInstance();
auto push_status = communicator->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status[" << status << "]";
}
void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {}
void FleetWrapper::PushDenseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* push_sparse_status, float scale_datanorm,
int batch_size) {
auto* communicator = Communicator::GetInstance();
PADDLE_ENFORCE_EQ(
communicator->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
communicator->Send(var_names, scope);
}
void FleetWrapper::PushSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::string& grad_varname,
std::vector<std::future<int32_t>>* push_sparse_status) {
std::vector<std::string> varnames;
varnames.push_back(grad_varname);
auto* communicator = Communicator::GetInstance();
PADDLE_ENFORCE_EQ(
communicator->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
communicator->Send(varnames, scope);
}
void FleetWrapper::PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys, const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<std::future<int32_t>>* push_sparse_status, const int batch_size,
const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm) {
// not support
return;
}
void FleetWrapper::PushSparseFromTensorWithLabelAsync(
const Scope& scope, const uint64_t table_id, int fea_dim,
uint64_t padding_id, bool scale_sparse, const std::string& accesor,
const std::string& click_name, platform::Place place,
const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs,
std::vector<const LoDTensor*>* outputs) {
// not support
return;
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
auto ret =
pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
<< ", from path: " << path << " failed";
}
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->save(path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
auto* communicator = Communicator::GetInstance();
auto ret =
communicator->_worker_ptr->save(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "save model of table id: " << table_id
<< ", to path: " << path << " failed";
}
}
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->print_table_stat(table_id);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
LOG(ERROR) << "print table stat failed";
}
}
void FleetWrapper::ShrinkSparseTable(int table_id) {
auto ret = pserver_ptr_->_worker_ptr->shrink(table_id);
ret.wait();
}
void FleetWrapper::ClearModel() {
auto ret = pserver_ptr_->_worker_ptr->clear();
ret.wait();
}
void FleetWrapper::ClearOneTable(const uint64_t table_id) {
auto ret = pserver_ptr_->_worker_ptr->clear(table_id);
ret.wait();
}
void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list,
float decay, int emb_dim) {
std::vector<paddle::distributed::Region> regions;
for (std::string& name : var_list) {
if (name.find("batch_sum") != std::string::npos) {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
VLOG(0) << "prepare shrink dense batch_sum";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
// show_batch_sum += N * log(decay)
std::string size_name = name;
size_name.replace(size_name.find("batch_sum"), size_name.length(),
"batch_size");
Variable* var_size = scope->FindVar(size_name);
CHECK(var_size != nullptr) << "var[" << size_name << "] not found";
VLOG(3) << "shrink dense batch_sum: " << name << ", " << size_name;
float* g_size = var_size->GetMutable<LoDTensor>()->data<float>();
for (int k = 0; k < tensor->numel(); k += emb_dim) {
g[k] = g[k] + g_size[k] * log(decay);
}
paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
} else {
Variable* var = scope->FindVar(name);
CHECK(var != nullptr) << "var[" << name << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>();
paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto push_status = pserver_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
// PADDLE_THORW(platform::errors::Fatal(
// "push shrink dense param failed, status is [%d].", status));
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
void FleetWrapper::ClientFlush() {
auto ret = pserver_ptr_->_worker_ptr->flush();
ret.wait();
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pserver_ptr_=" << pserver_ptr_;
VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr;
return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type,
handler);
}
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type,
to_client_id, msg);
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
};
thread_local engine_wrapper_t r;
return r.engine;
}
size_t FleetWrapper::GetAbsoluteSum(size_t start, size_t end, size_t level,
const framework::LoD& lod) {
if (level >= lod.size() - 1) {
return end - start;
}
size_t ret = 0;
for (size_t i = start; i < end - 1; ++i) {
size_t pos1 = lod[level][i];
size_t pos2 = lod[level][i + 1];
ret += GetAbsoluteSum(pos1, pos2, level + 1, lod);
}
return ret;
}
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/distributed/service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace distributed {
using framework::LoDTensor;
using framework::Scope;
using framework::SelectedRows;
using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {
scale_sparse_gradient_with_batch_size_ = true;
// trainer sleep some time for pserver core dump
sleep_seconds_before_fail_exit_ = 300;
// pserver request server timeout ms
client2client_request_timeout_ms_ = 500000;
// pserver connect server timeout_ms
client2client_connect_timeout_ms_ = 10000;
// pserver request max retry
client2client_max_retry_ = 3;
}
// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
// Pull sparse variables from server in sync mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values,
int fea_dim,
const std::vector<std::string>& var_emb_names);
// Pull sparse variables from server in async mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values std::future
std::future<int32_t> PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_dim);
// Pull sparse variables from server in sync mode
// pull immediately to tensors
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<LoDTensor*>* outputs); // NOLINT
// pull dense variables from server in sync mod
// Param<in>: scope, table_id, var_names
// Param<out>: void
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// pull dense variables from server in async mod
// Param<in>: scope, table_id, var_names
// Param<out>: pull_dense_status
void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* pull_dense_status,
bool in_cpu);
// push dense parameters(not gradients) to server in sync mode
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* push_sparse_status,
float scale_datanorm, int batch_size);
// push dense variables to server in sync mode
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PushSparseVarsAsync(
const Scope& scope, const uint64_t table_id, const std::string& grad,
std::vector<std::future<int32_t>>* push_sparse_status);
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
// sparse_grad_names, batch_size, use_cvm, dump_slot
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in async mode
void PushSparseFromTensorWithLabelAsync(
const Scope& scope, const uint64_t table_id, int fea_dim,
uint64_t padding_id, bool scale_sparse, const std::string& accesor,
const std::string& click_name, platform::Place place,
const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<const LoDTensor*>* outputs); // NOLINT
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
// init server
void LoadSparseOnServer(const std::string& path, const std::string& meta,
uint32_t table_id);
// init server
// void InitServer(const std::string& dist_desc,
// const std::vector<uint64_t>& host_sign_list, int index);
void InitServer(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index);
// init trainer
void InitWorker(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, Scope* scope,
const RpcCtxMap& send_ctx,
const std::unordered_map<uint64_t, std::vector<std::string>>&
dense_varnames,
const std::map<std::string, std::string>& envs, int node_num,
int index);
// stop server
void StopServer();
// finalize worker to make worker can be stop
void FinalizeWorker();
// run server with ip port
uint64_t RunServer(const std::string& ip, uint32_t port);
// get client info
std::vector<uint64_t> GetClientsInfo();
// create client to client connection
void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// barrier with barrier table
void BarrierWithTable(uint32_t barrier_type);
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// clear all models, release their memory
void ClearModel();
// clear one table
void ClearOneTable(const uint64_t table_id);
// shrink sparse table
void ShrinkSparseTable(int table_id);
// shrink dense table
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay,
int emb_dim);
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
// register client to client communication
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
// send client to client message
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
// FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::FleetWrapper());
}
return s_instance_;
}
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
private:
static std::shared_ptr<FleetWrapper> s_instance_;
size_t GetAbsoluteSum(size_t start, size_t end, size_t level,
const framework::LoD& lod);
protected:
static bool is_initialized_;
std::map<uint64_t, std::vector<paddle::distributed::Region>> _regions;
bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_;
int client2client_request_timeout_ms_;
int client2client_connect_timeout_ms_;
int client2client_max_retry_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;
message FsClientParameter {
enum FsApiType {
HDFS = 0;
AFS = 1;
}
optional FsApiType fs_type = 1 [ default = HDFS ];
optional string uri = 2; // such as afs://xxx.afs.com:9902
optional string user = 3; // user_name to access fs
optional string passwd = 4; // password
optional int32 buffer_size = 5; // buffer for read/write
optional string hadoop_bin = 51;
optional string afs_conf = 101;
}
message PSParameter {
optional string worker_class = 1;
optional string server_class = 2;
optional string instance_class = 3;
optional string init_gflags = 4 [ default = "" ];
optional WorkerParameter worker_param = 101;
optional ServerParameter server_param = 102;
repeated DownpourTrainerParameter trainer_param = 301;
optional FsClientParameter fs_client_param = 501;
}
message WorkerParameter {
optional DownpourWorkerParameter downpour_worker_param = 1;
}
message DownpourWorkerParameter {
repeated TableParameter downpour_table_param = 1;
}
message DownpourServerParameter {
repeated TableParameter downpour_table_param = 1;
optional ServerServiceParameter service_param = 2;
}
message ServerParameter {
optional DownpourServerParameter downpour_server_param = 1;
}
message DownpourTrainerParameter {
repeated DenseTableParameter dense_table = 1;
repeated SparseTableParameter sparse_table = 2;
optional int32 push_sparse_per_batch = 3;
optional int32 push_dense_per_batch = 4;
repeated string skip_op = 5;
repeated ProgramConfig program_config = 6;
}
message DenseTableParameter {
optional int32 table_id = 1;
repeated string dense_variable_name = 2;
repeated string dense_gradient_variable_name = 3;
optional int32 fea_dim = 4;
}
message SparseTableParameter {
optional int32 table_id = 1;
optional int32 feature_dim = 2;
repeated string slot_key = 3;
repeated string slot_value = 4;
repeated string slot_gradient = 5;
}
message ServerServiceParameter {
optional string server_class = 1 [ default = "BrpcPsServer" ];
optional string client_class = 2 [ default = "BrpcPsClient" ];
optional string service_class = 3 [ default = "PsService" ];
optional uint32 start_server_port = 4
[ default = 0 ]; // will find a avaliable port from it
optional uint32 server_thread_num = 5 [ default = 12 ];
}
message ProgramConfig {
required string program_id = 1;
repeated int32 push_sparse_table_id = 2;
repeated int32 push_dense_table_id = 3;
repeated int32 pull_sparse_table_id = 4;
repeated int32 pull_dense_table_id = 5;
}
enum TableType {
PS_SPARSE_TABLE = 0;
PS_DENSE_TABLE = 1;
PS_OTHER_TABLE = 2;
}
message TableParameter {
optional uint64 table_id = 1;
optional string table_class = 2;
optional uint64 shard_num = 3 [ default = 1000 ];
optional TableAccessorParameter accessor = 4;
optional TensorAccessorParameter tensor = 5;
optional CommonAccessorParameter common = 6;
optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ];
}
message TableAccessorParameter {
optional string accessor_class = 1;
optional uint32 fea_dim = 4 [ default = 11 ];
optional uint32 embedx_dim = 5 [ default = 8 ];
optional uint32 embedx_threshold = 6 [ default = 10 ];
repeated TableAccessorSaveParameter table_accessor_save_param = 8;
}
message TensorAccessorParameter {
optional string tensor_class = 1;
optional uint32 fea_dim = 2;
optional uint32 emb_dim = 3;
optional string param = 4;
optional string grad = 5;
optional string common_block_map = 6;
}
message CommonAccessorParameter {
optional string name = 1;
optional string table_name = 2;
repeated string attributes = 3;
repeated string params = 4;
repeated uint32 dims = 5;
repeated string initializers = 6;
optional int32 trainer_num = 7;
optional bool sync = 8;
}
message TableAccessorSaveParameter {
optional uint32 param = 1;
optional string converter = 2;
optional string deconverter = 3;
}
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
PROTO sendrecv.proto
DEPS ${BRPC_DEPS} )
set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS})
cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS})
cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
const static int max_port = 65535;
DEFINE_int32(pserver_push_dense_merge_limit, 12,
"limit max push_dense local merge requests");
DEFINE_int32(pserver_push_sparse_merge_limit, 12,
"limit max push_sparse local merge requests");
DEFINE_int32(pserver_pull_dense_limit, 12,
"limit max push_sparse local merge requests");
DEFINE_int32(pserver_async_push_dense_interval_ms, 10,
"async push_dense to server interval");
DEFINE_int32(pserver_async_push_sparse_interval_ms, 10,
"async push_sparse to server interval");
DEFINE_bool(pserver_scale_gradient_by_merge, false,
"scale dense gradient when merged");
DEFINE_int32(pserver_communicate_compress_type, 0,
"none:0 snappy:1 gzip:2 zlib:3 lz4:4");
DEFINE_int32(pserver_max_async_call_num, 13,
"max task num in async_call_server");
DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms, 10000,
"pserver connect server timeout_ms");
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
namespace paddle {
namespace distributed {
inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
uint64_t key) {
size_t remind = shard_num % server_num;
size_t local_shard_num =
remind == 0 ? shard_num / server_num : shard_num / server_num + 1;
return (key % shard_num) / local_shard_num;
}
void DownpourPsClientService::service(
::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
int ret = _client->handle_client2client_msg(
request->cmd_id(), request->client_id(), request->data());
response->set_err_code(0);
response->set_err_msg("");
if (ret != 0) {
response->set_err_code(-1);
response->set_err_msg("handle_client2client_msg failed");
}
}
// 启动client端RpcService 用于数据互发等操作
int32_t BrpcPsClient::start_client_service() {
if (_service.configure(this, _client_id) != 0) {
LOG(ERROR)
<< "service initialize failed, service_name:DownpourPsClientService";
return -1;
}
_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
int start_port = 8500;
options.num_threads = 24;
if (_server.Start(butil::my_ip_cstr(), brpc::PortRange(start_port, max_port),
&options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed";
return -1;
}
_env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port,
_client_id);
return 0;
}
int32_t BrpcPsClient::create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms = pserver_connect_timeout_ms;
options.max_retry = max_retry;
std::vector<PSHost> client_list = _env->get_ps_clients();
_client_channels.resize(client_list.size());
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < client_list.size(); ++i) {
server_ip_port.assign(client_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(client_list[i].port));
_client_channels[i].reset(new brpc::Channel());
if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "psclient connect to client:" << server_ip_port
<< " Failed!";
}
os << server_ip_port << ",";
}
LOG(INFO) << "Client connect success:" << os.str();
return 0;
}
int32_t BrpcPsClient::initialize() {
_async_call_num = 0;
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
std::ostringstream os;
std::string server_ip_port;
std::string client_ip(butil::my_ip_cstr());
// 获取server列表,并连接
std::vector<PSHost> server_list = _env->get_ps_servers();
_server_channels.resize(server_list.size());
for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(server_list[i].port));
for (size_t j = 0; j < _server_channels[i].size(); ++j) {
_server_channels[i][j].reset(new brpc::Channel());
if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "psclient connect to server:" << server_ip_port
<< " Failed!";
return -1;
}
}
os << server_ip_port << ",";
}
// 启动client探听接口, 并相互建立连接
start_client_service();
_running = true;
_flushing = false;
return 0;
}
int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) {
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
if (_responses[request_idx].err_code() != 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return 0;
}
int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) {
uint32_t feasign_size = 0;
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
feasign_size = _responses[request_idx].err_code();
if (feasign_size < 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return feasign_size;
}
std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
std::string data = _responses[request_idx].data();
return data;
}
std::future<int32_t> BrpcPsClient::print_table_stat(uint32_t table_id) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, table_id](void *done) {
int ret = 0;
uint64_t feasign_size = 0;
uint64_t mf_size = 0;
paddle::framework::BinaryArchive ar;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) {
ret = -1;
break;
}
std::string resp = closure->get_response(i, PS_PRINT_TABLE_STAT);
ar.SetReadBuffer(const_cast<char *>(resp.c_str()), resp.length(),
nullptr);
feasign_size += ar.Get<uint64_t>();
mf_size += ar.Get<uint64_t>();
}
closure->set_promise_value(ret);
std::cout << "table id: " << table_id
<< ", feasign size: " << feasign_size
<< ", mf size: " << mf_size << std::endl;
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::send_cmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::send_save_cmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
uint32_t feasign_size = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_save_response(i, cmd_id) < 0) {
ret = -1;
break;
}
feasign_size += closure->check_save_response(i, cmd_id);
}
if (ret == 0) {
closure->set_promise_value(feasign_size);
} else {
closure->set_promise_value(ret);
}
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id) {
return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")});
}
std::future<int32_t> BrpcPsClient::load(const std::string &epoch,
const std::string &mode) {
return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) {
return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
}
std::future<int32_t> BrpcPsClient::clear(uint32_t table_id) {
return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {});
}
std::future<int32_t> BrpcPsClient::flush() {
_flushing = true;
std::promise<int> promise;
std::future<int32_t> fut = promise.get_future();
do {
VLOG(3) << "wait _async_call_num:" << _async_call_num;
usleep(100000); // sleep 100ms wait async end
} while (_async_call_num > 0);
promise.set_value(0);
_flushing = false;
return fut;
}
void BrpcPsClient::finalize_worker() {
flush();
_running = false;
_server.Stop(1000);
_server.Join();
}
std::future<int32_t> BrpcPsClient::stop_server() {
return send_cmd(-1, PS_STOP_SERVER, {});
}
std::future<int32_t> BrpcPsClient::start_profiler() {
return send_cmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::stop_profiler() {
return send_cmd(-1, PS_STOP_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
uint32_t barrier_type) {
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}
std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) {
auto *accessor = table_accessor(table_id);
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(1, [keys, values, accessor](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
uint32_t shard_nums;
if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) {
ret = -1;
}
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t));
keys->resize(shard_nums);
values->resize(shard_nums * accessor->update_dim());
io_buffer_itr.copy_and_forward((void *)(keys->data()),
sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward((void *)(values->data()),
shard_nums * accessor->update_size());
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
PsService_Stub rpc_stub(get_cmd_channel(pserver_idx));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
std::future<int32_t> BrpcPsClient::push_sparse_param(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) {
auto *accessor = table_accessor(table_id);
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
size_t request_call_num = _server_channels.size();
std::vector<std::vector<uint64_t>> ids;
std::vector<std::vector<const float *>> value_ptrs;
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = keys[i] % request_call_num;
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size();
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t));
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size());
push_data_ptr += accessor->update_size();
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
closure->response(shard_idx), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = table_accessor(table_id);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num,
accessor](void *done) {
int ret = 0;
size_t region_idx = 0; // 当前填充的region偏移
size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = (DownpourBrpcClosure *)done;
size_t shard_data_size = num_per_shard * accessor->select_size();
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1;
break;
}
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t shard_buffer_remain = res_io_buffer.size();
if (shard_buffer_remain != shard_data_size) {
LOG(ERROR) << "expect res_size:" << shard_data_size
<< ", but size:" << shard_buffer_remain
<< ", ignore this response";
ret = -1;
break;
}
while (shard_buffer_remain > 0 && region_idx < region_num) {
auto &region = regions[region_idx];
if (region.size - region_data_idx >= shard_buffer_remain) {
// region待填充空间 >= 分片buffer数据, 直接拷贝置入
io_buffer_itr.copy_and_forward(
(void *)(region.data + region_data_idx), shard_buffer_remain);
region_data_idx += shard_buffer_remain;
shard_buffer_remain = 0;
} else if (region.size - region_data_idx == 0) {
// region填满,切换到下一个region
++region_idx;
region_data_idx = 0;
} else {
// region不足以容纳所有数据,则能放多少 拷贝多少
io_buffer_itr.copy_and_forward(
(void *)(region.data + region_data_idx),
region.size - region_data_idx);
shard_buffer_remain -= (region.size - region_data_idx);
++region_idx;
region_data_idx = 0;
}
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&num_per_shard,
sizeof(num_per_shard));
PsService_Stub rpc_stub(get_dense_channel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = table_accessor(table_id);
size_t request_call_num = _server_channels.size();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
size_t shard_data_size = num_per_shard * accessor->update_size();
size_t current_region_idx = 0;
size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) {
size_t shard_data_remain_size = shard_data_size;
while (shard_data_remain_size > 0 && current_region_idx < region_num) {
const auto &region = regions[current_region_idx];
size_t region_remain_size = region.size - current_region_data_idx;
if (shard_data_remain_size >= region_remain_size) {
regions_partition[i].push_back(
Region(region.data + current_region_data_idx, region_remain_size));
++current_region_idx;
current_region_data_idx = 0;
shard_data_remain_size -= region_remain_size;
} else {
regions_partition[i].push_back(Region(
region.data + current_region_data_idx, shard_data_remain_size));
current_region_data_idx += shard_data_remain_size;
shard_data_remain_size = 0;
}
}
}
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10;
static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐
//开始多shard并行拷贝&请求
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append((void *)&num_per_shard, sizeof(uint32_t));
auto &region_list = regions_partition[i];
size_t fill_remain_size = shard_data_size;
for (auto &region : region_list) {
fill_remain_size -= region.size;
request_buffer.append((void *)region.data, region.size);
}
//保证各分片数据对齐
while (fill_remain_size > 0) {
size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE
? REGION_ASSIGN_BUFFER_SIZE
: fill_remain_size;
request_buffer.append((void *)region_assign_buffer, fill_num);
fill_remain_size -= fill_num;
}
PsService_Stub rpc_stub(get_dense_channel(i));
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) {
auto *accessor = table_accessor(table_id);
//发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
size_t request_call_num = _server_channels.size();
std::vector<std::vector<uint64_t>> ids;
std::vector<std::vector<const float *>> value_ptrs;
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = keys[i] % request_call_num;
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size();
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t));
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size());
push_data_ptr += accessor->update_size();
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx),
closure->response(shard_idx), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
auto *accessor = table_accessor(table_id);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard, num_per_shard * sizeof(float));
VLOG(1) << "push_dense_raw_gradient finish memcpy";
// closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(get_dense_channel(i));
VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i;
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
VLOG(1) << "push_dense_raw_gradient async service " << i;
}
return fut;
}
std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t shard_id = keys[i] % request_call_num;
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < ids.size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy((void *)kv_pair->second, (void *)last_value_data,
value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward((void *)(last_value_data),
value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(), sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});
uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append((void *)&last_key, sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
}
}
if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count,
sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
}
return fut;
}
std::future<int32_t> BrpcPsClient::send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
if (to_client_id >= _client_channels.size()) {
LOG(FATAL) << "to_client_id is out of range clients, which size is "
<< _client_channels.size();
promise->set_value(-1);
return fut;
}
auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourBrpcClosure *)done;
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
closure->add_promise(promise);
closure->request(0)->set_cmd_id(msg_type);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_client_channels[to_client_id].get());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) {
auto *accessor = table_accessor(table_id);
size_t value_size = accessor->update_size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
// 发送RPC请求
auto *push_request = closure->request(0);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&num, sizeof(uint32_t));
auto *push_data = push_request->mutable_data();
push_data->resize(num * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, keys, num * sizeof(uint64_t));
push_data_ptr += num * sizeof(uint64_t);
for (int i = 0; i < num; ++i) {
memcpy(push_data_ptr, update_values[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(get_sparse_channel(pserver_idx));
closure->cntl(0)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/ps_client.h"
namespace paddle {
namespace distributed {
class DownpourPsClientService : public PsService {
public:
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}
virtual int32_t configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
size_t _rank;
PSClient *_client;
};
class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourBrpcClosure() {}
virtual void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
template <class T>
struct array_deleter {
void operator()(T *&x) const { delete[] x; }
};
class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
// _running = false;
// try {
// _async_push_dense_thread.join();
// _async_push_sparse_thread.join();
//} catch (...) {
//}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
virtual std::future<int32_t> shrink(uint32_t table_id) override;
virtual std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> clear() override;
virtual std::future<int32_t> clear(uint32_t table_id) override;
virtual std::future<int32_t> stop_server() override;
virtual std::future<int32_t> start_profiler() override;
virtual std::future<int32_t> stop_profiler() override;
virtual void finalize_worker() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num);
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> flush();
virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override;
private:
virtual int32_t initialize() override;
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
inline brpc::Channel *get_sparse_channel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get();
}
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; //异步请求计数
std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
virtual size_t get_server_nums() { return _server_channels.size(); }
private:
int32_t start_client_service();
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
DownpourPsClientService _service;
std::atomic_uint grad_num_{0};
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "Eigen/Dense"
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
int32_t BrpcPsServer::initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service = CREATE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
return -1;
}
_service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
}
if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(ERROR) << "service add to brpc failed, service:"
<< service_config.service_class();
return -1;
}
return 0;
}
uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
int num_threads = std::thread::hardware_concurrency();
brpc::ServerOptions options;
options.num_threads = num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port;
return 0;
}
VLOG(0) << "BrpcPsServer::start registe_ps_server";
_environment->registe_ps_server(ip, port, _rank);
VLOG(0) << "BrpcPsServer::start wait";
cv_.wait(lock, [&] { return stoped_; });
PSHost host;
host.ip = ip;
host.port = port;
host.rank = _rank;
VLOG(0) << "BrpcPsServer::start return host.rank";
return host.rank;
}
int32_t BrpcPsServer::port() { return _server.listen_address().port; }
int32_t PsService::initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &PsService::stop_server;
_service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table;
_service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table;
_service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table;
_service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param;
_service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat;
_service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param;
_service_handler_map[PS_BARRIER] = &PsService::barrier;
_service_handler_map[PS_START_PROFILER] = &PsService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
return 0;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t PsService::initialize_shard_info() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
size_t shard_num = _server->environment()->get_ps_servers().size();
auto &table_map = *(_server->table());
for (auto itr : table_map) {
itr.second->set_shard(_rank, shard_num);
}
_is_initialize_shard_info = true;
}
return 0;
}
void PsService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
return;
}
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->table(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_dense");
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 1 for num of dense");
return 0;
}
uint32_t num = *(const uint32_t *)request.params(0).c_str();
if (num < 0) {
set_response_code(response, -1,
"PsRequestMessage.datas[0] is invalid, num must >= 0");
return 0;
}
std::vector<float> res_data;
res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_dense(res_data.data(), num);
cntl->response_attachment().append((char *)res_data.data(),
res_data.size() * sizeof(float));
return 0;
}
int32_t PsService::push_dense_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense_param");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer;
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
push_buffer.resize(0);
push_buffer.reserve(req_buffer_size);
const char *data = (const char *)cntl->request_attachment().fetch(
const_cast<char *>(push_buffer.data()), req_buffer_size);
uint32_t num = *(const uint32_t *)data;
const float *values = (const float *)(data + sizeof(uint32_t));
if (table->push_dense_param(values, num) != 0) {
set_response_code(response, -1, "push_dense_param failed");
}
return 0;
}
int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense");
CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
// set_response_code(response, 0, "push dense data is empty");
return 0;
}
/*
Push Content:
|--num--|---valuesData---|
|--4B---|----------------|
*/
uint32_t num = *(const uint32_t *)(request.data().data());
const float *values =
(const float *)(request.data().data() + sizeof(uint32_t));
if (table->push_dense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed");
}
return 0;
}
int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type);
return 0;
}
int32_t PsService::push_sparse_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse_param");
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse_param(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse_param error");
}
return 0;
}
int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_geo_param");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
auto trainer_id = request.client_id();
std::vector<float> values;
std::vector<uint64_t> ids;
table->pull_geo_param(trainer_id, &values, &ids);
uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
cntl->response_attachment().append((char *)ids.data(),
ids.size() * sizeof(uint64_t));
cntl->response_attachment().append((char *)values.data(),
values.size() * sizeof(float));
return 0;
}
int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_sparse");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
push_sparse_request_buffer.resize(0);
push_sparse_request_buffer.reserve(req_buffer_size);
const char *data = (const char *)cntl->request_attachment().fetch(
const_cast<char *>(push_sparse_request_buffer.data()), req_buffer_size);
/*
Attachment Content:
|---keysData---|
|---8*{num}B---|
*/
const uint64_t *keys = (const uint64_t *)data;
std::vector<float> res_data;
res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_sparse(res_data.data(), keys, num);
cntl->response_attachment().append((char *)res_data.data(),
res_data.size() * sizeof(float));
return 0;
}
int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse");
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse error");
}
return 0;
}
int32_t PsService::print_table_stat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info);
return 0;
}
int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 2, path&mode");
return -1;
}
table->flush();
int32_t feasign_size = 0;
feasign_size = table->save(request.params(0), request.params(1));
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
int32_t all_feasign_size = 0;
int32_t feasign_size = 0;
for (auto &itr : table_map) {
feasign_size = save_one_table(itr.second.get(), request, response, cntl);
if (feasign_size < 0) {
LOG(ERROR) << "save table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->flush();
if (table->shrink() != 0) {
set_response_code(response, -1, "table shrink failed");
}
return 0;
}
int32_t PsService::clear_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->flush();
table->clear();
return 0;
}
int32_t PsService::clear_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
for (auto &itr : table_map) {
if (clear_one_table(itr.second.get(), request, response, cntl) != 0) {
return -1;
}
}
return 0;
}
int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
std::thread t_stop([p_server]() {
p_server->stop();
LOG(INFO) << "Server Stoped";
});
t_stop.detach();
return 0;
}
int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/service/server.h"
namespace paddle {
namespace distributed {
class BrpcPsServer : public PSServer {
public:
BrpcPsServer() {}
virtual ~BrpcPsServer() {}
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t stop() {
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
virtual int32_t port();
private:
virtual int32_t initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class PsService;
typedef int32_t (PsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl);
class PsService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
int32_t initialize_shard_info();
int32_t pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
};
class DownpourPServerBrpcClosure : public PServerClosure {
public:
DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
: PServerClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id) { return 1; }
int check_save_response(size_t request_idx, int cmd_id) { return 1; }
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include <limits>
#include <memory>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace distributed {
framework::proto::VarType::Type VarMessageToVarType(
VariableMessage::Type type) {
switch (type) {
case VariableMessage::FP32:
return framework::proto::VarType::FP32; // NOLINT
case VariableMessage::FP64:
return framework::proto::VarType::FP64; // NOLINT
case VariableMessage::INT32:
return framework::proto::VarType::INT32; // NOLINT
case VariableMessage::INT64:
return framework::proto::VarType::INT64; // NOLINT
case VariableMessage::BOOL:
return framework::proto::VarType::BOOL; // NOLINT
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType:Unsupported type %d", type));
}
}
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope,
MultiVarMsg* request, butil::IOBuf* iobuf) {
// 1. message_name
request->set_message_name(message_name);
// 2. var_names
for (auto& send_var_name : send_var_name_val) {
request->add_send_var_names(send_var_name);
}
for (auto& recv_var_name : recv_var_name_val) {
request->add_recv_var_names(recv_var_name);
}
// 3. VarMessage
for (auto& send_var_name : send_var_name_val) {
auto* send_var_msg = request->add_var_messages();
butil::IOBuf temp_iobuf;
send_var_msg->set_varname(send_var_name);
framework::Variable* var = scope->FindVar(send_var_name);
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<framework::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
iobuf->append(temp_iobuf);
}
}
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::LOD_TENSOR);
const framework::LoD lod = tensor->lod();
if (lod.size() > 0) {
var_msg->set_lod_level(lod.size());
for (auto& each : lod) {
VarMsg::LodData* lod_inner = var_msg->add_lod();
for (auto& d : each) {
lod_inner->add_lod_data(d);
}
}
}
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
for (auto& dim : framework::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data<void>()),
data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) {
framework::SelectedRows* slr = var->GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::SELECTED_ROWS);
var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data();
var_data->clear();
var_data->resize(rows->size() * sizeof(int64_t));
char* data_ptr = const_cast<char*>(var_data->data());
if (platform::is_cpu_place(tensor->place())) {
memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t));
} else {
#ifdef PADDLE_WITH_CUDA
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), data_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
&(*rows)[0], rows->size() * sizeof(int64_t), stream);
#endif
}
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
for (auto& dim : framework::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data<void>()),
data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->FindVar(msg.varname());
PADDLE_ENFORCE_NE(var, nullptr,
platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr,
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(framework::make_ddim(vec_dim));
framework::LoD lod;
for (int i = 0; i < msg.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < msg.lod(i).lod_data_size(); ++j) {
v.push_back(msg.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
void* tensor_data =
tensor->mutable_data(place, VarMessageToVarType(msg.data_type()));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
unsigned long data_len;
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data,
platform::CPUPlace(), (void*)temp_ptr,
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
delete[] temp_ptr;
#endif
}
}
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr,
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
auto* slr = var->GetMutable<framework::SelectedRows>();
framework::Tensor* tensor = slr->mutable_value();
slr->set_height(msg.slr_height());
std::vector<int64_t> tmp_rows(msg.slr_height());
memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t));
slr->set_rows(tmp_rows);
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(framework::make_ddim(vec_dim));
void* tensor_data =
tensor->mutable_data(place, VarMessageToVarType(msg.data_type()));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward(temp_ptr, data_len);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data,
platform::CPUPlace(), temp_ptr,
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
delete[] temp_ptr;
#endif
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/port.h"
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope,
MultiVarMsg* var_msg, butil::IOBuf* iobuf);
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf);
// Deserialize for Server
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope);
// Deserialize for Client
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope);
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& iobuf,
const platform::DeviceContext& ctx);
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& iobuf,
const platform::DeviceContext& ctx);
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/communicator.h"
#include <google/protobuf/text_format.h>
#include "paddle/fluid/distributed/table/table.h"
#include <gflags/gflags.h>
#include <paddle/fluid/framework/program_desc.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <map>
#include <thread> // NOLINT
#include <unordered_set>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace distributed {
using framework::LoDTensor;
using framework::SelectedRows;
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
Communicator::Communicator() {}
void Communicator::init_gflag(const std::string &gflags) {
VLOG(0) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40");
flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950");
}
auto it = flags.begin();
flags.insert(it, "exe default");
char *flags_ptr[flags.size()];
for (size_t i = 0; i < flags.size(); ++i) {
flags_ptr[i] = (char *)(flags[i].c_str());
}
int params_cnt = flags.size();
char **params_ptr = &(flags_ptr[0]);
::google::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
std::once_flag Communicator::init_flag_;
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void Communicator::InitBrpcClient(
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list) {
// not used, just for psclient's init
std::map<uint64_t, std::vector<paddle::distributed::Region>>
_dense_pull_regions;
for (auto &iter : recv_varname_to_ctx_) {
auto tid = iter.first;
auto var_names = iter.second;
auto &regions = _dense_pull_regions[tid];
regions.reserve(var_names.size());
for (auto &t : var_names) {
Variable *var = recv_scope_->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
float *w = tensor->data<float>();
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
if (_worker_ptr.get() == nullptr) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_);
_worker_ptr = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(_ps_param));
_worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env,
trainer_id_);
}
return;
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvDense");
std::vector<paddle::distributed::Region> regions;
regions.reserve(varnames.size());
for (auto &t : varnames) {
Variable *var = scope->Var(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
Variable *temp_var = xpu_temp_scope_->Var(t);
LoDTensor *temp_tensor = temp_var->GetMutable<LoDTensor>();
temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
} else {
float *w = tensor->mutable_data<float>(tensor->place());
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto status =
_worker_ptr->pull_dense(regions.data(), regions.size(), table_id);
status.wait();
for (auto &t : varnames) {
Variable *var = scope->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
xpu_temp_scope_->FindVar(t)->GetMutable<LoDTensor>();
framework::TensorCopy(*temp_tensor, tensor->place(), tensor);
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
}
}
return;
}
void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id, const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendDenseParam");
auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions;
for (auto &t : varnames) {
Variable *var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor *tensor = var->GetMutable<LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
Variable *temp_var = xpu_temp_scope_->Var(t);
LoDTensor *temp_tensor = temp_var->GetMutable<LoDTensor>();
temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor);
paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t
<< " table_id " << table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
} else {
float *w = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t
<< " talbe_id " << table_id << " Temp_data[0] " << w[0]
<< " Temp_data[-1] " << w[tensor->numel() - 1];
}
}
auto status =
_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id);
status.wait();
VLOG(4) << "RPC Send Dense Param " << table_id << " done!";
return;
}
void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendDense");
auto &var_names = ctx.origin_varnames;
auto &table_id = ctx.table_id;
auto dense_data = std::make_shared<std::vector<float>>();
size_t request_call_num = _worker_ptr->get_server_nums();
uint32_t num_per_shard =
dense_dim_per_shard(ctx.height_sections[0], request_call_num);
dense_data->resize(num_per_shard *
request_call_num); // accessor->update_dim() = 1
float *data = dense_data->data();
uint32_t pos = 0;
for (size_t i = 0; i < var_names.size(); ++i) {
const LoDTensor tensor = scope.FindVar(var_names[i])->Get<LoDTensor>();
size_t count = static_cast<size_t>(tensor.numel());
const float *g = tensor.data<float>();
CHECK(pos + count <= dense_data->size())
<< "invalid dense size, cur pos[" << pos << "]"
<< " data_num[" << count << "] size[" << dense_data->size() << "]";
memcpy(data + pos, g, count * sizeof(float));
pos += count;
}
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_dense_raw_gradient(
table_id, data, dense_data->size(), closure);
status.wait();
return;
}
void Communicator::RpcSendSparseParam(const std::string &varname, int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparseParam");
size_t request_call_num = _worker_ptr->get_server_nums();
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(varname);
auto *tensor = send_var->GetMutable<framework::LoDTensor>();
auto dim = tensor->dims()[1];
uint64_t sparse_num = static_cast<uint64_t>(tensor->dims()[0]);
std::vector<uint64_t> sparse_push_keys(sparse_num);
std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0);
push_g_vec.reserve(sparse_num);
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * dim);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_PARAM) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->push_sparse_param(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
status.wait();
return;
}
void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparse");
size_t request_call_num = _worker_ptr->get_server_nums();
std::vector<uint64_t> sparse_push_keys;
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(var_name);
auto *tensor = send_var->GetMutable<SelectedRows>();
auto dim = tensor->value().dims()[1];
std::transform(tensor->rows().begin(), tensor->rows().end(),
std::back_inserter(sparse_push_keys),
[&](int id) { return static_cast<uint64_t>(id); });
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_sparse_raw_gradient(
table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(),
sparse_push_keys.size(), closure);
status.wait();
return;
}
void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvSparse");
auto *send_var = scope->Var(varname);
auto *tensor = send_var->GetMutable<framework::LoDTensor>();
auto dim = tensor->dims()[1];
uint64_t sparse_num = static_cast<uint64_t>(tensor->dims()[0]);
std::vector<uint64_t> sparse_push_keys(sparse_num);
std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0);
std::vector<float *> push_g_vec;
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * dim);
}
auto status = _worker_ptr->pull_sparse((float **)push_g_vec.data(), table_id,
sparse_push_keys.data(),
sparse_push_keys.size());
status.wait();
return;
}
void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
if (trainer_id_ == 0) {
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcSendDenseParam(varnames, table_id, *recv_scope_);
VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done";
}
BarrierWithTable(1);
} else {
BarrierWithTable(1);
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
}
BarrierWithTable(1);
return;
}
void Communicator::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = _worker_ptr->start_profiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = _worker_ptr->stop_profiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
}
void AsyncCommunicator::RecvThread() {
if (!independent_recv_) return;
VLOG(3) << "Independent RecvThread Start and Wait";
while (running_) {
int grad_num = grad_num_.load();
if (grad_num > min_send_grad_num_before_recv_) {
RecvByCommunicator();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
VLOG(1) << "communicator stopped, independent recv thread exit";
}
void AsyncCommunicator::RecvByCommunicator() {
if (!running_) return;
RecvNoBarrier();
VLOG(3) << "run recv graph end";
}
void AsyncCommunicator::RecvNoBarrier() {
for (auto &iter : recv_varname_to_ctx_) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
}
for (auto &iter : recv_varname_to_ctx_) {
auto var_names = iter.second;
for (auto &t : var_names) {
Variable *var = recv_scope_->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
xpu_temp_scope_->FindVar(t)->GetMutable<LoDTensor>();
framework::TensorCopy(*temp_tensor, tensor->place(), tensor);
#endif
}
}
}
return;
}
void AsyncCommunicator::SendByCommunicator() {
std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto send_recv_task = [this, &ctx] {
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
size_t var_nums = varnames.size();
auto &check_queue = send_varname_to_queue_[varnames[0]];
std::vector<std::vector<std::shared_ptr<Variable>>> vars;
vars.resize(var_nums);
int merged_var_num = 0;
int wait_times = 0;
while (merged_var_num < max_merge_var_num_) {
if (check_queue->Size() == 0) {
VLOG(4) << "wait_times -> " << wait_times;
if (wait_times >= send_wait_times_) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
} else {
wait_times = 0;
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
auto &var_queue = send_varname_to_queue_[var_name];
vars[i].push_back(var_queue->Pop());
}
merged_var_num++;
}
}
if (merged_var_num == 0) return;
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
} else {
RpcSendDense(ctx, *send_scope_);
if (!independent_recv_ &&
recv_varname_to_ctx_.find(table_id) != recv_varname_to_ctx_.end()) {
auto recv_varnames = recv_varname_to_ctx_.at(table_id);
RpcRecvDense(recv_varnames, table_id, recv_scope_);
}
}
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
return;
}
void AsyncCommunicator::MainThread() {
VLOG(3) << "AsyncCommunicator MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
RpcProfilerControl();
}
VLOG(1) << "communicator stopped, send thread exit";
}
void HalfAsyncCommunicator::MainThread() {
VLOG(3) << "HalfAsyncCommunicator MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
}
VLOG(1) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
send_scope_.reset(new Scope());
xpu_temp_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto &varnames = ctx.origin_varnames;
for (auto &var_name : varnames) {
send_varname_to_queue_[var_name] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
send_queue_size_);
}
}
send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
AsyncCommunicator::~AsyncCommunicator() {
running_ = false;
if (main_thread_) main_thread_->join();
if (recv_thread_) recv_thread_->join();
}
void AsyncCommunicator::Start() {
VLOG(1) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
VLOG(1) << "start send thread and recv thread";
waiting_ = true;
running_ = true;
// flushing_ = false;
BarrierTriggerReset(max_merge_var_num_);
// start send and recv thread
main_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
if (independent_recv_) {
recv_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::RecvThread, this)));
}
}
}
void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
if (recv_thread_) {
VLOG(1) << "stop recv thread";
recv_thread_->join();
recv_thread_.reset(nullptr);
}
if (main_thread_) {
VLOG(1) << "stop main thread";
main_thread_->join();
main_thread_.reset(nullptr);
}
}
VLOG(1) << "Communicator stop done";
}
bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
PADDLE_ENFORCE_EQ(
var_tables.size(), 1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end())
return false;
return true;
}
bool AsyncCommunicator::Check(const int table_id) {
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (ctx.table_id == table_id) return true;
}
return false;
}
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) {
waiting_ = false;
for (size_t i = 0; i < var_names.size(); i++) {
auto *var = scope.FindVar(var_names[i]);
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*var, tmp_grad_var.get());
send_varname_to_queue_[var_names[i]]->Push(tmp_grad_var);
}
}
void HalfAsyncCommunicator::Clean() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
while (var_queue->Size() > 0) {
var_queue->Pop();
}
VLOG(3) << "clean var: " << var_name << " done";
}
}
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
barrier_trigger_--;
VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
barrier_trigger_.store(initial_val);
VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::Barrier() {
barrier_counter_++;
if (!running_) {
VLOG(3) << "Communicator is not running, release barrier";
return;
}
{
std::unique_lock<std::mutex> lk(barrier_mutex_);
barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
}
}
int HalfAsyncCommunicator::BatchesCounter() {
while (running_) {
if (barrier_counter_.load() >= barrier_trigger_.load() &&
barrier_trigger_.load() != 0) {
break;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
return barrier_counter_.load();
}
void HalfAsyncCommunicator::SendByCommunicator() {
int batches = BatchesCounter();
VLOG(1) << "HalfAsyncCommunicator::BatchesCounter = " << batches;
if (batches <= 0) return;
std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto send_recv_task = [this, &ctx, batches] {
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
size_t var_nums = varnames.size();
std::vector<std::vector<std::shared_ptr<Variable>>> vars;
vars.resize(var_nums);
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
auto &var_queue = send_varname_to_queue_[var_name];
for (int j = 0; j < batches; j++) vars[i].push_back(var_queue->Pop());
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
} else {
RpcSendDense(ctx, *send_scope_);
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
return;
}
void HalfAsyncCommunicator::BarrierWeakUp() {
barrier_counter_.store(0);
barrier_cond_.notify_all();
}
void SyncCommunicator::BarrierSend() {
if (!running_) return;
BarrierWithTable(0);
VLOG(4) << "BarrierSend with SyncCommunicator";
}
void SyncCommunicator::BarrierRecv() {
if (!running_) return;
BarrierWithTable(1);
VLOG(4) << "BarrierRecv with SyncCommunicator";
}
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) {
waiting_ = false;
auto before_send = GetCurrentUS();
auto table_name = var_names[0];
size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
for (size_t j = 0; j < splited_var_nums; j++) {
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
send_varname_to_ctx_[table_name].splited_varnames[j],
std::unordered_set<int64_t>()));
}
auto *var = scope.FindVar(table_name);
PADDLE_ENFORCE_EQ(var->IsType<framework::SelectedRows>(), true,
platform::errors::InvalidArgument(
"Only need to send Sparse Grad in Geo mode."));
auto &rows = var->Get<framework::SelectedRows>().rows();
// insert ids which has not been record
for (size_t j = 0; j < rows.size(); j++) {
auto ep_idx = rows[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(rows[j]);
}
for (auto &iter : ids_table) {
auto &key = iter.first;
auto &sparse_ids_set = iter.second;
auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
sparse_id_queues_.at(key)->Push(sparse_ids_vec);
VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
<< "'s queue";
}
auto after_send = GetCurrentUS();
VLOG(2) << "run send op finish. use time " << (after_send - before_send);
}
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
PADDLE_ENFORCE_GT(
send_varname_to_ctx.size(), 0,
platform::errors::InvalidArgument("send var contexts can not be zero"));
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) continue;
auto &varnames = ctx.origin_varnames;
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
for (auto &splited_var : ctx.splited_varnames) {
parallel_task_nums_ += 1;
sparse_id_queues_.insert(
std::pair<std::string, std::shared_ptr<BlockingQueue<
std::shared_ptr<std::vector<int64_t>>>>>(
splited_var,
std::make_shared<
BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>(
send_queue_size_)));
}
}
send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
delta_scope_.reset(new Scope());
old_scope_.reset(new Scope());
pserver_scope_.reset(new Scope());
}
void GeoCommunicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
std::vector<std::future<void>> tasks;
tasks.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto &table_id = iter.first;
auto &varnames = iter.second;
auto recv_task = [this, &table_id, &varnames] {
InitDense(varnames, table_id);
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) return;
auto &varname = ctx.origin_varnames[0];
auto &table_id = ctx.table_id;
auto param = varname.substr(0, varname.size() - 5);
InitSparse(param, table_id);
}
return;
}
void GeoCommunicator::InitDense(std::vector<std::string> &varnames,
int table_id) {
if (trainer_id_ == 0) {
RpcSendDenseParam(varnames, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(0) << "push dense param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(0) << "push dense param to table " << table_id
<< " from 0' trainer done";
}
// copy to old_scope
for (auto &t : varnames) {
auto *global_var = recv_scope_->FindVar(t);
global_var->GetMutable<framework::LoDTensor>();
auto *old_var = old_scope_->Var(t);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var);
}
VLOG(1) << "init dense table " << table_id << " done";
}
void GeoCommunicator::SendDense(const CommContext &send_ctx) {
platform::RecordEvent record_event("GeoCommunicator->SendDense");
auto &var_names = send_ctx.origin_varnames;
auto &table_id = send_ctx.table_id;
for (auto &varname : var_names) {
auto param_name = GradToParam(varname);
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_timestamp = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
auto blas =
paddle::operators::math::GetBlas<platform::CPUDeviceContext, float>(
cpu_ctx);
blas.VSUB(t_latest.numel(), t_latest.data<float>(),
t_timestamp->data<float>(), t_delta->data<float>());
float coefficient = 1.0 / static_cast<float>(trainers_);
blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
blas.VADD(t_latest.numel(), t_timestamp->data<float>(),
t_delta->data<float>(), t_timestamp->data<float>());
}
RpcSendDense(send_ctx, *delta_scope_);
VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id;
return;
}
void GeoCommunicator::RecvDense(const CommContext &send_ctx) {
platform::RecordEvent record_event("GeoCommunicator->RecvDense");
auto &table_id = send_ctx.table_id;
auto &varnames = recv_varname_to_ctx_.at(table_id);
// 1. recv from pserver
RpcRecvDense(varnames, table_id, pserver_scope_.get());
// 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver
auto cpu_ctx = paddle::platform::CPUDeviceContext();
for (auto &varname : varnames) {
auto *var_latest = recv_scope_->FindVar(varname);
auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto *var_old = old_scope_->FindVar(varname);
auto t_old = var_old->GetMutable<framework::LoDTensor>();
auto *var_pserver = pserver_scope_->FindVar(varname);
auto t_pserver = var_pserver->Get<framework::LoDTensor>();
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
auto blas =
paddle::operators::math::GetBlas<platform::CPUDeviceContext, float>(
cpu_ctx);
blas.VSUB(t_latest->numel(), t_pserver.data<float>(), t_old->data<float>(),
t_delta->data<float>());
blas.VADD(t_latest->numel(), t_latest->data<float>(),
t_delta->data<float>(), t_latest->data<float>());
blas.VCOPY(t_latest->numel(), t_pserver.data<float>(),
t_old->data<float>());
}
VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id;
return;
}
void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) {
VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " begin.";
if (trainer_id_ == 0) {
RpcSendSparseParam(var_name, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(0) << "push sparse param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvSparse(var_name, table_id, recv_scope_);
VLOG(0) << "push dense param to table " << table_id
<< " from 0' trainer done";
}
VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " done.";
auto *global_var = recv_scope_->FindVar(var_name);
auto *var = old_scope_->Var(var_name);
framework::CopyVariable(*global_var, var);
return;
}
std::vector<int64_t> GeoCommunicator::MergeSparseIds(
const std::string &send_varname) {
size_t merge_num = 0, wait_times = 0;
std::unordered_set<int64_t> sparse_ids;
while (merge_num < static_cast<size_t>(max_merge_var_num_)) {
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
std::shared_ptr<std::vector<int64_t>> pop_ids =
sparse_id_queues_.at(send_varname)->Pop();
for (size_t j = 0; j < pop_ids->size(); j++) {
sparse_ids.insert(pop_ids->at(j));
}
merge_num += 1;
VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed";
} else if (sparse_id_queues_.at(send_varname)->Size() == 0) {
VLOG(3) << "wait_times -> " << wait_times;
if (wait_times >= static_cast<size_t>(send_wait_times_)) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
}
}
std::vector<int64_t> res;
res.assign(sparse_ids.begin(), sparse_ids.end());
return res;
}
void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse");
std::string param_name = SplitedGradToParam(varname);
VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id
<< ", ep_idx: " << ep_idx;
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_old = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_old->IsInitialized(), true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto *t_old = var_old->GetMutable<framework::LoDTensor>();
auto dims1 = t_latest.dims()[1];
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
auto *var_t_value = t_delta->mutable_value();
var_t_value->Resize({static_cast<int64_t>(sparse_ids.size()), dims1});
auto *t_value = var_t_value->mutable_data<float>(cpu_ctx.GetPlace());
t_delta->set_rows(sparse_ids);
t_delta->set_height(t_latest.dims()[0]);
auto blas =
paddle::operators::math::GetBlas<platform::CPUDeviceContext, float>(
cpu_ctx);
float coefficient = 1.0 / static_cast<float>(trainers_);
std::vector<float *> push_g_vec;
for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
blas.VSUB(dims1, t_latest.data<float>() + sparse_ids[j] * dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1);
blas.SCAL(dims1, coefficient, t_value + j * dims1);
blas.VADD(dims1, t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1);
}
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [this](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->push_sparse_raw_gradient_partial(
table_id, (const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx);
status.wait();
VLOG(1) << "Finish Send Sparse " << varname
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id;
return;
}
void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->RecvSparse");
// 1. recv from pserver
std::vector<uint64_t> keys;
std::vector<float> values;
auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx);
status.wait();
std::string param = SplitedGradToParam(varname);
VLOG(1) << "RecvSparse receive var: " << varname << " " << param << ", "
<< table_id << "; ids Size: " << keys.size()
<< "; values size: " << values.size();
auto *var_latest = recv_scope_->FindVar(param);
auto *var_old = old_scope_->FindVar(param);
auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto *t_old = var_old->GetMutable<framework::LoDTensor>();
auto dims1 = t_latest->dims()[1];
auto numel = keys.size() * dims1;
std::vector<float> v_delta;
v_delta.resize(numel);
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto blas =
paddle::operators::math::GetBlas<platform::CPUDeviceContext, float>(
cpu_ctx);
for (auto j = 0; j < static_cast<int>(keys.size()); ++j) {
float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta
blas.VSUB(dims1, values.data() + j * dims1, old_data,
v_delta.data() + j * dims1);
// latest + delta => latest
blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data);
// pserver => old
blas.VCOPY(dims1, values.data() + j * dims1, old_data);
}
VLOG(1) << "Finish Recv Sparse " << param << ", table_id: " << table_id;
}
void GeoCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
std::vector<std::future<void>> tasks;
tasks.reserve(parallel_task_nums_);
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
int pserver_num = static_cast<int>(ctx.epmap.size());
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
// varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0
auto send_recv_task = [this, table_id, ep_idx, &ctx] {
auto splited_varname = ctx.splited_varnames[ep_idx];
auto sparse_ids = MergeSparseIds(splited_varname);
SendSparse(splited_varname, sparse_ids, table_id, ep_idx);
RecvSparse(splited_varname, table_id, ep_idx);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
} else {
auto send_recv_task = [this, &ctx] {
SendDense(ctx);
RecvDense(ctx);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
}
for (auto &task : tasks) {
task.wait();
}
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/ps_client.h"
DECLARE_bool(communicator_is_sgd_optimizer);
namespace paddle {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.push_back(elem);
}
cv_.notify_one();
return true;
}
bool Push(T &&elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.emplace_back(std::move(elem));
}
cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope, bool merge_add = true) {
PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument(
"vector vars are empty."));
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
<< "; merge add: " << merge_add;
// init output tensor
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<T>(dims, cpu_place);
// check the input dims
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
var_t.dims(), dims,
platform::errors::InvalidArgument("vars should have the same dims."));
}
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext, T>
constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<T>(0));
// sum all vars to out
auto result = EigenVector<T>::Flatten(*out_t);
for (auto &var : vars) {
auto &in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
if (!merge_add) {
result.device(*cpu_ctx.eigen_device()) =
result / static_cast<T>(vars.size());
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto &slr0 = var0->Get<framework::SelectedRows>();
auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
auto dev_ctx = paddle::platform::CPUDeviceContext();
if (merge_add) {
paddle::operators::math::scatter::MergeAdd<
paddle::platform::CPUDeviceContext, T>
merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
paddle::operators::math::scatter::MergeAverage<
paddle::platform::CPUDeviceContext, T>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else {
PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
var0->Type()));
}
}
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
using RecvCtxMap = std::unordered_map<uint64_t, std::vector<std::string>>;
using SparseValue = std::unordered_map<int64_t, std::vector<float>>;
class Communicator {
public:
Communicator();
explicit Communicator(const std::map<std::string, std::string> &envs_) {
VLOG(0) << "Communicator Init Envs";
for (auto &iter : envs_) {
envs[iter.first] = iter.second;
VLOG(0) << iter.first << ": " << iter.second;
}
barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
trainer_id_ = std::stoi(envs.at("trainer_id"));
trainers_ = std::stoi(envs.at("trainers"));
}
virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list);
// 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope);
// 2. send dense param
virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id, const Scope &scope);
// 3. send dense grad
virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
// 4. send sparse grad
virtual void RpcSendSparse(const std::string &var_name, int table_id,
const Scope &scope);
// 5. send sparse param
virtual void RpcSendSparseParam(const std::string &varname, int table_id,
const Scope &scope);
// 6. recv sparse param
virtual void RpcRecvSparse(const std::string &varname, int table_id,
Scope *scope);
virtual ~Communicator() {}
virtual void RpcProfilerControl();
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
virtual void Clean() {}
virtual bool Check(const int table_id) = 0;
virtual bool Check(const std::vector<std::string> &var_tables) = 0;
virtual void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) = 0;
virtual void RecvNoBarrier() {}
virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type);
rets.wait();
}
virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {}
virtual void InitEnvs() = 0;
virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {}
static Communicator *GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator *InitInstance(
const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list, Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>, send_ctx,
recv_ctx, dist_desc, host_sign_list, recv_scope,
std::ref(envs));
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T(std::ref(envs)));
communicator_->InitEnvs();
communicator_->InitBrpcClient(dist_desc, host_sign_list);
communicator_->InitImpl(send_ctx, recv_ctx, recv_scope);
}
}
PSClient *GetPsClient() { return _worker_ptr.get(); }
std::shared_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
return _worker_ptr;
}
std::shared_ptr<PSClient> _worker_ptr; // pointer to worker
protected:
bool running_ = false;
bool waiting_ = true;
bool flushing_ = false;
bool do_server_profiler_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
void init_gflag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0;
int trainers_;
int trainer_id_ = 0;
int barrier_table_id_ = 0;
RpcCtxMap send_varname_to_ctx_;
RecvCtxMap recv_varname_to_ctx_;
Scope *recv_scope_; // should be global scope
std::unique_ptr<Scope> xpu_temp_scope_;
std::atomic<uint32_t> _async_call_num{0};
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() : Communicator() {}
explicit AsyncCommunicator(const std::map<std::string, std::string> &envs)
: Communicator(envs) {}
~AsyncCommunicator();
void InitEnvs() {
independent_recv_ = static_cast<bool>(
std::stoi(envs.at("communicator_independent_recv_thread")));
min_send_grad_num_before_recv_ =
std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
}
void Start() override;
void Stop() override;
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
virtual void MainThread();
virtual void RecvThread();
virtual bool Check(const int table_id);
virtual bool Check(const std::vector<std::string> &var_tables);
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
virtual void SendByCommunicator();
virtual void SendGlobalStep(int batches) {}
virtual void RecvByCommunicator();
virtual void RecvNoBarrier();
virtual int BatchesCounter() { return 1; }
virtual void BarrierSend() {}
virtual void BarrierRecv() {}
virtual void BarrierWeakUp() {}
protected:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
int min_send_grad_num_before_recv_;
int thread_pool_size_;
int max_merge_var_num_;
int send_wait_times_;
int send_queue_size_;
bool need_global_step_ = false;
bool independent_recv_ = true;
int parallel_task_nums_ = 0;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::unique_ptr<std::thread> recv_thread_{nullptr};
std::unique_ptr<Scope> send_scope_; // an independent scope
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
};
class HalfAsyncCommunicator : public AsyncCommunicator {
public:
HalfAsyncCommunicator() {}
explicit HalfAsyncCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(0) << "HalfAsyncCommunicator Initialized";
}
void MainThread() override;
void SendByCommunicator() override;
void Clean() override;
void Barrier() override;
void BarrierTriggerDecrement() override;
void BarrierTriggerReset(int initial_val) override;
int BatchesCounter();
void BarrierWeakUp();
protected:
// mutex for Wait for barrier
std::mutex barrier_mutex_;
std::condition_variable barrier_cond_;
std::atomic<int64_t> barrier_trigger_{0};
std::atomic<int64_t> barrier_counter_{0};
};
class SyncCommunicator : public HalfAsyncCommunicator {
public:
SyncCommunicator() : HalfAsyncCommunicator() {}
explicit SyncCommunicator(const std::map<std::string, std::string> &envs)
: HalfAsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(0) << "SyncCommunicator Initialized";
}
void BarrierSend();
void BarrierRecv();
private:
std::vector<std::string> pserver_endpoints_{};
};
class GeoCommunicator : public AsyncCommunicator {
public:
GeoCommunicator() : AsyncCommunicator() {}
explicit GeoCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
void InitDense(std::vector<std::string> &varnames, int table_id);
void InitSparse(const std::string &var_name, int table_id);
void SendDense(const CommContext &send_ctx);
void RecvDense(const CommContext &send_ctx);
std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname, std::vector<int64_t> &sparse_ids,
int table_id, int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx);
void MainThread() override;
void InitEnvs() {
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
// id_queue's size
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_queue_size_ = max_merge_var_num_;
VLOG(0) << "GeoCommunicator Initialized";
}
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
void SendByCommunicator() { return; }
void SendGlobalStep(int batches) override { return; }
void RecvByCommunicator() override { return; }
inline std::string GradToParam(const std::string var_name) {
std::string param_name = var_name.substr(0, var_name.size() - 5);
return param_name;
}
inline std::string SplitedGradToParam(const std::string delta_name) {
// delta_name: emb.delta0
auto pos = delta_name.find(".block");
std::string param_name = delta_name.substr(0, pos);
return param_name;
}
private:
// parameter for delta calc and send
std::shared_ptr<Scope> delta_scope_;
// parameter for storage the pserver param after last recv
std::shared_ptr<Scope> old_scope_;
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<
std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
sparse_id_queues_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/env.h"
namespace paddle {
namespace distributed {} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <arpa/inet.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <netinet/in.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace distributed {
struct PSHost {
std::string ip;
uint32_t port;
uint32_t rank;
PSHost() = default;
PSHost(const std::string ip, uint32_t port, uint32_t rank)
: ip(ip), port(port), rank(rank) {}
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
// for pslib
uint64_t serialize_to_uint64() {
uint64_t host_label = 0;
host_label = inet_addr(ip.c_str());
host_label = host_label << 32;
host_label += (port << 12);
host_label += rank;
return host_label;
}
void parse_from_uint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask;
port = (host_label >> 12) & port_label_mask;
uint32_t ip_addr = (host_label >> 32);
ip = inet_ntoa(*(in_addr *)&ip_addr);
}
std::string to_string() {
std::stringstream s;
s << "host: " << ip;
s << " port: " << port;
s << " rank: " << rank;
s << " uint: " << serialize_to_uint64();
return s.str();
}
// for open source parameter server
std::string serialize_to_string() {
std::stringstream s;
s << ip << ":";
s << port << ":";
s << rank;
return s.str();
}
void parse_from_string(std::string endpoint) {
std::vector<std::string> endpoint_info;
string_split(endpoint, ':', &endpoint_info);
ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]);
}
void string_split(const std::string &str, char sep,
std::vector<std::string> *pieces, bool ignore_null = true) {
pieces->clear();
if (str.empty()) {
if (!ignore_null) {
pieces->push_back(str);
}
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
};
class PSEnvironment {
public:
explicit PSEnvironment() {}
virtual ~PSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_servers(
const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t set_ps_clients(std::string *host_endpoint_list,
int node_num) {
return 0;
}
virtual uint64_t get_local_host_sign() { return 0; }
virtual std::vector<PSHost> get_ps_servers() const { return _ps_server_list; }
virtual int32_t registe_ps_server(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_server_list,
_ps_server_sign_set);
}
virtual std::vector<PSHost> get_ps_clients() const { return _ps_client_list; }
virtual int32_t registe_ps_client(const std::string &ip, uint32_t port,
int32_t rank) {
return registe_ps_host(ip, port, rank, _ps_client_list,
_ps_client_sign_set);
}
virtual std::vector<uint64_t> get_client_info() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_sign_set) {
client_info.push_back(i);
}
return client_info;
}
virtual std::vector<std::string> get_client_info(bool use_string_endpoint) {
if (use_string_endpoint) {
std::vector<std::string> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.serialize_to_string());
}
return client_info;
}
return {};
}
protected:
//注册一个host
virtual int32_t registe_ps_host(const std::string &ip, uint32_t port,
int32_t rank, std::vector<PSHost> &host_list,
std::unordered_set<uint64_t> &sign_set) {
PSHost host;
host.ip = ip;
host.port = port;
host.rank = rank;
if (sign_set.count(rank) > 0) {
LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port
<< ", rank:" << host.rank
<< " already register, ignore register";
} else {
host_list.push_back(host);
sign_set.insert(rank);
}
// if (sign_set.count(host.serialize_to_uint64()) > 0) {
// LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port
// << ", rank:" << host.rank
// << " already register, ignore register";
// } else {
// host_list.push_back(host);
// sign_set.insert(host.serialize_to_uint64());
// }
return 0;
}
std::vector<PSHost> _ps_client_list;
std::unordered_set<uint64_t> _ps_client_sign_set; // for unique filter
std::vector<PSHost> _ps_server_list;
std::unordered_set<uint64_t> _ps_server_sign_set; // for unique filter
};
class PaddlePSEnvironment : public PSEnvironment {
public:
explicit PaddlePSEnvironment() {}
virtual ~PaddlePSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.serialize_to_uint64());
}
}
std::sort(
_ps_server_list.begin(), _ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t set_ps_servers(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank);
}
}
std::sort(
_ps_server_list.begin(), _ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.parse_from_uint64(host_sign_list[i]);
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.serialize_to_uint64());
}
}
std::sort(
_ps_client_list.begin(), _ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t set_ps_clients(std::vector<std::string> *host_sign_list,
int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.parse_from_string(host_sign_list->at(i));
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank);
}
}
std::sort(
_ps_client_list.begin(), _ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual uint64_t get_local_host_sign() {
if (_ps_client_list.size() > 0) {
return _ps_client_list[0].serialize_to_uint64();
} else {
return 0;
}
}
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/heter_client.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/timer.h"
DECLARE_int32(rpc_deadline);
namespace paddle {
namespace distributed {
DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms");
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;
void HeterClient::MainThread() {
while (running_) {
RpcProfilerControl();
}
}
void HeterClient::Stop() {
running_ = false;
if (!is_initialized_) {
VLOG(0) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
auto status = StopHeterWorker();
status.wait();
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(1) << "HeterClient Stop Done";
}
}
void HeterClient::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
}
void HeterClient::CreateClient2XpuConnection() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = pserver_timeout_ms;
xpu_channels_.resize(xpu_list_.size());
for (size_t i = 0; i < xpu_list_.size(); ++i) {
xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterServer channel init fail";
}
}
}
void HeterClient::SendAndRecvAsync(
const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val;
// Todo: get correct channel
int num = trainer_id_ % xpu_channels_.size();
brpc::Controller cntl;
cntl.set_timeout_ms(pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
PADDLE_ENFORCE_NE(
cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
auto& response_io_buffer = cntl.response_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
ctx, p_scope);
}
std::future<int32_t> HeterClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
size_t request_call_num = xpu_channels_.size();
paddle::distributed::DownpourBrpcClosure* closure =
new paddle::distributed::DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(trainer_id_);
for (const auto& param : params) {
closure->request(i)->add_params(param);
}
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc;
class OnHeterRpcDone : public google::protobuf::Closure {
public:
OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
virtual ~OnHeterRpcDone() {}
void Run() {
std::unique_ptr<OnHeterRpcDone> self_guard(this);
handler_(this);
}
HeterRpcCallbackFunc handler_;
MultiVariableMessage response;
brpc::Controller cntl;
};
class HeterClient {
public:
virtual ~HeterClient() {}
HeterClient() {
running_ = true;
main_thread_.reset(
new std::thread(std::bind(&HeterClient::MainThread, this)));
}
void CreateClient2XpuConnection();
void SendAndRecvAsync(const std::vector<std::string>& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name);
// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint, const int& trainer_id) {
if (NULL == s_instance_) {
is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient());
std::vector<std::string> xpu_list = {endpoint};
s_instance_->SetXpuList(endpoint);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
}
return s_instance_;
}
void Stop();
void MainThread();
void RpcProfilerControl();
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string>& params);
std::future<int32_t> StartProfiler();
std::future<int32_t> StopProfiler();
std::future<int32_t> StopHeterWorker();
std::vector<std::string>& GetXpuList() { return xpu_list_; }
void SetXpuList(const std::vector<std::string>& xpu_list) {
xpu_list_ = xpu_list;
};
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
private:
static std::shared_ptr<HeterClient> s_instance_;
protected:
static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
bool running_ = false;
int trainer_id_;
bool do_server_profiler_ = false;
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/heter_server.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace distributed {
std::shared_ptr<HeterServer> HeterServer::s_instance_ = NULL;
void HeterServer::RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
service_.RegisterServiceHandler(message_name, func);
}
void HeterServer::StartHeterService() {
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "heter server start fail";
} else {
VLOG(0) << "heter server start success! listen on " << endpoint_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
}
condition_ready_.notify_all();
server_.Join();
}
void HeterServer::SetEndPoint(std::string& endpoint) {
endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}
void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
}
int32_t HeterService::stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
return 0;
}
int32_t HeterService::start_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
return 0;
}
int32_t HeterService::stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
}
return 0;
}
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
class HeterService;
typedef int32_t (HeterService::*serviceHandlerFunc)(
const PsRequestMessage& request, PsResponseMessage& response,
brpc::Controller* cntl);
typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
HeterServiceHandler;
class HeterService : public ::paddle::PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
_service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler;
}
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const ::paddle::PsRequestMessage* request,
::paddle::PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(*request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
};
void SendAndRecvVariable(::google::protobuf::RpcController* controller,
const MultiVarMsg* request, MultiVarMsg* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string message_name = request->message_name();
auto itr = handler_map_.find(message_name);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
PADDLE_ENFORCE_NE(
itr, handler_map_.end(),
platform::errors::InvalidArgument(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_",
message_name));
itr->second(request, response, cntl);
}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
handler_map_[message_name] = func;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
bool IsExit() { return is_exit_; }
private:
int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, brpc::Controller* cntl);
int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, brpc::Controller* cntl);
int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl);
private:
std::string endpoint_;
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_set<int> stop_cpu_worker_set_;
int fan_in_;
bool is_exit_ = false;
};
class HeterServer {
public:
virtual ~HeterServer() {}
void Stop() {
server_.Stop(1000);
server_.Join();
}
bool IsExit() { return service_.IsExit(); }
HeterServer() {}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func);
void StartHeterService();
void SetEndPoint(std::string& endpoint);
void SetFanin(int& fan_in);
// HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new HeterServer());
}
return s_instance_;
}
void WaitServerReady();
private:
static std::shared_ptr<HeterServer> s_instance_;
std::string endpoint_;
protected:
brpc::Server server_;
HeterService service_;
DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
class HeterRequestHandler {
public:
HeterRequestHandler()
: dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr) {}
virtual ~HeterRequestHandler() {}
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
message_to_prepared_ctx_ = g;
}
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
message_to_prepared_ctx_;
};
class RequestSendAndRecvHandler final : public HeterRequestHandler {
public:
RequestSendAndRecvHandler() {}
virtual ~RequestSendAndRecvHandler() {}
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) override {
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
auto& local_scope = scope_->NewScope();
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, &local_scope);
executor_->RunPreparedContext(
(*message_to_prepared_ctx_)[message_name].get(), &local_scope, false);
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
response_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name, response_var_names, empty_var_names, *dev_ctx_,
&local_scope, response, &response_io_buffer);
scope_->DeleteScope(&local_scope);
return 0;
}
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h"
#include <map>
#include "brpc/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSClient, BrpcPsClient);
int32_t PSClient::configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env, size_t client_id) {
_env = &env;
_config = config;
_dense_pull_regions = regions;
_client_id = client_id;
_config.mutable_worker_param()
->mutable_downpour_worker_param()
->mutable_downpour_table_param()
->CopyFrom(_config.server_param()
.downpour_server_param()
.downpour_table_param());
const auto &work_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor());
accessor->initialize();
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
return initialize();
}
PSClient *PSClientFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_client_class()) {
LOG(ERROR) << "miss client_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client = CREATE_CLASS(PSClient, service_param.client_class());
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
return NULL;
}
TableManager::instance().initialize();
LOG(INFO) << "Create PSClient[" << service_param.client_class()
<< "] success";
return client;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
namespace paddle {
namespace distributed {
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
PSClientClosure(PSClientCallBack callback) : _callback(callback) {}
virtual ~PSClientClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
_promises.push_back(promise);
}
protected:
PSClientCallBack _callback;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PSClient {
public:
PSClient() {}
virtual ~PSClient() {}
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, size_t client_id) final;
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms,
int max_retry) = 0;
// 触发table数据退场
virtual std::future<int32_t> shrink(uint32_t table_id) = 0;
// 全量table进行数据load
virtual std::future<int32_t> load(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
//清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; //保留
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) = 0;
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> flush() = 0;
// server优雅退出
virtual std::future<int32_t> stop_server() = 0;
// server profilera
virtual std::future<int32_t> start_profiler() = 0;
virtual std::future<int32_t> stop_profiler() = 0;
virtual std::future<int32_t> barrier(size_t table_id,
uint32_t barrier_type) = 0;
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual void finalize_worker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> send_client2client_msg(int msg_type,
int to_client_id,
const std::string &msg) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_client2client_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int handle_client2client_msg(int msg_type, int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
return -1;
}
return itr->second(msg_type, from_client_id, msg);
}
virtual ValueAccessor *table_accessor(size_t table_id) {
auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) {
return NULL;
}
return itr->second.get();
}
virtual size_t get_server_nums() = 0;
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) = 0;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) = 0;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
protected:
virtual int32_t initialize() = 0;
size_t _client_id;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
_dense_pull_regions;
PSEnvironment *_env;
std::unordered_map<uint32_t, std::shared_ptr<ValueAccessor>> _table_accessors;
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; //处理client2client消息
};
REGISTER_REGISTERER(PSClient);
class PSClientFactory {
public:
static PSClient *create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle;
option cc_generic_services = true;
option cc_enable_arenas = true;
enum PsCmdID {
PS_PULL_DENSE_TABLE = 0;
PS_PUSH_DENSE_TABLE = 1;
PS_PULL_SPARSE_TABLE = 2;
PS_PUSH_SPARSE_TABLE = 3;
PS_SHRINK_TABLE = 4;
PS_SAVE_ONE_TABLE = 5;
PS_SAVE_ALL_TABLE = 6;
PS_LOAD_ONE_TABLE = 7;
PS_LOAD_ALL_TABLE = 8;
PS_CLEAR_ONE_TABLE = 9;
PS_CLEAR_ALL_TABLE = 10;
PS_PUSH_DENSE_PARAM = 11;
PS_STOP_SERVER = 12;
PS_SAVE_ONE_CACHE_TABLE = 13;
PS_GET_CACHE_THRESHOLD = 14;
PS_CACHE_SHUFFLE = 15;
PS_COPY_TABLE = 16;
PS_COPY_TABLE_BY_FEASIGN = 17;
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
PS_PRINT_TABLE_STAT = 20;
PS_SAVE_ONE_TABLE_PREFIX = 21;
PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22;
PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23;
PS_PULL_GEO_PARAM = 24;
PS_BARRIER = 25;
PS_PUSH_SPARSE_PARAM = 26;
PS_START_PROFILER = 27;
PS_STOP_PROFILER = 28;
}
message PsRequestMessage {
required uint32 cmd_id = 1;
optional uint32 table_id = 2;
repeated bytes params = 3;
optional int32 client_id = 4;
optional bytes data = 5;
};
message PsResponseMessage {
required int32 err_code = 1 [ default = 0 ];
required string err_msg = 2 [ default = "" ];
optional bytes data = 3;
};
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
}
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
optional string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
optional VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
optional Type data_type = 3;
repeated int64 dims = 4;
// lod details:
optional int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
optional int64 slr_height = 7;
// tensor data
optional bytes data = 8;
}
// for SendAndRecv RPC method
message MultiVariableMessage {
// message flags
required string message_name = 1;
repeated string send_var_names = 2;
repeated string recv_var_names = 3;
repeated VariableMessage var_messages = 4;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
};
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSServer, BrpcPsServer);
REGISTER_CLASS(PsBaseService, PsService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_server_class()) {
LOG(ERROR) << "miss server_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSServer *server = CREATE_CLASS(PSServer, service_param.server_class());
if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class();
return NULL;
}
TableManager::instance().initialize();
return server;
}
int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
size_t server_rank) {
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
const auto &downpour_param = _config.downpour_server_param();
uint32_t barrier_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() ==
"BarrierTable") {
barrier_table = downpour_param.downpour_table_param(i).table_id();
}
table->initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->set_table_map(&_table_map);
}
return initialize();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
namespace paddle {
namespace distributed {
class Table;
class PSServer {
public:
PSServer() {}
virtual ~PSServer() {}
PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete;
virtual int32_t configure(const PSParameter &config, PSEnvironment &env,
size_t server_rank) final;
// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;
virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0;
inline size_t rank() const { return _rank; }
inline PSEnvironment *environment() { return _environment; }
inline const ServerParameter *config() const { return &_config; }
inline Table *table(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
}
return NULL;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() {
return &_table_map;
}
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected:
virtual int32_t initialize() = 0;
protected:
size_t _rank;
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
};
REGISTER_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack;
class PServerClosure : public google::protobuf::Closure {
public:
PServerClosure(PServerCallBack callback) : _callback(callback) {}
virtual ~PServerClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
_promises.push_back(promise);
}
protected:
PServerCallBack _callback;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PsBaseService : public PsService {
public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {}
virtual int32_t configure(PSServer *server) {
_server = server;
_rank = _server->rank();
_config = _server->config();
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response, int err_code,
const char *err_msg) {
response.set_err_msg(err_msg);
response.set_err_code(err_code);
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
}
virtual int32_t initialize() = 0;
protected:
size_t _rank;
PSServer *_server;
const ServerParameter *_config;
};
REGISTER_REGISTERER(PsBaseService);
class PSServerFactory {
public:
static PSServer *create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/service.h"
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <iostream>
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/string/string_helper.h"
using namespace std;
namespace paddle {
namespace distributed {
paddle::distributed::PSParameter load_from_prototxt(
const std::string& filename) {
paddle::distributed::PSParameter param;
int file_descriptor = open(filename.c_str(), O_RDONLY);
if (file_descriptor == -1) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
google::protobuf::io::FileInputStream fileInput(file_descriptor);
if (!google::protobuf::TextFormat::Parse(&fileInput, &param)) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
close(file_descriptor);
return param;
}
void PSCore::init_gflag(const std::string& gflags) {
LOG(INFO) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40");
flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950");
}
auto it = flags.begin();
flags.insert(it, "exe default");
char* flags_ptr[flags.size()];
for (size_t i = 0; i < flags.size(); ++i) {
flags_ptr[i] = (char*)(flags[i].c_str());
}
int params_cnt = flags.size();
char** params_ptr = &(flags_ptr[0]);
::google::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
int PSCore::init_server(const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param));
ret = _server_ptr->configure(_ps_param, _ps_env, index);
CHECK(ret == 0) << "failed to configure server";
return ret;
}
int PSCore::init_worker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list, int node_num, int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
int ret = 0;
VLOG(1) << "PSCore::init_worker";
auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env,
index);
communicator->Start();
return ret;
}
std::vector<uint64_t> PSCore::get_client_info() {
return _ps_env.get_client_info();
}
int PSCore::create_client2client_connection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
int ret = _worker_ptr->create_client2client_connection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret;
}
uint64_t PSCore::run_server(const std::string& ip, uint32_t port) {
return _server_ptr->start(ip, port);
}
int PSCore::finalize_worker() {
_worker_ptr->finalize_worker();
return 0;
}
int PSCore::stop_server() {
auto stop_status = _worker_ptr->stop_server();
stop_status.wait();
return 0;
}
paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; }
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <glog/logging.h>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/server.h"
namespace paddle {
namespace distributed {
class PSCore {
public:
explicit PSCore() {}
virtual ~PSCore() {}
virtual int init_server(const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num, int index);
virtual int init_worker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions,
const std::vector<std::string>* host_sign_list, int node_num, int index);
virtual uint64_t run_server(const std::string& ip, uint32_t port);
virtual int stop_server();
virtual int finalize_worker();
virtual std::vector<uint64_t> get_client_info();
virtual int create_client2client_connection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* get_param();
private:
void init_gflag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
};
} // namespace distributed
} // namespace paddle
set_property(GLOBAL PROPERTY TABLE_DEPS string_helper)
get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS)
set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc DEPS ${TABLE_DEPS} device_context string_helper simple_threadpool xxhash generator)
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(tensor_accessor SRCS tensor_accessor.cc DEPS ${TABLE_DEPS} eigen3 ps_framework_proto device_context)
cc_library(tensor_table SRCS tensor_table.cc DEPS ps_framework_proto proto_desc enforce executor tensor device_context simple_threadpool gflags glog )
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(table SRCS table.cc DEPS common_table tensor_table tensor_accessor ps_framework_proto string_helper device_context gflags glog boost)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace paddle {
namespace distributed {
struct FsDataConverter {
std::string converter;
std::string deconverter;
};
struct Region {
Region() : data(NULL), size(0) {}
Region(char* data, size_t data_num) : data(data), size(data_num) {}
Region(float* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
Region(int16_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 1) {}
Region(int32_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
Region(int64_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 3) {}
char* data;
size_t size;
};
struct DataConverter {
int param;
std::string converter;
std::string deconverter;
};
class ValueAccessor {
public:
explicit ValueAccessor(){};
virtual ~ValueAccessor(){};
virtual int configure(const TableAccessorParameter& parameter) {
_config = parameter;
// data_convert结构体初始化
if (_config.table_accessor_save_param_size() != 0) {
for (int i = 0; i < _config.table_accessor_save_param_size(); ++i) {
int param = _config.table_accessor_save_param(i).param();
std::string converter =
_config.table_accessor_save_param(i).converter();
std::string deconverter =
_config.table_accessor_save_param(i).deconverter();
_data_coverter_map[param] = std::make_shared<DataConverter>();
*(_data_coverter_map[param]) = {param, converter, deconverter};
}
}
return 0;
}
virtual int initialize() = 0;
// value维度
virtual size_t dim() = 0;
// value各个维度的size
virtual size_t dim_size(size_t dim) = 0;
// value各维度相加总size
virtual size_t size() = 0;
// value中mf动态长度部分总size大小, sparse下生效
virtual size_t mf_size() { return 0; }
virtual bool need_extend_mf(float* value) { return false; }
virtual bool has_mf(size_t size) { return false; }
// pull value维度
virtual size_t select_dim() = 0;
// pull value各个维度的size
virtual size_t select_dim_size(size_t dim) = 0;
// pull value各维度相加总size
virtual size_t select_size() = 0;
// push value维度
virtual size_t update_dim() = 0;
// push value各个维度的size
virtual size_t update_dim_size(size_t dim) = 0;
// push value各维度相加总size
virtual size_t update_size() = 0;
// fea total for dense
virtual size_t fea_dim() { return _config.fea_dim(); }
// converter for save
virtual std::string get_converter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
} else {
return (*itr).second->converter;
}
}
// deconverter for load
virtual std::string get_deconverter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
} else {
return (*itr).second->deconverter;
}
}
// 判断该value是否进行shrink
virtual bool shrink(float* value) = 0;
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool save(float* value, int param) = 0;
// update delta_score and unseen_days after save
virtual void update_stat_after_save(float* value, int param) {}
// keys不存在时,为values生成随机值
virtual int32_t create(float** value, size_t num) = 0;
virtual bool create_value(int type, const float* value) { return true; }
// 从values中选取到select_values中
virtual int32_t select(float** select_values, const float** values,
size_t num) = 0;
// 将update_values聚合到一起
virtual int32_t merge(float** update_values,
const float** other_update_values, size_t num) = 0;
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t update(float** values, const float** update_values,
size_t num) = 0;
// used to save model, will filter feature
virtual std::string parse_to_string(const float* value, int param) = 0;
// parse value from string, used to load model
virtual int32_t parse_from_string(const std::string& data, float* value) = 0;
virtual FsDataConverter converter(int param) {
FsDataConverter data_convert;
data_convert.converter = this->get_converter(param);
data_convert.deconverter = this->get_deconverter(param);
return data_convert;
}
virtual int set_weight(float** values, const float** update_values,
size_t num) {
return 0;
}
virtual float get_field(float* value, const std::string& name) { return 0.0; }
protected:
size_t _value_size;
size_t _select_value_size;
size_t _update_value_size;
TableAccessorParameter _config;
std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map;
};
REGISTER_REGISTERER(ValueAccessor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <chrono> // NOLINT
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/common_table.h"
namespace paddle {
namespace distributed {
int32_t BarrierTable::initialize() {
auto trainers = _config.common().trainer_num();
trigger_.store(trainers);
for (int x = 0; x < trainers; ++x) {
trainer_all_.insert(x);
}
VLOG(1) << "BarrierTable init trigger: " << trigger_.load();
return 0;
}
// 0: send_barrier 1: recv_barrier 2: complete
int32_t BarrierTable::barrier(const uint32_t trainer_id,
const std::string barrier_type) {
std::unique_lock<std::mutex> lock(mutex_);
if (barrier_type == "2") {
trigger_.fetch_sub(1, std::memory_order::memory_order_relaxed);
VLOG(1) << "trigger sub to : " << trigger_.load();
} else {
trainer_ids_.insert(trainer_id);
VLOG(1) << "barrier type: " << barrier_type
<< " add trainer id: " << trainer_id;
}
if (trainer_ids_.size() < trigger_.load()) {
std::vector<uint32_t> diffs(trainer_all_.size());
auto iter = std::set_difference(trainer_all_.begin(), trainer_all_.end(),
trainer_ids_.begin(), trainer_ids_.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
auto diff = to_string<uint32_t>(diffs);
VLOG(1) << "still need trainers: " << diff;
trainer_wait_.wait(lock, [&] { return trainer_ids_.size() == 0; });
} else {
VLOG(1) << "barrier table optimize begin";
for (auto& x : *table_map_) {
auto table = x.second;
table->pour();
}
VLOG(1) << "barrier table optimize done";
trainer_ids_.clear();
trainer_wait_.notify_all();
}
return 0;
}
int32_t BarrierTable::set_table_map(
std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) {
table_map_ = table_map;
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/common/utils.h"
namespace paddle {
namespace distributed {
void CommonDenseTable::create_initializer(const std::string& attr,
const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&");
if (slices[0] == "gaussian_random") {
initializers_[name] = new GaussianInitializer(slices);
} else if (slices[0] == "fill_constant") {
initializers_[name] = new FillConstantInitializer(slices);
} else if (slices[0] == "uniform_random") {
initializers_[name] = new UniformInitializer(slices);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("%s can not be supported", name));
}
}
int32_t CommonDenseTable::initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
}
sync = _config.common().sync();
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
initialize_value();
initialize_optimizer();
return 0;
}
int32_t CommonDenseTable::initialize_value() {
auto common = _config.common();
int size = static_cast<int>(common.params().size());
values_.resize(size);
for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x];
auto& dim = common.dims()[x];
if (varname == "Param") {
param_dim_ = dim;
param_idx_ = x;
}
auto& initializer = common.initializers()[x];
create_initializer(initializer, varname);
values_[x].resize(dim);
names_index_[varname] = x;
for (int y = 0; y < dim; ++y) {
values_[x][y] = initializers_[varname]->GetValue();
}
}
pull_reservoir_ = ReservoirValue<float>(param_dim_);
return 0;
}
int32_t CommonDenseTable::initialize_optimizer() {
auto common = _config.common();
auto name = common.name();
auto attrs = common.attributes();
if (name == "sgd") {
optimizer_ = std::make_shared<DSGD>(common, &values_);
} else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_);
} else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_);
} else {
VLOG(0) << "init optimizer failed";
}
VLOG(0) << "init optimizer " << name << " done";
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values);
return 0;
}
int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) {
PADDLE_ENFORCE_GE(
num, param_dim_,
paddle::platform::errors::InvalidArgument(
"update desne param numel expected %d, but got %d", param_dim_, num));
std::copy_n(values, param_dim_, values_[param_idx_].begin());
return 0;
}
int32_t CommonDenseTable::pour() {
_push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size());
pull_reservoir_.reset();
return 0;
}
int32_t CommonDenseTable::push_dense(const float* values, size_t num) {
if (sync) {
std::future<int> task =
_shards_task_pool[0]->enqueue([this, &values]() -> int {
pull_reservoir_.add(values, param_dim_);
return 0;
});
task.wait();
} else {
_push_dense(values, num);
}
return 0;
}
int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
PADDLE_ENFORCE_GE(
num, param_dim_,
paddle::platform::errors::InvalidArgument(
"update desne numel expected %d, but got %d", param_dim_, num));
std::vector<int> buckets = bucket(param_dim_, task_pool_size_);
std::vector<std::future<int>> tasks(task_pool_size_);
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &buckets, &values]() -> int {
auto begin = buckets[shard_id];
auto end = buckets[shard_id + 1];
optimizer_->update(values, param_dim_, begin, end);
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <string>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/depends/dense.h"
#include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
class CommonDenseTable : public DenseTable {
public:
explicit CommonDenseTable() {}
virtual ~CommonDenseTable() {}
virtual int32_t initialize() override;
virtual int32_t initialize_shard() override { return 0; }
virtual void create_initializer(const std::string& attr,
const std::string& name);
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
virtual int32_t pull_dense(float* pull_values, size_t num) override;
virtual int32_t push_dense_param(const float* values, size_t num) override;
virtual int32_t push_dense(const float* values, size_t num) override;
virtual int32_t pour() override;
int32_t load(const std::string& path, const std::string& param) override {
VLOG(0) << "Dense table may load by "
"paddle.distributed.fleet.init_server";
return 0;
}
int32_t save(const std::string& path, const std::string& param) override {
VLOG(0)
<< "Dense table may be saved by "
"paddle.distributed.fleet.save_persistables/save_inference_model";
return 0;
}
virtual int32_t flush() override { return 0; }
virtual int32_t shrink() override { return 0; }
virtual void clear() override { return; }
protected:
int32_t _push_dense(const float* values, size_t num);
private:
const int task_pool_size_ = 1;
bool sync = true;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
int param_dim_ = 0;
int param_idx_ = 0;
std::shared_ptr<DenseOptimizer> optimizer_;
std::vector<std::vector<float>> values_;
ReservoirValue<float> pull_reservoir_;
std::unordered_map<std::string, Initializer*> initializers_;
std::unordered_map<std::string, int> names_index_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include <algorithm>
#include <sstream>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
struct Meta {
std::string param;
int shard_id;
std::vector<std::string> names;
std::vector<int> dims;
uint64_t count;
std::unordered_map<std::string, int> dims_map;
explicit Meta(const std::string& metapath) {
std::ifstream file(metapath);
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
if (StartWith(line, "#")) {
continue;
}
auto pairs = paddle::string::split_string<std::string>(line, "=");
PADDLE_ENFORCE_EQ(
pairs.size(), 2,
paddle::platform::errors::InvalidArgument(
"info in %s except k=v, but got %s", metapath, line));
if (pairs[0] == "param") {
param = pairs[1];
}
if (pairs[0] == "shard_id") {
shard_id = std::stoi(pairs[1]);
}
if (pairs[0] == "row_names") {
names = paddle::string::split_string<std::string>(pairs[1], ",");
}
if (pairs[0] == "row_dims") {
auto dims_strs =
paddle::string::split_string<std::string>(pairs[1], ",");
for (auto& str : dims_strs) {
dims.push_back(std::stoi(str));
}
}
if (pairs[0] == "count") {
count = std::stoull(pairs[1]);
}
}
for (int x = 0; x < names.size(); ++x) {
dims_map[names[x]] = dims[x];
}
}
Meta(std::string param, int shard_id, std::vector<std::string> row_names,
std::vector<int> dims, uint64_t count) {
this->param = param;
this->shard_id = shard_id;
this->names = row_names;
this->dims = dims;
this->count = count;
}
std::string ToString() {
std::stringstream ss;
ss << "param=" << param << "\n";
ss << "shard_id=" << shard_id << "\n";
ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n";
ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n";
ss << "count=" << count << "\n";
return ss.str();
}
};
void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
std::vector<std::vector<float>>* values) {
PADDLE_ENFORCE_EQ(columns.size(), meta.names.size() + 1,
paddle::platform::errors::InvalidArgument(
"record in txt do not match meta."));
values->reserve(columns.size() - 1);
for (int x = 1; x < columns.size(); ++x) {
auto& column = columns[x];
auto val_ = paddle::string::split_string<std::string>(column, ",");
std::vector<float> val;
std::transform(val_.begin(), val_.end(), std::back_inserter(val),
[](std::string va) { return std::stof(va); });
PADDLE_ENFORCE_EQ(val.size(), meta.dims[x - 1],
paddle::platform::errors::InvalidArgument(
"record in txt do not match meta."));
values->push_back(val);
}
}
int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
const std::vector<std::string>& saved_names,
const int mode) {
for (auto value : block->values_) {
std::vector<std::vector<float>*> vss = value.second->get(saved_names);
std::stringstream ss;
auto id = value.first;
ss << id << "\t";
for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto& vs = vss[i];
ss << paddle::string::join_strings((*vs), ',');
ss << "\t";
}
ss << "\n";
os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
}
return block->values_.size();
}
int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num,
const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks) {
Meta meta = Meta(metapath);
int num_lines = 0;
std::ifstream file(valuepath);
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
auto id = std::stoull(values[0]);
if (id % pserver_num != pserver_id) {
VLOG(0) << "will not load " << values[0] << " from " << valuepath
<< ", please check id distribution";
continue;
}
auto shard_id = id % local_shard_num;
auto block = blocks->at(shard_id);
std::vector<std::vector<float>> kvalues;
ProcessALine(values, meta, &kvalues);
block->Init(id, &kvalues, 1);
}
return 0;
}
void SaveShard(std::shared_ptr<ValueBlock> block, const std::string& dirname,
const CommonAccessorParameter& common, const int mode,
const int pserver_id, const int shard_id) {
auto varname = common.table_name();
std::string var_store = string::Sprintf("%s/%s", dirname, varname);
VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
MkDirRecursively(var_store.c_str());
std::string shard_var_pre =
string::Sprintf("%s.block%d.%d", varname, pserver_id, shard_id);
std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);
std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
// save values
std::vector<std::string> params(common.params().begin(),
common.params().end());
std::unique_ptr<std::ofstream> value_out(new std::ofstream(value_));
SaveToText(value_out.get(), block, params, mode);
// save meta
std::stringstream stream;
stream << "param=" << common.table_name() << "\n";
stream << "server_id=" << pserver_id << "\n";
stream << "shard_id=" << shard_id << "\n";
stream << "row_names=" << paddle::string::join_strings(common.params(), ',')
<< "\n";
stream << "row_dims=" << paddle::string::join_strings(common.dims(), ',')
<< "\n";
stream << "count=" << block->values_.size() << "\n";
std::unique_ptr<std::ofstream> meta_out(new std::ofstream(meta_));
meta_out->write(stream.str().c_str(), sizeof(char) * stream.str().size());
meta_out->close();
VLOG(3) << "save " << varname << " in dir: " << var_store << " done";
}
void CommonSparseTable::create_initializer(const std::string& attr,
const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&");
if (slices[0] == "gaussian_random") {
initializers_[name] = new GaussianInitializer(slices);
} else if (slices[0] == "fill_constant") {
initializers_[name] = new FillConstantInitializer(slices);
} else if (slices[0] == "uniform_random") {
initializers_[name] = new UniformInitializer(slices);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("%s can not be supported", name));
}
}
int32_t CommonSparseTable::initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
}
sync = _config.common().sync();
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
initialize_value();
initialize_optimizer();
initialize_recorder();
return 0;
}
int32_t CommonSparseTable::initialize_recorder() { return 0; }
int32_t CommonSparseTable::initialize_value() {
auto common = _config.common();
int size = static_cast<int>(common.params().size());
for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x];
auto& dim = common.dims()[x];
if (varname == "Param") {
param_dim_ = dim;
}
auto& initializer = common.initializers()[x];
create_initializer(initializer, varname);
}
shard_values_.reserve(task_pool_size_);
for (int x = 0; x < task_pool_size_; ++x) {
auto shard = std::make_shared<ValueBlock>(common, &initializers_);
shard_values_.emplace_back(shard);
}
return 0;
}
int32_t CommonSparseTable::initialize_optimizer() {
auto common = _config.common();
auto name = common.name();
auto attrs = common.attributes();
if (name == "sgd") {
optimizer_ = std::make_shared<SSGD>(common);
} else if (name == "adam") {
optimizer_ = std::make_shared<SAdam>(common);
} else if (name == "sum") {
optimizer_ = std::make_shared<SSUM>(common);
} else {
VLOG(0) << "init optimizer failed";
}
VLOG(0) << "init optimizer " << name << " done";
return 0;
}
int32_t CommonSparseTable::load(const std::string& path,
const std::string& param) {
rwlock_->WRLock();
VLOG(0) << "sparse table load with " << path << " with meta " << param;
LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_,
&shard_values_);
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::save(const std::string& dirname,
const std::string& param) {
rwlock_->WRLock();
int mode = std::stoi(param);
VLOG(0) << "sparse table save: " << dirname << " mode: " << mode;
auto varname = _config.common().table_name();
std::string var_store = string::Sprintf("%s/%s", dirname, varname);
MkDirRecursively(var_store.c_str());
VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
std::vector<std::string> params(_config.common().params().begin(),
_config.common().params().end());
std::string shard_var_pre =
string::Sprintf("%s.block%d", varname, _shard_idx);
std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
std::unique_ptr<std::ofstream> value_out(new std::ofstream(value_));
int64_t total_ins = 0;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// save values
total_ins +=
SaveToText(value_out.get(), shard_values_[shard_id], params, mode);
}
value_out->close();
// save meta
std::stringstream stream;
stream << "param=" << _config.common().table_name() << "\n";
stream << "shard_id=" << _shard_idx << "\n";
stream << "row_names="
<< paddle::string::join_strings(_config.common().params(), ',')
<< "\n";
stream << "row_dims="
<< paddle::string::join_strings(_config.common().dims(), ',') << "\n";
stream << "count=" << total_ins << "\n";
std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);
std::unique_ptr<std::ofstream> meta_out(new std::ofstream(meta_));
meta_out->write(stream.str().c_str(), sizeof(char) * stream.str().size());
meta_out->close();
VLOG(3) << "save " << varname << " in dir: " << var_store << " done";
rwlock_->UNLock();
return 0;
}
std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
int64_t feasign_size = 0;
int64_t mf_size = 0;
for (auto& value : shard_values_) {
feasign_size += value->values_.size();
}
return {feasign_size, mf_size};
}
int32_t CommonSparseTable::pour() {
rwlock_->RDLock();
std::vector<float> values;
std::vector<uint64_t> keys;
keys.reserve(pull_reservoir_.size());
values.reserve(pull_reservoir_.size() * param_dim_);
for (auto& val : pull_reservoir_) {
keys.push_back(val.first);
auto& reservoir = val.second;
reservoir.avg();
std::copy(reservoir.values.begin(), reservoir.values.end(),
std::back_inserter(values));
}
_push_sparse(keys.data(), values.data(), pull_reservoir_.size());
pull_reservoir_.clear();
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys,
size_t num) {
rwlock_->RDLock();
std::vector<std::string> value_names;
for (auto name : _config.common().params()) {
value_names.push_back(name);
}
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
for (int x = 0; x < num; ++x) {
auto y = keys[x] % task_pool_size_;
offset_bucket[y].push_back(x);
}
std::vector<std::future<int>> tasks(task_pool_size_);
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &offset_bucket, &value_names,
&pull_values]() -> int {
auto& block = shard_values_[shard_id];
auto& offsets = offset_bucket[shard_id];
for (int i = 0; i < offsets.size(); ++i) {
auto offset = offsets[i];
auto id = keys[offset];
block->InitFromInitializer(id, value_names);
auto values = block->Get(id, {"Param"});
auto dim = values[0]->size();
std::copy(values[0]->begin(), values[0]->end(),
pull_values + dim * offset);
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
const float* values, size_t num) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
for (int x = 0; x < num; ++x) {
auto y = keys[x] % task_pool_size_;
offset_bucket[y].push_back(x);
}
std::vector<std::future<int>> tasks(task_pool_size_);
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &values, num, &offset_bucket]() -> int {
auto& offsets = offset_bucket[shard_id];
optimizer_->update(keys, values, num, offsets,
shard_values_[shard_id].get());
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
const float* values, size_t num) {
if (sync) {
std::future<int> task =
_shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int {
for (int x = 0; x < num; ++x) {
auto id = keys[x];
auto has = pull_reservoir_.find(id);
if (has == pull_reservoir_.end()) {
pull_reservoir_[id] = ReservoirValue<float>(param_dim_);
}
auto& reservoir = pull_reservoir_[id];
reservoir.add(values + x * param_dim_, param_dim_);
}
return 0;
});
task.wait();
} else {
_push_sparse(keys, values, num);
}
return 0;
}
int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
const float* values, size_t num) {
rwlock_->RDLock();
std::vector<std::string> value_names;
for (auto name : _config.common().params()) {
value_names.push_back(name);
}
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
for (int x = 0; x < num; ++x) {
auto y = keys[x] % task_pool_size_;
offset_bucket[y].push_back(x);
}
std::vector<std::future<int>> tasks(task_pool_size_);
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &offset_bucket, &value_names,
&values]() -> int {
auto& block = shard_values_[shard_id];
auto& offsets = offset_bucket[shard_id];
for (int i = 0; i < offsets.size(); ++i) {
auto offset = offsets[i];
auto id = keys[offset];
block->InitFromInitializer(id, value_names);
auto values_ = block->Get(id, {"Param"});
auto dim = values_[0]->size();
std::copy_n(values + dim * offset, dim, values_[0]->data());
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
rwlock_->UNLock();
return 0;
}
int32_t CommonSparseTable::flush() { return 0; }
int32_t CommonSparseTable::shrink() {
VLOG(0) << "shrink coming soon";
return 0;
}
void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#include "paddle/fluid/distributed/table/depends/sparse.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
class CommonSparseTable : public SparseTable {
public:
CommonSparseTable() { rwlock_.reset(new framework::RWLock); }
virtual ~CommonSparseTable() {}
// unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) {
return 0;
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual void create_initializer(const std::string& attr,
const std::string& name);
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
virtual int32_t initialize_recorder();
int32_t load(const std::string& path, const std::string& param);
int32_t save(const std::string& path, const std::string& param);
virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* pull_values, const uint64_t* keys,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num);
// only for sparse geo table
virtual int32_t push_sparse_param(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t pour();
virtual int32_t flush();
virtual int32_t shrink();
virtual void clear();
protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float* values,
size_t num);
private:
const int task_pool_size_ = 11;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
bool sync = false;
int param_dim_ = 0;
std::shared_ptr<SparseOptimizer> optimizer_;
std::unordered_map<std::string, Initializer*> initializers_;
std::vector<std::shared_ptr<ValueBlock>> shard_values_;
std::unordered_map<uint64_t, ReservoirValue<float>> pull_reservoir_;
std::unique_ptr<framework::RWLock> rwlock_{nullptr};
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <set>
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/distributed/common/utils.h"
namespace paddle {
namespace distributed {
template <typename T>
struct ReservoirValue {
std::vector<T> values;
uint32_t counter;
uint32_t dim;
ReservoirValue() {
dim = 0;
values.resize(dim);
counter = 0;
}
ReservoirValue(uint32_t dim) {
this->dim = dim;
values.resize(dim);
counter = 0;
}
void add(const T *value, int numel) {
GetBlas<T>().VADD(numel, values.data(), value, values.data());
counter++;
}
void add(T *value, int numel) {
GetBlas<T>().VADD(numel, values.data(), value, values.data());
counter++;
}
void avg() {
auto scale = 1 / static_cast<T>(counter);
GetBlas<T>().SCAL(values.size(), scale, values.data());
}
void reset() {
values.resize(dim, 0);
counter = 0;
}
};
class SparseTable : public Table {
public:
SparseTable() {}
virtual ~SparseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}
static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
};
class DenseTable : public Table {
public:
DenseTable() {}
virtual ~DenseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t push_dense_param(const float *values, size_t num) override {
return 0;
}
int32_t shrink() override { return 0; }
};
class BarrierTable : public Table {
public:
BarrierTable() {}
virtual ~BarrierTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t push_dense_param(const float *values, size_t num) override {
return 0;
}
int32_t shrink() override { return 0; }
virtual void clear(){};
virtual int32_t flush() { return 0; };
virtual int32_t load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t save(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t initialize_shard() { return 0; };
virtual int32_t initialize() override;
// only for barrier
// 0: send_barrier 1: recv_barrier 2: complete
virtual int32_t barrier(const uint32_t trainer_id,
const std::string barrier_type) override;
virtual int32_t set_table_map(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
private:
std::mutex mutex_;
std::condition_variable trainer_wait_;
std::set<uint64_t> trainer_ids_;
std::set<uint64_t> trainer_all_;
std::atomic<int> trigger_;
std::atomic<bool> exit_;
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <math.h> // for sqrt in CPU and CUDA
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/common/utils.h"
namespace paddle {
namespace distributed {
// dense optimzier
// TODO(tangwei12) integrate with sparse optimzer later.
class DenseOptimizer {
public:
DenseOptimizer() {}
explicit DenseOptimizer(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {}
virtual void update(const float* update_values, size_t num, int begin,
int end) = 0;
};
// sum calc for dense tensor
class DSUM : public DenseOptimizer {
public:
explicit DSUM(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "Param") {
param = (*values)[x].data();
}
}
}
void update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
GetBlas<float>().VADD(update_numel, update_values + begin, param + begin,
param + begin);
}
float* param;
};
// sgd optimizer for dense tensor
class DSGD : public DenseOptimizer {
public:
explicit DSGD(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate = (*values)[x].data();
}
if (names[x] == "Param") {
param = (*values)[x].data();
}
}
}
void update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grads;
grads.resize(update_numel);
auto blas = GetBlas<float>();
blas.VCOPY(update_numel, update_values + begin, grads.data());
blas.SCAL(update_numel, *learning_rate, grads.data());
blas.VSUB(update_numel, param + begin, grads.data(), param + begin);
}
float* learning_rate;
float* param;
};
// adam optimizer for dense tensor
class DAdam : public DenseOptimizer {
public:
explicit DAdam(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate = (*values)[x].data();
}
if (names[x] == "Param") {
param = (*values)[x].data();
}
if (names[x] == "Moment1") {
moment1 = (*values)[x].data();
}
if (names[x] == "Moment2") {
moment2 = (*values)[x].data();
}
if (names[x] == "Beta1Pow") {
beta1_pow = (*values)[x].data();
}
if (names[x] == "Beta2Pow") {
beta2_pow = (*values)[x].data();
}
}
// add attr later
beta1 = 0.9;
beta2 = 0.999;
epsilon = 1.0e-8;
}
void update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grad, grad2, tmp;
grad.resize(update_numel);
grad2.resize(update_numel);
tmp.resize(update_numel);
auto blas = GetBlas<float>();
blas.VCOPY(update_numel, update_values + begin, grad.data());
blas.VCOPY(update_numel, update_values + begin, grad2.data());
blas.SCAL(update_numel, 1 - beta1, grad.data());
blas.VSQUARE(update_numel, grad2.data(), grad2.data());
blas.SCAL(update_numel, 1 - beta2, grad2.data());
blas.SCAL(update_numel, beta1, moment1 + begin);
blas.VADD(update_numel, moment1 + begin, grad.data(), moment1 + begin);
blas.SCAL(update_numel, beta2, moment2 + begin);
blas.VADD(update_numel, moment2 + begin, grad2.data(), moment2 + begin);
beta1_pow[0] = beta1_pow[0] * beta1;
beta2_pow[0] = beta2_pow[0] * beta2;
float lr_ = learning_rate[0];
lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
float* tmp_ = tmp.data();
float eps_ = epsilon * sqrt(1 - beta2_pow[0]);
SQRT<float>(update_numel, moment2 + begin, tmp_);
ADD<float>(update_numel, tmp_, eps_, tmp_);
blas.VDIV(update_numel, moment1 + begin, tmp_, tmp_);
blas.SCAL(update_numel, lr_, tmp_);
blas.VSUB(update_numel, param + begin, tmp_, param + begin);
}
float* learning_rate;
float* param;
float* moment1;
float* moment2;
float* beta1_pow;
float* beta2_pow;
float beta1;
float beta2;
float epsilon;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace paddle {
namespace distributed {
class ConcurrentSet {
public:
ConcurrentSet() : pool_(new ::ThreadPool(1)) {}
~ConcurrentSet() {}
std::future<void> Update(const std::vector<uint64_t>& rows) {
auto task = [this, rows] {
for (auto row : rows) {
set_.insert(row);
}
};
return pool_->enqueue(std::move(task));
}
std::future<void> GetAndClear(std::vector<uint64_t>* result) {
auto task = [this, &result] {
result->clear();
for (auto& id : set_) {
result->push_back(id);
}
set_.clear();
};
return pool_->enqueue(std::move(task));
}
private:
std::unordered_set<uint64_t> set_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
};
class GeoRecorder {
public:
explicit GeoRecorder(int trainer_num) : trainer_num_(trainer_num) {
trainer_rows_.reserve(trainer_num);
for (auto i = 0; i < trainer_num; ++i) {
trainer_rows_.emplace_back(new ConcurrentSet());
}
}
~GeoRecorder() = default;
void Update(const std::vector<uint64_t>& update_rows) {
VLOG(3) << " row size: " << update_rows.size();
std::vector<std::future<void>> fs;
for (auto& set : trainer_rows_) {
fs.push_back(set->Update(update_rows));
}
for (auto& f : fs) {
f.wait();
}
}
void GetAndClear(uint32_t trainer_id, std::vector<uint64_t>* result) {
VLOG(3) << "GetAndClear for trainer: " << trainer_id;
trainer_rows_.at(trainer_id)->GetAndClear(result).wait();
}
private:
const int trainer_num_;
std::vector<std::unique_ptr<ConcurrentSet>> trainer_rows_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/generator.h"
namespace paddle {
namespace distributed {
class Initializer {
public:
Initializer() {}
explicit Initializer(const std::vector<std::string> &attrs) {}
virtual float GetValue() = 0;
virtual ~Initializer() {}
protected:
std::string name_;
unsigned int seed_;
};
class UniformInitializer : public Initializer {
public:
explicit UniformInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
min_ = std::stof(attrs[2]);
max_ = std::stof(attrs[3]);
dist_ = std::uniform_real_distribution<float>(min_, max_);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float min_;
float max_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
class GaussianInitializer : public Initializer {
public:
explicit GaussianInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
random_engine_ = framework::GetCPURandomEngine(seed_);
dist_ = std::normal_distribution<float>(mean_, std_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float std_;
float mean_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::normal_distribution<float> dist_;
};
class FillConstantInitializer : public Initializer {
public:
explicit FillConstantInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
value_ = std::stof(attrs[1]);
}
float GetValue() override { return value_; }
private:
float value_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
enum Mode { training, infer };
template <typename T>
inline bool entry(const int count, const T threshold);
template <>
inline bool entry<std::string>(const int count, const std::string threshold) {
return true;
}
template <>
inline bool entry<int>(const int count, const int threshold) {
return count >= threshold;
}
template <>
inline bool entry<float>(const int count, const float threshold) {
UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
return uniform.GetValue() >= threshold;
}
struct VALUE {
explicit VALUE(const std::vector<std::string> &names)
: names_(names), count_(0), unseen_days_(0) {
values_.resize(names.size());
for (int i = 0; i < static_cast<int>(names.size()); i++) {
places[names[i]] = i;
}
}
void set(std::vector<std::vector<float>> *values) {
values_ = std::move(*values);
}
void set(const std::vector<std::string> &names,
const std::vector<std::vector<float>> &values) {
for (int i = 0; i < static_cast<int>(names.size()); i++) {
auto idx = places[names[i]];
auto value = values[i];
values_[idx].assign(value.begin(), value.end());
}
}
std::vector<std::vector<float> *> get() {
auto pts = std::vector<std::vector<float> *>();
pts.reserve(values_.size());
for (auto &value : values_) {
pts.push_back(&value);
}
return pts;
}
int fetch_count() { return ++count_; }
void reset_unseen_days() { unseen_days_ = 0; }
void set_entry(bool is_entry) { is_entry_ = is_entry; }
bool get_entry() { return is_entry_; }
std::vector<std::vector<float> *> get(const std::vector<std::string> names) {
auto pts = std::vector<std::vector<float> *>();
pts.reserve(values_.size());
for (int i = 0; i < static_cast<int>(names.size()); i++) {
pts.push_back(&(values_[places[names[i]]]));
}
return pts;
}
std::vector<std::string> names_;
int count_;
bool seen_after_last_save_;
int unseen_days_;
bool is_entry_;
std::vector<std::vector<float>> values_;
std::unordered_map<std::string, int> places;
};
class ValueBlock {
public:
explicit ValueBlock(
const CommonAccessorParameter &common,
std::unordered_map<std::string, Initializer *> *initializers) {
initializers_ = initializers;
int size = static_cast<int>(common.params().size());
for (int x = 0; x < size; ++x) {
auto varname = common.params()[x];
auto dim = common.dims()[x];
value_names_.push_back(varname);
value_dims_.push_back(dim);
}
// for Entry
{
// entry will add later
std::string entry_attr = "none";
if (entry_attr == "none") {
entry_func_ =
std::bind(entry<std::string>, std::placeholders::_1, "none");
} else {
auto slices = string::split_string<std::string>(entry_attr, "&");
if (slices[0] == "count_filter") {
int threshold = std::stoi(slices[1]);
entry_func_ = std::bind(entry<int>, std::placeholders::_1, threshold);
} else if (slices[0] == "probability") {
float threshold = std::stof(slices[1]);
entry_func_ =
std::bind(entry<float>, std::placeholders::_1, threshold);
}
}
}
}
~ValueBlock() {}
void Init(const uint64_t &id, std::vector<std::vector<float>> *values,
int count) {
if (Has(id)) {
PADDLE_THROW(platform::errors::AlreadyExists("id already exist, error"));
}
if (values->size() != value_names_.size()) {
PADDLE_THROW(
platform::errors::AlreadyExists("values can not match, error"));
}
auto value = new VALUE(value_names_);
value->set(values);
value->seen_after_last_save_ = true;
value->count_ = count;
values_[id] = value;
}
std::vector<std::vector<float> *> Get(
const uint64_t &id, const std::vector<std::string> &value_names) {
auto ret_values = values_.at(id)->get(value_names);
return ret_values;
}
std::vector<std::vector<float> *> Get(const uint64_t &id) {
auto ret_values = values_.at(id)->get(value_names_);
return ret_values;
}
void InitFromInitializer(const uint64_t &id,
const std::vector<std::string> &value_names) {
if (Has(id)) {
Update(id);
return;
}
auto rets = std::vector<std::vector<float>>();
rets.resize(value_names_.size());
for (int i = 0; i < static_cast<int>(value_names_.size()); i++) {
auto name = value_names_[i];
auto *init = initializers_->at(name);
auto dim = value_dims_[i];
rets[i].resize(dim);
for (int j = 0; j < static_cast<int>(dim); j++) {
rets[i][j] = init->GetValue();
}
}
Init(id, &rets, 0);
Update(id);
}
bool GetEntry(const uint64_t &id) {
auto value = values_.at(id);
auto entry = value->get_entry();
return entry;
}
void Set(const uint64_t &id, const std::vector<std::string> &value_names,
const std::vector<std::vector<float>> &values) {
auto value = values_.at(id);
value->set(value_names, values);
}
void Update(const uint64_t id) {
auto *value = values_.at(id);
value->reset_unseen_days();
auto count = value->fetch_count();
if (!value->get_entry()) {
value->set_entry(entry_func_(count));
}
}
private:
bool Has(const uint64_t id) {
auto got = values_.find(id);
if (got == values_.end()) {
return false;
} else {
return true;
}
}
public:
std::unordered_map<uint64_t, VALUE *> values_;
private:
std::vector<std::string> value_names_;
std::vector<int> value_dims_;
std::function<bool(uint64_t)> entry_func_;
std::unordered_map<std::string, Initializer *> *initializers_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <math.h> // for sqrt in CPU and CUDA
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
namespace paddle {
namespace distributed {
class SparseOptimizer {
public:
SparseOptimizer() {}
explicit SparseOptimizer(const CommonAccessorParameter& common) {}
virtual void update(const uint64_t* keys, const float* update_values,
size_t num, const std::vector<uint64_t>& offsets,
ValueBlock* block) = 0;
};
// sum calc for sparse tensor
class SSUM : public SparseOptimizer {
public:
SSUM(){};
explicit SSUM(const CommonAccessorParameter& common) {
auto& names = common.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "Param") {
param_idx = x;
update_numel = common.dims()[x];
}
}
}
void update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
for (auto x : offsets) {
auto id = keys[x];
auto values = block->Get(id);
float* param = values[param_idx]->data();
std::vector<float> delta;
delta.resize(update_numel);
blas.VCOPY(update_numel, update_values + x * update_numel, delta.data());
blas.VADD(update_numel, delta.data(), param, param);
}
}
int param_idx;
int update_numel;
};
// sgd optimzer for sparse tensor
class SSGD : public SparseOptimizer {
public:
SSGD(){};
explicit SSGD(const CommonAccessorParameter& common) {
auto& names = common.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate_idx = x;
}
if (names[x] == "Param") {
param_idx = x;
update_numel = common.dims()[x];
}
}
}
void update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
for (auto x : offsets) {
auto id = keys[x];
auto values = block->Get(id);
float* learning_rate = values[learning_rate_idx]->data();
float* param = values[param_idx]->data();
std::vector<float> grads;
grads.resize(update_numel);
blas.VCOPY(update_numel, update_values + x * update_numel, grads.data());
blas.SCAL(update_numel, learning_rate[0], grads.data());
blas.VSUB(update_numel, param, grads.data(), param);
}
}
int learning_rate_idx;
int param_idx;
int update_numel;
};
// adam optimzer for sparse tensor
class SAdam : public SparseOptimizer {
public:
SAdam() {}
explicit SAdam(const CommonAccessorParameter& common) {
auto& names = common.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate_idx = x;
}
if (names[x] == "Param") {
param_idx = x;
update_numel = common.dims()[x];
}
if (names[x] == "Moment1") {
moment1_idx = x;
}
if (names[x] == "Moment2") {
moment2_idx = x;
}
if (names[x] == "Beta1Pow") {
beta1_pow_idx = x;
}
if (names[x] == "Beta2Pow") {
beta2_pow_idx = x;
}
}
// add attr later
beta1 = 0.9;
beta2 = 0.999;
epsilon = 1.0e-8;
}
void update(const uint64_t* keys, const float* update_values, size_t num,
const std::vector<uint64_t>& offsets,
ValueBlock* block) override {
auto blas = GetBlas<float>();
for (auto x : offsets) {
auto id = keys[x];
auto values = block->Get(id);
float* learning_rate = values[learning_rate_idx]->data();
float* param = values[param_idx]->data();
float* moment1 = values[moment1_idx]->data();
float* moment2 = values[moment2_idx]->data();
float* beta1_pow = values[beta1_pow_idx]->data();
float* beta2_pow = values[beta2_pow_idx]->data();
beta1_pow[0] = beta1_pow[0] * beta1;
beta2_pow[0] = beta2_pow[0] * beta2;
float lr_ = learning_rate[0];
lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
std::vector<float> grad, grad2, tmp;
grad.resize(update_numel);
grad2.resize(update_numel);
tmp.resize(update_numel);
blas.VCOPY(update_numel, update_values + x * update_numel, grad.data());
blas.VCOPY(update_numel, update_values + x * update_numel, grad2.data());
blas.SCAL(update_numel, 1 - beta1, grad.data());
blas.VSQUARE(update_numel, grad2.data(), grad2.data());
blas.SCAL(update_numel, 1 - beta2, grad2.data());
blas.SCAL(update_numel, beta1, moment1);
blas.VADD(update_numel, moment1, grad.data(), moment1);
blas.SCAL(update_numel, beta2, moment2);
blas.VADD(update_numel, moment2, grad2.data(), moment2);
float* tmp_ = tmp.data();
float eps_ = epsilon * sqrt(1 - beta2_pow[0]);
SQRT<float>(update_numel, moment2, tmp_);
ADD<float>(update_numel, tmp_, eps_, tmp_);
blas.VDIV(update_numel, moment1, tmp_, tmp_);
blas.SCAL(update_numel, lr_, tmp_);
blas.VSUB(update_numel, param, tmp_, param);
}
}
int learning_rate_idx;
int param_idx;
int moment1_idx;
int moment2_idx;
int beta1_pow_idx;
int beta2_pow_idx;
float beta1;
float beta2;
float epsilon;
int update_numel;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
namespace paddle {
namespace distributed {
int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id,
std::vector<float>* values,
std::vector<uint64_t>* ids) {
geo_recorder->GetAndClear(trainer_id, ids);
auto dim = _config.common().dims()[0];
values->resize(ids->size() * dim);
CommonSparseTable::pull_sparse(values->data(), ids->data(), ids->size());
return 0;
}
int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values,
size_t num) {
std::vector<uint64_t> ids;
ids.resize(num);
std::copy_n(keys, num, ids.begin());
geo_recorder->Update(ids);
CommonSparseTable::push_sparse(keys, values, num);
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <pthread.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/depends/geo_recorder.h"
#include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#include "paddle/fluid/distributed/table/depends/sparse.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
class SparseGeoTable : public CommonSparseTable {
public:
explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; }
virtual ~SparseGeoTable() {}
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys);
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num) override;
virtual int32_t initialize_recorder() {
if (!geo_recorder) {
auto trainers = _config.common().trainer_num();
geo_recorder = std::make_shared<GeoRecorder>(trainers);
}
return 0;
}
private:
std::shared_ptr<GeoRecorder> geo_recorder;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/tensor_accessor.h"
#include "paddle/fluid/distributed/table/tensor_table.h"
namespace paddle {
namespace distributed {
REGISTER_CLASS(Table, CommonDenseTable);
REGISTER_CLASS(Table, CommonSparseTable);
REGISTER_CLASS(Table, DenseTensorTable);
REGISTER_CLASS(Table, SparseGeoTable);
REGISTER_CLASS(Table, BarrierTable);
REGISTER_CLASS(ValueAccessor, CommMergeAccessor);
int32_t TableManager::initialize() {
static bool initialized = false;
if (initialized) {
return 0;
}
initialized = true;
return 0;
}
int32_t Table::initialize(const TableParameter &config,
const FsClientParameter &fs_config) {
_config = config;
if (initialize_accessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed";
return -1;
}
return initialize();
}
int32_t Table::initialize_accessor() {
if (!_config.has_accessor() || !_config.accessor().has_accessor_class()) {
LOG(ERROR) << "missing accessor config in table, table_id:"
<< _config.table_id();
return -1;
}
auto *accessor =
CREATE_CLASS(ValueAccessor,
_config.accessor().accessor_class()) if (accessor == NULL) {
LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id()
<< ", accessor_name:" << _config.accessor().accessor_class();
return -1;
}
if (accessor->configure(_config.accessor()) || accessor->initialize() != 0) {
LOG(ERROR) << " accessor initialize failed, table_id:" << _config.table_id()
<< ", accessor_name:" << _config.accessor().accessor_class();
return -1;
}
_value_accesor.reset(accessor);
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
#include <atomic>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
class Table {
public:
Table() {}
virtual ~Table() {}
virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config) final;
virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0;
virtual int32_t push_dense_param(const float *values, size_t num) {
return 0;
}
virtual int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) = 0;
virtual int32_t push_sparse_param(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
// only for sparse geo table
virtual int32_t pull_geo_param(const uint32_t trainer_id,
std::vector<float> *values,
std::vector<uint64_t> *keys) {
return 0;
}
// only for barrier
virtual int32_t barrier(const uint32_t trainer_id,
const std::string barrier_type) {
return 0;
}
// only for barrier table
virtual int32_t set_table_map(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) {
return 0;
}
virtual int32_t pour() { return 0; }
virtual void clear() = 0;
virtual int32_t flush() = 0;
virtual int32_t shrink() = 0;
//指定加载路径
virtual int32_t load(const std::string &path,
const std::string &converter) = 0;
//指定保存路径
virtual int32_t save(const std::string &path,
const std::string &converter) = 0;
virtual int32_t set_shard(size_t shard_idx, size_t shard_num) final {
_shard_idx = shard_idx;
_shard_num = shard_num;
return initialize_shard();
}
inline std::shared_ptr<ValueAccessor> value_accesor() {
return _value_accesor;
}
virtual void *get_shard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> print_table_stat() { return {0, 0}; }
protected:
virtual int32_t initialize() = 0;
virtual int32_t initialize_accessor() final;
virtual int32_t initialize_shard() = 0;
virtual std::string table_dir(const std::string &model_dir) {
return paddle::string::format_string("%s/%03d/", model_dir.c_str(),
_config.table_id());
}
size_t _shard_idx; // table 分片编号
size_t _shard_num; // table 分片总数
TableParameter _config;
std::shared_ptr<ValueAccessor> _value_accesor;
};
REGISTER_REGISTERER(Table);
class TableManager {
public:
static TableManager &instance() {
static TableManager manager;
return manager;
}
int32_t initialize();
private:
TableManager() {}
~TableManager() {}
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/tensor_accessor.h"
#include "Eigen/Dense"
namespace paddle {
namespace distributed {
int CommMergeAccessor::initialize() { return 0; }
// value 维度
size_t CommMergeAccessor::dim() { return 0; }
// value 各个维度的size
size_t CommMergeAccessor::dim_size(size_t dim) { return 0; }
// value 各维度相加总size
size_t CommMergeAccessor::size() { return 0; }
// pull value 维度
size_t CommMergeAccessor::select_dim() { return _config.embedx_dim(); }
// pull value 各个维度的size
size_t CommMergeAccessor::select_dim_size(size_t dim) { return sizeof(float); }
// pull value 各维度相加总size
size_t CommMergeAccessor::select_size() { return select_dim() * sizeof(float); }
// push value 维度
size_t CommMergeAccessor::update_dim() { return _config.embedx_dim(); }
// push value 各个维度的size
size_t CommMergeAccessor::update_dim_size(size_t dim) { return sizeof(float); }
// push value 各维度相加总size
size_t CommMergeAccessor::update_size() { return update_dim() * sizeof(float); }
// 判断该value 是否进行shrink
bool CommMergeAccessor::shrink(float * /*value*/) { return false; }
// 判断该value 是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
bool CommMergeAccessor::save(float * /*value*/, int /*param*/) { return true; }
// keys不存在时,为values生成随机值
int32_t CommMergeAccessor::create(float **value, size_t num) { return 0; }
// 从values中选取到select_values中
int32_t CommMergeAccessor::select(float **select_values, const float **values,
size_t num) {
return 0;
}
// 将update_values聚合到一起
int32_t CommMergeAccessor::merge(float **update_values,
const float **other_update_values,
size_t num) {
Eigen::Map<Eigen::MatrixXf> u_mat(update_values[0], 1, num);
Eigen::Map<const Eigen::MatrixXf> o_mat(other_update_values[0], 1, num);
u_mat += o_mat;
return 0;
}
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// int32_t merge(float** update_values, iterator it);
// 将update_values更新应用到values中
int32_t CommMergeAccessor::update(float **values, const float **update_values,
size_t num) {
return 0;
}
int CommMergeAccessor::set_weight(float **values, const float **update_values,
size_t num) {
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/accessor.h"
namespace paddle {
namespace distributed {
class CommMergeAccessor : public ValueAccessor {
public:
CommMergeAccessor() {}
virtual ~CommMergeAccessor() {}
virtual int initialize();
// value维度
virtual size_t dim();
// value各个维度的size
virtual size_t dim_size(size_t dim);
// value各维度相加总size
virtual size_t size();
// pull value维度
virtual size_t select_dim();
// pull value各个维度的size
virtual size_t select_dim_size(size_t dim);
// pull value各维度相加总size
virtual size_t select_size();
// push value维度
virtual size_t update_dim();
// push value各个维度的size
virtual size_t update_dim_size(size_t dim);
// push value各维度相加总size
virtual size_t update_size();
// 判断该value是否进行shrink
virtual bool shrink(float * /*value*/);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool save(float * /*value*/, int /*param*/);
// keys不存在时,为values生成随机值
virtual int32_t create(float **value, size_t num);
// 从values中选取到select_values中
virtual int32_t select(float **select_values, const float **values,
size_t num);
// 将update_values聚合到一起
virtual int32_t merge(float **update_values,
const float **other_update_values, size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t update(float **values, const float **update_values,
size_t num);
virtual int set_weight(float **values, const float **update_values,
size_t num);
virtual std::string parse_to_string(const float *value, int param) {
return "";
}
virtual int parse_from_string(const std::string &str, float *v) { return 0; }
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/table/tensor_table.h"
#include "paddle/fluid/distributed/common/utils.h"
namespace paddle {
namespace distributed {
int32_t DenseTensorTable::initialize() {
_shards_task_pool.resize(10);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
}
return 0;
}
int32_t DenseTensorTable::initialize_tensor(framework::Scope *scope,
framework::ProgramDesc *program,
framework::Executor *executor) {
scope_ = scope;
program_ = program;
executor_ = executor;
auto tensor_config = _config.tensor();
if (tensor_config.has_common_block_map()) {
auto block_maps =
paddle::string::split_string(tensor_config.common_block_map(), "#");
for (auto &block_map : block_maps) {
auto block = paddle::string::split_string(block_map, ":");
auto block_id = std::stoi(block[0]);
std::vector<int> block_ids{block_id};
auto block_cmd = block[1];
auto prepared = executor_->Prepare(*program_, block_ids);
(*prepared_ctx_)[block_cmd] = prepared[0];
}
}
}
int32_t DenseTensorTable::pull_dense(float *values, size_t numel) {
PADDLE_ENFORCE_EQ(numel, _data.numel(),
paddle::platform::errors::PreconditionNotMet(
"pull dense error, excepted numel %d, but actually %d.",
_data.numel(), numel));
GetBlas<float>().VCOPY(numel, _data.data<float>(), values);
return 0;
}
int32_t DenseTensorTable::push_dense(const float *values, size_t numel) {
auto varname = _config.tensor().grad();
auto local_scope = scope_->NewTmpScope();
auto *var = local_scope->Var(varname);
auto *t = var->GetMutable<framework::LoDTensor>();
auto dims = paddle::framework::make_ddim({});
auto ctx = paddle::platform::CPUDeviceContext();
t->mutable_data<float>(_data.dims(), ctx.GetPlace());
GetBlas<float>().VCOPY(numel, values, t->data<float>());
executor_->RunPreparedContext((*prepared_ctx_)["push"].get(),
local_scope.get());
}
int32_t DenseTensorTable::push_dense_param(const float *values, size_t numel) {
auto ctx = paddle::platform::CPUDeviceContext();
if (_data.IsInitialized()) {
PADDLE_ENFORCE_EQ(
numel, _data.numel(),
paddle::platform::errors::PreconditionNotMet(
"pull dense error, excepted numel %d, but actually %d.",
_data.numel(), numel));
} else {
_data.mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(numel), 1}), ctx.GetPlace());
}
GetBlas<float>().VCOPY(numel, values, _data.data<float>());
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class TensorTable : public Table {
public:
TensorTable() : Table() {}
virtual ~TensorTable() {}
virtual int32_t initialize() { return 0; }
virtual int32_t pull_dense(float *values, size_t num) override { return 0; };
virtual int32_t push_dense(const float *values, size_t num) override {
return 0;
};
virtual void *get_shard(size_t shard_idx) override { return 0; }
virtual int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
return 0;
};
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
};
virtual int32_t push_dense_param(const float *values, size_t num) {
return 0;
}
virtual int32_t shrink() { return 0; }
virtual void clear() {}
virtual int32_t flush() { return 0; }
//指定加载路径
virtual int32_t load(const std::string &path, const std::string &converter) {
return 0;
}
//指定保存路径
virtual int32_t save(const std::string &path, const std::string &converter) {
return 0;
}
protected:
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_tensor(paddle::framework::Scope *scope,
paddle::framework::ProgramDesc *program,
paddle::framework::Executor *executor) {
return 0;
}
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
framework::Executor *executor_;
framework::Scope *scope_;
framework::ProgramDesc *program_;
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
*prepared_ctx_;
};
class DenseTensorTable : public TensorTable {
public:
DenseTensorTable() : TensorTable() {}
~DenseTensorTable() {}
virtual int32_t initialize();
void *get_shard(size_t shard_idx) { return 0; }
int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values, size_t num) {
return 0;
}
int32_t shrink() { return 0; }
int32_t pull_dense(float *values, size_t num) override;
int32_t push_dense_param(const float *values, size_t num) override;
int32_t push_dense(const float *values, size_t num) override;
virtual void clear() {}
virtual int32_t flush() { return 0; }
//指定加载路径
virtual int32_t load(const std::string &path, const std::string &converter) {
return 0;
}
//指定保存路径
virtual int32_t save(const std::string &path, const std::string &converter) {
return 0;
}
protected:
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_tensor(paddle::framework::Scope *scope,
paddle::framework::ProgramDesc *program,
paddle::framework::Executor *executor);
protected:
framework::Tensor _data;
};
//
//// common sparse table [0, N) with out large scale
// class SparseTensorTable : public TensorTable {
// void *get_shard(size_t shard_idx) { return 0; }
//
// int32_t pull_sparse(float *values, const uint64_t *keys, size_t num)
// override;
// int32_t push_sparse(const uint64_t *keys, const float *values, size_t num)
// override ;
// int32_t shrink() { return 0; }
// void *get_shard(size_t shard_idx) { return 0; };
//
// int32_t pull_dense(float *values, size_t num) { return 0; };
// int32_t push_dense_param(const float *values, size_t num) { return 0; };
// int32_t push_dense(const float *values, size_t num) { return 0; };
//
// protected:
// framework::Tensor _data;
//};
//// for Large scale kv tensor [0, int64] do not use specific optimizer
// class KvTensorTable : public TensorTable {
// int32_t pull_dense(float *values, size_t num) { return 0; };
// int32_t push_dense_param(const float *values, size_t num) { return 0; };
// int32_t push_dense(const float *values, size_t num) { return 0; };
//
// void *get_shard(size_t shard_idx) override;
// int32_t pull_sparse(float *values, const uint64_t *keys, size_t num)
// override;
// int32_t push_sparse(const uint64_t *keys, const float *values,
// size_t num) override;
// int32_t shrink() override;
// void *get_shard(size_t shard_idx) override;
//};
//
//// for Geo sparse handle
// class GeoSparseTensorTable : public TensorTable {};
} // namespace distributed
} // namespace paddle
if(APPLE)
return()
endif()
set_source_files_properties(table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(table_test SRCS table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(dense_table_test SRCS dense_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(sparse_table_test SRCS sparse_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(geo_table_test SRCS geo_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
# open it until CI support brpc
return()
set_source_files_properties(brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_service_dense_sgd_test SRCS brpc_service_dense_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_service_sparse_sgd_test SRCS brpc_service_sparse_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS})
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
TEST(BarrierTable, Barrier) {
int emb_dim = 10;
int trainers = 2;
bool sync = true;
TableParameter table_config;
table_config.set_table_class("BarrierTable");
FsClientParameter fs_config;
Table *table = new BarrierTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common();
common_config->set_table_name("barrier_table");
common_config->set_trainer_num(trainers);
common_config->set_sync(sync);
auto ret = table->initialize(table_config, fs_config);
std::unordered_map<uint32_t, std::shared_ptr<Table>> maps =
std::unordered_map<uint32_t, std::shared_ptr<Table>>();
table->set_table_map(&maps);
std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status;
for (auto x = 0; x < trainers; x++) {
auto task = [table, x] { table->barrier(x, 0); };
task_status.push_back(pool_->enqueue(std::move(task)));
}
for (auto &status : task_status) {
status.wait();
}
ASSERT_EQ(ret, 0);
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/service.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0 * (float)i;
}
void GetDownpourDenseTableProto(
::paddle::distributed::TableParameter* dense_table_proto) {
dense_table_proto->set_table_id(0);
dense_table_proto->set_table_class("CommonDenseTable");
dense_table_proto->set_shard_num(256);
dense_table_proto->set_type(::paddle::distributed::PS_DENSE_TABLE);
::paddle::distributed::TableAccessorParameter* accessor_proto =
dense_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto =
dense_table_proto->mutable_common();
accessor_proto->set_accessor_class("CommMergeAccessor");
accessor_proto->set_fea_dim(100);
accessor_proto->set_embedx_dim(1);
common_proto->set_name("sgd");
common_proto->set_table_name("MergedDense");
common_proto->set_trainer_num(1);
common_proto->set_sync(false);
common_proto->add_params("Param");
common_proto->add_dims(100);
common_proto->add_initializers("fill_constant&1.0");
common_proto->add_params("LearningRate");
common_proto->add_dims(1);
common_proto->add_initializers("fill_constant&1.0");
}
::paddle::distributed::PSParameter GetServerProto() {
// Generate server proto desc
::paddle::distributed::PSParameter server_fleet_desc;
::paddle::distributed::ServerParameter* server_proto =
server_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
::paddle::distributed::TableParameter* dense_table_proto =
downpour_server_proto->add_downpour_table_param();
GetDownpourDenseTableProto(dense_table_proto);
return server_fleet_desc;
}
::paddle::distributed::PSParameter GetWorkerProto() {
::paddle::distributed::PSParameter worker_fleet_desc;
::paddle::distributed::WorkerParameter* worker_proto =
worker_fleet_desc.mutable_worker_param();
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
worker_proto->mutable_downpour_worker_param();
::paddle::distributed::TableParameter* worker_dense_table_proto =
downpour_worker_proto->add_downpour_table_param();
GetDownpourDenseTableProto(worker_dense_table_proto);
::paddle::distributed::ServerParameter* server_proto =
worker_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
::paddle::distributed::TableParameter* server_dense_table_proto =
downpour_server_proto->add_downpour_table_param();
GetDownpourDenseTableProto(server_dense_table_proto);
return worker_fleet_desc;
}
/*-------------------------------------------------------------------------*/
std::string ip_ = "127.0.0.1";
uint32_t port_ = 4214;
std::vector<std::string> host_sign_list_;
std::shared_ptr<paddle::distributed::PSServer> pserver_ptr_;
std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
LOG(INFO) << "RUN set_ps_servers";
_ps_env.set_ps_servers(&host_sign_list_, 1);
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(server_proto));
LOG(INFO) << "RUN configure";
pserver_ptr_->configure(server_proto, _ps_env, 0);
LOG(INFO) << "RUN start";
pserver_ptr_->start(ip_, port_);
LOG(INFO) << "End start";
}
void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
dense_regions) {
::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
LOG(INFO) << "Run set_ps_servers";
_ps_env.set_ps_servers(&host_sign_list_, servers_);
LOG(INFO) << "Run Create PSClient";
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(worker_proto));
LOG(INFO) << "Run configure";
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0);
}
void RunBrpcPushDense() {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());
// Srart Server
std::thread server_thread(RunServer);
sleep(1);
// Start Client
LOG(INFO) << "Run InitTensorsOnClient";
framework::Scope client_scope;
platform::CPUPlace place;
InitTensorsOnClient(&client_scope, &place, 100);
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
framework::Variable* var = client_scope.FindVar("x");
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
float* w = tensor->data<float>();
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
LOG(INFO) << "Run RunClient";
RunClient(dense_regions);
/*-----------------------Test Server Init----------------------------------*/
LOG(INFO) << "Run pull_dense_param";
float* temp = new float[tensor->numel()]();
std::vector<paddle::distributed::Region> temp_region;
paddle::distributed::Region temp_reg(temp, tensor->numel());
temp_region.emplace_back(std::move(temp_reg));
auto pull_status =
worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0);
pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(temp[idx], 1.0);
}
/*-----------------------Test Push Param----------------------------------*/
LOG(INFO) << "Run push_dense_param";
auto push_status =
worker_ptr_->push_dense_param(regions.data(), regions.size(), 0);
push_status.wait();
pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0);
pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(w[idx], float(idx));
}
/*-----------------------Test Push Grad----------------------------------*/
paddle::distributed::DownpourBrpcClosure* closure =
new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
LOG(INFO) << "Run pull_dense_grad";
auto push_grad_status =
worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure);
push_grad_status.wait();
auto pull_update_status =
worker_ptr_->pull_dense(regions.data(), regions.size(), 0);
pull_update_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(w[idx], float(idx) - 1.0);
}
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
server_thread.join();
}
TEST(RunBrpcPushDense, Run) { RunBrpcPushDense(); }
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/service.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
}
void GetDownpourSparseTableProto(
::paddle::distributed::TableParameter* sparse_table_proto) {
sparse_table_proto->set_table_id(0);
sparse_table_proto->set_table_class("CommonSparseTable");
sparse_table_proto->set_shard_num(256);
sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE);
::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto->mutable_common();
accessor_proto->set_accessor_class("CommMergeAccessor");
accessor_proto->set_fea_dim(0);
accessor_proto->set_embedx_dim(10);
common_proto->set_name("sgd");
common_proto->set_table_name("MergedDense");
common_proto->set_trainer_num(1);
common_proto->set_sync(false);
common_proto->add_params("Param");
common_proto->add_dims(10);
common_proto->add_initializers("uniform_random&0&-1.0&1.0");
common_proto->add_params("LearningRate");
common_proto->add_dims(1);
common_proto->add_initializers("fill_constant&1.0");
}
::paddle::distributed::PSParameter GetServerProto() {
// Generate server proto desc
::paddle::distributed::PSParameter server_fleet_desc;
::paddle::distributed::ServerParameter* server_proto =
server_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
GetDownpourSparseTableProto(sparse_table_proto);
return server_fleet_desc;
}
::paddle::distributed::PSParameter GetWorkerProto() {
::paddle::distributed::PSParameter worker_fleet_desc;
::paddle::distributed::WorkerParameter* worker_proto =
worker_fleet_desc.mutable_worker_param();
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
worker_proto->mutable_downpour_worker_param();
::paddle::distributed::TableParameter* worker_sparse_table_proto =
downpour_worker_proto->add_downpour_table_param();
GetDownpourSparseTableProto(worker_sparse_table_proto);
::paddle::distributed::ServerParameter* server_proto =
worker_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
::paddle::distributed::TableParameter* server_sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
GetDownpourSparseTableProto(server_sparse_table_proto);
return worker_fleet_desc;
}
/*-------------------------------------------------------------------------*/
std::string ip_ = "127.0.0.1";
uint32_t port_ = 4209;
std::vector<std::string> host_sign_list_;
std::shared_ptr<paddle::distributed::PSServer> pserver_ptr_;
std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
void RunServer() {
::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 1);
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(server_proto));
pserver_ptr_->configure(server_proto, _ps_env, 0);
pserver_ptr_->start(ip_, port_);
}
void RunClient(std::map<uint64_t, std::vector<paddle::distributed::Region>>&
dense_regions) {
::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0);
}
void RunBrpcPushSparse() {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());
// Srart Server
std::thread server_thread(RunServer);
sleep(1);
// Start Client
framework::Scope client_scope;
platform::CPUPlace place;
InitTensorsOnClient(&client_scope, &place, 100);
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
framework::Variable* var = client_scope.FindVar("x");
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
RunClient(dense_regions);
std::vector<uint64_t> fea_keys(10);
std::vector<float> fea_values(100);
std::vector<float> fea_temp_values(100);
std::vector<float*> fea_value_ptr(10);
std::vector<float*> fea_temp_value_ptr(10);
for (size_t idx = 0; idx < fea_keys.size(); ++idx) {
fea_keys[idx] = (uint64_t)idx;
fea_value_ptr[idx] = fea_values.data() + idx * 10;
fea_temp_value_ptr[idx] = fea_temp_values.data() + idx * 10;
}
/*-----------------------Test Server Init----------------------------------*/
LOG(INFO) << "Run pull_sparse_param";
auto pull_status = worker_ptr_->pull_sparse(fea_value_ptr.data(), 0,
fea_keys.data(), fea_keys.size());
pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
fea_values.data()[idx] *= 2.0;
}
/*-----------------------Test Push Param----------------------------------*/
LOG(INFO) << "Run push_sparse_param";
paddle::distributed::DownpourBrpcClosure* closure_push_param =
new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_SPARSE_PARAM) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto push_status = worker_ptr_->push_sparse_param(
0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(),
closure_push_param);
push_status.wait();
auto pull_param_status = worker_ptr_->pull_sparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size());
pull_param_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx]);
}
/*-----------------------Test Push Grad----------------------------------*/
paddle::distributed::DownpourBrpcClosure* closure_push_grad =
new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
LOG(INFO) << "Run pull_sparse_grad";
std::vector<float*> push_g_vec;
for (auto i = 0; i < static_cast<int>(fea_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * 10);
}
auto push_grad_status = worker_ptr_->push_sparse_raw_gradient(
0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(),
closure_push_grad);
push_grad_status.wait();
auto pull_update_status = worker_ptr_->pull_sparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size());
pull_update_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx] - 1.0);
}
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
server_thread.join();
}
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
void CreateVarsOnScope(framework::Scope* scope, platform::Place* place,
const platform::DeviceContext& ctx) {
// var 1
framework::Variable* var1 = scope->Var("x1");
auto* tensor1 = var1->GetMutable<framework::LoDTensor>();
tensor1->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod1;
lod1.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor1->set_lod(lod1);
tensor1->mutable_data<float>(*place);
math::set_constant(ctx, tensor1, 31.9);
// var 2
framework::Variable* var2 = scope->Var("x2");
auto* tensor2 = var2->GetMutable<framework::LoDTensor>();
tensor2->Resize(framework::make_ddim({1000, 64}));
framework::LoD lod2;
lod2.push_back(framework::Vector<size_t>({1, 1}));
tensor2->set_lod(lod2);
tensor2->mutable_data<int>(*place);
math::set_constant(ctx, tensor2, 100);
// var 3
framework::Variable* var3 = scope->Var("x3");
auto* slr = var3->GetMutable<framework::SelectedRows>();
slr->set_height(564);
auto* tensor3 = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor3->Resize(framework::make_ddim({564, 128}));
tensor3->mutable_data<float>(*place);
math::set_constant(ctx, tensor3, 32.7);
for (int i = 0; i < 564; ++i) rows->push_back(i);
}
void RunMultiVarMsg(platform::Place place) {
framework::Scope scope;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
CreateVarsOnScope(&scope, &place, ctx);
::paddle::MultiVariableMessage multi_msg;
std::string message_name("se_de_test");
std::vector<std::string> send_var_name = {"x1", "x2", "x3"};
std::vector<std::string> recv_var_name = {};
LOG(INFO) << "begin SerializeToMultiVarMsg";
butil::IOBuf io_buf;
distributed::SerializeToMultiVarMsgAndIOBuf(message_name, send_var_name,
recv_var_name, ctx, &scope,
&multi_msg, &io_buf);
EXPECT_GT(multi_msg.ByteSizeLong(), static_cast<size_t>(0));
// deserialize
framework::Scope scope_recv;
LOG(INFO) << "begin DeserializeFromMultiVarMsg";
distributed::DeserializeFromMultiVarMsgAndIOBuf(multi_msg, &io_buf, ctx,
&scope_recv);
// check var1
framework::Variable* var1 = scope_recv.FindVar("x1");
auto* tensor1 = var1->GetMutable<framework::LoDTensor>();
EXPECT_EQ(tensor1->dims(), framework::make_ddim({512, 8, 4, 2}));
// EXPECT_EQ(tensor1->lod(), framework::Vector<size_t>({1, 3, 8}));
auto* tensor_data1 = const_cast<float*>(tensor1->data<float>());
int tensor_numel1 = 512 * 8 * 4 * 2;
for (int i = 0; i < tensor_numel1; ++i)
EXPECT_FLOAT_EQ(tensor_data1[i], 31.9);
// check var2
framework::Variable* var2 = scope_recv.FindVar("x2");
auto* tensor2 = var2->GetMutable<framework::LoDTensor>();
EXPECT_EQ(tensor2->dims(), framework::make_ddim({1000, 64}));
// EXPECT_EQ(tensor2->lod(), framework::Vector<size_t>({1, 1}));
auto* tensor_data2 = const_cast<int*>(tensor2->data<int>());
int tensor_numel2 = 1000 * 64;
for (int i = 0; i < tensor_numel2; ++i) EXPECT_EQ(tensor_data2[i], 100);
// check var3
framework::Variable* var3 = scope_recv.FindVar("x3");
auto* slr = var3->GetMutable<framework::SelectedRows>();
EXPECT_EQ(slr->rows().size(), 564);
for (int i = 0; i < 564; ++i) {
EXPECT_EQ(slr->rows()[i], i);
}
auto* tensor3 = slr->mutable_value();
EXPECT_EQ(tensor3->dims(), framework::make_ddim({564, 128}));
auto* tensor_data3 = const_cast<float*>(tensor3->data<float>());
int tensor_numel3 = 564 * 128;
for (int i = 0; i < tensor_numel3; ++i)
EXPECT_FLOAT_EQ(tensor_data3[i], 32.7);
}
TEST(MultiVarMsgCPU, Run) {
platform::CPUPlace place;
RunMultiVarMsg(place);
}
// #ifdef PADDLE_WITH_CUDA
// TEST(MultiVarMsgGPU, Run) {
// platform::CUDAPlace place;
// RunMultiVarMsg(place);
// }
// #endif
\ No newline at end of file
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
// CommonDenseTable + Adam
TEST(CommonDenseTable, Adam) {
int fea_dim = 10;
int trainers = 2;
float beta1 = 0.9;
float beta2 = 0.999;
float epsilon = 1.0e-8;
TableParameter table_config;
table_config.set_table_class("CommonDenseTable");
FsClientParameter fs_config;
Table *table = new CommonDenseTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common();
// set adam optimize config
common_config->set_name("adam");
common_config->set_table_name("adam_test_table");
common_config->set_trainer_num(trainers);
common_config->add_params("Param");
common_config->add_dims(fea_dim);
common_config->add_initializers("gaussian_random&0&0.0&1.0");
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
common_config->add_params("Moment1");
common_config->add_dims(fea_dim);
common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Moment2");
common_config->add_dims(fea_dim);
common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Beta1Pow");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
common_config->add_params("Beta2Pow");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
std::vector<float> init_values;
init_values.resize(fea_dim);
table->pull_dense(init_values.data(), fea_dim);
// push gradient
std::vector<std::vector<float>> trainer_gradient_values;
trainer_gradient_values.resize(trainers);
float start = 10.0;
for (int i = 0; i < trainers; i++) {
for (int k = 0; k < fea_dim; k++) {
trainer_gradient_values[i].push_back(start);
start += 0.1;
}
}
// for adam
for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i];
table->push_dense(push_values.data(), push_values.size());
}
std::vector<float> pull_values;
pull_values.resize(fea_dim);
table->pull_dense(pull_values.data(), fea_dim);
std::vector<float> beta1_pow, beta2_pow, lr, mom1, mom2, param;
beta1_pow.push_back(beta1);
beta2_pow.push_back(beta2);
lr.push_back(1.0);
for (int i = 0; i < fea_dim; i++) {
mom1.push_back(0.0);
mom2.push_back(0.0);
param.push_back(init_values[i]);
}
for (int i = 0; i < trainers; i++) {
auto lr_ = lr[0] * sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
for (int j = 0; j < fea_dim; j++) {
mom1[j] = beta1 * mom1[j] + (1 - beta1) * trainer_gradient_values[i][j];
mom2[j] = beta2 * mom2[j] +
(1 - beta2) * trainer_gradient_values[i][j] *
trainer_gradient_values[i][j];
param[j] =
param[j] -
lr_ * (mom1[j] / (sqrt(mom2[j]) + epsilon * sqrt(1 - beta2_pow[0])));
}
beta1_pow[0] *= beta1;
beta2_pow[0] *= beta2;
}
for (int j = 0; j < fea_dim; j++) {
ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-6);
}
}
// CommonDenseTable + Adam
TEST(CommonDenseTable, SGD) {
int fea_dim = 10;
int trainers = 2;
TableParameter table_config;
table_config.set_table_class("CommonDenseTable");
FsClientParameter fs_config;
Table *table = new CommonDenseTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common();
common_config->set_name("sgd");
common_config->set_table_name("sgd_test_table");
common_config->set_trainer_num(trainers);
common_config->add_params("Param");
common_config->add_dims(fea_dim);
common_config->add_initializers("gaussian_random&0&0.0&1.0");
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
std::vector<float> init_values;
init_values.resize(fea_dim);
table->pull_dense(init_values.data(), fea_dim);
std::vector<float> total_gradients;
total_gradients.resize(fea_dim);
memset(total_gradients.data(), 0, sizeof(float) * total_gradients.size());
// push gradient
std::vector<std::vector<float>> trainer_gradient_values;
trainer_gradient_values.resize(trainers);
float start = 10.0;
for (int i = 0; i < trainers; i++) {
for (int k = 0; k < fea_dim; k++) {
trainer_gradient_values[i].push_back(start);
total_gradients[k] += start;
start += 0.1;
}
}
std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status;
for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_values] {
table->push_dense(push_values.data(), push_values.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
for (auto &status : task_status) {
status.wait();
}
std::vector<float> pull_values;
pull_values.resize(fea_dim);
table->pull_dense(pull_values.data(), fea_dim);
for (int j = 0; j < fea_dim; j++) {
auto update_val = init_values[j] - 1.0 * total_gradients[j];
ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5);
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
// SparseGeoTable + SSUM
TEST(SparseGeoTable, SSUM) {
int emb_dim = 10;
int trainers = 2;
TableParameter table_config;
table_config.set_table_class("SparseGeoTable");
FsClientParameter fs_config;
Table *table = new SparseGeoTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common();
common_config->set_name("sum");
common_config->set_table_name("ssum_test_table");
common_config->set_trainer_num(trainers);
common_config->add_params("Param");
common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// test push_sparse_param, and create params
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<float> init_values;
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
init_values.push_back(0.0);
}
table->push_sparse_param(init_keys.data(), init_values.data(),
init_keys.size());
std::vector<float> pull_values(init_values.size());
table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size());
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-6);
}
std::vector<std::vector<uint64_t>> trainer_keys;
std::vector<std::vector<float>> trainer_values;
trainer_keys.resize(trainers);
trainer_values.resize(trainers);
float start = 0.0;
for (int i = 0; i < trainers; i++) {
trainer_keys[i] = init_keys;
for (size_t j = 0; j < trainer_keys[i].size(); j++) {
auto id = trainer_keys[i][j];
for (int k = 0; k < emb_dim; k++) {
trainer_values[i].push_back(start);
pull_values[id * emb_dim + k] += start;
start += 0.1;
}
}
}
std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status;
for (int i = 0; i < trainers; i++) {
auto &push_keys = trainer_keys[i];
auto &push_values = trainer_values[i];
auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(),
push_keys.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
for (auto &status : task_status) {
status.wait();
}
std::vector<std::vector<uint64_t>> geo_pull_ids;
std::vector<std::vector<float>> geo_pull_values;
geo_pull_ids.resize(trainers);
geo_pull_values.resize(trainers);
for (int i = 0; i < trainers; i++) {
table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]);
ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim);
for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) {
auto id = geo_pull_ids[i][j];
for (int k = 0; k < emb_dim; k++) {
ASSERT_TRUE(abs(geo_pull_values[i][j * emb_dim + k] -
pull_values[id * emb_dim + k]) < 1e-5);
}
}
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
// CommonSparseTable + SSGD
TEST(CommonSparseTable, SGD) {
int emb_dim = 10;
int trainers = 2;
TableParameter table_config;
table_config.set_table_class("CommonSparseTable");
FsClientParameter fs_config;
Table *table = new CommonSparseTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common();
common_config->set_name("sgd");
common_config->set_table_name("sgd_test_table");
common_config->set_trainer_num(trainers);
common_config->add_params("Param");
common_config->add_dims(emb_dim);
common_config->add_initializers("uniform_random&0&-1.0&1.0"); // param
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0"); // learning_rate
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<float> init_values;
init_values.resize(init_keys.size() * emb_dim);
table->pull_sparse(init_values.data(), init_keys.data(), init_keys.size());
// for check
std::vector<float> total_gradients;
total_gradients.resize(init_keys.size() * emb_dim);
memset(total_gradients.data(), 0, sizeof(float) * total_gradients.size());
// push gradient
std::vector<std::vector<uint64_t>> trainer_keys;
std::vector<std::vector<float>> trainer_gradient_values;
trainer_keys.resize(trainers);
trainer_gradient_values.resize(trainers);
float start = 0.0;
for (int i = 0; i < trainers; i++) {
trainer_keys[i] = init_keys;
for (size_t j = 0; j < trainer_keys[i].size(); j++) {
auto id = trainer_keys[i][j];
for (int k = 0; k < emb_dim; k++) {
trainer_gradient_values[i].push_back(start);
total_gradients[id * emb_dim + k] += start;
start += 0.1;
}
}
}
std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status;
for (int i = 0; i < trainers; i++) {
auto &push_keys = trainer_keys[i];
auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(),
push_keys.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
for (auto &status : task_status) {
status.wait();
}
std::vector<float> pull_values;
pull_values.resize(init_keys.size() * emb_dim);
table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size());
for (size_t i = 0; i < init_values.size(); ++i) {
auto update_val = init_values[i] - 1.0 * total_gradients[i];
ASSERT_TRUE(abs(update_val - pull_values[i]) < 1e-5);
}
}
// CommonSparseTable + Adam
TEST(CommonSparseTable, Adam) {
int emb_dim = 10;
int trainers = 2;
float beta1 = 0.9;
float beta2 = 0.999;
float epsilon = 1.0e-8;
TableParameter table_config;
table_config.set_table_class("CommonSparseTable");
FsClientParameter fs_config;
Table *table = new CommonSparseTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common();
common_config->set_name("adam");
common_config->set_table_name("adam_test_table");
common_config->set_trainer_num(trainers);
common_config->add_params("Param");
common_config->add_dims(emb_dim);
common_config->add_initializers("uniform_random&0&-1.0&1.0");
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
common_config->add_params("Moment1");
common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Moment2");
common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Beta1Pow");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
common_config->add_params("Beta2Pow");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<float> init_values;
init_values.resize(init_keys.size() * emb_dim);
table->pull_sparse(init_values.data(), init_keys.data(), init_keys.size());
// push gradient
std::vector<std::vector<uint64_t>> trainer_keys;
std::vector<std::vector<float>> trainer_gradient_values;
trainer_keys.resize(trainers);
trainer_gradient_values.resize(trainers);
float start = 0.0;
for (int i = 0; i < trainers; i++) {
trainer_keys[i] = init_keys;
for (size_t j = 0; j < trainer_keys[i].size(); j++) {
for (int k = 0; k < emb_dim; k++) {
trainer_gradient_values[i].push_back(start);
start += 0.1;
}
}
}
for (int i = 0; i < trainers; i++) {
auto &push_keys = trainer_keys[i];
auto &push_values = trainer_gradient_values[i];
table->push_sparse(push_keys.data(), push_values.data(), push_keys.size());
}
std::vector<float> pull_values;
pull_values.resize(init_keys.size() * emb_dim);
table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size());
for (size_t idx = 0; idx < init_keys.size(); idx += emb_dim) {
std::vector<float> beta1_pow, beta2_pow, lr, mom1, mom2, param;
beta1_pow.push_back(beta1);
beta2_pow.push_back(beta2);
lr.push_back(1.0);
for (int i = 0; i < emb_dim; i++) {
mom1.push_back(0.0);
mom2.push_back(0.0);
param.push_back(init_values[idx + i]);
}
for (int i = 0; i < trainers; i++) {
auto lr_ = lr[0] * sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
for (int j = 0; j < emb_dim; j++) {
mom1[j] =
beta1 * mom1[j] + (1 - beta1) * trainer_gradient_values[i][idx + j];
mom2[j] = beta2 * mom2[j] +
(1 - beta2) * trainer_gradient_values[i][idx + j] *
trainer_gradient_values[i][idx + j];
param[j] = param[j] -
lr_ * (mom1[j] /
(sqrt(mom2[j]) + epsilon * sqrt(1 - beta2_pow[0])));
}
beta1_pow[0] *= beta1;
beta2_pow[0] *= beta2;
}
for (int i = 0; i < emb_dim; i++) {
ASSERT_TRUE(abs(param[i] - pull_values[idx + i]) < 1e-5);
}
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
TEST(Table, Initialize) {
TableParameter table_config;
table_config.set_table_class("SparseGeoTable");
FsClientParameter fs_config;
// case 1. no accessor
Table *table = new SparseGeoTable();
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, -1);
}
} // namespace distributed
} // // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册