diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 16457b564ffc82a4246776dc283261bed0351ec6..c18332d3b873164a725a25316fc611aa7e7a3092 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(memory) add_subdirectory(platform) +add_subdirectory(distributed) add_subdirectory(framework) add_subdirectory(imperative) add_subdirectory(operators) diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e99b8b76534369c81e79b274bfdd18fb0e73b394 --- /dev/null +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -0,0 +1,30 @@ +if(NOT WITH_DISTRIBUTE) + return() +endif() + +proto_library(ps_framework_proto SRCS ps.proto) + +set(DISTRIBUTE_COMPILE_FLAGS "-Wno-error=unused-value -Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result") + +if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS + "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") +endif() + + +add_subdirectory(table) +add_subdirectory(test) + +# open it until CI support brpc +return() + +add_subdirectory(service) + +get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + +set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(fleet + SRCS fleet.cc + DEPS framework_proto ps_framework_proto ps_service variable_helper scope op_registry fs shell ${RPC_DEPS}) + +target_link_libraries(fleet z) diff --git a/paddle/fluid/distributed/common/registerer.h b/paddle/fluid/distributed/common/registerer.h new file mode 100644 index 0000000000000000000000000000000000000000..a4eab9c4a75e9ecabb183a9f41460a8b0cb516f6 --- /dev/null +++ b/paddle/fluid/distributed/common/registerer.h @@ -0,0 +1,127 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +class Any { + public: + Any() : content_(NULL) {} + + template + Any(const ValueType &value) : content_(new Holder(value)) {} + + Any(const Any &other) + : content_(other.content_ ? other.content_->clone() : NULL) {} + + ~Any() { delete content_; } + + template + ValueType *any_cast() { + return content_ ? &static_cast *>(content_)->held_ : NULL; + } + + private: + class PlaceHolder { + public: + virtual ~PlaceHolder() {} + virtual PlaceHolder *clone() const = 0; + }; + + template + class Holder : public PlaceHolder { + public: + explicit Holder(const ValueType &value) : held_(value) {} + virtual PlaceHolder *clone() const { return new Holder(held_); } + + ValueType held_; + }; + + PlaceHolder *content_; +}; + +class ObjectFactory { + public: + ObjectFactory() {} + virtual ~ObjectFactory() {} + virtual Any NewInstance() { return Any(); } + + private: +}; + +typedef std::map FactoryMap; +typedef std::map BaseClassMap; +#ifdef __cplusplus +extern "C" { +#endif + +inline BaseClassMap &global_factory_map() { + static BaseClassMap *base_class = new BaseClassMap(); + return *base_class; +} +#ifdef __cplusplus +} +#endif + +inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } + +// typedef pa::Any Any; +// typedef ::FactoryMap FactoryMap; +#define REGISTER_REGISTERER(base_class) \ + class base_class##Registerer { \ + public: \ + static base_class *CreateInstanceByName(const ::std::string &name) { \ + if (global_factory_map_cpp().find(#base_class) == \ + global_factory_map_cpp().end()) { \ + LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" \ + << #base_class; \ + return NULL; \ + } \ + FactoryMap &map = global_factory_map_cpp()[#base_class]; \ + FactoryMap::iterator iter = map.find(name); \ + if (iter == map.end()) { \ + LOG(ERROR) << "Can't Find Class For Create with:" << name; \ + return NULL; \ + } \ + Any object = iter->second->NewInstance(); \ + return *(object.any_cast()); \ + } \ + }; + +#define REGISTER_CLASS(clazz, name) \ + class ObjectFactory##name : public ObjectFactory { \ + public: \ + Any NewInstance() { return Any(new name()); } \ + }; \ + void register_factory_##name() { \ + FactoryMap &map = global_factory_map_cpp()[#clazz]; \ + if (map.find(#name) == map.end()) { \ + map[#name] = new ObjectFactory##name(); \ + } \ + } \ + void register_factory_##name() __attribute__((constructor)); + +#define CREATE_CLASS(base_class, name) \ + base_class##Registerer::CreateInstanceByName(name); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/common/utils.h b/paddle/fluid/distributed/common/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f81f84b1e117510443a5698a6ba1574262f640a5 --- /dev/null +++ b/paddle/fluid/distributed/common/utils.h @@ -0,0 +1,87 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace distributed { + +template +inline paddle::operators::math::BlasT +GetBlas() { + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + return paddle::operators::math::GetBlas(cpu_ctx); +} + +template +inline void SQRT(int n, const T* x, T* z) { + for (int i = 0; i < n; ++i) { + z[i] = sqrt(x[i]); + } +} + +template +inline void ADD(int n, const T* x, const T y, T* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y; + } +} + +static bool StartWith(const std::string& str, const std::string& substr) { + return str.find(substr) == 0; +} + +static bool EndWith(const std::string& str, const std::string& substr) { + return str.rfind(substr) == (str.length() - substr.length()); +} + +inline std::vector bucket(const int v_size, const int b_size) { + int remainder = v_size % b_size; + int bucket = v_size / b_size; + std::vector ret_vec(b_size, bucket); + for (int i = 0; i < remainder; ++i) { + ret_vec[i] = ret_vec[i] + 1; + } + int cur_bucket = 0; + for (int& j : ret_vec) { + int tmp = j; + j = cur_bucket; + cur_bucket += tmp; + } + ret_vec.push_back(cur_bucket); + return ret_vec; +} + +template +std::string to_string(const std::vector& vec) { + std::stringstream ss; + for (const auto& c : vec) { + ss << c << " "; + } + return ss.str(); +} +} +} diff --git a/paddle/fluid/distributed/communicator_common.h b/paddle/fluid/distributed/communicator_common.h new file mode 100644 index 0000000000000000000000000000000000000000..6a8ce552370bf72d95dd0d52a8e521afb0b0dfd0 --- /dev/null +++ b/paddle/fluid/distributed/communicator_common.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +struct CommContext { + CommContext() = default; + + CommContext(const std::string &name, const std::vector &names, + const std::vector &emap, + const std::vector §ions, + const std::vector &origin_names, int id, + bool merge_add_ = true, bool is_sparse_ = true, + bool is_distributed_ = false, int table_id_ = -1) + : var_name(name), + splited_varnames(names), + epmap(emap), + height_sections(sections), + origin_varnames(origin_names), + trainer_id(id), + merge_add(merge_add_), + is_sparse(is_sparse_), + is_distributed(is_distributed_), + table_id(table_id_) {} + + CommContext(const CommContext &ctx) { + var_name = ctx.var_name; + splited_varnames = ctx.splited_varnames; + epmap = ctx.epmap; + height_sections = ctx.height_sections; + trainer_id = ctx.trainer_id; + merge_add = ctx.merge_add; + is_sparse = ctx.is_sparse; + origin_varnames = ctx.origin_varnames; + is_distributed = ctx.is_distributed; + table_id = ctx.table_id; + } + + std::string print() const { + std::stringstream ss; + + ss << "varname: " << var_name << " trainer_id: " << trainer_id << " "; + ss << " table_id: " << table_id; + + for (size_t i = 0; i < splited_varnames.size(); i++) { + ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i] + << " section: " << height_sections[i] << " "; + } + + ss << "origin varnames: "; + for (size_t i = 0; i < origin_varnames.size(); i++) { + ss << origin_varnames[i] << " "; + } + + ss << " aggregation->add: " << merge_add; + ss << " is_sparse: " << is_sparse; + ss << " is_distributed: " << is_distributed << "\n"; + ss << " table_id: " << table_id << "\n"; + + return ss.str(); + } + + std::string var_name; + std::vector splited_varnames; + std::vector epmap; + std::vector height_sections; + std::vector origin_varnames; + int trainer_id; + bool merge_add; + bool is_sparse; + bool is_distributed; + int table_id; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc new file mode 100644 index 0000000000000000000000000000000000000000..92211a72e748eb3ca7555a5a68707ad5a52dc4bf --- /dev/null +++ b/paddle/fluid/distributed/fleet.cc @@ -0,0 +1,585 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/fleet.h" +#include +#include +#include "paddle/fluid/distributed/service/communicator.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::ProgramDesc; +using framework::VarDesc; +using framework::Variable; + +const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; +std::shared_ptr FleetWrapper::s_instance_ = NULL; +bool FleetWrapper::is_initialized_ = false; + +std::shared_ptr FleetWrapper::pserver_ptr_ = NULL; + +void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, + int connect_timeout_ms, + int max_retry) { + client2client_request_timeout_ms_ = request_timeout_ms; + client2client_connect_timeout_ms_ = connect_timeout_ms; + client2client_max_retry_ = max_retry; +} + +void FleetWrapper::LoadSparseOnServer(const std::string& path, + const std::string& meta, + uint32_t table_id) { + VLOG(3) << "load sparse table " << table_id << " with " << path << " meta " + << meta; + pserver_ptr_->_server_ptr->table(table_id)->load(path, meta); +} + +void FleetWrapper::InitServer(const std::string& dist_desc, + const std::vector& host_sign_list, + int index) { + if (!is_initialized_) { + VLOG(3) << "Going to init server"; + pserver_ptr_ = std::shared_ptr( + new paddle::distributed::PSCore()); + pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), + index); + is_initialized_ = true; + } else { + VLOG(3) << "Server can be initialized only once"; + } +} + +// void FleetWrapper::InitWorker( +// const std::string& dist_desc, const std::vector& +// host_sign_list, Scope* scope, const RpcCtxMap& send_ctx, const +// std::unordered_map>& +// dense_varnames, +// const std::map& envs, int node_num, int index) +// { +// if (!is_initialized_) { +// VLOG(3) << "Going to init worker"; + +// Communicator::InitInstance( +// send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs); + +// pserver_ptr_ = std::shared_ptr( +// new paddle::distributed::PSCore()); +// pserver_ptr_->init_worker(dist_desc, _regions, +// const_cast(host_sign_list.data()), +// node_num, index); +// is_initialized_ = true; +// } else { +// VLOG(3) << "Worker can be initialized only once"; +// } +// } + +void FleetWrapper::InitWorker( + const std::string& dist_desc, + const std::vector& host_sign_list, Scope* scope, + const RpcCtxMap& send_ctx, + const std::unordered_map>& + dense_varnames, + const std::map& envs, int node_num, int index) { + if (!is_initialized_) { + VLOG(3) << "Going to init worker"; + + Communicator::InitInstance( + send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs); + + pserver_ptr_ = std::shared_ptr( + new paddle::distributed::PSCore()); + pserver_ptr_->init_worker(dist_desc, _regions, &host_sign_list, node_num, + index); + is_initialized_ = true; + } else { + VLOG(3) << "Worker can be initialized only once"; + } +} + +void FleetWrapper::StopServer() { + VLOG(3) << "Going to stop server"; + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->stop_server(); + status.wait(); +} + +void FleetWrapper::FinalizeWorker() { + VLOG(3) << "Going to finalize worker"; + pserver_ptr_->finalize_worker(); +} + +void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { + VLOG(3) << "Going to Barrier worker"; + auto* communicator = Communicator::GetInstance(); + communicator->BarrierWithTable(barrier_type); +} + +uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { + VLOG(3) << "Going to run server with ip " << ip << " port " << port; + auto ret = pserver_ptr_->run_server(ip, port); + return ret; +} + +std::vector FleetWrapper::GetClientsInfo() { + VLOG(3) << "Going to get client info"; + return pserver_ptr_->get_client_info(); + return std::vector(); +} + +void FleetWrapper::CreateClient2ClientConnection() { + VLOG(3) << "Going to create client2client connection"; + pserver_ptr_->create_client2client_connection( + client2client_request_timeout_ms_, client2client_connect_timeout_ms_, + client2client_max_retry_); +} + +std::future FleetWrapper::PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim) { + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (auto name : var_names) { + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : *fea_values) { + pull_result_ptr.push_back(t.data()); + } + return pserver_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size()); +} + +void FleetWrapper::PullSparseVarsSync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim, + const std::vector& var_emb_names) { + std::vector> pull_sparse_status; + pull_sparse_status.resize(0); + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (size_t var_index = 0; var_index < var_names.size(); ++var_index) { + const std::string& name = var_names[var_index]; + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + + // skip slots which do not have embedding + const std::string& emb_name = var_emb_names[var_index]; + Variable* emb_var = scope.FindVar(emb_name); + if (emb_var == nullptr) { + continue; + } + + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : *fea_values) { + pull_result_ptr.push_back(t.data()); + } + auto status = pserver_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size()); + pull_sparse_status.push_back(std::move(status)); + for (auto& t : pull_sparse_status) { + t.wait(); + auto status = t.get(); + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + } +} + +void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, + uint64_t padding_id, + platform::Place place, + std::vector* inputs, + std::vector* outputs) { + std::vector fea_keys; + std::vector pull_result_ptr; + fea_keys.reserve(MAX_FEASIGN_NUM / 100); + pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100); + std::vector init_value(fea_dim, 0); + framework::LoDTensor* output = nullptr; + float* output_data = nullptr; + size_t output_index = -1; + size_t output_len = 0; + for (size_t index = 0; index < inputs->size(); ++index) { + const framework::LoDTensor* tensor = inputs->at(index); + const int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (size_t i = 0; i < len; ++i, output_len += fea_dim) { + if (!output || output_len == size_t(output->numel())) { + ++output_index; + CHECK(output_index < outputs->size()); // NOLINT + output = outputs->at(output_index); + output->set_lod(tensor->lod()); + output_data = output->mutable_data(place); + output_len = 0; + CHECK(output->numel() % fea_dim == 0); // NOLINT + CHECK(output_data != nullptr); // NOLINT + } + uint64_t real_id = static_cast(ids[i]); + if (real_id == padding_id) { + memcpy(output_data + output_len, init_value.data(), + sizeof(float) * fea_dim); + continue; + } + fea_keys.push_back(real_id); + pull_result_ptr.push_back(output_data + output_len); + } + } + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size()); + status.wait(); + auto ret = status.get(); + if (ret != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]"; + sleep(sleep_seconds_before_fail_exit_); + } +} + +void FleetWrapper::PullDenseVarsAsync( + const Scope& scope, const uint64_t tid, + const std::vector& var_names, + std::vector>* pull_dense_status, bool in_cpu) { + auto& regions = _regions[tid]; + regions.clear(); + regions.resize(var_names.size()); + for (auto i = 0u; i < var_names.size(); ++i) { + std::string varname = var_names[i]; + if (!in_cpu) { + varname = var_names[i] + "pin"; + } + Variable* var = scope.FindVar(varname); + LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions[i] = std::move(reg); + } + auto status = pserver_ptr_->_worker_ptr->pull_dense(regions.data(), + regions.size(), tid); + pull_dense_status->push_back(std::move(status)); +} + +void FleetWrapper::PullDenseVarsSync( + const Scope& scope, const uint64_t tid, + const std::vector& var_names) { + auto& regions = _regions[tid]; + regions.clear(); + regions.reserve(var_names.size()); + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->pull_dense(regions.data(), + regions.size(), tid); + status.wait(); +} + +void FleetWrapper::PushDenseParamSync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names) { + auto place = platform::CPUPlace(); + std::vector regions; + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->mutable_data(place); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + auto* communicator = Communicator::GetInstance(); + auto push_status = communicator->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); + push_status.wait(); + auto status = push_status.get(); + CHECK(status == 0) << "push dense param failed, status[" << status << "]"; +} + +void FleetWrapper::PushDenseVarsSync( + Scope* scope, const uint64_t table_id, + const std::vector& var_names) {} + +void FleetWrapper::PushDenseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* push_sparse_status, float scale_datanorm, + int batch_size) { + auto* communicator = Communicator::GetInstance(); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + communicator->Send(var_names, scope); +} + +void FleetWrapper::PushSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::string& grad_varname, + std::vector>* push_sparse_status) { + std::vector varnames; + varnames.push_back(grad_varname); + + auto* communicator = Communicator::GetInstance(); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + communicator->Send(varnames, scope); +} + +void FleetWrapper::PushSparseVarsWithLabelAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& fea_keys, const std::vector& fea_labels, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector>* push_values, + std::vector>* push_sparse_status, const int batch_size, + const bool use_cvm, const bool dump_slot, + std::vector* sparse_push_keys, const bool no_cvm) { + // not support + return; +} + +void FleetWrapper::PushSparseFromTensorWithLabelAsync( + const Scope& scope, const uint64_t table_id, int fea_dim, + uint64_t padding_id, bool scale_sparse, const std::string& accesor, + const std::string& click_name, platform::Place place, + const std::vector& input_names, + std::vector* inputs, + std::vector* outputs) { + // not support + return; +} + +void FleetWrapper::LoadModel(const std::string& path, const int mode) { + auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "load model from path:" << path << " failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::LoadModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { + auto ret = + pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "load model of table id: " << table_id + << ", from path: " << path << " failed"; + } +} + +void FleetWrapper::SaveModel(const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->save(path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "save model failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::SaveModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = + communicator->_worker_ptr->save(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "save model of table id: " << table_id + << ", to path: " << path << " failed"; + } +} + +void FleetWrapper::PrintTableStat(const uint64_t table_id) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->print_table_stat(table_id); + ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "print table stat failed"; + } +} + +void FleetWrapper::ShrinkSparseTable(int table_id) { + auto ret = pserver_ptr_->_worker_ptr->shrink(table_id); + ret.wait(); +} + +void FleetWrapper::ClearModel() { + auto ret = pserver_ptr_->_worker_ptr->clear(); + ret.wait(); +} + +void FleetWrapper::ClearOneTable(const uint64_t table_id) { + auto ret = pserver_ptr_->_worker_ptr->clear(table_id); + ret.wait(); +} + +void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, + std::vector var_list, + float decay, int emb_dim) { + std::vector regions; + for (std::string& name : var_list) { + if (name.find("batch_sum") != std::string::npos) { + Variable* var = scope->FindVar(name); + CHECK(var != nullptr) << "var[" << name << "] not found"; + VLOG(0) << "prepare shrink dense batch_sum"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->data(); + + // show_batch_sum += N * log(decay) + std::string size_name = name; + size_name.replace(size_name.find("batch_sum"), size_name.length(), + "batch_size"); + Variable* var_size = scope->FindVar(size_name); + CHECK(var_size != nullptr) << "var[" << size_name << "] not found"; + VLOG(3) << "shrink dense batch_sum: " << name << ", " << size_name; + float* g_size = var_size->GetMutable()->data(); + + for (int k = 0; k < tensor->numel(); k += emb_dim) { + g[k] = g[k] + g_size[k] * log(decay); + } + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } else { + Variable* var = scope->FindVar(name); + CHECK(var != nullptr) << "var[" << name << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->data(); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + auto push_status = pserver_ptr_->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); + push_status.wait(); + auto status = push_status.get(); + if (status != 0) { + // PADDLE_THORW(platform::errors::Fatal( + // "push shrink dense param failed, status is [%d].", status)); + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::ClientFlush() { + auto ret = pserver_ptr_->_worker_ptr->flush(); + ret.wait(); +} + +int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, + MsgHandlerFunc handler) { + VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; + VLOG(3) << "pserver_ptr_=" << pserver_ptr_; + VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr; + return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, + handler); +} + +std::future FleetWrapper::SendClientToClientMsg( + int msg_type, int to_client_id, const std::string& msg) { + return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type, + to_client_id, msg); +} + +std::default_random_engine& FleetWrapper::LocalRandomEngine() { + struct engine_wrapper_t { + std::default_random_engine engine; + + engine_wrapper_t() { + struct timespec tp; + clock_gettime(CLOCK_REALTIME, &tp); + double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9; + static std::atomic x(0); + std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)}; + engine.seed(sseq); + } + }; + thread_local engine_wrapper_t r; + return r.engine; +} + +size_t FleetWrapper::GetAbsoluteSum(size_t start, size_t end, size_t level, + const framework::LoD& lod) { + if (level >= lod.size() - 1) { + return end - start; + } + size_t ret = 0; + for (size_t i = start; i < end - 1; ++i) { + size_t pos1 = lod[level][i]; + size_t pos2 = lod[level][i + 1]; + ret += GetAbsoluteSum(pos1, pos2, level + 1, lod); + } + return ret; +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h new file mode 100644 index 0000000000000000000000000000000000000000..7f106fafbf2e2e3cb8fd4e7769d97314ee2f31e5 --- /dev/null +++ b/paddle/fluid/distributed/fleet.h @@ -0,0 +1,246 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "paddle/fluid/distributed/communicator_common.h" +#include "paddle/fluid/distributed/service/service.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/io/shell.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::Scope; +using framework::SelectedRows; +using framework::Variable; + +using RpcCtxMap = std::unordered_map; + +class FleetWrapper { + public: + virtual ~FleetWrapper() {} + FleetWrapper() { + scale_sparse_gradient_with_batch_size_ = true; + // trainer sleep some time for pserver core dump + sleep_seconds_before_fail_exit_ = 300; + // pserver request server timeout ms + client2client_request_timeout_ms_ = 500000; + // pserver connect server timeout_ms + client2client_connect_timeout_ms_ = 10000; + // pserver request max retry + client2client_max_retry_ = 3; + } + + // set client to client communication config + void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, + int max_retry); + + // Pull sparse variables from server in sync mode + // Param: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names + // Param: fea_values + void PullSparseVarsSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, + int fea_dim, + const std::vector& var_emb_names); + + // Pull sparse variables from server in async mode + // Param: scope, table_id, var_names, fea_keys, fea_dim + // Param: fea_values std::future + std::future PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, int fea_dim); + + // Pull sparse variables from server in sync mode + // pull immediately to tensors + void PullSparseToTensorSync(const uint64_t table_id, int fea_dim, + uint64_t padding_id, platform::Place place, + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT + + // pull dense variables from server in sync mod + // Param: scope, table_id, var_names + // Param: void + void PullDenseVarsSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names); + + // pull dense variables from server in async mod + // Param: scope, table_id, var_names + // Param: pull_dense_status + void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* pull_dense_status, + bool in_cpu); + + // push dense parameters(not gradients) to server in sync mode + void PushDenseParamSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names); + + void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* push_sparse_status, + float scale_datanorm, int batch_size); + + // push dense variables to server in sync mode + void PushDenseVarsSync(Scope* scope, const uint64_t table_id, + const std::vector& var_names); + + void PushSparseVarsAsync( + const Scope& scope, const uint64_t table_id, const std::string& grad, + std::vector>* push_sparse_status); + // This is specially designed for click/show stats in server + // Param: scope, table_id, fea_keys, fea_labels, sparse_key_names, + // sparse_grad_names, batch_size, use_cvm, dump_slot + // Param: push_values, push_sparse_status + void PushSparseVarsWithLabelAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& fea_keys, + const std::vector& fea_labels, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector>* push_values, + std::vector>* push_sparse_status, + const int batch_size, const bool use_cvm, const bool dump_slot, + std::vector* sparse_push_keys, const bool no_cvm); + + // Push sparse variables to server in async mode + void PushSparseFromTensorWithLabelAsync( + const Scope& scope, const uint64_t table_id, int fea_dim, + uint64_t padding_id, bool scale_sparse, const std::string& accesor, + const std::string& click_name, platform::Place place, + const std::vector& input_names, + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT + + // Push sparse variables to server in Async mode + // Param: scope, table_id, fea_keys, sparse_grad_names + // Param: push_values, push_sparse_status + + // init server + void LoadSparseOnServer(const std::string& path, const std::string& meta, + uint32_t table_id); + // init server + // void InitServer(const std::string& dist_desc, + // const std::vector& host_sign_list, int index); + void InitServer(const std::string& dist_desc, + const std::vector& host_sign_list, int index); + // init trainer + void InitWorker(const std::string& dist_desc, + const std::vector& host_sign_list, Scope* scope, + const RpcCtxMap& send_ctx, + const std::unordered_map>& + dense_varnames, + const std::map& envs, int node_num, + int index); + + // stop server + void StopServer(); + // finalize worker to make worker can be stop + void FinalizeWorker(); + // run server with ip port + uint64_t RunServer(const std::string& ip, uint32_t port); + // get client info + std::vector GetClientsInfo(); + // create client to client connection + void CreateClient2ClientConnection(); + // flush all push requests + void ClientFlush(); + + // barrier with barrier table + void BarrierWithTable(uint32_t barrier_type); + + void PrintTableStat(const uint64_t table_id); + // mode = 0, load all feature + // mode = 1, load delta feature, which means load diff + void LoadModel(const std::string& path, const int mode); + // mode = 0, load all feature + // mode = 1, load delta feature, which means load diff + void LoadModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModel(const std::string& path, const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // clear all models, release their memory + void ClearModel(); + // clear one table + void ClearOneTable(const uint64_t table_id); + // shrink sparse table + void ShrinkSparseTable(int table_id); + // shrink dense table + void ShrinkDenseTable(int table_id, Scope* scope, + std::vector var_list, float decay, + int emb_dim); + + typedef std::function MsgHandlerFunc; + // register client to client communication + int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); + // send client to client message + std::future SendClientToClientMsg(int msg_type, int to_client_id, + const std::string& msg); + + // FleetWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::distributed::FleetWrapper()); + } + return s_instance_; + } + // this performs better than rand_r, especially large data + std::default_random_engine& LocalRandomEngine(); + + static std::shared_ptr pserver_ptr_; + + private: + static std::shared_ptr s_instance_; + size_t GetAbsoluteSum(size_t start, size_t end, size_t level, + const framework::LoD& lod); + + protected: + static bool is_initialized_; + std::map> _regions; + bool scale_sparse_gradient_with_batch_size_; + int32_t sleep_seconds_before_fail_exit_; + int client2client_request_timeout_ms_; + int client2client_connect_timeout_ms_; + int client2client_max_retry_; + DISABLE_COPY_AND_ASSIGN(FleetWrapper); +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/ps.proto b/paddle/fluid/distributed/ps.proto new file mode 100644 index 0000000000000000000000000000000000000000..383ff73690bfdbb35ad87fa91c0f511c7b1a3b85 --- /dev/null +++ b/paddle/fluid/distributed/ps.proto @@ -0,0 +1,152 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; +package paddle.distributed; +option cc_generic_services = true; +option cc_enable_arenas = true; + +message FsClientParameter { + enum FsApiType { + HDFS = 0; + AFS = 1; + } + optional FsApiType fs_type = 1 [ default = HDFS ]; + optional string uri = 2; // such as afs://xxx.afs.com:9902 + optional string user = 3; // user_name to access fs + optional string passwd = 4; // password + optional int32 buffer_size = 5; // buffer for read/write + optional string hadoop_bin = 51; + optional string afs_conf = 101; +} + +message PSParameter { + optional string worker_class = 1; + optional string server_class = 2; + optional string instance_class = 3; + optional string init_gflags = 4 [ default = "" ]; + optional WorkerParameter worker_param = 101; + optional ServerParameter server_param = 102; + repeated DownpourTrainerParameter trainer_param = 301; + optional FsClientParameter fs_client_param = 501; +} + +message WorkerParameter { + optional DownpourWorkerParameter downpour_worker_param = 1; +} + +message DownpourWorkerParameter { + repeated TableParameter downpour_table_param = 1; +} + +message DownpourServerParameter { + repeated TableParameter downpour_table_param = 1; + optional ServerServiceParameter service_param = 2; +} + +message ServerParameter { + optional DownpourServerParameter downpour_server_param = 1; +} + +message DownpourTrainerParameter { + repeated DenseTableParameter dense_table = 1; + repeated SparseTableParameter sparse_table = 2; + optional int32 push_sparse_per_batch = 3; + optional int32 push_dense_per_batch = 4; + repeated string skip_op = 5; + repeated ProgramConfig program_config = 6; +} + +message DenseTableParameter { + optional int32 table_id = 1; + repeated string dense_variable_name = 2; + repeated string dense_gradient_variable_name = 3; + optional int32 fea_dim = 4; +} + +message SparseTableParameter { + optional int32 table_id = 1; + optional int32 feature_dim = 2; + repeated string slot_key = 3; + repeated string slot_value = 4; + repeated string slot_gradient = 5; +} + +message ServerServiceParameter { + optional string server_class = 1 [ default = "BrpcPsServer" ]; + optional string client_class = 2 [ default = "BrpcPsClient" ]; + optional string service_class = 3 [ default = "PsService" ]; + optional uint32 start_server_port = 4 + [ default = 0 ]; // will find a avaliable port from it + optional uint32 server_thread_num = 5 [ default = 12 ]; +} + +message ProgramConfig { + required string program_id = 1; + repeated int32 push_sparse_table_id = 2; + repeated int32 push_dense_table_id = 3; + repeated int32 pull_sparse_table_id = 4; + repeated int32 pull_dense_table_id = 5; +} + +enum TableType { + PS_SPARSE_TABLE = 0; + PS_DENSE_TABLE = 1; + PS_OTHER_TABLE = 2; +} + +message TableParameter { + optional uint64 table_id = 1; + optional string table_class = 2; + optional uint64 shard_num = 3 [ default = 1000 ]; + optional TableAccessorParameter accessor = 4; + optional TensorAccessorParameter tensor = 5; + optional CommonAccessorParameter common = 6; + optional TableType type = 7; + optional bool compress_in_save = 8 [ default = false ]; +} + +message TableAccessorParameter { + optional string accessor_class = 1; + optional uint32 fea_dim = 4 [ default = 11 ]; + optional uint32 embedx_dim = 5 [ default = 8 ]; + optional uint32 embedx_threshold = 6 [ default = 10 ]; + repeated TableAccessorSaveParameter table_accessor_save_param = 8; +} + +message TensorAccessorParameter { + optional string tensor_class = 1; + optional uint32 fea_dim = 2; + optional uint32 emb_dim = 3; + optional string param = 4; + optional string grad = 5; + optional string common_block_map = 6; +} + +message CommonAccessorParameter { + optional string name = 1; + optional string table_name = 2; + repeated string attributes = 3; + repeated string params = 4; + repeated uint32 dims = 5; + repeated string initializers = 6; + optional int32 trainer_num = 7; + optional bool sync = 8; +} + +message TableAccessorSaveParameter { + optional uint32 param = 1; + optional string converter = 2; + optional string deconverter = 3; +} diff --git a/paddle/fluid/distributed/service/CMakeLists.txt b/paddle/fluid/distributed/service/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c767ad2b3fa6b0462a379795005c2bddf377516 --- /dev/null +++ b/paddle/fluid/distributed/service/CMakeLists.txt @@ -0,0 +1,40 @@ +set(BRPC_SRCS ps_client.cc server.cc) +set_source_files_properties(${BRPC_SRCS}) + +set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) + +brpc_library(sendrecv_rpc SRCS + ${BRPC_SRCS} + PROTO sendrecv.proto + DEPS ${BRPC_DEPS} ) + +set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper) + +get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + +set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + + +cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS}) +cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS}) + +cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS}) +cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS}) + +cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS}) +cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS}) + +cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS}) +cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) +cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc9d017532dff0bb5d17fd65fd231aadeccefcb8 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -0,0 +1,879 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" + +const static int max_port = 65535; + +DEFINE_int32(pserver_push_dense_merge_limit, 12, + "limit max push_dense local merge requests"); + +DEFINE_int32(pserver_push_sparse_merge_limit, 12, + "limit max push_sparse local merge requests"); + +DEFINE_int32(pserver_pull_dense_limit, 12, + "limit max push_sparse local merge requests"); + +DEFINE_int32(pserver_async_push_dense_interval_ms, 10, + "async push_dense to server interval"); + +DEFINE_int32(pserver_async_push_sparse_interval_ms, 10, + "async push_sparse to server interval"); + +DEFINE_bool(pserver_scale_gradient_by_merge, false, + "scale dense gradient when merged"); + +DEFINE_int32(pserver_communicate_compress_type, 0, + "none:0 snappy:1 gzip:2 zlib:3 lz4:4"); + +DEFINE_int32(pserver_max_async_call_num, 13, + "max task num in async_call_server"); + +DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms"); + +DEFINE_int32(pserver_connect_timeout_ms, 10000, + "pserver connect server timeout_ms"); + +DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); + +namespace paddle { +namespace distributed { + +inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + size_t remind = shard_num % server_num; + size_t local_shard_num = + remind == 0 ? shard_num / server_num : shard_num / server_num + 1; + return (key % shard_num) / local_shard_num; +} + +void DownpourPsClientService::service( + ::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + int ret = _client->handle_client2client_msg( + request->cmd_id(), request->client_id(), request->data()); + response->set_err_code(0); + response->set_err_msg(""); + if (ret != 0) { + response->set_err_code(-1); + response->set_err_msg("handle_client2client_msg failed"); + } +} + +// 启动client端RpcService 用于数据互发等操作 +int32_t BrpcPsClient::start_client_service() { + if (_service.configure(this, _client_id) != 0) { + LOG(ERROR) + << "service initialize failed, service_name:DownpourPsClientService"; + return -1; + } + _server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + int start_port = 8500; + options.num_threads = 24; + + if (_server.Start(butil::my_ip_cstr(), brpc::PortRange(start_port, max_port), + &options) != 0) { + LOG(ERROR) << "BrpcPsServer start failed"; + return -1; + } + _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, + _client_id); + return 0; +} + +int32_t BrpcPsClient::create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = pserver_connect_timeout_ms; + options.max_retry = max_retry; + + std::vector client_list = _env->get_ps_clients(); + _client_channels.resize(client_list.size()); + std::ostringstream os; + std::string server_ip_port; + for (size_t i = 0; i < client_list.size(); ++i) { + server_ip_port.assign(client_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(client_list[i].port)); + _client_channels[i].reset(new brpc::Channel()); + if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "psclient connect to client:" << server_ip_port + << " Failed!"; + } + os << server_ip_port << ","; + } + LOG(INFO) << "Client connect success:" << os.str(); + return 0; +} + +int32_t BrpcPsClient::initialize() { + _async_call_num = 0; + + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = FLAGS_pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms; + options.max_retry = 3; + + std::ostringstream os; + std::string server_ip_port; + std::string client_ip(butil::my_ip_cstr()); + + // 获取server列表,并连接 + std::vector server_list = _env->get_ps_servers(); + _server_channels.resize(server_list.size()); + for (size_t i = 0; i < server_list.size(); ++i) { + server_ip_port.assign(server_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(server_list[i].port)); + for (size_t j = 0; j < _server_channels[i].size(); ++j) { + _server_channels[i][j].reset(new brpc::Channel()); + if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "psclient connect to server:" << server_ip_port + << " Failed!"; + return -1; + } + } + os << server_ip_port << ","; + } + // 启动client探听接口, 并相互建立连接 + start_client_service(); + + _running = true; + _flushing = false; + return 0; +} + +int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) { + if (_cntls[request_idx]->Failed()) { + LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " + "err:" + << _cntls[request_idx]->ErrorText(); + return -1; + } + if (_responses[request_idx].err_code() != 0) { + LOG(ERROR) << "response ret bad, server_idx:" << request_idx + << "cmd_id:" << cmd_id + << " err_code:" << _responses[request_idx].err_code() + << " err_msg:" << _responses[request_idx].err_msg(); + return -1; + } + return 0; +} + +int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) { + uint32_t feasign_size = 0; + if (_cntls[request_idx]->Failed()) { + LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " + "err:" + << _cntls[request_idx]->ErrorText(); + return -1; + } + feasign_size = _responses[request_idx].err_code(); + if (feasign_size < 0) { + LOG(ERROR) << "response ret bad, server_idx:" << request_idx + << "cmd_id:" << cmd_id + << " err_code:" << _responses[request_idx].err_code() + << " err_msg:" << _responses[request_idx].err_msg(); + return -1; + } + return feasign_size; +} + +std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { + std::string data = _responses[request_idx].data(); + return data; +} + +std::future BrpcPsClient::print_table_stat(uint32_t table_id) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, table_id](void *done) { + int ret = 0; + uint64_t feasign_size = 0; + uint64_t mf_size = 0; + paddle::framework::BinaryArchive ar; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) { + ret = -1; + break; + } + std::string resp = closure->get_response(i, PS_PRINT_TABLE_STAT); + ar.SetReadBuffer(const_cast(resp.c_str()), resp.length(), + nullptr); + + feasign_size += ar.Get(); + mf_size += ar.Get(); + } + closure->set_promise_value(ret); + std::cout << "table id: " << table_id + << ", feasign size: " << feasign_size + << ", mf size: " << mf_size << std::endl; + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} +std::future BrpcPsClient::send_cmd( + uint32_t table_id, int cmd_id, const std::vector ¶ms) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, cmd_id) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + for (const auto ¶m : params) { + closure->request(i)->add_params(param); + } + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::send_save_cmd( + uint32_t table_id, int cmd_id, const std::vector ¶ms) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void *done) { + int ret = 0; + uint32_t feasign_size = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_save_response(i, cmd_id) < 0) { + ret = -1; + break; + } + feasign_size += closure->check_save_response(i, cmd_id); + } + if (ret == 0) { + closure->set_promise_value(feasign_size); + } else { + closure->set_promise_value(ret); + } + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + for (const auto ¶m : params) { + closure->request(i)->add_params(param); + } + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::shrink(uint32_t table_id) { + return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")}); +} + +std::future BrpcPsClient::load(const std::string &epoch, + const std::string &mode) { + return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); +} +std::future BrpcPsClient::load(uint32_t table_id, + const std::string &epoch, + const std::string &mode) { + return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); +} + +std::future BrpcPsClient::save(const std::string &epoch, + const std::string &mode) { + return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); +} +std::future BrpcPsClient::save(uint32_t table_id, + const std::string &epoch, + const std::string &mode) { + return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); +} + +std::future BrpcPsClient::clear() { + return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); +} +std::future BrpcPsClient::clear(uint32_t table_id) { + return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {}); +} + +std::future BrpcPsClient::flush() { + _flushing = true; + std::promise promise; + std::future fut = promise.get_future(); + do { + VLOG(3) << "wait _async_call_num:" << _async_call_num; + usleep(100000); // sleep 100ms wait async end + } while (_async_call_num > 0); + promise.set_value(0); + _flushing = false; + return fut; +} + +void BrpcPsClient::finalize_worker() { + flush(); + _running = false; + _server.Stop(1000); + _server.Join(); +} + +std::future BrpcPsClient::stop_server() { + return send_cmd(-1, PS_STOP_SERVER, {}); +} + +std::future BrpcPsClient::start_profiler() { + return send_cmd(-1, PS_START_PROFILER, {}); +} + +std::future BrpcPsClient::stop_profiler() { + return send_cmd(-1, PS_STOP_PROFILER, {}); +} + +std::future BrpcPsClient::barrier(size_t table_id, + uint32_t barrier_type) { + return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); +} + +std::future BrpcPsClient::pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) { + auto *accessor = table_accessor(table_id); + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + uint32_t shard_nums; + if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) { + ret = -1; + } + auto &res_io_buffer = closure->cntl(0)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t)); + keys->resize(shard_nums); + values->resize(shard_nums * accessor->update_dim()); + io_buffer_itr.copy_and_forward((void *)(keys->data()), + sizeof(uint64_t) * shard_nums); + io_buffer_itr.copy_and_forward((void *)(values->data()), + shard_nums * accessor->update_size()); + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM); + closure->request(0)->set_table_id(table_id); + closure->request(0)->set_client_id(_client_id); + PsService_Stub rpc_stub(get_cmd_channel(pserver_idx)); + closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +std::future BrpcPsClient::push_sparse_param( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) { + auto *accessor = table_accessor(table_id); + // 发送RPC请求 + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + size_t request_call_num = _server_channels.size(); + std::vector> ids; + std::vector> value_ptrs; + ids.resize(request_call_num); + value_ptrs.resize(request_call_num); + for (size_t i = 0; i < num; ++i) { + size_t pserver_idx = keys[i] % request_call_num; + ids[pserver_idx].push_back(keys[i]); + value_ptrs[pserver_idx].push_back(update_values[i]); + } + for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + auto kvs = ids[shard_idx]; + auto value_ptr = value_ptrs[shard_idx]; + size_t kv_size = kvs.size(); + uint32_t value_size = accessor->update_size(); + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); + push_data_ptr += kv_size * sizeof(uint64_t); + for (int i = 0; i < kv_size; ++i) { + memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + } + return fut; +} + +std::future BrpcPsClient::pull_dense(Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + // callback 将各shard结果,顺序填入region + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, num_per_shard, regions, region_num, + accessor](void *done) { + int ret = 0; + size_t region_idx = 0; // 当前填充的region偏移 + size_t region_data_idx = 0; // 当前填充的region内data偏移 + auto *closure = (DownpourBrpcClosure *)done; + size_t shard_data_size = num_per_shard * accessor->select_size(); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { + ret = -1; + break; + } + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t shard_buffer_remain = res_io_buffer.size(); + if (shard_buffer_remain != shard_data_size) { + LOG(ERROR) << "expect res_size:" << shard_data_size + << ", but size:" << shard_buffer_remain + << ", ignore this response"; + ret = -1; + break; + } + while (shard_buffer_remain > 0 && region_idx < region_num) { + auto ®ion = regions[region_idx]; + if (region.size - region_data_idx >= shard_buffer_remain) { + // region待填充空间 >= 分片buffer数据, 直接拷贝置入 + io_buffer_itr.copy_and_forward( + (void *)(region.data + region_data_idx), shard_buffer_remain); + region_data_idx += shard_buffer_remain; + shard_buffer_remain = 0; + } else if (region.size - region_data_idx == 0) { + // region填满,切换到下一个region + ++region_idx; + region_data_idx = 0; + } else { + // region不足以容纳所有数据,则能放多少 拷贝多少 + io_buffer_itr.copy_and_forward( + (void *)(region.data + region_data_idx), + region.size - region_data_idx); + shard_buffer_remain -= (region.size - region_data_idx); + ++region_idx; + region_data_idx = 0; + } + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&num_per_shard, + sizeof(num_per_shard)); + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::push_dense_param(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 + std::vector> regions_partition(request_call_num); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + size_t shard_data_size = num_per_shard * accessor->update_size(); + size_t current_region_idx = 0; + size_t current_region_data_idx = 0; + for (size_t i = 0; i < request_call_num; ++i) { + size_t shard_data_remain_size = shard_data_size; + while (shard_data_remain_size > 0 && current_region_idx < region_num) { + const auto ®ion = regions[current_region_idx]; + size_t region_remain_size = region.size - current_region_data_idx; + if (shard_data_remain_size >= region_remain_size) { + regions_partition[i].push_back( + Region(region.data + current_region_data_idx, region_remain_size)); + ++current_region_idx; + current_region_data_idx = 0; + shard_data_remain_size -= region_remain_size; + } else { + regions_partition[i].push_back(Region( + region.data + current_region_data_idx, shard_data_remain_size)); + current_region_data_idx += shard_data_remain_size; + shard_data_remain_size = 0; + } + } + } + + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10; + static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐 + //开始多shard并行拷贝&请求 + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + auto &request_buffer = closure->cntl(i)->request_attachment(); + request_buffer.append((void *)&num_per_shard, sizeof(uint32_t)); + auto ®ion_list = regions_partition[i]; + size_t fill_remain_size = shard_data_size; + for (auto ®ion : region_list) { + fill_remain_size -= region.size; + request_buffer.append((void *)region.data, region.size); + } + //保证各分片数据对齐 + while (fill_remain_size > 0) { + size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE + ? REGION_ASSIGN_BUFFER_SIZE + : fill_remain_size; + request_buffer.append((void *)region_assign_buffer, fill_num); + fill_remain_size -= fill_num; + } + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) { + auto *accessor = table_accessor(table_id); + //发送RPC请求 + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + size_t request_call_num = _server_channels.size(); + std::vector> ids; + std::vector> value_ptrs; + ids.resize(request_call_num); + value_ptrs.resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t pserver_idx = keys[i] % request_call_num; + ids[pserver_idx].push_back(keys[i]); + value_ptrs[pserver_idx].push_back(update_values[i]); + } + + for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + auto kvs = ids[shard_idx]; + auto value_ptr = value_ptrs[shard_idx]; + + size_t kv_size = kvs.size(); + uint32_t value_size = accessor->update_size(); + + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); + push_data_ptr += kv_size * sizeof(uint64_t); + + for (int i = 0; i < kv_size; ++i) { + memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + } + return fut; +} + +std::future BrpcPsClient::push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + auto *accessor = table_accessor(table_id); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + auto *push_data = closure->request(i)->mutable_data(); + push_data->clear(); + push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); + memcpy(push_data_ptr + sizeof(uint32_t), + total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); + VLOG(1) << "push_dense_raw_gradient finish memcpy"; + // closure->cntl(i)->set_request_compress_type( + // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + PsService_Stub rpc_stub(get_dense_channel(i)); + VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i; + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + VLOG(1) << "push_dense_raw_gradient async service " << i; + } + return fut; +} + +std::future BrpcPsClient::pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) { + size_t request_call_num = _server_channels.size(); + + auto shard_sorted_kvs = std::make_shared< + std::vector>>>(); + shard_sorted_kvs->resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t shard_id = keys[i] % request_call_num; + shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); + } + + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->select_size(); + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [shard_sorted_kvs, value_size](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < ids.size(); ++i) { + if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + + auto &request_kvs = shard_sorted_kvs->at(i); + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + uint64_t last_key = UINT64_MAX; + float *last_value_data = NULL; + + for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { + auto *kv_pair = &(request_kvs[kv_idx]); + if (kv_pair->first == last_key) { + memcpy((void *)kv_pair->second, (void *)last_value_data, + value_size); + } else { + last_key = kv_pair->first; + last_value_data = kv_pair->second; + if (value_size != + io_buffer_itr.copy_and_forward((void *)(last_value_data), + value_size)) { + LOG(WARNING) << "res data is lack or not in format"; + ret = -1; + break; + } + } + } + } + closure->set_promise_value(ret); + }); + + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (size_t i = 0; i < request_call_num; ++i) { + auto &sorted_kvs = shard_sorted_kvs->at(i); + std::sort(sorted_kvs.begin(), sorted_kvs.end(), + [](const std::pair &k1, + const std::pair &k2) { + return k1.first < k2.first; + }); + + uint64_t last_key = UINT64_MAX; + uint32_t kv_request_count = 0; + size_t sorted_kv_size = sorted_kvs.size(); + auto &request_buffer = closure->cntl(i)->request_attachment(); + for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { + ++kv_request_count; + last_key = sorted_kvs[kv_idx].first; + request_buffer.append((void *)&last_key, sizeof(uint64_t)); + while (kv_idx < sorted_kv_size - 1 && + last_key == sorted_kvs[kv_idx + 1].first) { + ++kv_idx; + } + } + + if (kv_request_count == 0) { + closure->Run(); + } else { + closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&kv_request_count, + sizeof(uint32_t)); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + } + return fut; +} + +std::future BrpcPsClient::send_client2client_msg( + int msg_type, int to_client_id, const std::string &msg) { + auto promise = std::make_shared>(); + std::future fut = promise->get_future(); + if (to_client_id >= _client_channels.size()) { + LOG(FATAL) << "to_client_id is out of range clients, which size is " + << _client_channels.size(); + promise->set_value(-1); + return fut; + } + auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) { + auto *closure = (DownpourBrpcClosure *)done; + int32_t ret = closure->check_response(0, msg_type + 1000); + closure->set_promise_value(ret); + }); + closure->add_promise(promise); + closure->request(0)->set_cmd_id(msg_type); + closure->request(0)->set_client_id(_client_id); + closure->request(0)->set_data(msg); + PsService_Stub rpc_stub(_client_channels[to_client_id].get()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +std::future BrpcPsClient::push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) { + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->update_size(); + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + // 发送RPC请求 + auto *push_request = closure->request(0); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&num, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(num * (sizeof(uint64_t) + value_size)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, keys, num * sizeof(uint64_t)); + push_data_ptr += num * sizeof(uint64_t); + for (int i = 0; i < num; ++i) { + memcpy(push_data_ptr, update_values[i], value_size); + push_data_ptr += value_size; + } + PsService_Stub rpc_stub(get_sparse_channel(pserver_idx)); + closure->cntl(0)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h new file mode 100644 index 0000000000000000000000000000000000000000..c07165151507951a6c9023906ac3f3666e1209b3 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -0,0 +1,212 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/ps_client.h" + +namespace paddle { +namespace distributed { + +class DownpourPsClientService : public PsService { + public: + DownpourPsClientService() {} + virtual ~DownpourPsClientService() {} + + virtual int32_t configure(PSClient *client, size_t rank_id) { + _client = client; + _rank = rank_id; + return 0; + } + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override; + + protected: + size_t _rank; + PSClient *_client; +}; + +class DownpourBrpcClosure : public PSClientClosure { + public: + DownpourBrpcClosure(size_t num, PSClientCallBack callback) + : PSClientClosure(callback) { + _waiting_num = num; + + _cntls.resize(num); + _requests.resize(num); + _responses.resize(num); + for (size_t i = 0; i < num; ++i) { + _cntls[i].reset(new brpc::Controller()); + } + } + virtual ~DownpourBrpcClosure() {} + virtual void Run() override { + if (_waiting_num.fetch_sub(1) == 1) { + _callback(this); + delete this; + } + } + PsRequestMessage *request(size_t i) { return &_requests[i]; } + PsResponseMessage *response(size_t i) { return &_responses[i]; } + brpc::Controller *cntl(size_t i) { return _cntls[i].get(); } + int check_response(size_t request_idx, int cmd_id); + int check_save_response(size_t request_idx, int cmd_id); + std::string get_response(size_t request_idx, int cmd_id); + + private: + std::atomic _waiting_num; + std::vector _requests; + std::vector _responses; + std::vector> _cntls; +}; + +template +struct array_deleter { + void operator()(T *&x) const { delete[] x; } +}; + +class BrpcPsClient : public PSClient { + public: + BrpcPsClient() {} + virtual ~BrpcPsClient() { + // _running = false; + // try { + // _async_push_dense_thread.join(); + // _async_push_sparse_thread.join(); + //} catch (...) { + //} + } + virtual int32_t create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); + virtual std::future shrink(uint32_t table_id) override; + virtual std::future load(const std::string &epoch, + const std::string &mode) override; + virtual std::future load(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; + + virtual std::future save(const std::string &epoch, + const std::string &mode) override; + + virtual std::future save(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; + + virtual std::future clear() override; + + virtual std::future clear(uint32_t table_id) override; + + virtual std::future stop_server() override; + + virtual std::future start_profiler() override; + virtual std::future stop_profiler() override; + + virtual void finalize_worker() override; + + virtual std::future pull_dense(Region *regions, size_t region_num, + size_t table_id); + + virtual std::future push_dense_param(const Region *regions, + size_t region_num, + size_t table_id); + + virtual std::future pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num); + + virtual std::future print_table_stat(uint32_t table_id); + + virtual std::future barrier(size_t table_id, uint32_t barrier_type); + + virtual std::future pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx); + + virtual std::future flush(); + + virtual std::future send_client2client_msg( + int msg_type, int to_client_id, const std::string &msg) override; + + private: + virtual int32_t initialize() override; + + inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, + uint32_t shard_num) { + return dense_dim_total / shard_num + 1; + } + + std::future send_cmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); + + std::future send_save_cmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); + + inline brpc::Channel *get_sparse_channel(size_t server_id) { + return _server_channels[server_id][0].get(); + } + inline brpc::Channel *get_dense_channel(size_t server_id) { + return _server_channels[server_id][1].get(); + } + inline brpc::Channel *get_cmd_channel(size_t server_id) { + return _server_channels[server_id][2].get(); + } + + bool _running = false; + bool _flushing = false; + std::atomic _async_call_num; //异步请求计数 + + std::vector> + _client_channels; // client2client + std::vector, 3>> + _server_channels; // client2server + virtual std::future push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) override; + + virtual std::future push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) override; + + virtual std::future push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) override; + + virtual std::future push_sparse_param(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, + void *done) override; + + virtual size_t get_server_nums() { return _server_channels.size(); } + + private: + int32_t start_client_service(); + + float _mae = 0; + float _mse = 0; + uint16_t _push_times = 0; + brpc::Server _server; + DownpourPsClientService _service; + std::atomic_uint grad_num_{0}; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc new file mode 100644 index 0000000000000000000000000000000000000000..1386e83447567f9e3acfefe8b992a3dbaa045d39 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -0,0 +1,530 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include // NOLINT +#include "Eigen/Dense" +#include "butil/endpoint.h" +#include "iomanip" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace distributed { + +int32_t BrpcPsServer::initialize() { + auto &service_config = _config.downpour_server_param().service_param(); + if (!service_config.has_service_class()) { + LOG(ERROR) << "miss service_class in ServerServiceParameter"; + return -1; + } + auto *service = CREATE_CLASS(PsBaseService, service_config.service_class()); + if (service == NULL) { + LOG(ERROR) << "service is unregistered, service_name:" + << service_config.service_class(); + return -1; + } + + _service.reset(service); + if (service->configure(this) != 0 || service->initialize() != 0) { + LOG(ERROR) << "service initialize failed, service_name:" + << service_config.service_class(); + return -1; + } + if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + LOG(ERROR) << "service add to brpc failed, service:" + << service_config.service_class(); + return -1; + } + return 0; +} + +uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { + std::unique_lock lock(mutex_); + + std::string ip_port = ip + ":" + std::to_string(port); + VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; + int num_threads = std::thread::hardware_concurrency(); + brpc::ServerOptions options; + options.num_threads = num_threads; + + if (_server.Start(ip_port.c_str(), &options) != 0) { + LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port; + return 0; + } + VLOG(0) << "BrpcPsServer::start registe_ps_server"; + _environment->registe_ps_server(ip, port, _rank); + VLOG(0) << "BrpcPsServer::start wait"; + cv_.wait(lock, [&] { return stoped_; }); + + PSHost host; + host.ip = ip; + host.port = port; + host.rank = _rank; + VLOG(0) << "BrpcPsServer::start return host.rank"; + return host.rank; +} + +int32_t BrpcPsServer::port() { return _server.listen_address().port; } + +int32_t PsService::initialize() { + _is_initialize_shard_info = false; + _service_handler_map[PS_STOP_SERVER] = &PsService::stop_server; + _service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense; + _service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense; + _service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse; + _service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse; + _service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table; + _service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table; + _service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table; + _service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table; + _service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table; + _service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table; + _service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table; + _service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param; + _service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat; + _service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param; + _service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param; + _service_handler_map[PS_BARRIER] = &PsService::barrier; + _service_handler_map[PS_START_PROFILER] = &PsService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler; + + // shard初始化,server启动后才可从env获取到server_list的shard信息 + initialize_shard_info(); + + return 0; +} + +#define CHECK_TABLE_EXIST(table, request, response) \ + if (table == NULL) { \ + std::string err_msg("table not found with table_id:"); \ + err_msg.append(std::to_string(request.table_id())); \ + set_response_code(response, -1, err_msg.c_str()); \ + return -1; \ + } + +int32_t PsService::initialize_shard_info() { + if (!_is_initialize_shard_info) { + std::lock_guard guard(_initialize_shard_mutex); + if (_is_initialize_shard_info) { + return 0; + } + size_t shard_num = _server->environment()->get_ps_servers().size(); + auto &table_map = *(_server->table()); + for (auto itr : table_map) { + itr.second->set_shard(_rank, shard_num); + } + _is_initialize_shard_info = true; + } + return 0; +} + +void PsService::service(google::protobuf::RpcController *cntl_base, + const PsRequestMessage *request, + PsResponseMessage *response, + google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + std::string log_label("ReceiveCmd-"); + if (!request->has_table_id()) { + set_response_code(*response, -1, "PsRequestMessage.tabel_id is required"); + return; + } + + response->set_err_code(0); + response->set_err_msg(""); + auto *table = _server->table(request->table_id()); + brpc::Controller *cntl = static_cast(cntl_base); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + set_response_code(*response, -1, err_msg.c_str()); + return; + } + serviceHandlerFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(table, *request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } +} + +int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_dense"); + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 1) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 1 for num of dense"); + return 0; + } + uint32_t num = *(const uint32_t *)request.params(0).c_str(); + if (num < 0) { + set_response_code(response, -1, + "PsRequestMessage.datas[0] is invalid, num must >= 0"); + return 0; + } + + std::vector res_data; + res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); + table->pull_dense(res_data.data(), num); + + cntl->response_attachment().append((char *)res_data.data(), + res_data.size() * sizeof(float)); + + return 0; +} + +int32_t PsService::push_dense_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_dense_param"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_buffer; + auto &req_io_buffer = cntl->request_attachment(); + auto req_buffer_size = req_io_buffer.size(); + if (req_buffer_size < 1) { + set_response_code(response, -1, "req attachment is empty"); + return 0; + } + push_buffer.resize(0); + push_buffer.reserve(req_buffer_size); + const char *data = (const char *)cntl->request_attachment().fetch( + const_cast(push_buffer.data()), req_buffer_size); + + uint32_t num = *(const uint32_t *)data; + + const float *values = (const float *)(data + sizeof(uint32_t)); + if (table->push_dense_param(values, num) != 0) { + set_response_code(response, -1, "push_dense_param failed"); + } + return 0; +} + +int32_t PsService::push_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_dense"); + CHECK_TABLE_EXIST(table, request, response) + auto req_buffer_size = request.data().size(); + if (req_buffer_size < 1) { + // set_response_code(response, 0, "push dense data is empty"); + return 0; + } + + /* + Push Content: + |--num--|---valuesData---| + |--4B---|----------------| + */ + uint32_t num = *(const uint32_t *)(request.data().data()); + const float *values = + (const float *)(request.data().data() + sizeof(uint32_t)); + if (table->push_dense(values, num) != 0) { + set_response_code(response, -1, "push_dense failed"); + } + + return 0; +} + +int32_t PsService::barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + + auto trainer_id = request.client_id(); + auto barrier_type = request.params(0); + table->barrier(trainer_id, barrier_type); + return 0; +} + +int32_t PsService::push_sparse_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_sparse_param"); + CHECK_TABLE_EXIST(table, request, response) + auto &push_data = request.data(); + if (push_data.size() < 1) { + // set_response_code(response, 0, "push sparse data is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + /* + Push Content: + |---keysData---|---valuesData---| + |---8*{num}B---|----------------| + */ + const uint64_t *keys = (const uint64_t *)push_data.data(); + const float *values = + (const float *)(push_data.data() + sizeof(uint64_t) * num); + if (table->push_sparse_param(keys, values, num) != 0) { + set_response_code(response, -1, "push_sparse_param error"); + } + return 0; +} + +int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_geo_param"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_sparse_request_buffer; + + auto trainer_id = request.client_id(); + + std::vector values; + std::vector ids; + table->pull_geo_param(trainer_id, &values, &ids); + + uint32_t num = ids.size(); + cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); + cntl->response_attachment().append((char *)ids.data(), + ids.size() * sizeof(uint64_t)); + cntl->response_attachment().append((char *)values.data(), + values.size() * sizeof(float)); + return 0; +} + +int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_sparse"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_sparse_request_buffer; + auto &req_io_buffer = cntl->request_attachment(); + auto req_buffer_size = req_io_buffer.size(); + if (req_buffer_size < 1) { + set_response_code(response, -1, "req attachment is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + push_sparse_request_buffer.resize(0); + push_sparse_request_buffer.reserve(req_buffer_size); + const char *data = (const char *)cntl->request_attachment().fetch( + const_cast(push_sparse_request_buffer.data()), req_buffer_size); + /* + Attachment Content: + |---keysData---| + |---8*{num}B---| + */ + const uint64_t *keys = (const uint64_t *)data; + std::vector res_data; + res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); + table->pull_sparse(res_data.data(), keys, num); + cntl->response_attachment().append((char *)res_data.data(), + res_data.size() * sizeof(float)); + return 0; +} + +int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_sparse"); + CHECK_TABLE_EXIST(table, request, response) + auto &push_data = request.data(); + if (push_data.size() < 1) { + // set_response_code(response, 0, "push sparse data is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + /* + Push Content: + |---keysData---|---valuesData---| + |---8*{num}B---|----------------| + */ + const uint64_t *keys = (const uint64_t *)push_data.data(); + const float *values = + (const float *)(push_data.data() + sizeof(uint64_t) * num); + if (table->push_sparse(keys, values, num) != 0) { + set_response_code(response, -1, "push_sparse error"); + } + return 0; +} + +int32_t PsService::print_table_stat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + std::pair ret = table->print_table_stat(); + paddle::framework::BinaryArchive ar; + ar << ret.first << ret.second; + std::string table_info(ar.Buffer(), ar.Length()); + response.set_data(table_info); + + return 0; +} + +int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 2 for path & load_param"); + return -1; + } + if (table->load(request.params(0), request.params(1)) != 0) { + set_response_code(response, -1, "table load failed"); + return -1; + } + return 0; +} + +int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + for (auto &itr : table_map) { + if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + LOG(ERROR) << "load table[" << itr.first << "] failed"; + return -1; + } + } + return 0; +} + +int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 2, path&mode"); + return -1; + } + table->flush(); + + int32_t feasign_size = 0; + feasign_size = table->save(request.params(0), request.params(1)); + if (feasign_size < 0) { + set_response_code(response, -1, "table save failed"); + return -1; + } + return feasign_size; +} + +int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + int32_t all_feasign_size = 0; + int32_t feasign_size = 0; + + for (auto &itr : table_map) { + feasign_size = save_one_table(itr.second.get(), request, response, cntl); + if (feasign_size < 0) { + LOG(ERROR) << "save table[" << itr.first << "] failed"; + return -1; + } + } + return 0; +} + +int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + table->flush(); + if (table->shrink() != 0) { + set_response_code(response, -1, "table shrink failed"); + } + return 0; +} + +int32_t PsService::clear_one_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + table->flush(); + table->clear(); + return 0; +} + +int32_t PsService::clear_all_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + for (auto &itr : table_map) { + if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { + return -1; + } + } + return 0; +} + +int32_t PsService::stop_server(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto *p_server = _server; + std::thread t_stop([p_server]() { + p_server->stop(); + LOG(INFO) << "Server Stoped"; + }); + t_stop.detach(); + return 0; +} + +int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::DisableProfiler(platform::EventSortingKey::kDefault, + string::Sprintf("server_%s_profile", _rank)); + return 0; +} + +int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::EnableProfiler(platform::ProfilerState::kCPU); + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_server.h b/paddle/fluid/distributed/service/brpc_ps_server.h new file mode 100644 index 0000000000000000000000000000000000000000..0a053848e1eb3c915b6405fcff33b5710c776943 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_server.h @@ -0,0 +1,153 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" + +#include +#include +#include "paddle/fluid/distributed/service/server.h" + +namespace paddle { +namespace distributed { + +class BrpcPsServer : public PSServer { + public: + BrpcPsServer() {} + virtual ~BrpcPsServer() {} + virtual uint64_t start(const std::string &ip, uint32_t port); + virtual int32_t stop() { + std::unique_lock lock(mutex_); + stoped_ = true; + cv_.notify_all(); + + _server.Stop(1000); + _server.Join(); + return 0; + } + virtual int32_t port(); + + private: + virtual int32_t initialize(); + + mutable std::mutex mutex_; + std::condition_variable cv_; + bool stoped_ = false; + brpc::Server _server; + std::shared_ptr _service; + std::vector> _pserver_channels; +}; + +class PsService; + +typedef int32_t (PsService::*serviceHandlerFunc)( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl); + +class PsService : public PsBaseService { + public: + virtual int32_t initialize() override; + + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override; + + private: + int32_t initialize_shard_info(); + int32_t pull_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_dense_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_sparse_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + int32_t pull_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t pull_geo_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t save_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t save_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t shrink_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t clear_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t clear_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_server(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t start_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + int32_t print_table_stat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + bool _is_initialize_shard_info; + std::mutex _initialize_shard_mutex; + std::unordered_map _service_handler_map; + std::unordered_map _msg_handler_map; + std::vector _ori_values; +}; + +class DownpourPServerBrpcClosure : public PServerClosure { + public: + DownpourPServerBrpcClosure(size_t num, PServerCallBack callback) + : PServerClosure(callback) { + _waiting_num = num; + _cntls.resize(num); + _requests.resize(num); + _responses.resize(num); + for (size_t i = 0; i < num; ++i) { + _cntls[i].reset(new brpc::Controller()); + } + } + virtual ~DownpourPServerBrpcClosure() {} + + virtual void Run() override { + if (_waiting_num.fetch_sub(1) == 1) { + _callback(this); + delete this; + } + } + PsRequestMessage *request(size_t i) { return &_requests[i]; } + PsResponseMessage *response(size_t i) { return &_responses[i]; } + brpc::Controller *cntl(size_t i) { return _cntls[i].get(); } + int check_response(size_t request_idx, int cmd_id) { return 1; } + int check_save_response(size_t request_idx, int cmd_id) { return 1; } + + private: + std::atomic _waiting_num; + std::vector _requests; + std::vector _responses; + std::vector> _cntls; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..abd58bf028c2c19e50d18d8b33ff34e2b92e2d3f --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_utils.cc @@ -0,0 +1,314 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace distributed { + +framework::proto::VarType::Type VarMessageToVarType( + VariableMessage::Type type) { + switch (type) { + case VariableMessage::FP32: + return framework::proto::VarType::FP32; // NOLINT + case VariableMessage::FP64: + return framework::proto::VarType::FP64; // NOLINT + case VariableMessage::INT32: + return framework::proto::VarType::INT32; // NOLINT + case VariableMessage::INT64: + return framework::proto::VarType::INT64; // NOLINT + case VariableMessage::BOOL: + return framework::proto::VarType::BOOL; // NOLINT + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "VarMessageToVarType:Unsupported type %d", type)); + } +} + +void SerializeToMultiVarMsgAndIOBuf( + const std::string& message_name, + const std::vector& send_var_name_val, + const std::vector& recv_var_name_val, + const platform::DeviceContext& ctx, const framework::Scope* scope, + MultiVarMsg* request, butil::IOBuf* iobuf) { + // 1. message_name + request->set_message_name(message_name); + + // 2. var_names + for (auto& send_var_name : send_var_name_val) { + request->add_send_var_names(send_var_name); + } + for (auto& recv_var_name : recv_var_name_val) { + request->add_recv_var_names(recv_var_name); + } + + // 3. VarMessage + for (auto& send_var_name : send_var_name_val) { + auto* send_var_msg = request->add_var_messages(); + butil::IOBuf temp_iobuf; + send_var_msg->set_varname(send_var_name); + + framework::Variable* var = scope->FindVar(send_var_name); + + if (var->IsType()) { + SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf); + } else if (var->IsType()) { + SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf); + } + iobuf->append(temp_iobuf); + } +} + +void SerializeLodTensor(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf) { + auto* tensor = var->GetMutable(); + var_msg->set_type(::paddle::LOD_TENSOR); + const framework::LoD lod = tensor->lod(); + if (lod.size() > 0) { + var_msg->set_lod_level(lod.size()); + for (auto& each : lod) { + VarMsg::LodData* lod_inner = var_msg->add_lod(); + for (auto& d : each) { + lod_inner->add_lod_data(d); + } + } + } + var_msg->set_data_type(static_cast(tensor->type())); + for (auto& dim : framework::vectorize(tensor->dims())) { + var_msg->add_dims(dim); + } + // IO Buffer + if (platform::is_cpu_place(tensor->place())) { + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(tensor->data()), + data_len); + } else { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(temp_ptr), data_len); + delete[] temp_ptr; +#endif + } +} + +void SerializeSelectedRows(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf) { + framework::SelectedRows* slr = var->GetMutable(); + auto* tensor = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + + var_msg->set_type(::paddle::SELECTED_ROWS); + var_msg->set_slr_height(slr->height()); + + auto* var_data = var_msg->mutable_data(); + var_data->clear(); + var_data->resize(rows->size() * sizeof(int64_t)); + char* data_ptr = const_cast(var_data->data()); + + if (platform::is_cpu_place(tensor->place())) { + memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t)); + } else { +#ifdef PADDLE_WITH_CUDA + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), data_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + &(*rows)[0], rows->size() * sizeof(int64_t), stream); +#endif + } + var_msg->set_data_type(static_cast(tensor->type())); + for (auto& dim : framework::vectorize(tensor->dims())) { + var_msg->add_dims(dim); + } + + // IO Buffer + if (platform::is_cpu_place(tensor->place())) { + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(tensor->data()), + data_len); + } else { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(temp_ptr), data_len); + delete[] temp_ptr; +#endif + } +} + +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + framework::Scope* scope) { + butil::IOBufBytesIterator io_buffer_itr(*iobuf); + // size_t shard_buffer_remain = res_io_buffer.size(); + for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size(); + ++recv_var_index) { + const auto& msg = multi_msg.var_messages(recv_var_index); + auto* var = scope->Var(msg.varname()); + if (msg.type() == ::paddle::LOD_TENSOR) { + DeserializeLodTensor(var, msg, io_buffer_itr, ctx); + } else if (msg.type() == ::paddle::SELECTED_ROWS) { + DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); + } + } +} + +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + const framework::Scope* scope) { + butil::IOBufBytesIterator io_buffer_itr(*iobuf); + // size_t shard_buffer_remain = res_io_buffer.size(); + for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size(); + ++recv_var_index) { + const auto& msg = multi_msg.var_messages(recv_var_index); + auto* var = scope->FindVar(msg.varname()); + PADDLE_ENFORCE_NE(var, nullptr, + platform::errors::InvalidArgument( + "Not find variable %s in scope.", msg.varname())); + if (msg.type() == ::paddle::LOD_TENSOR) { + DeserializeLodTensor(var, msg, io_buffer_itr, ctx); + } else if (msg.type() == ::paddle::SELECTED_ROWS) { + DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); + } + } +} + +void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& io_buffer_itr, + const platform::DeviceContext& ctx) { + const auto place = ctx.GetPlace(); + framework::LoDTensor* tensor = var->GetMutable(); + std::vector vec_dim; + for (auto& x : msg.dims()) { + vec_dim.push_back(x); + } + tensor->Resize(framework::make_ddim(vec_dim)); + + framework::LoD lod; + for (int i = 0; i < msg.lod_level(); ++i) { + framework::Vector v; + for (int j = 0; j < msg.lod(i).lod_data_size(); ++j) { + v.push_back(msg.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + + void* tensor_data = + tensor->mutable_data(place, VarMessageToVarType(msg.data_type())); + + // IO Buffer + if (platform::is_cpu_place(place)) { + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(tensor_data, data_len); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + unsigned long data_len; + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, + platform::CPUPlace(), (void*)temp_ptr, + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + delete[] temp_ptr; +#endif + } +} + +void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& io_buffer_itr, + const platform::DeviceContext& ctx) { + const auto place = ctx.GetPlace(); + auto* slr = var->GetMutable(); + framework::Tensor* tensor = slr->mutable_value(); + slr->set_height(msg.slr_height()); + std::vector tmp_rows(msg.slr_height()); + memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t)); + slr->set_rows(tmp_rows); + std::vector vec_dim; + for (auto& x : msg.dims()) { + vec_dim.push_back(x); + } + tensor->Resize(framework::make_ddim(vec_dim)); + void* tensor_data = + tensor->mutable_data(place, VarMessageToVarType(msg.data_type())); + // IO Buffer + if (platform::is_cpu_place(place)) { + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(tensor_data, data_len); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(temp_ptr, data_len); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, + platform::CPUPlace(), temp_ptr, + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + delete[] temp_ptr; +#endif + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_utils.h b/paddle/fluid/distributed/service/brpc_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..aa340c58a7b8b0ed93d1dd67cd747689be9fe094 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_utils.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "brpc/channel.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/port.h" + +namespace grpc { +class ByteBuffer; +} // namespace grpc +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +void SerializeToMultiVarMsgAndIOBuf( + const std::string& message_name, + const std::vector& send_var_name_val, + const std::vector& recv_var_name_val, + const platform::DeviceContext& ctx, const framework::Scope* scope, + MultiVarMsg* var_msg, butil::IOBuf* iobuf); + +void SerializeLodTensor(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf); + +void SerializeSelectedRows(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* request, + butil::IOBuf* iobuf); + +// Deserialize for Server +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + framework::Scope* scope); + +// Deserialize for Client +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + const framework::Scope* scope); + +void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& iobuf, + const platform::DeviceContext& ctx); + +void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& iobuf, + const platform::DeviceContext& ctx); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc new file mode 100644 index 0000000000000000000000000000000000000000..18776a61a5cee7306ea85114dbeec81287579f34 --- /dev/null +++ b/paddle/fluid/distributed/service/communicator.cc @@ -0,0 +1,1171 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/service/communicator.h" +#include +#include "paddle/fluid/distributed/table/table.h" + +#include +#include + +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/string/split.h" + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::SelectedRows; + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} + +Communicator::Communicator() {} + +void Communicator::init_gflag(const std::string &gflags) { + VLOG(0) << "Init With Gflags:" << gflags; + std::vector flags = paddle::string::split_string(gflags); + if (flags.size() < 1) { + flags.push_back("-max_body_size=314217728"); + flags.push_back("-bthread_concurrency=40"); + flags.push_back("-socket_max_unwritten_bytes=2048000000"); + flags.push_back("-max_connection_pool_size=1950"); + } + auto it = flags.begin(); + flags.insert(it, "exe default"); + char *flags_ptr[flags.size()]; + for (size_t i = 0; i < flags.size(); ++i) { + flags_ptr[i] = (char *)(flags[i].c_str()); + } + int params_cnt = flags.size(); + char **params_ptr = &(flags_ptr[0]); + ::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); +} + +std::once_flag Communicator::init_flag_; +std::shared_ptr Communicator::communicator_(nullptr); + +void Communicator::InitBrpcClient( + const std::string &dist_desc, + const std::vector &host_sign_list) { + // not used, just for psclient's init + std::map> + _dense_pull_regions; + for (auto &iter : recv_varname_to_ctx_) { + auto tid = iter.first; + auto var_names = iter.second; + + auto ®ions = _dense_pull_regions[tid]; + regions.reserve(var_names.size()); + for (auto &t : var_names) { + Variable *var = recv_scope_->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + float *w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + + if (_worker_ptr.get() == nullptr) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + servers_ = host_sign_list.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list, servers_); + _worker_ptr = std::shared_ptr( + paddle::distributed::PSClientFactory::create(_ps_param)); + _worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env, + trainer_id_); + } + return; +} + +void Communicator::RpcRecvDense(const std::vector &varnames, + int table_id, Scope *scope) { + platform::RecordEvent record_event("Communicator->RpcRecvDense"); + std::vector regions; + regions.reserve(varnames.size()); + for (auto &t : varnames) { + Variable *var = scope->Var(t); + LoDTensor *tensor = var->GetMutable(); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + Variable *temp_var = xpu_temp_scope_->Var(t); + LoDTensor *temp_tensor = temp_var->GetMutable(); + temp_tensor->Resize(tensor->dims()); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + paddle::distributed::Region reg(temp_data, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } else { + float *w = tensor->mutable_data(tensor->place()); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + auto status = + _worker_ptr->pull_dense(regions.data(), regions.size(), table_id); + status.wait(); + + for (auto &t : varnames) { + Variable *var = scope->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " + << platform::is_gpu_place(tensor->place()); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + LoDTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); + framework::TensorCopy(*temp_tensor, tensor->place(), tensor); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } + } + + return; +} + +void Communicator::RpcSendDenseParam(const std::vector &varnames, + int table_id, const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendDenseParam"); + auto place = platform::CPUPlace(); + std::vector regions; + for (auto &t : varnames) { + Variable *var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor *tensor = var->GetMutable(); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + Variable *temp_var = xpu_temp_scope_->Var(t); + LoDTensor *temp_tensor = temp_var->GetMutable(); + temp_tensor->Resize(tensor->dims()); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor); + paddle::distributed::Region reg(temp_data, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t + << " table_id " << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } else { + float *w = tensor->mutable_data(place); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t + << " talbe_id " << table_id << " Temp_data[0] " << w[0] + << " Temp_data[-1] " << w[tensor->numel() - 1]; + } + } + auto status = + _worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); + status.wait(); + VLOG(4) << "RPC Send Dense Param " << table_id << " done!"; + return; +} + +void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendDense"); + auto &var_names = ctx.origin_varnames; + auto &table_id = ctx.table_id; + auto dense_data = std::make_shared>(); + size_t request_call_num = _worker_ptr->get_server_nums(); + uint32_t num_per_shard = + dense_dim_per_shard(ctx.height_sections[0], request_call_num); + dense_data->resize(num_per_shard * + request_call_num); // accessor->update_dim() = 1 + float *data = dense_data->data(); + uint32_t pos = 0; + for (size_t i = 0; i < var_names.size(); ++i) { + const LoDTensor tensor = scope.FindVar(var_names[i])->Get(); + size_t count = static_cast(tensor.numel()); + const float *g = tensor.data(); + CHECK(pos + count <= dense_data->size()) + << "invalid dense size, cur pos[" << pos << "]" + << " data_num[" << count << "] size[" << dense_data->size() << "]"; + memcpy(data + pos, g, count * sizeof(float)); + pos += count; + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_dense_raw_gradient( + table_id, data, dense_data->size(), closure); + status.wait(); + return; +} + +void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, + const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendSparseParam"); + size_t request_call_num = _worker_ptr->get_server_nums(); + std::vector push_g_vec; + + auto *send_var = scope.FindVar(varname); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->dims()[1]; + uint64_t sparse_num = static_cast(tensor->dims()[0]); + std::vector sparse_push_keys(sparse_num); + std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0); + push_g_vec.reserve(sparse_num); + + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * dim); + } + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto status = _worker_ptr->push_sparse_param( + table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); + status.wait(); + return; +} + +void Communicator::RpcSendSparse(const std::string &var_name, int table_id, + const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendSparse"); + size_t request_call_num = _worker_ptr->get_server_nums(); + std::vector sparse_push_keys; + std::vector push_g_vec; + + auto *send_var = scope.FindVar(var_name); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->value().dims()[1]; + std::transform(tensor->rows().begin(), tensor->rows().end(), + std::back_inserter(sparse_push_keys), + [&](int id) { return static_cast(id); }); + + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->mutable_value()->data() + i * dim); + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_sparse_raw_gradient( + table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); + status.wait(); + return; +} + +void Communicator::RpcRecvSparse(const std::string &varname, int table_id, + Scope *scope) { + platform::RecordEvent record_event("Communicator->RpcRecvSparse"); + auto *send_var = scope->Var(varname); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->dims()[1]; + uint64_t sparse_num = static_cast(tensor->dims()[0]); + + std::vector sparse_push_keys(sparse_num); + std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0); + + std::vector push_g_vec; + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * dim); + } + + auto status = _worker_ptr->pull_sparse((float **)push_g_vec.data(), table_id, + sparse_push_keys.data(), + sparse_push_keys.size()); + status.wait(); + return; +} + +void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { + if (trainer_id_ == 0) { + for (auto &iter : recv_varname_to_ctx) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcSendDenseParam(varnames, table_id, *recv_scope_); + VLOG(1) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + BarrierWithTable(1); + } else { + BarrierWithTable(1); + for (auto &iter : recv_varname_to_ctx) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcRecvDense(varnames, table_id, recv_scope_); + VLOG(1) << "pull dense param to table " << table_id + << " from 0' trainer done"; + } + } + BarrierWithTable(1); + return; +} + +void Communicator::RpcProfilerControl() { + if (trainer_id_ == 0) { + if (!do_server_profiler_ && platform::IsProfileEnabled()) { + // send profiler start flag + do_server_profiler_ = true; + auto start_status = _worker_ptr->start_profiler(); + start_status.wait(); + } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { + // send profiler end flag + auto stop_status = _worker_ptr->stop_profiler(); + stop_status.wait(); + do_server_profiler_ = false; + } + } +} + +void AsyncCommunicator::RecvThread() { + if (!independent_recv_) return; + VLOG(3) << "Independent RecvThread Start and Wait"; + + while (running_) { + int grad_num = grad_num_.load(); + if (grad_num > min_send_grad_num_before_recv_) { + RecvByCommunicator(); + grad_num_.store(0); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + VLOG(1) << "communicator stopped, independent recv thread exit"; +} + +void AsyncCommunicator::RecvByCommunicator() { + if (!running_) return; + RecvNoBarrier(); + VLOG(3) << "run recv graph end"; +} + +void AsyncCommunicator::RecvNoBarrier() { + for (auto &iter : recv_varname_to_ctx_) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcRecvDense(varnames, table_id, recv_scope_); + } + + for (auto &iter : recv_varname_to_ctx_) { + auto var_names = iter.second; + for (auto &t : var_names) { + Variable *var = recv_scope_->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " + << platform::is_gpu_place(tensor->place()); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + LoDTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); + framework::TensorCopy(*temp_tensor, tensor->place(), tensor); +#endif + } + } + } + + return; +} + +void AsyncCommunicator::SendByCommunicator() { + std::vector> tasks; + tasks.reserve(send_varname_to_ctx_.size()); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + + auto send_recv_task = [this, &ctx] { + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + size_t var_nums = varnames.size(); + auto &check_queue = send_varname_to_queue_[varnames[0]]; + std::vector>> vars; + vars.resize(var_nums); + int merged_var_num = 0; + int wait_times = 0; + while (merged_var_num < max_merge_var_num_) { + if (check_queue->Size() == 0) { + VLOG(4) << "wait_times -> " << wait_times; + if (wait_times >= send_wait_times_) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } else { + wait_times = 0; + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + auto &var_queue = send_varname_to_queue_[var_name]; + vars[i].push_back(var_queue->Pop()); + } + merged_var_num++; + } + } + if (merged_var_num == 0) return; + + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + MergeVars(var_name, vars[i], send_scope_.get(), 1); + } + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + RpcSendSparse(varnames[0], table_id, *send_scope_); + } else { + RpcSendDense(ctx, *send_scope_); + if (!independent_recv_ && + recv_varname_to_ctx_.find(table_id) != recv_varname_to_ctx_.end()) { + auto recv_varnames = recv_varname_to_ctx_.at(table_id); + RpcRecvDense(recv_varnames, table_id, recv_scope_); + } + } + if (independent_recv_) { + grad_num_.fetch_add(1, std::memory_order_relaxed); + } + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task))); + } + for (auto &task : tasks) { + task.wait(); + } + return; +} + +void AsyncCommunicator::MainThread() { + VLOG(3) << "AsyncCommunicator MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + SendByCommunicator(); + RpcProfilerControl(); + } + VLOG(1) << "communicator stopped, send thread exit"; +} + +void HalfAsyncCommunicator::MainThread() { + VLOG(3) << "HalfAsyncCommunicator MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + SendByCommunicator(); + BarrierSend(); + RecvByCommunicator(); + BarrierRecv(); + BarrierWeakUp(); + } + VLOG(1) << "communicator stopped, send thread exit"; +} + +void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) { + send_varname_to_ctx_ = std::move(send_varname_to_ctx); + recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); + recv_scope_ = std::move(recv_scope); + send_scope_.reset(new Scope()); + xpu_temp_scope_.reset(new Scope()); + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto &varnames = ctx.origin_varnames; + for (auto &var_name : varnames) { + send_varname_to_queue_[var_name] = + std::make_shared>>( + send_queue_size_); + } + } + send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); +} + +AsyncCommunicator::~AsyncCommunicator() { + running_ = false; + if (main_thread_) main_thread_->join(); + if (recv_thread_) recv_thread_->join(); +} + +void AsyncCommunicator::Start() { + VLOG(1) << "Communicator start"; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + VLOG(1) << "start send thread and recv thread"; + waiting_ = true; + running_ = true; + // flushing_ = false; + BarrierTriggerReset(max_merge_var_num_); + // start send and recv thread + main_thread_.reset( + new std::thread(std::bind(&AsyncCommunicator::MainThread, this))); + if (independent_recv_) { + recv_thread_.reset( + new std::thread(std::bind(&AsyncCommunicator::RecvThread, this))); + } + } +} + +void AsyncCommunicator::Stop() { + VLOG(1) << "Communicator stop"; + running_ = false; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + if (recv_thread_) { + VLOG(1) << "stop recv thread"; + recv_thread_->join(); + recv_thread_.reset(nullptr); + } + if (main_thread_) { + VLOG(1) << "stop main thread"; + main_thread_->join(); + main_thread_.reset(nullptr); + } + } + VLOG(1) << "Communicator stop done"; +} + +bool AsyncCommunicator::Check(const std::vector &var_tables) { + PADDLE_ENFORCE_EQ( + var_tables.size(), 1, + platform::errors::InvalidArgument("var_tables.size() == 1 is permitted")); + + auto table_name = var_tables[0]; + if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) + return false; + return true; +} + +bool AsyncCommunicator::Check(const int table_id) { + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (ctx.table_id == table_id) return true; + } + return false; +} + +void AsyncCommunicator::Send(const std::vector &var_names, + const framework::Scope &scope) { + waiting_ = false; + for (size_t i = 0; i < var_names.size(); i++) { + auto *var = scope.FindVar(var_names[i]); + auto tmp_grad_var = std::make_shared(); + framework::CopyVariable(*var, tmp_grad_var.get()); + send_varname_to_queue_[var_names[i]]->Push(tmp_grad_var); + } +} + +void HalfAsyncCommunicator::Clean() { + for (auto &iter : send_varname_to_queue_) { + auto &var_name = iter.first; + auto &var_queue = iter.second; + + while (var_queue->Size() > 0) { + var_queue->Pop(); + } + + VLOG(3) << "clean var: " << var_name << " done"; + } +} + +void HalfAsyncCommunicator::BarrierTriggerDecrement() { + barrier_trigger_--; + VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to " + << barrier_trigger_.load(); +} + +void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) { + barrier_trigger_.store(initial_val); + + VLOG(3) << "BarrierTriggerReset reset barrier trigger to " + << barrier_trigger_.load(); +} + +void HalfAsyncCommunicator::Barrier() { + barrier_counter_++; + + if (!running_) { + VLOG(3) << "Communicator is not running, release barrier"; + return; + } + + { + std::unique_lock lk(barrier_mutex_); + barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); }); + } +} + +int HalfAsyncCommunicator::BatchesCounter() { + while (running_) { + if (barrier_counter_.load() >= barrier_trigger_.load() && + barrier_trigger_.load() != 0) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + return barrier_counter_.load(); +} + +void HalfAsyncCommunicator::SendByCommunicator() { + int batches = BatchesCounter(); + VLOG(1) << "HalfAsyncCommunicator::BatchesCounter = " << batches; + if (batches <= 0) return; + + std::vector> tasks; + tasks.reserve(send_varname_to_ctx_.size()); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto send_recv_task = [this, &ctx, batches] { + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + size_t var_nums = varnames.size(); + + std::vector>> vars; + vars.resize(var_nums); + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + auto &var_queue = send_varname_to_queue_[var_name]; + for (int j = 0; j < batches; j++) vars[i].push_back(var_queue->Pop()); + MergeVars(var_name, vars[i], send_scope_.get(), 1); + } + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + RpcSendSparse(varnames[0], table_id, *send_scope_); + } else { + RpcSendDense(ctx, *send_scope_); + } + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task))); + } + for (auto &task : tasks) { + task.wait(); + } + return; +} + +void HalfAsyncCommunicator::BarrierWeakUp() { + barrier_counter_.store(0); + barrier_cond_.notify_all(); +} + +void SyncCommunicator::BarrierSend() { + if (!running_) return; + BarrierWithTable(0); + VLOG(4) << "BarrierSend with SyncCommunicator"; +} + +void SyncCommunicator::BarrierRecv() { + if (!running_) return; + BarrierWithTable(1); + + VLOG(4) << "BarrierRecv with SyncCommunicator"; +} + +void GeoCommunicator::Send(const std::vector &var_names, + const framework::Scope &scope) { + waiting_ = false; + auto before_send = GetCurrentUS(); + auto table_name = var_names[0]; + + size_t splited_var_nums = + send_varname_to_ctx_[table_name].splited_varnames.size(); + + std::unordered_map> ids_table; + + for (size_t j = 0; j < splited_var_nums; j++) { + ids_table.insert(std::pair>( + send_varname_to_ctx_[table_name].splited_varnames[j], + std::unordered_set())); + } + + auto *var = scope.FindVar(table_name); + + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "Only need to send Sparse Grad in Geo mode.")); + auto &rows = var->Get().rows(); + + // insert ids which has not been record + for (size_t j = 0; j < rows.size(); j++) { + auto ep_idx = rows[j] % splited_var_nums; + ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx]) + .insert(rows[j]); + } + + for (auto &iter : ids_table) { + auto &key = iter.first; + auto &sparse_ids_set = iter.second; + auto sparse_ids_vec = std::make_shared>(); + sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end()); + sparse_id_queues_.at(key)->Push(sparse_ids_vec); + VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key + << "'s queue"; + } + + auto after_send = GetCurrentUS(); + VLOG(2) << "run send op finish. use time " << (after_send - before_send); +} + +void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) { + send_varname_to_ctx_ = std::move(send_varname_to_ctx); + recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); + recv_scope_ = std::move(recv_scope); + + PADDLE_ENFORCE_GT( + send_varname_to_ctx.size(), 0, + platform::errors::InvalidArgument("send var contexts can not be zero")); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (!ctx.is_sparse) continue; + auto &varnames = ctx.origin_varnames; + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + for (auto &splited_var : ctx.splited_varnames) { + parallel_task_nums_ += 1; + sparse_id_queues_.insert( + std::pair>>>>( + splited_var, + std::make_shared< + BlockingQueue>>>( + send_queue_size_))); + } + } + + send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); + + delta_scope_.reset(new Scope()); + old_scope_.reset(new Scope()); + pserver_scope_.reset(new Scope()); +} + +void GeoCommunicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { + std::vector> tasks; + tasks.reserve(recv_varname_to_ctx_.size()); + + for (auto &iter : recv_varname_to_ctx_) { + auto &table_id = iter.first; + auto &varnames = iter.second; + + auto recv_task = [this, &table_id, &varnames] { + InitDense(varnames, table_id); + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task))); + } + + for (auto &task : tasks) { + task.wait(); + } + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (!ctx.is_sparse) return; + auto &varname = ctx.origin_varnames[0]; + auto &table_id = ctx.table_id; + auto param = varname.substr(0, varname.size() - 5); + InitSparse(param, table_id); + } + return; +} + +void GeoCommunicator::InitDense(std::vector &varnames, + int table_id) { + if (trainer_id_ == 0) { + RpcSendDenseParam(varnames, table_id, *recv_scope_); + BarrierWithTable(1); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } else { + BarrierWithTable(1); + RpcRecvDense(varnames, table_id, recv_scope_); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + + // copy to old_scope + for (auto &t : varnames) { + auto *global_var = recv_scope_->FindVar(t); + global_var->GetMutable(); + auto *old_var = old_scope_->Var(t); + old_var->GetMutable(); + framework::CopyVariable(*global_var, old_var); + } + VLOG(1) << "init dense table " << table_id << " done"; +} + +void GeoCommunicator::SendDense(const CommContext &send_ctx) { + platform::RecordEvent record_event("GeoCommunicator->SendDense"); + auto &var_names = send_ctx.origin_varnames; + auto &table_id = send_ctx.table_id; + for (auto &varname : var_names) { + auto param_name = GradToParam(varname); + auto *var_latest = recv_scope_->FindVar(param_name); + auto *var_timestamp = old_scope_->FindVar(param_name); + + PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + + auto &t_latest = var_latest->Get(); + auto t_timestamp = var_timestamp->GetMutable(); + + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + t_delta->mutable_data(t_latest.dims(), cpu_ctx.GetPlace()); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + blas.VSUB(t_latest.numel(), t_latest.data(), + t_timestamp->data(), t_delta->data()); + + float coefficient = 1.0 / static_cast(trainers_); + blas.SCAL(t_latest.numel(), coefficient, t_delta->data()); + + blas.VADD(t_latest.numel(), t_timestamp->data(), + t_delta->data(), t_timestamp->data()); + } + RpcSendDense(send_ctx, *delta_scope_); + VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::RecvDense(const CommContext &send_ctx) { + platform::RecordEvent record_event("GeoCommunicator->RecvDense"); + auto &table_id = send_ctx.table_id; + auto &varnames = recv_varname_to_ctx_.at(table_id); + // 1. recv from pserver + RpcRecvDense(varnames, table_id, pserver_scope_.get()); + + // 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + for (auto &varname : varnames) { + auto *var_latest = recv_scope_->FindVar(varname); + auto t_latest = var_latest->GetMutable(); + + auto *var_old = old_scope_->FindVar(varname); + auto t_old = var_old->GetMutable(); + + auto *var_pserver = pserver_scope_->FindVar(varname); + auto t_pserver = var_pserver->Get(); + + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + t_delta->mutable_data(t_latest->dims(), cpu_ctx.GetPlace()); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + blas.VSUB(t_latest->numel(), t_pserver.data(), t_old->data(), + t_delta->data()); + blas.VADD(t_latest->numel(), t_latest->data(), + t_delta->data(), t_latest->data()); + blas.VCOPY(t_latest->numel(), t_pserver.data(), + t_old->data()); + } + VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) { + VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " begin."; + if (trainer_id_ == 0) { + RpcSendSparseParam(var_name, table_id, *recv_scope_); + BarrierWithTable(1); + VLOG(0) << "push sparse param to table " << table_id + << " from 0' trainer done"; + } else { + BarrierWithTable(1); + RpcRecvSparse(var_name, table_id, recv_scope_); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + + VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " done."; + auto *global_var = recv_scope_->FindVar(var_name); + auto *var = old_scope_->Var(var_name); + framework::CopyVariable(*global_var, var); + return; +} + +std::vector GeoCommunicator::MergeSparseIds( + const std::string &send_varname) { + size_t merge_num = 0, wait_times = 0; + std::unordered_set sparse_ids; + while (merge_num < static_cast(max_merge_var_num_)) { + VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num; + if (sparse_id_queues_.at(send_varname)->Size() > 0) { + wait_times = 0; + std::shared_ptr> pop_ids = + sparse_id_queues_.at(send_varname)->Pop(); + for (size_t j = 0; j < pop_ids->size(); j++) { + sparse_ids.insert(pop_ids->at(j)); + } + merge_num += 1; + VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed"; + } else if (sparse_id_queues_.at(send_varname)->Size() == 0) { + VLOG(3) << "wait_times -> " << wait_times; + if (wait_times >= static_cast(send_wait_times_)) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } + } + std::vector res; + res.assign(sparse_ids.begin(), sparse_ids.end()); + return res; +} + +void GeoCommunicator::SendSparse(const std::string &varname, + std::vector &sparse_ids, int table_id, + int ep_idx) { + platform::RecordEvent record_event("GeoCommunicator->SendSparse"); + std::string param_name = SplitedGradToParam(varname); + VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name + << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id + << ", ep_idx: " << ep_idx; + + auto *var_latest = recv_scope_->FindVar(param_name); + auto *var_old = old_scope_->FindVar(param_name); + + PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + PADDLE_ENFORCE_EQ(var_old->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + + auto &t_latest = var_latest->Get(); + auto *t_old = var_old->GetMutable(); + + auto dims1 = t_latest.dims()[1]; + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + auto *var_t_value = t_delta->mutable_value(); + var_t_value->Resize({static_cast(sparse_ids.size()), dims1}); + auto *t_value = var_t_value->mutable_data(cpu_ctx.GetPlace()); + + t_delta->set_rows(sparse_ids); + t_delta->set_height(t_latest.dims()[0]); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + float coefficient = 1.0 / static_cast(trainers_); + + std::vector push_g_vec; + for (auto j = 0; j < static_cast(sparse_ids.size()); ++j) { + blas.VSUB(dims1, t_latest.data() + sparse_ids[j] * dims1, + t_old->data() + sparse_ids[j] * dims1, + t_value + j * dims1); + blas.SCAL(dims1, coefficient, t_value + j * dims1); + blas.VADD(dims1, t_old->data() + sparse_ids[j] * dims1, + t_value + j * dims1, + t_old->data() + sparse_ids[j] * dims1); + push_g_vec.push_back(t_value + j * dims1); + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [this](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + if (closure->check_response(0, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_sparse_raw_gradient_partial( + table_id, (const uint64_t *)sparse_ids.data(), + (const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); + status.wait(); + + VLOG(1) << "Finish Send Sparse " << varname + << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, + int ep_idx) { + platform::RecordEvent record_event("GeoCommunicator->RecvSparse"); + // 1. recv from pserver + std::vector keys; + std::vector values; + auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx); + status.wait(); + + std::string param = SplitedGradToParam(varname); + VLOG(1) << "RecvSparse receive var: " << varname << " " << param << ", " + << table_id << "; ids Size: " << keys.size() + << "; values size: " << values.size(); + + auto *var_latest = recv_scope_->FindVar(param); + auto *var_old = old_scope_->FindVar(param); + + auto *t_latest = var_latest->GetMutable(); + auto *t_old = var_old->GetMutable(); + + auto dims1 = t_latest->dims()[1]; + auto numel = keys.size() * dims1; + + std::vector v_delta; + v_delta.resize(numel); + + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + + for (auto j = 0; j < static_cast(keys.size()); ++j) { + float *latest_data = t_latest->data() + keys[j] * dims1; + float *old_data = t_old->data() + keys[j] * dims1; + // pserver - old => delta + blas.VSUB(dims1, values.data() + j * dims1, old_data, + v_delta.data() + j * dims1); + // latest + delta => latest + blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data); + // pserver => old + blas.VCOPY(dims1, values.data() + j * dims1, old_data); + } + VLOG(1) << "Finish Recv Sparse " << param << ", table_id: " << table_id; +} + +void GeoCommunicator::MainThread() { + VLOG(3) << "MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + std::vector> tasks; + tasks.reserve(parallel_task_nums_); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + int pserver_num = static_cast(ctx.epmap.size()); + for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { + // varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0 + auto send_recv_task = [this, table_id, ep_idx, &ctx] { + auto splited_varname = ctx.splited_varnames[ep_idx]; + auto sparse_ids = MergeSparseIds(splited_varname); + SendSparse(splited_varname, sparse_ids, table_id, ep_idx); + RecvSparse(splited_varname, table_id, ep_idx); + }; + tasks.emplace_back( + send_threadpool_->enqueue(std::move(send_recv_task))); + } + } else { + auto send_recv_task = [this, &ctx] { + SendDense(ctx); + RecvDense(ctx); + }; + tasks.emplace_back( + send_threadpool_->enqueue(std::move(send_recv_task))); + } + } + for (auto &task : tasks) { + task.wait(); + } + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/communicator.h b/paddle/fluid/distributed/service/communicator.h new file mode 100644 index 0000000000000000000000000000000000000000..a22b006013461c9ca4c10710339ac60550dabec9 --- /dev/null +++ b/paddle/fluid/distributed/service/communicator.h @@ -0,0 +1,561 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "paddle/fluid/distributed/communicator_common.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/ps_client.h" + +DECLARE_bool(communicator_is_sgd_optimizer); + +namespace paddle { +namespace distributed { + +using Scope = framework::Scope; +using Variable = framework::Variable; + +template +class BlockingQueue { + public: + explicit BlockingQueue(size_t capacity) : capacity_(capacity) { + PADDLE_ENFORCE_GT(capacity_, 0, + platform::errors::InvalidArgument( + "The capacity must be greater than 0.")); + } + + bool Push(const T &elem) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.push_back(elem); + } + cv_.notify_one(); + return true; + } + + bool Push(T &&elem) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.emplace_back(std::move(elem)); + } + cv_.notify_one(); + return true; + } + + T Pop() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [=] { return !queue_.empty(); }); + T rc(std::move(queue_.front())); + queue_.pop_front(); + cv_.notify_one(); + return rc; + } + + size_t Cap() const { + std::lock_guard lock(mutex_); + return capacity_; + } + + size_t Size() const { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + private: + const size_t capacity_; + std::deque queue_; + + mutable std::mutex mutex_; + std::condition_variable cv_; +}; + +template +using EigenVector = framework::EigenVector; + +template +inline void MergeVars(const std::string &var_name, + const std::vector> &vars, + Scope *scope, bool merge_add = true) { + PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument( + "vector vars are empty.")); + auto cpu_place = platform::CPUPlace(); + auto &var0 = vars[0]; + auto *out_var = scope->Var(var_name); + + if (var0->IsType()) { + auto dims = var0->Get().dims(); + VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims + << "; merge add: " << merge_add; + // init output tensor + auto *out_t = out_var->GetMutable(); + out_t->mutable_data(dims, cpu_place); + // check the input dims + for (auto &var : vars) { + auto &var_t = var->Get(); + PADDLE_ENFORCE_EQ( + var_t.dims(), dims, + platform::errors::InvalidArgument("vars should have the same dims.")); + } + + // set output tensor to 0. + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + paddle::operators::math::SetConstant + constant_functor; + constant_functor(cpu_ctx, out_t, static_cast(0)); + // sum all vars to out + auto result = EigenVector::Flatten(*out_t); + for (auto &var : vars) { + auto &in_t = var->Get(); + auto in = EigenVector::Flatten(in_t); + result.device(*cpu_ctx.eigen_device()) = result + in; + } + if (!merge_add) { + result.device(*cpu_ctx.eigen_device()) = + result / static_cast(vars.size()); + } + } else if (var0->IsType()) { + auto &slr0 = var0->Get(); + auto *out_slr = out_var->GetMutable(); + out_slr->mutable_rows()->clear(); + out_slr->mutable_value()->mutable_data({{}}, cpu_place); + std::vector inputs; + inputs.reserve(vars.size()); + for (auto &var : vars) { + inputs.push_back(&var->Get()); + } + auto dev_ctx = paddle::platform::CPUDeviceContext(); + if (merge_add) { + paddle::operators::math::scatter::MergeAdd< + paddle::platform::CPUDeviceContext, T> + merge_add; + merge_add(dev_ctx, inputs, out_slr); + } else { + paddle::operators::math::scatter::MergeAverage< + paddle::platform::CPUDeviceContext, T> + merge_average; + merge_average(dev_ctx, inputs, out_slr); + } + + VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() + << " dims: " << slr0.value().dims() << "; merge add: " << merge_add; + } else { + PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!", + var0->Type())); + } +} + +using RpcCtxMap = std::unordered_map; +using RecvCtxMap = std::unordered_map>; +using SparseValue = std::unordered_map>; + +class Communicator { + public: + Communicator(); + + explicit Communicator(const std::map &envs_) { + VLOG(0) << "Communicator Init Envs"; + for (auto &iter : envs_) { + envs[iter.first] = iter.second; + VLOG(0) << iter.first << ": " << iter.second; + } + barrier_table_id_ = std::stoi(envs.at("barrier_table_id")); + trainer_id_ = std::stoi(envs.at("trainer_id")); + trainers_ = std::stoi(envs.at("trainers")); + } + + virtual void InitBrpcClient(const std::string &dist_desc, + const std::vector &host_sign_list); + // 1. recv dense param + virtual void RpcRecvDense(const std::vector &varnames, + int table_id, Scope *scope); + // 2. send dense param + virtual void RpcSendDenseParam(const std::vector &varnames, + int table_id, const Scope &scope); + // 3. send dense grad + virtual void RpcSendDense(const CommContext &ctx, const Scope &scope); + // 4. send sparse grad + virtual void RpcSendSparse(const std::string &var_name, int table_id, + const Scope &scope); + // 5. send sparse param + virtual void RpcSendSparseParam(const std::string &varname, int table_id, + const Scope &scope); + // 6. recv sparse param + virtual void RpcRecvSparse(const std::string &varname, int table_id, + Scope *scope); + + virtual ~Communicator() {} + virtual void RpcProfilerControl(); + + virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx); + + virtual void Start() = 0; + + virtual void Stop() = 0; + + virtual bool IsRunning() { return running_; } + + virtual void Clean() {} + + virtual bool Check(const int table_id) = 0; + virtual bool Check(const std::vector &var_tables) = 0; + + virtual void Send(const std::vector &var_names, + const framework::Scope &scope) = 0; + + virtual void RecvNoBarrier() {} + + virtual void Barrier() {} + + virtual void BarrierWithTable(uint32_t barrier_type) { + auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); + rets.wait(); + } + + virtual void BarrierTriggerDecrement() {} + + virtual void BarrierTriggerReset(int init_counter) {} + + virtual void InitEnvs() = 0; + + virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) {} + + static Communicator *GetInstance() { return communicator_.get(); } + + static std::shared_ptr GetInstantcePtr() { + return communicator_; + } + + template + static Communicator *InitInstance( + const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx, + const std::string &dist_desc, + const std::vector &host_sign_list, Scope *recv_scope, + const std::map &envs) { + std::call_once(init_flag_, &Communicator::InitWithRpcCtx, send_ctx, + recv_ctx, dist_desc, host_sign_list, recv_scope, + std::ref(envs)); + return communicator_.get(); + } + + // Init is called by InitInstance. + template + static void InitWithRpcCtx(const RpcCtxMap &send_ctx, + const RecvCtxMap &recv_ctx, + const std::string &dist_desc, + const std::vector &host_sign_list, + Scope *recv_scope, + const std::map &envs) { + if (communicator_.get() == nullptr) { + communicator_.reset(new T(std::ref(envs))); + communicator_->InitEnvs(); + communicator_->InitBrpcClient(dist_desc, host_sign_list); + communicator_->InitImpl(send_ctx, recv_ctx, recv_scope); + } + } + + PSClient *GetPsClient() { return _worker_ptr.get(); } + + std::shared_ptr GetPsClientPtr() { + return _worker_ptr; + } + + std::shared_ptr _worker_ptr; // pointer to worker + + protected: + bool running_ = false; + bool waiting_ = true; + bool flushing_ = false; + bool do_server_profiler_ = false; + static std::shared_ptr communicator_; + static std::once_flag init_flag_; + + std::unordered_map envs; + + // 计算每个shard 对 dense的存储量 + inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, + uint32_t shard_num) { + return dense_dim_total / shard_num + 1; + } + + void init_gflag(const std::string &gflags); + paddle::distributed::PSParameter _ps_param; + paddle::distributed::PaddlePSEnvironment _ps_env; + int servers_ = 0; + int trainers_; + int trainer_id_ = 0; + int barrier_table_id_ = 0; + RpcCtxMap send_varname_to_ctx_; + RecvCtxMap recv_varname_to_ctx_; + + Scope *recv_scope_; // should be global scope + std::unique_ptr xpu_temp_scope_; + std::atomic _async_call_num{0}; +}; + +class AsyncCommunicator : public Communicator { + public: + AsyncCommunicator() : Communicator() {} + + explicit AsyncCommunicator(const std::map &envs) + : Communicator(envs) {} + + ~AsyncCommunicator(); + + void InitEnvs() { + independent_recv_ = static_cast( + std::stoi(envs.at("communicator_independent_recv_thread"))); + min_send_grad_num_before_recv_ = + std::stoi(envs.at("communicator_min_send_grad_num_before_recv")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + } + + void Start() override; + + void Stop() override; + + void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) override; + + virtual void MainThread(); + virtual void RecvThread(); + + virtual bool Check(const int table_id); + virtual bool Check(const std::vector &var_tables); + + void Send(const std::vector &var_names, + const framework::Scope &scope) override; + + virtual void SendByCommunicator(); + + virtual void SendGlobalStep(int batches) {} + + virtual void RecvByCommunicator(); + + virtual void RecvNoBarrier(); + + virtual int BatchesCounter() { return 1; } + + virtual void BarrierSend() {} + + virtual void BarrierRecv() {} + + virtual void BarrierWeakUp() {} + + protected: + std::unordered_map>>> + send_varname_to_queue_; + std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; + + int min_send_grad_num_before_recv_; + int thread_pool_size_; + int max_merge_var_num_; + int send_wait_times_; + int send_queue_size_; + bool need_global_step_ = false; + bool independent_recv_ = true; + int parallel_task_nums_ = 0; + + std::unique_ptr main_thread_{nullptr}; + std::unique_ptr recv_thread_{nullptr}; + + std::unique_ptr send_scope_; // an independent scope + std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv +}; + +class HalfAsyncCommunicator : public AsyncCommunicator { + public: + HalfAsyncCommunicator() {} + + explicit HalfAsyncCommunicator(const std::map &envs) + : AsyncCommunicator(envs) {} + + void InitEnvs() { + // enfore to recv after send + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + + VLOG(0) << "HalfAsyncCommunicator Initialized"; + } + + void MainThread() override; + + void SendByCommunicator() override; + + void Clean() override; + + void Barrier() override; + + void BarrierTriggerDecrement() override; + + void BarrierTriggerReset(int initial_val) override; + + int BatchesCounter(); + + void BarrierWeakUp(); + + protected: + // mutex for Wait for barrier + std::mutex barrier_mutex_; + std::condition_variable barrier_cond_; + std::atomic barrier_trigger_{0}; + std::atomic barrier_counter_{0}; +}; + +class SyncCommunicator : public HalfAsyncCommunicator { + public: + SyncCommunicator() : HalfAsyncCommunicator() {} + + explicit SyncCommunicator(const std::map &envs) + : HalfAsyncCommunicator(envs) {} + + void InitEnvs() { + // enfore to recv after send + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + + VLOG(0) << "SyncCommunicator Initialized"; + } + + void BarrierSend(); + + void BarrierRecv(); + + private: + std::vector pserver_endpoints_{}; +}; + +class GeoCommunicator : public AsyncCommunicator { + public: + GeoCommunicator() : AsyncCommunicator() {} + + explicit GeoCommunicator(const std::map &envs) + : AsyncCommunicator(envs) {} + + void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) override; + + void InitParams(const RecvCtxMap &recv_varname_to_ctx) override; + void InitDense(std::vector &varnames, int table_id); + void InitSparse(const std::string &var_name, int table_id); + + void SendDense(const CommContext &send_ctx); + void RecvDense(const CommContext &send_ctx); + + std::vector MergeSparseIds(const std::string &varname); + void SendSparse(const std::string &varname, std::vector &sparse_ids, + int table_id, int ep_idx); + void RecvSparse(const std::string &varname, int table_id, int ep_idx); + + void MainThread() override; + + void InitEnvs() { + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + // id_queue's size + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_queue_size_ = max_merge_var_num_; + VLOG(0) << "GeoCommunicator Initialized"; + } + + void Send(const std::vector &var_names, + const framework::Scope &scope) override; + + void SendByCommunicator() { return; } + + void SendGlobalStep(int batches) override { return; } + + void RecvByCommunicator() override { return; } + + inline std::string GradToParam(const std::string var_name) { + std::string param_name = var_name.substr(0, var_name.size() - 5); + return param_name; + } + + inline std::string SplitedGradToParam(const std::string delta_name) { + // delta_name: emb.delta0 + auto pos = delta_name.find(".block"); + std::string param_name = delta_name.substr(0, pos); + return param_name; + } + + private: + // parameter for delta calc and send + std::shared_ptr delta_scope_; + // parameter for storage the pserver param after last recv + std::shared_ptr old_scope_; + // parameter on pserver + std::shared_ptr pserver_scope_; + + std::unordered_map< + std::string, + std::shared_ptr>>>> + sparse_id_queues_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/env.cc b/paddle/fluid/distributed/service/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..25bc2cc366aaacba32c22a5225d344f8618767d9 --- /dev/null +++ b/paddle/fluid/distributed/service/env.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/env.h" + +namespace paddle { +namespace distributed {} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h new file mode 100644 index 0000000000000000000000000000000000000000..42f31717f7fba4203cdbd24d59cfa2d9973d5e8a --- /dev/null +++ b/paddle/fluid/distributed/service/env.h @@ -0,0 +1,284 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +struct PSHost { + std::string ip; + uint32_t port; + uint32_t rank; + + PSHost() = default; + PSHost(const std::string ip, uint32_t port, uint32_t rank) + : ip(ip), port(port), rank(rank) {} + + // |---ip---|---port---|--rank--| + // |-32bit--|--20bit---|--12bit-| + // for pslib + uint64_t serialize_to_uint64() { + uint64_t host_label = 0; + host_label = inet_addr(ip.c_str()); + host_label = host_label << 32; + host_label += (port << 12); + host_label += rank; + return host_label; + } + + void parse_from_uint64(uint64_t host_label) { + static uint64_t rank_label_mask = (1L << 12) - 1; + static uint64_t port_label_mask = (1L << 20) - 1; + rank = host_label & rank_label_mask; + port = (host_label >> 12) & port_label_mask; + uint32_t ip_addr = (host_label >> 32); + ip = inet_ntoa(*(in_addr *)&ip_addr); + } + + std::string to_string() { + std::stringstream s; + s << "host: " << ip; + s << " port: " << port; + s << " rank: " << rank; + s << " uint: " << serialize_to_uint64(); + return s.str(); + } + + // for open source parameter server + std::string serialize_to_string() { + std::stringstream s; + s << ip << ":"; + s << port << ":"; + s << rank; + return s.str(); + } + + void parse_from_string(std::string endpoint) { + std::vector endpoint_info; + string_split(endpoint, ':', &endpoint_info); + ip = endpoint_info[0]; + port = std::stoi(endpoint_info[1]); + rank = std::stoi(endpoint_info[2]); + } + + void string_split(const std::string &str, char sep, + std::vector *pieces, bool ignore_null = true) { + pieces->clear(); + if (str.empty()) { + if (!ignore_null) { + pieces->push_back(str); + } + return; + } + size_t pos = 0; + size_t next = str.find(sep, pos); + while (next != std::string::npos) { + pieces->push_back(str.substr(pos, next - pos)); + pos = next + 1; + next = str.find(sep, pos); + } + if (!str.substr(pos).empty()) { + pieces->push_back(str.substr(pos)); + } + } +}; + +class PSEnvironment { + public: + explicit PSEnvironment() {} + virtual ~PSEnvironment() {} + + virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + return 0; + } + virtual int32_t set_ps_servers( + const std::vector *host_endpoint_list, int node_num) { + return 0; + } + + virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + return 0; + } + + virtual int32_t set_ps_clients(std::string *host_endpoint_list, + int node_num) { + return 0; + } + virtual uint64_t get_local_host_sign() { return 0; } + virtual std::vector get_ps_servers() const { return _ps_server_list; } + virtual int32_t registe_ps_server(const std::string &ip, uint32_t port, + int32_t rank) { + return registe_ps_host(ip, port, rank, _ps_server_list, + _ps_server_sign_set); + } + + virtual std::vector get_ps_clients() const { return _ps_client_list; } + virtual int32_t registe_ps_client(const std::string &ip, uint32_t port, + int32_t rank) { + return registe_ps_host(ip, port, rank, _ps_client_list, + _ps_client_sign_set); + } + + virtual std::vector get_client_info() { + std::vector client_info; + for (auto &i : _ps_client_sign_set) { + client_info.push_back(i); + } + return client_info; + } + + virtual std::vector get_client_info(bool use_string_endpoint) { + if (use_string_endpoint) { + std::vector client_info; + for (auto &i : _ps_client_list) { + client_info.push_back(i.serialize_to_string()); + } + return client_info; + } + return {}; + } + + protected: + //注册一个host + virtual int32_t registe_ps_host(const std::string &ip, uint32_t port, + int32_t rank, std::vector &host_list, + std::unordered_set &sign_set) { + PSHost host; + host.ip = ip; + host.port = port; + host.rank = rank; + if (sign_set.count(rank) > 0) { + LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port + << ", rank:" << host.rank + << " already register, ignore register"; + } else { + host_list.push_back(host); + sign_set.insert(rank); + } + // if (sign_set.count(host.serialize_to_uint64()) > 0) { + // LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port + // << ", rank:" << host.rank + // << " already register, ignore register"; + // } else { + // host_list.push_back(host); + // sign_set.insert(host.serialize_to_uint64()); + // } + return 0; + } + + std::vector _ps_client_list; + std::unordered_set _ps_client_sign_set; // for unique filter + + std::vector _ps_server_list; + std::unordered_set _ps_server_sign_set; // for unique filter +}; + +class PaddlePSEnvironment : public PSEnvironment { + public: + explicit PaddlePSEnvironment() {} + virtual ~PaddlePSEnvironment() {} + + virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + _ps_server_list.clear(); + _ps_server_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list[i] > 0) { + PSHost host; + host.parse_from_uint64(host_sign_list[i]); + _ps_server_list.push_back(host); + _ps_server_sign_set.insert(host.serialize_to_uint64()); + } + } + std::sort( + _ps_server_list.begin(), _ps_server_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_servers(const std::vector *host_sign_list, + int node_num) { + _ps_server_list.clear(); + _ps_server_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list->at(i) != "") { + PSHost host; + host.parse_from_string(host_sign_list->at(i)); + _ps_server_list.push_back(host); + _ps_server_sign_set.insert(host.rank); + } + } + std::sort( + _ps_server_list.begin(), _ps_server_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + _ps_client_list.clear(); + _ps_client_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list[i] > 0) { + PSHost host; + host.parse_from_uint64(host_sign_list[i]); + _ps_client_list.push_back(host); + _ps_client_sign_set.insert(host.serialize_to_uint64()); + } + } + std::sort( + _ps_client_list.begin(), _ps_client_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_clients(std::vector *host_sign_list, + int node_num) { + _ps_client_list.clear(); + _ps_client_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list->at(i) != "") { + PSHost host; + host.parse_from_string(host_sign_list->at(i)); + _ps_client_list.push_back(host); + _ps_client_sign_set.insert(host.rank); + } + } + std::sort( + _ps_client_list.begin(), _ps_client_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual uint64_t get_local_host_sign() { + if (_ps_client_list.size() > 0) { + return _ps_client_list[0].serialize_to_uint64(); + } else { + return 0; + } + } +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/heter_client.cc b/paddle/fluid/distributed/service/heter_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4d1f27377f0e6f8a01b288fc48007ed66b2005c --- /dev/null +++ b/paddle/fluid/distributed/service/heter_client.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/heter_client.h" +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/platform/timer.h" + +DECLARE_int32(rpc_deadline); +namespace paddle { +namespace distributed { + +DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms"); + +std::shared_ptr HeterClient::s_instance_ = NULL; +bool HeterClient::is_initialized_ = false; + +void HeterClient::MainThread() { + while (running_) { + RpcProfilerControl(); + } +} + +void HeterClient::Stop() { + running_ = false; + if (!is_initialized_) { + VLOG(0) << "HeterClient is not inited, do nothing"; + } else { + if (main_thread_) { + auto status = StopHeterWorker(); + status.wait(); + main_thread_->join(); + main_thread_.reset(nullptr); + } + VLOG(1) << "HeterClient Stop Done"; + } +} + +void HeterClient::RpcProfilerControl() { + if (trainer_id_ == 0) { + if (!do_server_profiler_ && platform::IsProfileEnabled()) { + // send profiler start flag + do_server_profiler_ = true; + auto start_status = StartProfiler(); + start_status.wait(); + } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { + // send profiler end flag + auto stop_status = StopProfiler(); + stop_status.wait(); + do_server_profiler_ = false; + } + } +} + +void HeterClient::CreateClient2XpuConnection() { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.connection_type = "single"; + options.timeout_ms = pserver_timeout_ms; + + xpu_channels_.resize(xpu_list_.size()); + for (size_t i = 0; i < xpu_list_.size(); ++i) { + xpu_channels_[i].reset(new brpc::Channel()); + if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) { + VLOG(0) << "HeterServer channel init fail"; + } + } +} + +void HeterClient::SendAndRecvAsync( + const std::vector& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& message_name, + const std::vector& send_var_name, + const std::vector& recv_var_name) { + platform::RecordEvent record_event("HeterClient->SendAndRecvAsync"); + const platform::DeviceContext* p_ctx = &ctx; + const framework::Scope* p_scope = &scope; + const std::string message_name_val = message_name; + const std::vector send_var_name_val = send_var_name; + const std::vector recv_var_name_val = recv_var_name; + + VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: " + << message_name_val; + // Todo: get correct channel + int num = trainer_id_ % xpu_channels_.size(); + + brpc::Controller cntl; + cntl.set_timeout_ms(pserver_timeout_ms); + distributed::MultiVarMsg request, response; + auto& request_io_buffer = cntl.request_attachment(); + ::paddle::PsService_Stub stub(xpu_channels_[num].get()); + distributed::SerializeToMultiVarMsgAndIOBuf( + message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, + &request, &request_io_buffer); + stub.SendAndRecvVariable(&cntl, &request, &response, NULL); + PADDLE_ENFORCE_NE( + cntl.Failed(), true, + platform::errors::Unimplemented( + "HeterClient::SendAndRecv meets brpc error, error message is %s", + cntl.ErrorText())); + VLOG(4) << "call heter_worker success"; + auto& response_io_buffer = cntl.response_attachment(); + distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer, + ctx, p_scope); +} + +std::future HeterClient::SendCmd( + uint32_t table_id, int cmd_id, const std::vector& params) { + size_t request_call_num = xpu_channels_.size(); + paddle::distributed::DownpourBrpcClosure* closure = + new paddle::distributed::DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, cmd_id) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(trainer_id_); + for (const auto& param : params) { + closure->request(i)->add_params(param); + } + ::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get()); + closure->cntl(i)->set_timeout_ms( + pserver_timeout_ms); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future HeterClient::StartProfiler() { + return SendCmd(-1, PS_START_PROFILER, {}); +} + +std::future HeterClient::StopProfiler() { + return SendCmd(-1, PS_STOP_PROFILER, {}); +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_client.h b/paddle/fluid/distributed/service/heter_client.h new file mode 100644 index 0000000000000000000000000000000000000000..b1c268c3231f9224163dc1e54206a207dc460551 --- /dev/null +++ b/paddle/fluid/distributed/service/heter_client.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +typedef std::function HeterRpcCallbackFunc; + +class OnHeterRpcDone : public google::protobuf::Closure { + public: + OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {} + virtual ~OnHeterRpcDone() {} + void Run() { + std::unique_ptr self_guard(this); + handler_(this); + } + + HeterRpcCallbackFunc handler_; + MultiVariableMessage response; + brpc::Controller cntl; +}; + +class HeterClient { + public: + virtual ~HeterClient() {} + + HeterClient() { + running_ = true; + main_thread_.reset( + new std::thread(std::bind(&HeterClient::MainThread, this))); + } + + void CreateClient2XpuConnection(); + + void SendAndRecvAsync(const std::vector& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& message_name, + const std::vector& send_var_name, + const std::vector& recv_var_name); + + // HeterClient singleton + static std::shared_ptr GetInstance( + const std::vector& endpoint, const int& trainer_id) { + if (NULL == s_instance_) { + is_initialized_ = true; + s_instance_.reset(new paddle::distributed::HeterClient()); + std::vector xpu_list = {endpoint}; + s_instance_->SetXpuList(endpoint); + s_instance_->SetTrainerID(trainer_id); + s_instance_->CreateClient2XpuConnection(); + } + return s_instance_; + } + + void Stop(); + + void MainThread(); + + void RpcProfilerControl(); + + std::future SendCmd(uint32_t table_id, int cmd_id, + const std::vector& params); + + std::future StartProfiler(); + std::future StopProfiler(); + std::future StopHeterWorker(); + + std::vector& GetXpuList() { return xpu_list_; } + + void SetXpuList(const std::vector& xpu_list) { + xpu_list_ = xpu_list; + }; + + void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; } + + private: + static std::shared_ptr s_instance_; + + protected: + static bool is_initialized_; + std::unique_ptr main_thread_{nullptr}; + std::vector> xpu_channels_; + DISABLE_COPY_AND_ASSIGN(HeterClient); + std::vector xpu_list_; + + bool running_ = false; + int trainer_id_; + bool do_server_profiler_ = false; +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_server.cc b/paddle/fluid/distributed/service/heter_server.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9daf8be1ccb66d6a185d8828d1d5a53417249d2 --- /dev/null +++ b/paddle/fluid/distributed/service/heter_server.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/heter_server.h" +#include +#include +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace distributed { + +std::shared_ptr HeterServer::s_instance_ = NULL; + +void HeterServer::RegisterServiceHandler(std::string message_name, + HeterServiceHandler func) { + service_.RegisterServiceHandler(message_name, func); +} + +void HeterServer::StartHeterService() { + server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + if (server_.Start(endpoint_.c_str(), &options) != 0) { + VLOG(0) << "heter server start fail"; + } else { + VLOG(0) << "heter server start success! listen on " << endpoint_; + } + + { + std::lock_guard lock(this->mutex_ready_); + ready_ = 1; + } + condition_ready_.notify_all(); + + server_.Join(); +} + +void HeterServer::SetEndPoint(std::string& endpoint) { + endpoint_ = endpoint; + service_.SetEndpoint(endpoint); +} + +void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); } + +void HeterServer::WaitServerReady() { + std::unique_lock lock(this->mutex_ready_); + condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); +} + +int32_t HeterService::stop_profiler(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + platform::DisableProfiler( + platform::EventSortingKey::kDefault, + string::Sprintf("heter_worker_%s_profile", endpoint_)); + return 0; +} + +int32_t HeterService::start_profiler(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + platform::EnableProfiler(platform::ProfilerState::kAll); + return 0; +} + +int32_t HeterService::stop_heter_worker(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + auto client_id = request.client_id(); + stop_cpu_worker_set_.insert(client_id); + if (stop_cpu_worker_set_.size() == fan_in_) { + is_exit_ = true; + } + return 0; +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h new file mode 100644 index 0000000000000000000000000000000000000000..07fff7adc6e94a24fe52efbca21605c4c4e4a44c --- /dev/null +++ b/paddle/fluid/distributed/service/heter_server.h @@ -0,0 +1,243 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +class HeterService; +typedef int32_t (HeterService::*serviceHandlerFunc)( + const PsRequestMessage& request, PsResponseMessage& response, + brpc::Controller* cntl); + +typedef std::function HeterRpcCallbackFunc; +typedef std::function + HeterServiceHandler; + +class HeterService : public ::paddle::PsService { + public: + HeterService() { + _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker; + _service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler; + } + + virtual ~HeterService() {} + + virtual void service(::google::protobuf::RpcController* controller, + const ::paddle::PsRequestMessage* request, + ::paddle::PsResponseMessage* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + std::string log_label("ReceiveCmd-"); + + response->set_err_code(0); + response->set_err_msg(""); + brpc::Controller* cntl = static_cast(controller); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + return; + } + serviceHandlerFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(*request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } + }; + + void SendAndRecvVariable(::google::protobuf::RpcController* controller, + const MultiVarMsg* request, MultiVarMsg* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + std::string message_name = request->message_name(); + auto itr = handler_map_.find(message_name); + brpc::Controller* cntl = static_cast(controller); + PADDLE_ENFORCE_NE( + itr, handler_map_.end(), + platform::errors::InvalidArgument( + "HeterService::SendAndRecvVariable Get illegal message_name: %s " + "which is not in HeterService::handler_map_", + message_name)); + itr->second(request, response, cntl); + } + + void RegisterServiceHandler(std::string message_name, + HeterServiceHandler func) { + handler_map_[message_name] = func; + } + + void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; } + void SetFanin(const int& fan_in) { fan_in_ = fan_in; } + bool IsExit() { return is_exit_; } + + private: + int32_t stop_profiler(const PsRequestMessage& request, + PsResponseMessage& response, brpc::Controller* cntl); + + int32_t start_profiler(const PsRequestMessage& request, + PsResponseMessage& response, brpc::Controller* cntl); + + int32_t stop_heter_worker(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl); + + private: + std::string endpoint_; + std::unordered_map handler_map_; + std::unordered_map _service_handler_map; + std::unordered_set stop_cpu_worker_set_; + int fan_in_; + bool is_exit_ = false; +}; + +class HeterServer { + public: + virtual ~HeterServer() {} + + void Stop() { + server_.Stop(1000); + server_.Join(); + } + + bool IsExit() { return service_.IsExit(); } + + HeterServer() {} + + void RegisterServiceHandler(std::string message_name, + HeterServiceHandler func); + + void StartHeterService(); + + void SetEndPoint(std::string& endpoint); + void SetFanin(int& fan_in); + + // HeterWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new HeterServer()); + } + return s_instance_; + } + + void WaitServerReady(); + + private: + static std::shared_ptr s_instance_; + std::string endpoint_; + + protected: + brpc::Server server_; + HeterService service_; + DISABLE_COPY_AND_ASSIGN(HeterServer); + std::mutex mutex_ready_; + std::condition_variable condition_ready_; + int ready_; +}; + +class HeterRequestHandler { + public: + HeterRequestHandler() + : dev_ctx_(nullptr), + executor_(nullptr), + scope_(nullptr), + program_(nullptr) {} + + virtual ~HeterRequestHandler() {} + + void SetScope(framework::Scope* scope) { scope_ = scope; } + void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + void SetProgram(framework::ProgramDesc* program) { program_ = program; } + void SetExecutor(framework::Executor* executor) { executor_ = executor; } + + void SetGradToPreparedCtx( + std::unordered_map< + std::string, std::shared_ptr>* g) { + message_to_prepared_ctx_ = g; + } + + virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) = 0; + + protected: + const platform::DeviceContext* dev_ctx_; + framework::Executor* executor_; + framework::Scope* scope_; + framework::ProgramDesc* program_; + + std::unordered_map>* + message_to_prepared_ctx_; +}; + +class RequestSendAndRecvHandler final : public HeterRequestHandler { + public: + RequestSendAndRecvHandler() {} + virtual ~RequestSendAndRecvHandler() {} + int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) override { + platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle"); + auto& local_scope = scope_->NewScope(); + auto message_name = request->message_name(); + auto& request_io_buffer = cntl->request_attachment(); + distributed::DeserializeFromMultiVarMsgAndIOBuf( + *request, &request_io_buffer, *dev_ctx_, &local_scope); + executor_->RunPreparedContext( + (*message_to_prepared_ctx_)[message_name].get(), &local_scope, false); + + auto response_var_nums = request->recv_var_names_size(); + std::vector response_var_names(response_var_nums), + empty_var_names{}; + + for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) { + response_var_names[var_idx] = request->recv_var_names(var_idx); + } + auto& response_io_buffer = cntl->response_attachment(); + distributed::SerializeToMultiVarMsgAndIOBuf( + message_name, response_var_names, empty_var_names, *dev_ctx_, + &local_scope, response, &response_io_buffer); + scope_->DeleteScope(&local_scope); + return 0; + } +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/ps_client.cc b/paddle/fluid/distributed/service/ps_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..dd5fb9c24b32cebd36f19822d97bde56171dac6d --- /dev/null +++ b/paddle/fluid/distributed/service/ps_client.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/ps_client.h" + +#include + +#include "brpc/server.h" +#include "glog/logging.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { +REGISTER_CLASS(PSClient, BrpcPsClient); + +int32_t PSClient::configure( + const PSParameter &config, + const std::map> ®ions, + PSEnvironment &env, size_t client_id) { + _env = &env; + _config = config; + _dense_pull_regions = regions; + _client_id = client_id; + _config.mutable_worker_param() + ->mutable_downpour_worker_param() + ->mutable_downpour_table_param() + ->CopyFrom(_config.server_param() + .downpour_server_param() + .downpour_table_param()); + + const auto &work_param = _config.worker_param().downpour_worker_param(); + + for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) { + auto *accessor = CREATE_CLASS( + ValueAccessor, + work_param.downpour_table_param(i).accessor().accessor_class()); + accessor->configure(work_param.downpour_table_param(i).accessor()); + accessor->initialize(); + _table_accessors[work_param.downpour_table_param(i).table_id()].reset( + accessor); + } + return initialize(); +} + +PSClient *PSClientFactory::create(const PSParameter &ps_config) { + const auto &config = ps_config.server_param(); + if (!config.has_downpour_server_param()) { + LOG(ERROR) << "miss downpour_server_param in ServerParameter"; + return NULL; + } + + if (!config.downpour_server_param().has_service_param()) { + LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param"; + return NULL; + } + + if (!config.downpour_server_param().service_param().has_client_class()) { + LOG(ERROR) << "miss client_class in " + "ServerParameter.downpour_server_param.service_param"; + return NULL; + } + + const auto &service_param = config.downpour_server_param().service_param(); + PSClient *client = CREATE_CLASS(PSClient, service_param.client_class()); + if (client == NULL) { + LOG(ERROR) << "client is not registered, server_name:" + << service_param.client_class(); + return NULL; + } + + TableManager::instance().initialize(); + LOG(INFO) << "Create PSClient[" << service_param.client_class() + << "] success"; + return client; +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h new file mode 100644 index 0000000000000000000000000000000000000000..23b00b3c816088e26c8d05f090a8bc815038b0d9 --- /dev/null +++ b/paddle/fluid/distributed/service/ps_client.h @@ -0,0 +1,208 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/table/accessor.h" + +namespace paddle { +namespace distributed { + +typedef std::function PSClientCallBack; +class PSClientClosure : public google::protobuf::Closure { + public: + PSClientClosure(PSClientCallBack callback) : _callback(callback) {} + virtual ~PSClientClosure() {} + virtual void set_promise_value(int value) { + for (auto &promise : _promises) { + promise->set_value(value); + } + } + + void add_promise(std::shared_ptr> &promise) { + _promises.push_back(promise); + } + + protected: + PSClientCallBack _callback; + std::vector>> _promises; +}; + +class PSClient { + public: + PSClient() {} + virtual ~PSClient() {} + PSClient(PSClient &&) = delete; + PSClient(const PSClient &) = delete; + + virtual int32_t configure( + const PSParameter &config, + const std::map> + ®ions, + PSEnvironment &_env, size_t client_id) final; + + virtual int32_t create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, + int max_retry) = 0; + + // 触发table数据退场 + virtual std::future shrink(uint32_t table_id) = 0; + + // 全量table进行数据load + virtual std::future load(const std::string &epoch, + const std::string &mode) = 0; + // 指定table数据load + virtual std::future load(uint32_t table_id, const std::string &epoch, + const std::string &mode) = 0; + // 全量table数据save value_accessor根据mode,可能有不同的save条件 + virtual std::future save(const std::string &epoch, + const std::string &mode) = 0; + // 指定table数据save value_accessor根据mode,可能有不同的save条件 + virtual std::future save(uint32_t table_id, const std::string &epoch, + const std::string &mode) = 0; + + //清空table数据 + virtual std::future clear() = 0; + virtual std::future clear(uint32_t table_id) = 0; + + // pull dense的参数部分,并分块填充到本地网络参数中 + // start和num用于拉取部分参数 + // future结束前keys和values缓冲区不能再次使用 + // client将values按照区块拆包后送交多个sender + // sender聚集同一区块的请求,累计多个填充buffer + // server将参数区块中配置的某一维提取返回 + // 返回数据解包后填充到累计的多个buffer中 + virtual std::future pull_dense(Region *regions, size_t region_num, + size_t table_id) = 0; //保留 + + // firstly push dense param for parameter server + // this is neccessary because dense weight initialized in trainer on cold + // start + virtual std::future push_dense_param(const Region *regions, + size_t region_num, + size_t table_id) = 0; + + // 使用keys进行pull请求,结果填充values + // keys和values的个数均为num个,每个value占用select_size空间 + // future结束前keys和values缓冲区不能再次使用 + // 整合多个线程请求的keys,聚集并分散发送到server + // 返回结果后,遍历buffer并对values赋值 + virtual std::future pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) = 0; + + virtual std::future print_table_stat(uint32_t table_id) = 0; + + // 确保所有积攒中的请求都发起发送 + virtual std::future flush() = 0; + // server优雅退出 + virtual std::future stop_server() = 0; + + // server profilera + virtual std::future start_profiler() = 0; + virtual std::future stop_profiler() = 0; + + virtual std::future barrier(size_t table_id, + uint32_t barrier_type) = 0; + + virtual std::future pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) = 0; + + virtual void finalize_worker() = 0; + // client to client, 消息发送 + virtual std::future send_client2client_msg(int msg_type, + int to_client_id, + const std::string &msg) { + LOG(FATAL) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + // client2client消息处理,std::function ret (msg_type, from_client_id, msg) + typedef std::function MsgHandlerFunc; + virtual int registe_client2client_msg_handler(int msg_type, + MsgHandlerFunc handler) { + _msg_handler_map[msg_type] = handler; + return 0; + } + virtual int handle_client2client_msg(int msg_type, int from_client_id, + const std::string &msg) { + auto itr = _msg_handler_map.find(msg_type); + if (itr == _msg_handler_map.end()) { + LOG(WARNING) << "unknown client2client_msg type:" << msg_type; + return -1; + } + return itr->second(msg_type, from_client_id, msg); + } + + virtual ValueAccessor *table_accessor(size_t table_id) { + auto itr = _table_accessors.find(table_id); + if (itr == _table_accessors.end()) { + return NULL; + } + return itr->second.get(); + } + + virtual size_t get_server_nums() = 0; + + virtual std::future push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) = 0; + + virtual std::future push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) = 0; + + virtual std::future push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) = 0; + + virtual std::future push_sparse_param(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) = 0; + + protected: + virtual int32_t initialize() = 0; + size_t _client_id; + PSParameter _config; + std::map> + _dense_pull_regions; + PSEnvironment *_env; + std::unordered_map> _table_accessors; + std::unordered_map + _msg_handler_map; //处理client2client消息 +}; +REGISTER_REGISTERER(PSClient); + +class PSClientFactory { + public: + static PSClient *create(const PSParameter &config); +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto new file mode 100644 index 0000000000000000000000000000000000000000..8f5c8baa2f82427bf62f4429eb622986d761af9e --- /dev/null +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -0,0 +1,113 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; +package paddle; +option cc_generic_services = true; +option cc_enable_arenas = true; + +enum PsCmdID { + PS_PULL_DENSE_TABLE = 0; + PS_PUSH_DENSE_TABLE = 1; + PS_PULL_SPARSE_TABLE = 2; + PS_PUSH_SPARSE_TABLE = 3; + PS_SHRINK_TABLE = 4; + PS_SAVE_ONE_TABLE = 5; + PS_SAVE_ALL_TABLE = 6; + PS_LOAD_ONE_TABLE = 7; + PS_LOAD_ALL_TABLE = 8; + PS_CLEAR_ONE_TABLE = 9; + PS_CLEAR_ALL_TABLE = 10; + PS_PUSH_DENSE_PARAM = 11; + PS_STOP_SERVER = 12; + PS_SAVE_ONE_CACHE_TABLE = 13; + PS_GET_CACHE_THRESHOLD = 14; + PS_CACHE_SHUFFLE = 15; + PS_COPY_TABLE = 16; + PS_COPY_TABLE_BY_FEASIGN = 17; + PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18; + PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19; + PS_PRINT_TABLE_STAT = 20; + PS_SAVE_ONE_TABLE_PREFIX = 21; + PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22; + PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23; + PS_PULL_GEO_PARAM = 24; + PS_BARRIER = 25; + PS_PUSH_SPARSE_PARAM = 26; + PS_START_PROFILER = 27; + PS_STOP_PROFILER = 28; +} + +message PsRequestMessage { + required uint32 cmd_id = 1; + optional uint32 table_id = 2; + repeated bytes params = 3; + optional int32 client_id = 4; + optional bytes data = 5; +}; + +message PsResponseMessage { + required int32 err_code = 1 [ default = 0 ]; + required string err_msg = 2 [ default = "" ]; + optional bytes data = 3; +}; + +enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; +} + +message VariableMessage { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + } + + message LodData { repeated int64 lod_data = 1; } + optional string varname = 1; + // TODO(Yancey1989): reference framework::proto::VarDesc::VarType + optional VarType type = 2; + // bool persistable is not needed for sending. + // tensor info: + optional Type data_type = 3; + repeated int64 dims = 4; + + // lod details: + optional int64 lod_level = 5; + repeated LodData lod = 6; + // selected_rows height, aka. original dim0 + optional int64 slr_height = 7; + // tensor data + optional bytes data = 8; +} + +// for SendAndRecv RPC method +message MultiVariableMessage { + // message flags + required string message_name = 1; + repeated string send_var_names = 2; + repeated string recv_var_names = 3; + repeated VariableMessage var_messages = 4; +}; + +service PsService { + rpc service(PsRequestMessage) returns (PsResponseMessage); + rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage); +}; \ No newline at end of file diff --git a/paddle/fluid/distributed/service/server.cc b/paddle/fluid/distributed/service/server.cc new file mode 100644 index 0000000000000000000000000000000000000000..1582b8739c1775fc828d7ab29ea9b0e61ffc5bef --- /dev/null +++ b/paddle/fluid/distributed/service/server.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/server.h" +#include "glog/logging.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +REGISTER_CLASS(PSServer, BrpcPsServer); +REGISTER_CLASS(PsBaseService, PsService); + +PSServer *PSServerFactory::create(const PSParameter &ps_config) { + const auto &config = ps_config.server_param(); + + if (!config.has_downpour_server_param()) { + LOG(ERROR) << "miss downpour_server_param in ServerParameter"; + return NULL; + } + + if (!config.downpour_server_param().has_service_param()) { + LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param"; + return NULL; + } + + if (!config.downpour_server_param().service_param().has_server_class()) { + LOG(ERROR) << "miss server_class in " + "ServerParameter.downpour_server_param.service_param"; + return NULL; + } + + const auto &service_param = config.downpour_server_param().service_param(); + PSServer *server = CREATE_CLASS(PSServer, service_param.server_class()); + if (server == NULL) { + LOG(ERROR) << "server is not registered, server_name:" + << service_param.server_class(); + return NULL; + } + TableManager::instance().initialize(); + return server; +} + +int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env, + size_t server_rank) { + _config = config.server_param(); + _rank = server_rank; + _environment = &env; + _shuffled_ins = + paddle::framework::MakeChannel>(); + const auto &downpour_param = _config.downpour_server_param(); + + uint32_t barrier_table = UINT32_MAX; + + for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { + auto *table = CREATE_CLASS( + Table, downpour_param.downpour_table_param(i).table_class()); + + if (downpour_param.downpour_table_param(i).table_class() == + "BarrierTable") { + barrier_table = downpour_param.downpour_table_param(i).table_id(); + } + table->initialize(downpour_param.downpour_table_param(i), + config.fs_client_param()); + _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); + } + + if (barrier_table != UINT32_MAX) { + _table_map[barrier_table]->set_table_map(&_table_map); + } + + return initialize(); +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h new file mode 100644 index 0000000000000000000000000000000000000000..4faa0f9db2c4c510f7e010616811f1a5fd10af43 --- /dev/null +++ b/paddle/fluid/distributed/service/server.h @@ -0,0 +1,150 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "butil/endpoint.h" +#include "google/protobuf/service.h" +#include "paddle/fluid/distributed/common/registerer.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/channel.h" + +namespace paddle { +namespace distributed { + +class Table; + +class PSServer { + public: + PSServer() {} + virtual ~PSServer() {} + PSServer(PSServer &&) = delete; + PSServer(const PSServer &) = delete; + + virtual int32_t configure(const PSParameter &config, PSEnvironment &env, + size_t server_rank) final; + + // return server_ip + virtual std::string ip() { return butil::my_ip_cstr(); } + // return server_port + virtual int32_t port() = 0; + + virtual uint64_t start(const std::string &ip, uint32_t port) = 0; + virtual int32_t stop() = 0; + + inline size_t rank() const { return _rank; } + + inline PSEnvironment *environment() { return _environment; } + + inline const ServerParameter *config() const { return &_config; } + inline Table *table(size_t table_id) { + auto itr = _table_map.find(table_id); + if (itr != _table_map.end()) { + return itr->second.get(); + } + return NULL; + } + + inline std::unordered_map> *table() { + return &_table_map; + } + + typedef std::function MsgHandlerFunc; + virtual int registe_pserver2pserver_msg_handler(int msg_type, + MsgHandlerFunc handler) { + _msg_handler_map[msg_type] = handler; + return 0; + } + + paddle::framework::Channel> _shuffled_ins; + + protected: + virtual int32_t initialize() = 0; + + protected: + size_t _rank; + ServerParameter _config; + PSEnvironment *_environment; + std::unordered_map> _table_map; + std::unordered_map _msg_handler_map; +}; + +REGISTER_REGISTERER(PSServer); + +typedef std::function PServerCallBack; + +class PServerClosure : public google::protobuf::Closure { + public: + PServerClosure(PServerCallBack callback) : _callback(callback) {} + virtual ~PServerClosure() {} + virtual void set_promise_value(int value) { + for (auto &promise : _promises) { + promise->set_value(value); + } + } + void add_promise(std::shared_ptr> &promise) { + _promises.push_back(promise); + } + + protected: + PServerCallBack _callback; + std::vector>> _promises; +}; + +class PsBaseService : public PsService { + public: + PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} + virtual ~PsBaseService() {} + + virtual int32_t configure(PSServer *server) { + _server = server; + _rank = _server->rank(); + _config = _server->config(); + return 0; + } + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override = 0; + + virtual void set_response_code(PsResponseMessage &response, int err_code, + const char *err_msg) { + response.set_err_msg(err_msg); + response.set_err_code(err_code); + LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg; + } + + virtual int32_t initialize() = 0; + + protected: + size_t _rank; + PSServer *_server; + const ServerParameter *_config; +}; +REGISTER_REGISTERER(PsBaseService); + +class PSServerFactory { + public: + static PSServer *create(const PSParameter &config); +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/service.cc b/paddle/fluid/distributed/service/service.cc new file mode 100644 index 0000000000000000000000000000000000000000..40a6d2e122718790d1970e7697c05ff862e6f738 --- /dev/null +++ b/paddle/fluid/distributed/service/service.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/distributed/service/service.h" + +#include +#include +#include +#include +#include "paddle/fluid/distributed/service/communicator.h" +#include "paddle/fluid/string/string_helper.h" + +using namespace std; + +namespace paddle { +namespace distributed { + +paddle::distributed::PSParameter load_from_prototxt( + const std::string& filename) { + paddle::distributed::PSParameter param; + int file_descriptor = open(filename.c_str(), O_RDONLY); + + if (file_descriptor == -1) { + VLOG(3) << "FATAL: fail to parse " << filename; + exit(-1); + } + + google::protobuf::io::FileInputStream fileInput(file_descriptor); + if (!google::protobuf::TextFormat::Parse(&fileInput, ¶m)) { + VLOG(3) << "FATAL: fail to parse " << filename; + exit(-1); + } + + close(file_descriptor); + return param; +} + +void PSCore::init_gflag(const std::string& gflags) { + LOG(INFO) << "Init With Gflags:" << gflags; + std::vector flags = paddle::string::split_string(gflags); + if (flags.size() < 1) { + flags.push_back("-max_body_size=314217728"); + flags.push_back("-bthread_concurrency=40"); + flags.push_back("-socket_max_unwritten_bytes=2048000000"); + flags.push_back("-max_connection_pool_size=1950"); + } + auto it = flags.begin(); + flags.insert(it, "exe default"); + char* flags_ptr[flags.size()]; + for (size_t i = 0; i < flags.size(); ++i) { + flags_ptr[i] = (char*)(flags[i].c_str()); + } + int params_cnt = flags.size(); + char** params_ptr = &(flags_ptr[0]); + ::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); +} + +int PSCore::init_server(const std::string& dist_desc, + const std::vector* host_sign_list, + int node_num, int index) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(host_sign_list, node_num); + int ret = 0; + _server_ptr = std::shared_ptr( + paddle::distributed::PSServerFactory::create(_ps_param)); + ret = _server_ptr->configure(_ps_param, _ps_env, index); + CHECK(ret == 0) << "failed to configure server"; + return ret; +} + +int PSCore::init_worker( + const std::string& dist_desc, + const std::map>& regions, + const std::vector* host_sign_list, int node_num, int index) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(host_sign_list, node_num); + int ret = 0; + VLOG(1) << "PSCore::init_worker"; + auto* communicator = Communicator::GetInstance(); + ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env, + index); + communicator->Start(); + return ret; +} + +std::vector PSCore::get_client_info() { + return _ps_env.get_client_info(); +} + +int PSCore::create_client2client_connection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) { + int ret = _worker_ptr->create_client2client_connection( + pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); + return ret; +} + +uint64_t PSCore::run_server(const std::string& ip, uint32_t port) { + return _server_ptr->start(ip, port); +} + +int PSCore::finalize_worker() { + _worker_ptr->finalize_worker(); + return 0; +} + +int PSCore::stop_server() { + auto stop_status = _worker_ptr->stop_server(); + stop_status.wait(); + return 0; +} +paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; } +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/service.h b/paddle/fluid/distributed/service/service.h new file mode 100644 index 0000000000000000000000000000000000000000..97cb864e344bf8a152a494a5365ce7a22f5eb4c8 --- /dev/null +++ b/paddle/fluid/distributed/service/service.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/server.h" + +namespace paddle { +namespace distributed { + +class PSCore { + public: + explicit PSCore() {} + virtual ~PSCore() {} + + virtual int init_server(const std::string& dist_desc, + const std::vector* host_sign_list, + int node_num, int index); + virtual int init_worker( + const std::string& dist_desc, + const std::map>& + regions, + const std::vector* host_sign_list, int node_num, int index); + virtual uint64_t run_server(const std::string& ip, uint32_t port); + virtual int stop_server(); + virtual int finalize_worker(); + virtual std::vector get_client_info(); + virtual int create_client2client_connection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry); + std::shared_ptr + _server_ptr; // pointer to server + std::shared_ptr + _worker_ptr; // pointer to worker + virtual paddle::distributed::PSParameter* get_param(); + + private: + void init_gflag(const std::string& gflags); + paddle::distributed::PSParameter _ps_param; + paddle::distributed::PaddlePSEnvironment _ps_env; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c0f8470b36b01d9bf4515e44a62f59ea8a0519a4 --- /dev/null +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -0,0 +1,19 @@ +set_property(GLOBAL PROPERTY TABLE_DEPS string_helper) + +get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS) + +set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc DEPS ${TABLE_DEPS} device_context string_helper simple_threadpool xxhash generator) + +set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(tensor_accessor SRCS tensor_accessor.cc DEPS ${TABLE_DEPS} eigen3 ps_framework_proto device_context) +cc_library(tensor_table SRCS tensor_table.cc DEPS ps_framework_proto proto_desc enforce executor tensor device_context simple_threadpool gflags glog ) + +set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +cc_library(table SRCS table.cc DEPS common_table tensor_table tensor_accessor ps_framework_proto string_helper device_context gflags glog boost) diff --git a/paddle/fluid/distributed/table/accessor.h b/paddle/fluid/distributed/table/accessor.h new file mode 100644 index 0000000000000000000000000000000000000000..a07a8e10b16f64539a83ebf55bbe4c43dbb7fef2 --- /dev/null +++ b/paddle/fluid/distributed/table/accessor.h @@ -0,0 +1,170 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/distributed/common/registerer.h" +#include "paddle/fluid/distributed/ps.pb.h" + +namespace paddle { +namespace distributed { +struct FsDataConverter { + std::string converter; + std::string deconverter; +}; + +struct Region { + Region() : data(NULL), size(0) {} + Region(char* data, size_t data_num) : data(data), size(data_num) {} + Region(float* data, size_t data_num) + : data(reinterpret_cast(data)), size(data_num << 2) {} + Region(int16_t* data, size_t data_num) + : data(reinterpret_cast(data)), size(data_num << 1) {} + Region(int32_t* data, size_t data_num) + : data(reinterpret_cast(data)), size(data_num << 2) {} + Region(int64_t* data, size_t data_num) + : data(reinterpret_cast(data)), size(data_num << 3) {} + char* data; + size_t size; +}; + +struct DataConverter { + int param; + std::string converter; + std::string deconverter; +}; + +class ValueAccessor { + public: + explicit ValueAccessor(){}; + virtual ~ValueAccessor(){}; + + virtual int configure(const TableAccessorParameter& parameter) { + _config = parameter; + // data_convert结构体初始化 + if (_config.table_accessor_save_param_size() != 0) { + for (int i = 0; i < _config.table_accessor_save_param_size(); ++i) { + int param = _config.table_accessor_save_param(i).param(); + std::string converter = + _config.table_accessor_save_param(i).converter(); + std::string deconverter = + _config.table_accessor_save_param(i).deconverter(); + _data_coverter_map[param] = std::make_shared(); + *(_data_coverter_map[param]) = {param, converter, deconverter}; + } + } + return 0; + } + virtual int initialize() = 0; + + // value维度 + virtual size_t dim() = 0; + // value各个维度的size + virtual size_t dim_size(size_t dim) = 0; + // value各维度相加总size + virtual size_t size() = 0; + + // value中mf动态长度部分总size大小, sparse下生效 + virtual size_t mf_size() { return 0; } + virtual bool need_extend_mf(float* value) { return false; } + virtual bool has_mf(size_t size) { return false; } + // pull value维度 + virtual size_t select_dim() = 0; + // pull value各个维度的size + virtual size_t select_dim_size(size_t dim) = 0; + // pull value各维度相加总size + virtual size_t select_size() = 0; + // push value维度 + virtual size_t update_dim() = 0; + // push value各个维度的size + virtual size_t update_dim_size(size_t dim) = 0; + // push value各维度相加总size + virtual size_t update_size() = 0; + // fea total for dense + virtual size_t fea_dim() { return _config.fea_dim(); } + // converter for save + virtual std::string get_converter(int param) { + auto itr = _data_coverter_map.find(param); + if (itr == _data_coverter_map.end()) { + return ""; + } else { + return (*itr).second->converter; + } + } + // deconverter for load + virtual std::string get_deconverter(int param) { + auto itr = _data_coverter_map.find(param); + if (itr == _data_coverter_map.end()) { + return ""; + } else { + return (*itr).second->deconverter; + } + } + // 判断该value是否进行shrink + virtual bool shrink(float* value) = 0; + + // 判断该value是否在save阶段dump, + // param作为参数用于标识save阶段,如downpour的xbox与batch_model + virtual bool save(float* value, int param) = 0; + // update delta_score and unseen_days after save + virtual void update_stat_after_save(float* value, int param) {} + + // keys不存在时,为values生成随机值 + virtual int32_t create(float** value, size_t num) = 0; + virtual bool create_value(int type, const float* value) { return true; } + // 从values中选取到select_values中 + virtual int32_t select(float** select_values, const float** values, + size_t num) = 0; + // 将update_values聚合到一起 + virtual int32_t merge(float** update_values, + const float** other_update_values, size_t num) = 0; + // 将update_values聚合到一起,通过it.next判定是否进入下一个key + // virtual int32_t merge(float** update_values, iterator it); + // 将update_values更新应用到values中 + virtual int32_t update(float** values, const float** update_values, + size_t num) = 0; + + // used to save model, will filter feature + virtual std::string parse_to_string(const float* value, int param) = 0; + // parse value from string, used to load model + virtual int32_t parse_from_string(const std::string& data, float* value) = 0; + + virtual FsDataConverter converter(int param) { + FsDataConverter data_convert; + data_convert.converter = this->get_converter(param); + data_convert.deconverter = this->get_deconverter(param); + return data_convert; + } + + virtual int set_weight(float** values, const float** update_values, + size_t num) { + return 0; + } + + virtual float get_field(float* value, const std::string& name) { return 0.0; } + + protected: + size_t _value_size; + size_t _select_value_size; + size_t _update_value_size; + TableAccessorParameter _config; + std::unordered_map> + _data_coverter_map; +}; +REGISTER_REGISTERER(ValueAccessor); +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/barrier_table.cc b/paddle/fluid/distributed/table/barrier_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1e545a133e6163f75c8c2ba756be3ee420e3916 --- /dev/null +++ b/paddle/fluid/distributed/table/barrier_table.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // NOLINT +#include "paddle/fluid/distributed/common/utils.h" +#include "paddle/fluid/distributed/table/common_table.h" + +namespace paddle { +namespace distributed { + +int32_t BarrierTable::initialize() { + auto trainers = _config.common().trainer_num(); + trigger_.store(trainers); + + for (int x = 0; x < trainers; ++x) { + trainer_all_.insert(x); + } + VLOG(1) << "BarrierTable init trigger: " << trigger_.load(); + return 0; +} + +// 0: send_barrier 1: recv_barrier 2: complete +int32_t BarrierTable::barrier(const uint32_t trainer_id, + const std::string barrier_type) { + std::unique_lock lock(mutex_); + + if (barrier_type == "2") { + trigger_.fetch_sub(1, std::memory_order::memory_order_relaxed); + VLOG(1) << "trigger sub to : " << trigger_.load(); + } else { + trainer_ids_.insert(trainer_id); + VLOG(1) << "barrier type: " << barrier_type + << " add trainer id: " << trainer_id; + } + + if (trainer_ids_.size() < trigger_.load()) { + std::vector diffs(trainer_all_.size()); + auto iter = std::set_difference(trainer_all_.begin(), trainer_all_.end(), + trainer_ids_.begin(), trainer_ids_.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + + auto diff = to_string(diffs); + VLOG(1) << "still need trainers: " << diff; + trainer_wait_.wait(lock, [&] { return trainer_ids_.size() == 0; }); + } else { + VLOG(1) << "barrier table optimize begin"; + for (auto& x : *table_map_) { + auto table = x.second; + table->pour(); + } + VLOG(1) << "barrier table optimize done"; + + trainer_ids_.clear(); + trainer_wait_.notify_all(); + } + return 0; +} + +int32_t BarrierTable::set_table_map( + std::unordered_map>* table_map) { + table_map_ = table_map; + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/common_dense_table.cc b/paddle/fluid/distributed/table/common_dense_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3d481f32eb8881505514281544ddd92b0d8f921 --- /dev/null +++ b/paddle/fluid/distributed/table/common_dense_table.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/common_dense_table.h" +#include "paddle/fluid/distributed/common/utils.h" + +namespace paddle { +namespace distributed { + +void CommonDenseTable::create_initializer(const std::string& attr, + const std::string& name) { + auto slices = string::split_string(attr, "&"); + + if (slices[0] == "gaussian_random") { + initializers_[name] = new GaussianInitializer(slices); + } else if (slices[0] == "fill_constant") { + initializers_[name] = new FillConstantInitializer(slices); + } else if (slices[0] == "uniform_random") { + initializers_[name] = new UniformInitializer(slices); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("%s can not be supported", name)); + } +} + +int32_t CommonDenseTable::initialize() { + _shards_task_pool.resize(task_pool_size_); + for (int i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + + sync = _config.common().sync(); + VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; + + initialize_value(); + initialize_optimizer(); + return 0; +} + +int32_t CommonDenseTable::initialize_value() { + auto common = _config.common(); + int size = static_cast(common.params().size()); + values_.resize(size); + for (int x = 0; x < size; ++x) { + auto& varname = common.params()[x]; + auto& dim = common.dims()[x]; + if (varname == "Param") { + param_dim_ = dim; + param_idx_ = x; + } + auto& initializer = common.initializers()[x]; + + create_initializer(initializer, varname); + values_[x].resize(dim); + names_index_[varname] = x; + + for (int y = 0; y < dim; ++y) { + values_[x][y] = initializers_[varname]->GetValue(); + } + } + + pull_reservoir_ = ReservoirValue(param_dim_); + return 0; +} + +int32_t CommonDenseTable::initialize_optimizer() { + auto common = _config.common(); + auto name = common.name(); + auto attrs = common.attributes(); + + if (name == "sgd") { + optimizer_ = std::make_shared(common, &values_); + } else if (name == "adam") { + optimizer_ = std::make_shared(common, &values_); + } else if (name == "sum") { + optimizer_ = std::make_shared(common, &values_); + } else { + VLOG(0) << "init optimizer failed"; + } + VLOG(0) << "init optimizer " << name << " done"; + return 0; +} + +int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { + std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), + pull_values); + return 0; +} + +int32_t CommonDenseTable::push_dense_param(const float* values, size_t num) { + PADDLE_ENFORCE_GE( + num, param_dim_, + paddle::platform::errors::InvalidArgument( + "update desne param numel expected %d, but got %d", param_dim_, num)); + std::copy_n(values, param_dim_, values_[param_idx_].begin()); + return 0; +} + +int32_t CommonDenseTable::pour() { + _push_dense(pull_reservoir_.values.data(), pull_reservoir_.values.size()); + pull_reservoir_.reset(); + return 0; +} + +int32_t CommonDenseTable::push_dense(const float* values, size_t num) { + if (sync) { + std::future task = + _shards_task_pool[0]->enqueue([this, &values]() -> int { + pull_reservoir_.add(values, param_dim_); + return 0; + }); + task.wait(); + } else { + _push_dense(values, num); + } + return 0; +} + +int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { + PADDLE_ENFORCE_GE( + num, param_dim_, + paddle::platform::errors::InvalidArgument( + "update desne numel expected %d, but got %d", param_dim_, num)); + + std::vector buckets = bucket(param_dim_, task_pool_size_); + std::vector> tasks(task_pool_size_); + + for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &buckets, &values]() -> int { + auto begin = buckets[shard_id]; + auto end = buckets[shard_id + 1]; + optimizer_->update(values, param_dim_, begin, end); + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/common_dense_table.h b/paddle/fluid/distributed/table/common_dense_table.h new file mode 100644 index 0000000000000000000000000000000000000000..eb97f3f26416a905020bcf722aee182dc2510de0 --- /dev/null +++ b/paddle/fluid/distributed/table/common_dense_table.h @@ -0,0 +1,80 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "Eigen/Dense" +#include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/distributed/table/common_table.h" +#include "paddle/fluid/distributed/table/depends/dense.h" +#include "paddle/fluid/distributed/table/depends/initializers.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { + +class CommonDenseTable : public DenseTable { + public: + explicit CommonDenseTable() {} + virtual ~CommonDenseTable() {} + virtual int32_t initialize() override; + virtual int32_t initialize_shard() override { return 0; } + virtual void create_initializer(const std::string& attr, + const std::string& name); + virtual int32_t initialize_value(); + virtual int32_t initialize_optimizer(); + virtual int32_t pull_dense(float* pull_values, size_t num) override; + virtual int32_t push_dense_param(const float* values, size_t num) override; + virtual int32_t push_dense(const float* values, size_t num) override; + virtual int32_t pour() override; + + int32_t load(const std::string& path, const std::string& param) override { + VLOG(0) << "Dense table may load by " + "paddle.distributed.fleet.init_server"; + return 0; + } + + int32_t save(const std::string& path, const std::string& param) override { + VLOG(0) + << "Dense table may be saved by " + "paddle.distributed.fleet.save_persistables/save_inference_model"; + return 0; + } + + virtual int32_t flush() override { return 0; } + virtual int32_t shrink() override { return 0; } + virtual void clear() override { return; } + + protected: + int32_t _push_dense(const float* values, size_t num); + + private: + const int task_pool_size_ = 1; + bool sync = true; + std::vector> _shards_task_pool; + int param_dim_ = 0; + int param_idx_ = 0; + std::shared_ptr optimizer_; + std::vector> values_; + ReservoirValue pull_reservoir_; + std::unordered_map initializers_; + std::unordered_map names_index_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..288f034c4bb3a67c202d2a4033cd43b3b71c66cc --- /dev/null +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -0,0 +1,521 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include +#include +#include "paddle/fluid/distributed/common/utils.h" +#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { + +struct Meta { + std::string param; + int shard_id; + std::vector names; + std::vector dims; + uint64_t count; + std::unordered_map dims_map; + + explicit Meta(const std::string& metapath) { + std::ifstream file(metapath); + std::string line; + int num_lines = 0; + while (std::getline(file, line)) { + if (StartWith(line, "#")) { + continue; + } + auto pairs = paddle::string::split_string(line, "="); + PADDLE_ENFORCE_EQ( + pairs.size(), 2, + paddle::platform::errors::InvalidArgument( + "info in %s except k=v, but got %s", metapath, line)); + + if (pairs[0] == "param") { + param = pairs[1]; + } + if (pairs[0] == "shard_id") { + shard_id = std::stoi(pairs[1]); + } + if (pairs[0] == "row_names") { + names = paddle::string::split_string(pairs[1], ","); + } + if (pairs[0] == "row_dims") { + auto dims_strs = + paddle::string::split_string(pairs[1], ","); + for (auto& str : dims_strs) { + dims.push_back(std::stoi(str)); + } + } + if (pairs[0] == "count") { + count = std::stoull(pairs[1]); + } + } + for (int x = 0; x < names.size(); ++x) { + dims_map[names[x]] = dims[x]; + } + } + + Meta(std::string param, int shard_id, std::vector row_names, + std::vector dims, uint64_t count) { + this->param = param; + this->shard_id = shard_id; + this->names = row_names; + this->dims = dims; + this->count = count; + } + + std::string ToString() { + std::stringstream ss; + ss << "param=" << param << "\n"; + ss << "shard_id=" << shard_id << "\n"; + ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n"; + ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n"; + ss << "count=" << count << "\n"; + return ss.str(); + } +}; + +void ProcessALine(const std::vector& columns, const Meta& meta, + std::vector>* values) { + PADDLE_ENFORCE_EQ(columns.size(), meta.names.size() + 1, + paddle::platform::errors::InvalidArgument( + "record in txt do not match meta.")); + + values->reserve(columns.size() - 1); + + for (int x = 1; x < columns.size(); ++x) { + auto& column = columns[x]; + auto val_ = paddle::string::split_string(column, ","); + + std::vector val; + std::transform(val_.begin(), val_.end(), std::back_inserter(val), + [](std::string va) { return std::stof(va); }); + PADDLE_ENFORCE_EQ(val.size(), meta.dims[x - 1], + paddle::platform::errors::InvalidArgument( + "record in txt do not match meta.")); + values->push_back(val); + } +} + +int64_t SaveToText(std::ostream* os, std::shared_ptr block, + const std::vector& saved_names, + const int mode) { + for (auto value : block->values_) { + std::vector*> vss = value.second->get(saved_names); + std::stringstream ss; + auto id = value.first; + ss << id << "\t"; + for (int i = 0; i < static_cast(vss.size()); i++) { + auto& vs = vss[i]; + ss << paddle::string::join_strings((*vs), ','); + ss << "\t"; + } + ss << "\n"; + + os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); + } + + return block->values_.size(); +} + +int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, + const int pserver_id, const int pserver_num, + const int local_shard_num, + std::vector>* blocks) { + Meta meta = Meta(metapath); + + int num_lines = 0; + std::ifstream file(valuepath); + std::string line; + + while (std::getline(file, line)) { + auto values = paddle::string::split_string(line, "\t"); + auto id = std::stoull(values[0]); + + if (id % pserver_num != pserver_id) { + VLOG(0) << "will not load " << values[0] << " from " << valuepath + << ", please check id distribution"; + continue; + } + + auto shard_id = id % local_shard_num; + auto block = blocks->at(shard_id); + + std::vector> kvalues; + ProcessALine(values, meta, &kvalues); + block->Init(id, &kvalues, 1); + } + + return 0; +} + +void SaveShard(std::shared_ptr block, const std::string& dirname, + const CommonAccessorParameter& common, const int mode, + const int pserver_id, const int shard_id) { + auto varname = common.table_name(); + std::string var_store = string::Sprintf("%s/%s", dirname, varname); + VLOG(3) << "save " << varname << " in dir: " << var_store << " begin"; + MkDirRecursively(var_store.c_str()); + + std::string shard_var_pre = + string::Sprintf("%s.block%d.%d", varname, pserver_id, shard_id); + std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre); + std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre); + + // save values + std::vector params(common.params().begin(), + common.params().end()); + std::unique_ptr value_out(new std::ofstream(value_)); + SaveToText(value_out.get(), block, params, mode); + // save meta + std::stringstream stream; + stream << "param=" << common.table_name() << "\n"; + stream << "server_id=" << pserver_id << "\n"; + stream << "shard_id=" << shard_id << "\n"; + stream << "row_names=" << paddle::string::join_strings(common.params(), ',') + << "\n"; + stream << "row_dims=" << paddle::string::join_strings(common.dims(), ',') + << "\n"; + stream << "count=" << block->values_.size() << "\n"; + std::unique_ptr meta_out(new std::ofstream(meta_)); + meta_out->write(stream.str().c_str(), sizeof(char) * stream.str().size()); + meta_out->close(); + VLOG(3) << "save " << varname << " in dir: " << var_store << " done"; +} + +void CommonSparseTable::create_initializer(const std::string& attr, + const std::string& name) { + auto slices = string::split_string(attr, "&"); + + if (slices[0] == "gaussian_random") { + initializers_[name] = new GaussianInitializer(slices); + } else if (slices[0] == "fill_constant") { + initializers_[name] = new FillConstantInitializer(slices); + } else if (slices[0] == "uniform_random") { + initializers_[name] = new UniformInitializer(slices); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("%s can not be supported", name)); + } +} + +int32_t CommonSparseTable::initialize() { + _shards_task_pool.resize(task_pool_size_); + for (int i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + + sync = _config.common().sync(); + VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync; + + initialize_value(); + initialize_optimizer(); + initialize_recorder(); + return 0; +} + +int32_t CommonSparseTable::initialize_recorder() { return 0; } + +int32_t CommonSparseTable::initialize_value() { + auto common = _config.common(); + int size = static_cast(common.params().size()); + + for (int x = 0; x < size; ++x) { + auto& varname = common.params()[x]; + auto& dim = common.dims()[x]; + if (varname == "Param") { + param_dim_ = dim; + } + auto& initializer = common.initializers()[x]; + create_initializer(initializer, varname); + } + + shard_values_.reserve(task_pool_size_); + for (int x = 0; x < task_pool_size_; ++x) { + auto shard = std::make_shared(common, &initializers_); + shard_values_.emplace_back(shard); + } + return 0; +} + +int32_t CommonSparseTable::initialize_optimizer() { + auto common = _config.common(); + auto name = common.name(); + auto attrs = common.attributes(); + + if (name == "sgd") { + optimizer_ = std::make_shared(common); + } else if (name == "adam") { + optimizer_ = std::make_shared(common); + } else if (name == "sum") { + optimizer_ = std::make_shared(common); + } else { + VLOG(0) << "init optimizer failed"; + } + + VLOG(0) << "init optimizer " << name << " done"; + return 0; +} + +int32_t CommonSparseTable::load(const std::string& path, + const std::string& param) { + rwlock_->WRLock(); + VLOG(0) << "sparse table load with " << path << " with meta " << param; + LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_, + &shard_values_); + rwlock_->UNLock(); + return 0; +} + +int32_t CommonSparseTable::save(const std::string& dirname, + const std::string& param) { + rwlock_->WRLock(); + int mode = std::stoi(param); + VLOG(0) << "sparse table save: " << dirname << " mode: " << mode; + + auto varname = _config.common().table_name(); + std::string var_store = string::Sprintf("%s/%s", dirname, varname); + MkDirRecursively(var_store.c_str()); + + VLOG(3) << "save " << varname << " in dir: " << var_store << " begin"; + std::vector params(_config.common().params().begin(), + _config.common().params().end()); + std::string shard_var_pre = + string::Sprintf("%s.block%d", varname, _shard_idx); + + std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre); + + std::unique_ptr value_out(new std::ofstream(value_)); + + int64_t total_ins = 0; + for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { + // save values + total_ins += + SaveToText(value_out.get(), shard_values_[shard_id], params, mode); + } + value_out->close(); + + // save meta + std::stringstream stream; + stream << "param=" << _config.common().table_name() << "\n"; + stream << "shard_id=" << _shard_idx << "\n"; + stream << "row_names=" + << paddle::string::join_strings(_config.common().params(), ',') + << "\n"; + stream << "row_dims=" + << paddle::string::join_strings(_config.common().dims(), ',') << "\n"; + stream << "count=" << total_ins << "\n"; + std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre); + std::unique_ptr meta_out(new std::ofstream(meta_)); + meta_out->write(stream.str().c_str(), sizeof(char) * stream.str().size()); + meta_out->close(); + VLOG(3) << "save " << varname << " in dir: " << var_store << " done"; + rwlock_->UNLock(); + return 0; +} + +std::pair CommonSparseTable::print_table_stat() { + int64_t feasign_size = 0; + int64_t mf_size = 0; + + for (auto& value : shard_values_) { + feasign_size += value->values_.size(); + } + + return {feasign_size, mf_size}; +} + +int32_t CommonSparseTable::pour() { + rwlock_->RDLock(); + + std::vector values; + std::vector keys; + + keys.reserve(pull_reservoir_.size()); + values.reserve(pull_reservoir_.size() * param_dim_); + + for (auto& val : pull_reservoir_) { + keys.push_back(val.first); + auto& reservoir = val.second; + reservoir.avg(); + std::copy(reservoir.values.begin(), reservoir.values.end(), + std::back_inserter(values)); + } + _push_sparse(keys.data(), values.data(), pull_reservoir_.size()); + + pull_reservoir_.clear(); + rwlock_->UNLock(); + return 0; +} + +int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys, + size_t num) { + rwlock_->RDLock(); + std::vector value_names; + for (auto name : _config.common().params()) { + value_names.push_back(name); + } + + std::vector> offset_bucket; + offset_bucket.resize(task_pool_size_); + + for (int x = 0; x < num; ++x) { + auto y = keys[x] % task_pool_size_; + offset_bucket[y].push_back(x); + } + + std::vector> tasks(task_pool_size_); + + for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &keys, &offset_bucket, &value_names, + &pull_values]() -> int { + auto& block = shard_values_[shard_id]; + auto& offsets = offset_bucket[shard_id]; + + for (int i = 0; i < offsets.size(); ++i) { + auto offset = offsets[i]; + auto id = keys[offset]; + block->InitFromInitializer(id, value_names); + auto values = block->Get(id, {"Param"}); + auto dim = values[0]->size(); + std::copy(values[0]->begin(), values[0]->end(), + pull_values + dim * offset); + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + rwlock_->UNLock(); + return 0; +} + +int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, + const float* values, size_t num) { + rwlock_->RDLock(); + std::vector> offset_bucket; + offset_bucket.resize(task_pool_size_); + + for (int x = 0; x < num; ++x) { + auto y = keys[x] % task_pool_size_; + offset_bucket[y].push_back(x); + } + + std::vector> tasks(task_pool_size_); + + for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &keys, &values, num, &offset_bucket]() -> int { + auto& offsets = offset_bucket[shard_id]; + optimizer_->update(keys, values, num, offsets, + shard_values_[shard_id].get()); + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + rwlock_->UNLock(); + return 0; +} + +int32_t CommonSparseTable::push_sparse(const uint64_t* keys, + const float* values, size_t num) { + if (sync) { + std::future task = + _shards_task_pool[0]->enqueue([this, &keys, &values, num]() -> int { + for (int x = 0; x < num; ++x) { + auto id = keys[x]; + auto has = pull_reservoir_.find(id); + + if (has == pull_reservoir_.end()) { + pull_reservoir_[id] = ReservoirValue(param_dim_); + } + + auto& reservoir = pull_reservoir_[id]; + reservoir.add(values + x * param_dim_, param_dim_); + } + return 0; + }); + task.wait(); + } else { + _push_sparse(keys, values, num); + } + + return 0; +} + +int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, + const float* values, size_t num) { + rwlock_->RDLock(); + std::vector value_names; + for (auto name : _config.common().params()) { + value_names.push_back(name); + } + + std::vector> offset_bucket; + offset_bucket.resize(task_pool_size_); + + for (int x = 0; x < num; ++x) { + auto y = keys[x] % task_pool_size_; + offset_bucket[y].push_back(x); + } + + std::vector> tasks(task_pool_size_); + + for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { + tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( + [this, shard_id, &keys, &offset_bucket, &value_names, + &values]() -> int { + auto& block = shard_values_[shard_id]; + auto& offsets = offset_bucket[shard_id]; + + for (int i = 0; i < offsets.size(); ++i) { + auto offset = offsets[i]; + auto id = keys[offset]; + block->InitFromInitializer(id, value_names); + auto values_ = block->Get(id, {"Param"}); + auto dim = values_[0]->size(); + std::copy_n(values + dim * offset, dim, values_[0]->data()); + } + return 0; + }); + } + + for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { + tasks[shard_id].wait(); + } + rwlock_->UNLock(); + return 0; +} + +int32_t CommonSparseTable::flush() { return 0; } + +int32_t CommonSparseTable::shrink() { + VLOG(0) << "shrink coming soon"; + return 0; +} +void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; } + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/common_sparse_table.h b/paddle/fluid/distributed/table/common_sparse_table.h new file mode 100644 index 0000000000000000000000000000000000000000..6baf60a44c15b0055faf2d486e484edb97365e42 --- /dev/null +++ b/paddle/fluid/distributed/table/common_sparse_table.h @@ -0,0 +1,97 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include "Eigen/Dense" +#include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/distributed/table/common_table.h" +#include "paddle/fluid/distributed/table/depends/initializers.h" +#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" +#include "paddle/fluid/distributed/table/depends/sparse.h" +#include "paddle/fluid/framework/rw_lock.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { + +class CommonSparseTable : public SparseTable { + public: + CommonSparseTable() { rwlock_.reset(new framework::RWLock); } + virtual ~CommonSparseTable() {} + + // unused method begin + virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; } + virtual int32_t push_dense_param(const float* values, size_t num) { + return 0; + } + virtual int32_t push_dense(const float* values, size_t num) { return 0; } + // unused method end + + virtual int32_t initialize(); + virtual int32_t initialize_shard() { return 0; } + virtual void create_initializer(const std::string& attr, + const std::string& name); + virtual int32_t initialize_value(); + virtual int32_t initialize_optimizer(); + virtual int32_t initialize_recorder(); + + int32_t load(const std::string& path, const std::string& param); + + int32_t save(const std::string& path, const std::string& param); + + virtual std::pair print_table_stat(); + virtual int32_t pull_sparse(float* pull_values, const uint64_t* keys, + size_t num); + + virtual int32_t push_sparse(const uint64_t* keys, const float* values, + size_t num); + + // only for sparse geo table + virtual int32_t push_sparse_param(const uint64_t* keys, const float* values, + size_t num); + + virtual int32_t pour(); + virtual int32_t flush(); + virtual int32_t shrink(); + virtual void clear(); + + protected: + virtual int32_t _push_sparse(const uint64_t* keys, const float* values, + size_t num); + + private: + const int task_pool_size_ = 11; + std::vector> _shards_task_pool; + + bool sync = false; + int param_dim_ = 0; + std::shared_ptr optimizer_; + std::unordered_map initializers_; + std::vector> shard_values_; + std::unordered_map> pull_reservoir_; + std::unique_ptr rwlock_{nullptr}; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/common_table.h b/paddle/fluid/distributed/table/common_table.h new file mode 100644 index 0000000000000000000000000000000000000000..d37e6677e634d7da93b661cf02389c2f4abee19a --- /dev/null +++ b/paddle/fluid/distributed/table/common_table.h @@ -0,0 +1,166 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // NOLINT +#include // NOLINT +#include + +#include "paddle/fluid/distributed/table/table.h" + +#include "paddle/fluid/distributed/common/utils.h" + +namespace paddle { +namespace distributed { + +template +struct ReservoirValue { + std::vector values; + uint32_t counter; + uint32_t dim; + + ReservoirValue() { + dim = 0; + values.resize(dim); + counter = 0; + } + + ReservoirValue(uint32_t dim) { + this->dim = dim; + values.resize(dim); + counter = 0; + } + + void add(const T *value, int numel) { + GetBlas().VADD(numel, values.data(), value, values.data()); + counter++; + } + + void add(T *value, int numel) { + GetBlas().VADD(numel, values.data(), value, values.data()); + counter++; + } + + void avg() { + auto scale = 1 / static_cast(counter); + GetBlas().SCAL(values.size(), scale, values.data()); + } + + void reset() { + values.resize(dim, 0); + counter = 0; + } +}; + +class SparseTable : public Table { + public: + SparseTable() {} + virtual ~SparseTable() {} + + virtual void *get_shard(size_t shard_idx) { return 0; } + + int32_t pull_dense(float *values, size_t num) override { return 0; } + + int32_t push_dense(const float *values, size_t num) override { return 0; } + + static int32_t sparse_local_shard_num(uint32_t shard_num, + uint32_t server_num) { + if (shard_num % server_num == 0) { + return shard_num / server_num; + } + size_t local_shard_num = shard_num / server_num + 1; + return local_shard_num; + } + + static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + return (key % shard_num) / sparse_local_shard_num(shard_num, server_num); + } +}; + +class DenseTable : public Table { + public: + DenseTable() {} + virtual ~DenseTable() {} + + virtual void *get_shard(size_t shard_idx) { return 0; } + int32_t pull_sparse(float *values, const uint64_t *keys, + size_t num) override { + return 0; + } + int32_t push_sparse(const uint64_t *keys, const float *values, + size_t num) override { + return 0; + } + int32_t push_dense_param(const float *values, size_t num) override { + return 0; + } + int32_t shrink() override { return 0; } +}; + +class BarrierTable : public Table { + public: + BarrierTable() {} + virtual ~BarrierTable() {} + + virtual void *get_shard(size_t shard_idx) { return 0; } + + int32_t pull_dense(float *values, size_t num) override { return 0; } + + int32_t push_dense(const float *values, size_t num) override { return 0; } + + int32_t pull_sparse(float *values, const uint64_t *keys, + size_t num) override { + return 0; + } + int32_t push_sparse(const uint64_t *keys, const float *values, + size_t num) override { + return 0; + } + int32_t push_dense_param(const float *values, size_t num) override { + return 0; + } + int32_t shrink() override { return 0; } + virtual void clear(){}; + virtual int32_t flush() { return 0; }; + virtual int32_t load(const std::string &path, const std::string ¶m) { + return 0; + } + virtual int32_t save(const std::string &path, const std::string ¶m) { + return 0; + } + virtual int32_t initialize_shard() { return 0; }; + + virtual int32_t initialize() override; + // only for barrier + // 0: send_barrier 1: recv_barrier 2: complete + virtual int32_t barrier(const uint32_t trainer_id, + const std::string barrier_type) override; + + virtual int32_t set_table_map( + std::unordered_map> *table_map) override; + + private: + std::mutex mutex_; + std::condition_variable trainer_wait_; + std::set trainer_ids_; + std::set trainer_all_; + std::atomic trigger_; + std::atomic exit_; + std::unordered_map> *table_map_; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/depends/dense.h b/paddle/fluid/distributed/table/depends/dense.h new file mode 100644 index 0000000000000000000000000000000000000000..8a71d9b5a8b651853333d8f4ce346471407dc901 --- /dev/null +++ b/paddle/fluid/distributed/table/depends/dense.h @@ -0,0 +1,182 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // for sqrt in CPU and CUDA +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/common/utils.h" + +namespace paddle { +namespace distributed { + +// dense optimzier +// TODO(tangwei12) integrate with sparse optimzer later. +class DenseOptimizer { + public: + DenseOptimizer() {} + explicit DenseOptimizer(const CommonAccessorParameter& accessor, + std::vector>* values) {} + virtual void update(const float* update_values, size_t num, int begin, + int end) = 0; +}; + +// sum calc for dense tensor +class DSUM : public DenseOptimizer { + public: + explicit DSUM(const CommonAccessorParameter& accessor, + std::vector>* values) { + auto& names = accessor.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "Param") { + param = (*values)[x].data(); + } + } + } + + void update(const float* update_values, size_t num, int begin, + int end) override { + auto update_numel = end - begin; + GetBlas().VADD(update_numel, update_values + begin, param + begin, + param + begin); + } + + float* param; +}; + +// sgd optimizer for dense tensor +class DSGD : public DenseOptimizer { + public: + explicit DSGD(const CommonAccessorParameter& accessor, + std::vector>* values) { + auto& names = accessor.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "LearningRate") { + learning_rate = (*values)[x].data(); + } + if (names[x] == "Param") { + param = (*values)[x].data(); + } + } + } + + void update(const float* update_values, size_t num, int begin, + int end) override { + auto update_numel = end - begin; + std::vector grads; + grads.resize(update_numel); + + auto blas = GetBlas(); + blas.VCOPY(update_numel, update_values + begin, grads.data()); + blas.SCAL(update_numel, *learning_rate, grads.data()); + blas.VSUB(update_numel, param + begin, grads.data(), param + begin); + } + + float* learning_rate; + float* param; +}; + +// adam optimizer for dense tensor +class DAdam : public DenseOptimizer { + public: + explicit DAdam(const CommonAccessorParameter& accessor, + std::vector>* values) { + auto& names = accessor.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "LearningRate") { + learning_rate = (*values)[x].data(); + } + if (names[x] == "Param") { + param = (*values)[x].data(); + } + if (names[x] == "Moment1") { + moment1 = (*values)[x].data(); + } + if (names[x] == "Moment2") { + moment2 = (*values)[x].data(); + } + if (names[x] == "Beta1Pow") { + beta1_pow = (*values)[x].data(); + } + if (names[x] == "Beta2Pow") { + beta2_pow = (*values)[x].data(); + } + } + + // add attr later + beta1 = 0.9; + beta2 = 0.999; + epsilon = 1.0e-8; + } + + void update(const float* update_values, size_t num, int begin, + int end) override { + auto update_numel = end - begin; + std::vector grad, grad2, tmp; + grad.resize(update_numel); + grad2.resize(update_numel); + tmp.resize(update_numel); + + auto blas = GetBlas(); + blas.VCOPY(update_numel, update_values + begin, grad.data()); + blas.VCOPY(update_numel, update_values + begin, grad2.data()); + + blas.SCAL(update_numel, 1 - beta1, grad.data()); + blas.VSQUARE(update_numel, grad2.data(), grad2.data()); + blas.SCAL(update_numel, 1 - beta2, grad2.data()); + + blas.SCAL(update_numel, beta1, moment1 + begin); + blas.VADD(update_numel, moment1 + begin, grad.data(), moment1 + begin); + blas.SCAL(update_numel, beta2, moment2 + begin); + blas.VADD(update_numel, moment2 + begin, grad2.data(), moment2 + begin); + + beta1_pow[0] = beta1_pow[0] * beta1; + beta2_pow[0] = beta2_pow[0] * beta2; + + float lr_ = learning_rate[0]; + lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]); + + float* tmp_ = tmp.data(); + float eps_ = epsilon * sqrt(1 - beta2_pow[0]); + + SQRT(update_numel, moment2 + begin, tmp_); + ADD(update_numel, tmp_, eps_, tmp_); + + blas.VDIV(update_numel, moment1 + begin, tmp_, tmp_); + blas.SCAL(update_numel, lr_, tmp_); + blas.VSUB(update_numel, param + begin, tmp_, param + begin); + } + + float* learning_rate; + + float* param; + float* moment1; + float* moment2; + + float* beta1_pow; + float* beta2_pow; + + float beta1; + float beta2; + float epsilon; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/depends/geo_recorder.h b/paddle/fluid/distributed/table/depends/geo_recorder.h new file mode 100644 index 0000000000000000000000000000000000000000..ad094f0dfbc48aeab046b80527b0193fad4189cb --- /dev/null +++ b/paddle/fluid/distributed/table/depends/geo_recorder.h @@ -0,0 +1,94 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +class ConcurrentSet { + public: + ConcurrentSet() : pool_(new ::ThreadPool(1)) {} + ~ConcurrentSet() {} + + std::future Update(const std::vector& rows) { + auto task = [this, rows] { + for (auto row : rows) { + set_.insert(row); + } + }; + return pool_->enqueue(std::move(task)); + } + + std::future GetAndClear(std::vector* result) { + auto task = [this, &result] { + result->clear(); + for (auto& id : set_) { + result->push_back(id); + } + set_.clear(); + }; + return pool_->enqueue(std::move(task)); + } + + private: + std::unordered_set set_; + std::unique_ptr<::ThreadPool> pool_{nullptr}; +}; + +class GeoRecorder { + public: + explicit GeoRecorder(int trainer_num) : trainer_num_(trainer_num) { + trainer_rows_.reserve(trainer_num); + for (auto i = 0; i < trainer_num; ++i) { + trainer_rows_.emplace_back(new ConcurrentSet()); + } + } + + ~GeoRecorder() = default; + + void Update(const std::vector& update_rows) { + VLOG(3) << " row size: " << update_rows.size(); + + std::vector> fs; + for (auto& set : trainer_rows_) { + fs.push_back(set->Update(update_rows)); + } + for (auto& f : fs) { + f.wait(); + } + } + + void GetAndClear(uint32_t trainer_id, std::vector* result) { + VLOG(3) << "GetAndClear for trainer: " << trainer_id; + trainer_rows_.at(trainer_id)->GetAndClear(result).wait(); + } + + private: + const int trainer_num_; + std::vector> trainer_rows_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/depends/initializers.h b/paddle/fluid/distributed/table/depends/initializers.h new file mode 100644 index 0000000000000000000000000000000000000000..e3d6e052c915863ecfd4ba5af636b4274f17f667 --- /dev/null +++ b/paddle/fluid/distributed/table/depends/initializers.h @@ -0,0 +1,102 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" + +namespace paddle { +namespace distributed { + +class Initializer { + public: + Initializer() {} + + explicit Initializer(const std::vector &attrs) {} + + virtual float GetValue() = 0; + + virtual ~Initializer() {} + + protected: + std::string name_; + unsigned int seed_; +}; + +class UniformInitializer : public Initializer { + public: + explicit UniformInitializer(const std::vector &attrs) { + name_ = attrs[0]; + seed_ = static_cast(std::stoi(attrs[1])); + min_ = std::stof(attrs[2]); + max_ = std::stof(attrs[3]); + + dist_ = std::uniform_real_distribution(min_, max_); + random_engine_ = framework::GetCPURandomEngine(seed_); + } + + float GetValue() override { return dist_(*random_engine_); } + + private: + float min_; + float max_; + + std::shared_ptr random_engine_; + std::uniform_real_distribution dist_; +}; + +class GaussianInitializer : public Initializer { + public: + explicit GaussianInitializer(const std::vector &attrs) { + name_ = attrs[0]; + seed_ = static_cast(std::stoi(attrs[1])); + mean_ = std::stof(attrs[2]); + std_ = std::stof(attrs[3]); + + random_engine_ = framework::GetCPURandomEngine(seed_); + + dist_ = std::normal_distribution(mean_, std_); + } + + float GetValue() override { return dist_(*random_engine_); } + + private: + float std_; + float mean_; + + std::shared_ptr random_engine_; + std::normal_distribution dist_; +}; + +class FillConstantInitializer : public Initializer { + public: + explicit FillConstantInitializer(const std::vector &attrs) { + name_ = attrs[0]; + value_ = std::stof(attrs[1]); + } + + float GetValue() override { return value_; } + + private: + float value_; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h new file mode 100644 index 0000000000000000000000000000000000000000..c0c424e7458939c0d6e579b6bd2e4501837d07ea --- /dev/null +++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h @@ -0,0 +1,264 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "paddle/fluid/distributed/common/utils.h" +#include "paddle/fluid/distributed/table/depends/initializers.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/rw_lock.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/port.h" +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { + +enum Mode { training, infer }; + +template +inline bool entry(const int count, const T threshold); + +template <> +inline bool entry(const int count, const std::string threshold) { + return true; +} + +template <> +inline bool entry(const int count, const int threshold) { + return count >= threshold; +} + +template <> +inline bool entry(const int count, const float threshold) { + UniformInitializer uniform = UniformInitializer({"0", "0", "1"}); + return uniform.GetValue() >= threshold; +} + +struct VALUE { + explicit VALUE(const std::vector &names) + : names_(names), count_(0), unseen_days_(0) { + values_.resize(names.size()); + for (int i = 0; i < static_cast(names.size()); i++) { + places[names[i]] = i; + } + } + + void set(std::vector> *values) { + values_ = std::move(*values); + } + + void set(const std::vector &names, + const std::vector> &values) { + for (int i = 0; i < static_cast(names.size()); i++) { + auto idx = places[names[i]]; + auto value = values[i]; + values_[idx].assign(value.begin(), value.end()); + } + } + + std::vector *> get() { + auto pts = std::vector *>(); + pts.reserve(values_.size()); + + for (auto &value : values_) { + pts.push_back(&value); + } + return pts; + } + + int fetch_count() { return ++count_; } + void reset_unseen_days() { unseen_days_ = 0; } + + void set_entry(bool is_entry) { is_entry_ = is_entry; } + + bool get_entry() { return is_entry_; } + + std::vector *> get(const std::vector names) { + auto pts = std::vector *>(); + pts.reserve(values_.size()); + + for (int i = 0; i < static_cast(names.size()); i++) { + pts.push_back(&(values_[places[names[i]]])); + } + return pts; + } + + std::vector names_; + int count_; + bool seen_after_last_save_; + int unseen_days_; + bool is_entry_; + std::vector> values_; + std::unordered_map places; +}; + +class ValueBlock { + public: + explicit ValueBlock( + const CommonAccessorParameter &common, + std::unordered_map *initializers) { + initializers_ = initializers; + int size = static_cast(common.params().size()); + + for (int x = 0; x < size; ++x) { + auto varname = common.params()[x]; + auto dim = common.dims()[x]; + value_names_.push_back(varname); + value_dims_.push_back(dim); + } + + // for Entry + { + // entry will add later + std::string entry_attr = "none"; + + if (entry_attr == "none") { + entry_func_ = + std::bind(entry, std::placeholders::_1, "none"); + } else { + auto slices = string::split_string(entry_attr, "&"); + if (slices[0] == "count_filter") { + int threshold = std::stoi(slices[1]); + entry_func_ = std::bind(entry, std::placeholders::_1, threshold); + } else if (slices[0] == "probability") { + float threshold = std::stof(slices[1]); + entry_func_ = + std::bind(entry, std::placeholders::_1, threshold); + } + } + } + } + + ~ValueBlock() {} + + void Init(const uint64_t &id, std::vector> *values, + int count) { + if (Has(id)) { + PADDLE_THROW(platform::errors::AlreadyExists("id already exist, error")); + } + + if (values->size() != value_names_.size()) { + PADDLE_THROW( + platform::errors::AlreadyExists("values can not match, error")); + } + + auto value = new VALUE(value_names_); + value->set(values); + value->seen_after_last_save_ = true; + value->count_ = count; + values_[id] = value; + } + + std::vector *> Get( + const uint64_t &id, const std::vector &value_names) { + auto ret_values = values_.at(id)->get(value_names); + return ret_values; + } + + std::vector *> Get(const uint64_t &id) { + auto ret_values = values_.at(id)->get(value_names_); + return ret_values; + } + + void InitFromInitializer(const uint64_t &id, + const std::vector &value_names) { + if (Has(id)) { + Update(id); + return; + } + + auto rets = std::vector>(); + rets.resize(value_names_.size()); + + for (int i = 0; i < static_cast(value_names_.size()); i++) { + auto name = value_names_[i]; + auto *init = initializers_->at(name); + + auto dim = value_dims_[i]; + rets[i].resize(dim); + + for (int j = 0; j < static_cast(dim); j++) { + rets[i][j] = init->GetValue(); + } + } + + Init(id, &rets, 0); + Update(id); + } + + bool GetEntry(const uint64_t &id) { + auto value = values_.at(id); + auto entry = value->get_entry(); + return entry; + } + + void Set(const uint64_t &id, const std::vector &value_names, + const std::vector> &values) { + auto value = values_.at(id); + value->set(value_names, values); + } + + void Update(const uint64_t id) { + auto *value = values_.at(id); + value->reset_unseen_days(); + auto count = value->fetch_count(); + + if (!value->get_entry()) { + value->set_entry(entry_func_(count)); + } + } + + private: + bool Has(const uint64_t id) { + auto got = values_.find(id); + if (got == values_.end()) { + return false; + } else { + return true; + } + } + + public: + std::unordered_map values_; + + private: + std::vector value_names_; + std::vector value_dims_; + std::function entry_func_; + std::unordered_map *initializers_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/depends/sparse.h b/paddle/fluid/distributed/table/depends/sparse.h new file mode 100644 index 0000000000000000000000000000000000000000..5d992a4c4f0f41b6e7d3ba9e22458294bd6e1e73 --- /dev/null +++ b/paddle/fluid/distributed/table/depends/sparse.h @@ -0,0 +1,210 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // for sqrt in CPU and CUDA +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/common/utils.h" +#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" + +namespace paddle { +namespace distributed { + +class SparseOptimizer { + public: + SparseOptimizer() {} + explicit SparseOptimizer(const CommonAccessorParameter& common) {} + virtual void update(const uint64_t* keys, const float* update_values, + size_t num, const std::vector& offsets, + ValueBlock* block) = 0; +}; + +// sum calc for sparse tensor +class SSUM : public SparseOptimizer { + public: + SSUM(){}; + explicit SSUM(const CommonAccessorParameter& common) { + auto& names = common.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "Param") { + param_idx = x; + update_numel = common.dims()[x]; + } + } + } + + void update(const uint64_t* keys, const float* update_values, size_t num, + const std::vector& offsets, + ValueBlock* block) override { + auto blas = GetBlas(); + for (auto x : offsets) { + auto id = keys[x]; + auto values = block->Get(id); + float* param = values[param_idx]->data(); + + std::vector delta; + delta.resize(update_numel); + blas.VCOPY(update_numel, update_values + x * update_numel, delta.data()); + blas.VADD(update_numel, delta.data(), param, param); + } + } + + int param_idx; + int update_numel; +}; + +// sgd optimzer for sparse tensor +class SSGD : public SparseOptimizer { + public: + SSGD(){}; + explicit SSGD(const CommonAccessorParameter& common) { + auto& names = common.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "LearningRate") { + learning_rate_idx = x; + } + if (names[x] == "Param") { + param_idx = x; + update_numel = common.dims()[x]; + } + } + } + + void update(const uint64_t* keys, const float* update_values, size_t num, + const std::vector& offsets, + ValueBlock* block) override { + auto blas = GetBlas(); + for (auto x : offsets) { + auto id = keys[x]; + auto values = block->Get(id); + float* learning_rate = values[learning_rate_idx]->data(); + float* param = values[param_idx]->data(); + + std::vector grads; + grads.resize(update_numel); + blas.VCOPY(update_numel, update_values + x * update_numel, grads.data()); + blas.SCAL(update_numel, learning_rate[0], grads.data()); + blas.VSUB(update_numel, param, grads.data(), param); + } + } + + int learning_rate_idx; + int param_idx; + int update_numel; +}; + +// adam optimzer for sparse tensor +class SAdam : public SparseOptimizer { + public: + SAdam() {} + explicit SAdam(const CommonAccessorParameter& common) { + auto& names = common.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "LearningRate") { + learning_rate_idx = x; + } + if (names[x] == "Param") { + param_idx = x; + update_numel = common.dims()[x]; + } + if (names[x] == "Moment1") { + moment1_idx = x; + } + if (names[x] == "Moment2") { + moment2_idx = x; + } + if (names[x] == "Beta1Pow") { + beta1_pow_idx = x; + } + if (names[x] == "Beta2Pow") { + beta2_pow_idx = x; + } + } + + // add attr later + beta1 = 0.9; + beta2 = 0.999; + epsilon = 1.0e-8; + } + + void update(const uint64_t* keys, const float* update_values, size_t num, + const std::vector& offsets, + ValueBlock* block) override { + auto blas = GetBlas(); + for (auto x : offsets) { + auto id = keys[x]; + auto values = block->Get(id); + float* learning_rate = values[learning_rate_idx]->data(); + float* param = values[param_idx]->data(); + float* moment1 = values[moment1_idx]->data(); + float* moment2 = values[moment2_idx]->data(); + float* beta1_pow = values[beta1_pow_idx]->data(); + float* beta2_pow = values[beta2_pow_idx]->data(); + + beta1_pow[0] = beta1_pow[0] * beta1; + beta2_pow[0] = beta2_pow[0] * beta2; + + float lr_ = learning_rate[0]; + lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]); + + std::vector grad, grad2, tmp; + grad.resize(update_numel); + grad2.resize(update_numel); + tmp.resize(update_numel); + + blas.VCOPY(update_numel, update_values + x * update_numel, grad.data()); + blas.VCOPY(update_numel, update_values + x * update_numel, grad2.data()); + + blas.SCAL(update_numel, 1 - beta1, grad.data()); + blas.VSQUARE(update_numel, grad2.data(), grad2.data()); + blas.SCAL(update_numel, 1 - beta2, grad2.data()); + + blas.SCAL(update_numel, beta1, moment1); + blas.VADD(update_numel, moment1, grad.data(), moment1); + blas.SCAL(update_numel, beta2, moment2); + blas.VADD(update_numel, moment2, grad2.data(), moment2); + + float* tmp_ = tmp.data(); + float eps_ = epsilon * sqrt(1 - beta2_pow[0]); + + SQRT(update_numel, moment2, tmp_); + ADD(update_numel, tmp_, eps_, tmp_); + + blas.VDIV(update_numel, moment1, tmp_, tmp_); + blas.SCAL(update_numel, lr_, tmp_); + blas.VSUB(update_numel, param, tmp_, param); + } + } + + int learning_rate_idx; + int param_idx; + int moment1_idx; + int moment2_idx; + int beta1_pow_idx; + int beta2_pow_idx; + float beta1; + float beta2; + float epsilon; + int update_numel; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/sparse_geo_table.cc b/paddle/fluid/distributed/table/sparse_geo_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b276e7de5c92d495f9d40535033b0a82186bc82 --- /dev/null +++ b/paddle/fluid/distributed/table/sparse_geo_table.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/sparse_geo_table.h" + +namespace paddle { +namespace distributed { + +int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id, + std::vector* values, + std::vector* ids) { + geo_recorder->GetAndClear(trainer_id, ids); + auto dim = _config.common().dims()[0]; + values->resize(ids->size() * dim); + CommonSparseTable::pull_sparse(values->data(), ids->data(), ids->size()); + return 0; +} + +int32_t SparseGeoTable::push_sparse(const uint64_t* keys, const float* values, + size_t num) { + std::vector ids; + ids.resize(num); + std::copy_n(keys, num, ids.begin()); + geo_recorder->Update(ids); + CommonSparseTable::push_sparse(keys, values, num); + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/sparse_geo_table.h b/paddle/fluid/distributed/table/sparse_geo_table.h new file mode 100644 index 0000000000000000000000000000000000000000..267d30a30fb7b939255d424434964c00b2af2f7b --- /dev/null +++ b/paddle/fluid/distributed/table/sparse_geo_table.h @@ -0,0 +1,62 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include "Eigen/Dense" +#include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/common_table.h" +#include "paddle/fluid/distributed/table/depends/geo_recorder.h" +#include "paddle/fluid/distributed/table/depends/initializers.h" +#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" +#include "paddle/fluid/distributed/table/depends/sparse.h" +#include "paddle/fluid/framework/rw_lock.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { + +class SparseGeoTable : public CommonSparseTable { + public: + explicit SparseGeoTable() : CommonSparseTable() { geo_recorder = nullptr; } + virtual ~SparseGeoTable() {} + + int32_t pull_geo_param(const uint32_t trainer_id, std::vector* values, + std::vector* keys); + + virtual int32_t push_sparse(const uint64_t* keys, const float* values, + size_t num) override; + + virtual int32_t initialize_recorder() { + if (!geo_recorder) { + auto trainers = _config.common().trainer_num(); + geo_recorder = std::make_shared(trainers); + } + return 0; + } + + private: + std::shared_ptr geo_recorder; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/table.cc b/paddle/fluid/distributed/table/table.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff241ee1066483117ad02af88d91bdcfe9d4d38e --- /dev/null +++ b/paddle/fluid/distributed/table/table.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/table.h" +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/distributed/common/registerer.h" + +#include "paddle/fluid/distributed/table/common_dense_table.h" +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/sparse_geo_table.h" +#include "paddle/fluid/distributed/table/tensor_accessor.h" +#include "paddle/fluid/distributed/table/tensor_table.h" + +namespace paddle { +namespace distributed { + +REGISTER_CLASS(Table, CommonDenseTable); +REGISTER_CLASS(Table, CommonSparseTable); +REGISTER_CLASS(Table, DenseTensorTable); +REGISTER_CLASS(Table, SparseGeoTable); +REGISTER_CLASS(Table, BarrierTable); + +REGISTER_CLASS(ValueAccessor, CommMergeAccessor); + +int32_t TableManager::initialize() { + static bool initialized = false; + if (initialized) { + return 0; + } + initialized = true; + return 0; +} + +int32_t Table::initialize(const TableParameter &config, + const FsClientParameter &fs_config) { + _config = config; + if (initialize_accessor() != 0) { + LOG(WARNING) << "Table accessor initialize failed"; + return -1; + } + return initialize(); +} + +int32_t Table::initialize_accessor() { + if (!_config.has_accessor() || !_config.accessor().has_accessor_class()) { + LOG(ERROR) << "missing accessor config in table, table_id:" + << _config.table_id(); + return -1; + } + auto *accessor = + CREATE_CLASS(ValueAccessor, + _config.accessor().accessor_class()) if (accessor == NULL) { + LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id() + << ", accessor_name:" << _config.accessor().accessor_class(); + return -1; + } + if (accessor->configure(_config.accessor()) || accessor->initialize() != 0) { + LOG(ERROR) << " accessor initialize failed, table_id:" << _config.table_id() + << ", accessor_name:" << _config.accessor().accessor_class(); + return -1; + } + _value_accesor.reset(accessor); + return 0; +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h new file mode 100644 index 0000000000000000000000000000000000000000..70d1211fe81c70c7e579f15e1445a6ba5acecf79 --- /dev/null +++ b/paddle/fluid/distributed/table/table.h @@ -0,0 +1,125 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include + +#include "paddle/fluid/distributed/table/accessor.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { +class Table { + public: + Table() {} + virtual ~Table() {} + virtual int32_t initialize(const TableParameter &config, + const FsClientParameter &fs_config) final; + + virtual int32_t pull_dense(float *values, size_t num) = 0; + virtual int32_t push_dense(const float *values, size_t num) = 0; + virtual int32_t push_dense_param(const float *values, size_t num) { + return 0; + } + + virtual int32_t pull_sparse(float *values, const uint64_t *keys, + size_t num) = 0; + virtual int32_t push_sparse(const uint64_t *keys, const float *values, + size_t num) = 0; + virtual int32_t push_sparse_param(const uint64_t *keys, const float *values, + size_t num) { + return 0; + } + + // only for sparse geo table + virtual int32_t pull_geo_param(const uint32_t trainer_id, + std::vector *values, + std::vector *keys) { + return 0; + } + + // only for barrier + virtual int32_t barrier(const uint32_t trainer_id, + const std::string barrier_type) { + return 0; + } + + // only for barrier table + virtual int32_t set_table_map( + std::unordered_map> *table_map) { + return 0; + } + + virtual int32_t pour() { return 0; } + + virtual void clear() = 0; + virtual int32_t flush() = 0; + virtual int32_t shrink() = 0; + + //指定加载路径 + virtual int32_t load(const std::string &path, + const std::string &converter) = 0; + //指定保存路径 + virtual int32_t save(const std::string &path, + const std::string &converter) = 0; + + virtual int32_t set_shard(size_t shard_idx, size_t shard_num) final { + _shard_idx = shard_idx; + _shard_num = shard_num; + return initialize_shard(); + } + + inline std::shared_ptr value_accesor() { + return _value_accesor; + } + + virtual void *get_shard(size_t shard_idx) = 0; + virtual std::pair print_table_stat() { return {0, 0}; } + + protected: + virtual int32_t initialize() = 0; + virtual int32_t initialize_accessor() final; + virtual int32_t initialize_shard() = 0; + virtual std::string table_dir(const std::string &model_dir) { + return paddle::string::format_string("%s/%03d/", model_dir.c_str(), + _config.table_id()); + } + + size_t _shard_idx; // table 分片编号 + size_t _shard_num; // table 分片总数 + TableParameter _config; + std::shared_ptr _value_accesor; +}; +REGISTER_REGISTERER(Table); + +class TableManager { + public: + static TableManager &instance() { + static TableManager manager; + return manager; + } + int32_t initialize(); + + private: + TableManager() {} + ~TableManager() {} +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/tensor_accessor.cc b/paddle/fluid/distributed/table/tensor_accessor.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1ece52c133a7169273d1a2f62da4d34a01cb029 --- /dev/null +++ b/paddle/fluid/distributed/table/tensor_accessor.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/tensor_accessor.h" +#include "Eigen/Dense" + +namespace paddle { +namespace distributed { + +int CommMergeAccessor::initialize() { return 0; } + +// value 维度 +size_t CommMergeAccessor::dim() { return 0; } + +// value 各个维度的size +size_t CommMergeAccessor::dim_size(size_t dim) { return 0; } + +// value 各维度相加总size +size_t CommMergeAccessor::size() { return 0; } + +// pull value 维度 +size_t CommMergeAccessor::select_dim() { return _config.embedx_dim(); } + +// pull value 各个维度的size +size_t CommMergeAccessor::select_dim_size(size_t dim) { return sizeof(float); } + +// pull value 各维度相加总size +size_t CommMergeAccessor::select_size() { return select_dim() * sizeof(float); } + +// push value 维度 +size_t CommMergeAccessor::update_dim() { return _config.embedx_dim(); } + +// push value 各个维度的size +size_t CommMergeAccessor::update_dim_size(size_t dim) { return sizeof(float); } + +// push value 各维度相加总size +size_t CommMergeAccessor::update_size() { return update_dim() * sizeof(float); } + +// 判断该value 是否进行shrink +bool CommMergeAccessor::shrink(float * /*value*/) { return false; } + +// 判断该value 是否在save阶段dump, +// param作为参数用于标识save阶段,如downpour的xbox与batch_model +bool CommMergeAccessor::save(float * /*value*/, int /*param*/) { return true; } + +// keys不存在时,为values生成随机值 +int32_t CommMergeAccessor::create(float **value, size_t num) { return 0; } + +// 从values中选取到select_values中 +int32_t CommMergeAccessor::select(float **select_values, const float **values, + size_t num) { + return 0; +} + +// 将update_values聚合到一起 +int32_t CommMergeAccessor::merge(float **update_values, + const float **other_update_values, + size_t num) { + Eigen::Map u_mat(update_values[0], 1, num); + Eigen::Map o_mat(other_update_values[0], 1, num); + u_mat += o_mat; + return 0; +} + +// 将update_values聚合到一起,通过it.next判定是否进入下一个key +// int32_t merge(float** update_values, iterator it); +// 将update_values更新应用到values中 +int32_t CommMergeAccessor::update(float **values, const float **update_values, + size_t num) { + return 0; +} + +int CommMergeAccessor::set_weight(float **values, const float **update_values, + size_t num) { + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/tensor_accessor.h b/paddle/fluid/distributed/table/tensor_accessor.h new file mode 100644 index 0000000000000000000000000000000000000000..12fb8a42d985981c602950645d7cdd1316b7a9cb --- /dev/null +++ b/paddle/fluid/distributed/table/tensor_accessor.h @@ -0,0 +1,78 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/distributed/common/registerer.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/table/accessor.h" + +namespace paddle { +namespace distributed { + +class CommMergeAccessor : public ValueAccessor { + public: + CommMergeAccessor() {} + virtual ~CommMergeAccessor() {} + virtual int initialize(); + // value维度 + virtual size_t dim(); + // value各个维度的size + virtual size_t dim_size(size_t dim); + // value各维度相加总size + virtual size_t size(); + // pull value维度 + virtual size_t select_dim(); + // pull value各个维度的size + virtual size_t select_dim_size(size_t dim); + // pull value各维度相加总size + virtual size_t select_size(); + // push value维度 + virtual size_t update_dim(); + // push value各个维度的size + virtual size_t update_dim_size(size_t dim); + // push value各维度相加总size + virtual size_t update_size(); + // 判断该value是否进行shrink + virtual bool shrink(float * /*value*/); + // 判断该value是否在save阶段dump, + // param作为参数用于标识save阶段,如downpour的xbox与batch_model + virtual bool save(float * /*value*/, int /*param*/); + + // keys不存在时,为values生成随机值 + virtual int32_t create(float **value, size_t num); + // 从values中选取到select_values中 + virtual int32_t select(float **select_values, const float **values, + size_t num); + // 将update_values聚合到一起 + virtual int32_t merge(float **update_values, + const float **other_update_values, size_t num); + // 将update_values聚合到一起,通过it.next判定是否进入下一个key + // virtual int32_t merge(float** update_values, iterator it); + // 将update_values更新应用到values中 + virtual int32_t update(float **values, const float **update_values, + size_t num); + + virtual int set_weight(float **values, const float **update_values, + size_t num); + virtual std::string parse_to_string(const float *value, int param) { + return ""; + } + virtual int parse_from_string(const std::string &str, float *v) { return 0; } +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/tensor_table.cc b/paddle/fluid/distributed/table/tensor_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8e1be7a9815c4aad21cd24733fd6747f3e0d56b --- /dev/null +++ b/paddle/fluid/distributed/table/tensor_table.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/table/tensor_table.h" +#include "paddle/fluid/distributed/common/utils.h" + +namespace paddle { +namespace distributed { + +int32_t DenseTensorTable::initialize() { + _shards_task_pool.resize(10); + for (int i = 0; i < _shards_task_pool.size(); ++i) { + _shards_task_pool[i].reset(new ::ThreadPool(1)); + } + return 0; +} + +int32_t DenseTensorTable::initialize_tensor(framework::Scope *scope, + framework::ProgramDesc *program, + framework::Executor *executor) { + scope_ = scope; + program_ = program; + executor_ = executor; + + auto tensor_config = _config.tensor(); + if (tensor_config.has_common_block_map()) { + auto block_maps = + paddle::string::split_string(tensor_config.common_block_map(), "#"); + for (auto &block_map : block_maps) { + auto block = paddle::string::split_string(block_map, ":"); + auto block_id = std::stoi(block[0]); + std::vector block_ids{block_id}; + auto block_cmd = block[1]; + auto prepared = executor_->Prepare(*program_, block_ids); + (*prepared_ctx_)[block_cmd] = prepared[0]; + } + } +} + +int32_t DenseTensorTable::pull_dense(float *values, size_t numel) { + PADDLE_ENFORCE_EQ(numel, _data.numel(), + paddle::platform::errors::PreconditionNotMet( + "pull dense error, excepted numel %d, but actually %d.", + _data.numel(), numel)); + + GetBlas().VCOPY(numel, _data.data(), values); + return 0; +} + +int32_t DenseTensorTable::push_dense(const float *values, size_t numel) { + auto varname = _config.tensor().grad(); + auto local_scope = scope_->NewTmpScope(); + auto *var = local_scope->Var(varname); + auto *t = var->GetMutable(); + auto dims = paddle::framework::make_ddim({}); + + auto ctx = paddle::platform::CPUDeviceContext(); + t->mutable_data(_data.dims(), ctx.GetPlace()); + + GetBlas().VCOPY(numel, values, t->data()); + executor_->RunPreparedContext((*prepared_ctx_)["push"].get(), + local_scope.get()); +} + +int32_t DenseTensorTable::push_dense_param(const float *values, size_t numel) { + auto ctx = paddle::platform::CPUDeviceContext(); + if (_data.IsInitialized()) { + PADDLE_ENFORCE_EQ( + numel, _data.numel(), + paddle::platform::errors::PreconditionNotMet( + "pull dense error, excepted numel %d, but actually %d.", + _data.numel(), numel)); + } else { + _data.mutable_data( + framework::make_ddim({static_cast(numel), 1}), ctx.GetPlace()); + } + + GetBlas().VCOPY(numel, values, _data.data()); + return 0; +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/tensor_table.h b/paddle/fluid/distributed/table/tensor_table.h new file mode 100644 index 0000000000000000000000000000000000000000..9744c931c472053926ce1b772b050be08d6b46f0 --- /dev/null +++ b/paddle/fluid/distributed/table/tensor_table.h @@ -0,0 +1,179 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace distributed { + +class TensorTable : public Table { + public: + TensorTable() : Table() {} + + virtual ~TensorTable() {} + + virtual int32_t initialize() { return 0; } + + virtual int32_t pull_dense(float *values, size_t num) override { return 0; }; + + virtual int32_t push_dense(const float *values, size_t num) override { + return 0; + }; + + virtual void *get_shard(size_t shard_idx) override { return 0; } + + virtual int32_t pull_sparse(float *values, const uint64_t *keys, + size_t num) override { + return 0; + }; + + virtual int32_t push_sparse(const uint64_t *keys, const float *values, + size_t num) override { + return 0; + }; + + virtual int32_t push_dense_param(const float *values, size_t num) { + return 0; + } + + virtual int32_t shrink() { return 0; } + + virtual void clear() {} + + virtual int32_t flush() { return 0; } + + //指定加载路径 + virtual int32_t load(const std::string &path, const std::string &converter) { + return 0; + } + //指定保存路径 + virtual int32_t save(const std::string &path, const std::string &converter) { + return 0; + } + + protected: + virtual int32_t initialize_shard() { return 0; } + + virtual int32_t initialize_tensor(paddle::framework::Scope *scope, + paddle::framework::ProgramDesc *program, + paddle::framework::Executor *executor) { + return 0; + } + + std::vector> _shards_task_pool; + + framework::Executor *executor_; + framework::Scope *scope_; + framework::ProgramDesc *program_; + std::unordered_map> + *prepared_ctx_; +}; + +class DenseTensorTable : public TensorTable { + public: + DenseTensorTable() : TensorTable() {} + ~DenseTensorTable() {} + virtual int32_t initialize(); + + void *get_shard(size_t shard_idx) { return 0; } + + int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) { + return 0; + } + int32_t push_sparse(const uint64_t *keys, const float *values, size_t num) { + return 0; + } + int32_t shrink() { return 0; } + + int32_t pull_dense(float *values, size_t num) override; + int32_t push_dense_param(const float *values, size_t num) override; + int32_t push_dense(const float *values, size_t num) override; + + virtual void clear() {} + virtual int32_t flush() { return 0; } + + //指定加载路径 + virtual int32_t load(const std::string &path, const std::string &converter) { + return 0; + } + //指定保存路径 + virtual int32_t save(const std::string &path, const std::string &converter) { + return 0; + } + + protected: + virtual int32_t initialize_shard() { return 0; } + + virtual int32_t initialize_tensor(paddle::framework::Scope *scope, + paddle::framework::ProgramDesc *program, + paddle::framework::Executor *executor); + + protected: + framework::Tensor _data; +}; +// +//// common sparse table [0, N) with out large scale +// class SparseTensorTable : public TensorTable { +// void *get_shard(size_t shard_idx) { return 0; } +// +// int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) +// override; +// int32_t push_sparse(const uint64_t *keys, const float *values, size_t num) +// override ; +// int32_t shrink() { return 0; } +// void *get_shard(size_t shard_idx) { return 0; }; +// +// int32_t pull_dense(float *values, size_t num) { return 0; }; +// int32_t push_dense_param(const float *values, size_t num) { return 0; }; +// int32_t push_dense(const float *values, size_t num) { return 0; }; +// +// protected: +// framework::Tensor _data; +//}; + +//// for Large scale kv tensor [0, int64] do not use specific optimizer +// class KvTensorTable : public TensorTable { +// int32_t pull_dense(float *values, size_t num) { return 0; }; +// int32_t push_dense_param(const float *values, size_t num) { return 0; }; +// int32_t push_dense(const float *values, size_t num) { return 0; }; +// +// void *get_shard(size_t shard_idx) override; +// int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) +// override; +// int32_t push_sparse(const uint64_t *keys, const float *values, +// size_t num) override; +// int32_t shrink() override; +// void *get_shard(size_t shard_idx) override; +//}; +// +//// for Geo sparse handle +// class GeoSparseTensorTable : public TensorTable {}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..405fe7561115e68b45d4f8a84e59e4f12faed18c --- /dev/null +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -0,0 +1,31 @@ +if(APPLE) + return() +endif() + +set_source_files_properties(table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(table_test SRCS table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(dense_table_test SRCS dense_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(sparse_table_test SRCS sparse_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(geo_table_test SRCS geo_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) + + +# open it until CI support brpc +return() + +set_source_files_properties(brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_service_dense_sgd_test SRCS brpc_service_dense_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_service_sparse_sgd_test SRCS brpc_service_sparse_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS}) diff --git a/paddle/fluid/distributed/test/barrier_table_test.cc b/paddle/fluid/distributed/test/barrier_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..12f6062c41c48f43459289c5fcad7e05acf458b7 --- /dev/null +++ b/paddle/fluid/distributed/test/barrier_table_test.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/table/common_table.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +TEST(BarrierTable, Barrier) { + int emb_dim = 10; + int trainers = 2; + bool sync = true; + + TableParameter table_config; + table_config.set_table_class("BarrierTable"); + FsClientParameter fs_config; + Table *table = new BarrierTable(); + TableAccessorParameter *accessor_config = table_config.mutable_accessor(); + accessor_config->set_accessor_class("CommMergeAccessor"); + CommonAccessorParameter *common_config = table_config.mutable_common(); + common_config->set_table_name("barrier_table"); + common_config->set_trainer_num(trainers); + common_config->set_sync(sync); + + auto ret = table->initialize(table_config, fs_config); + + std::unordered_map> maps = + std::unordered_map>(); + + table->set_table_map(&maps); + + std::shared_ptr<::ThreadPool> pool_ = + std::make_shared<::ThreadPool>(trainers); + std::vector> task_status; + + for (auto x = 0; x < trainers; x++) { + auto task = [table, x] { table->barrier(x, 0); }; + task_status.push_back(pool_->enqueue(std::move(task))); + } + + for (auto &status : task_status) { + status.wait(); + } + + ASSERT_EQ(ret, 0); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b2f808a2a82d558a6dabc85b57139f99d8ea389 --- /dev/null +++ b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc @@ -0,0 +1,272 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { + auto x_var = scope->Var("x"); + x_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0 * (float)i; +} + +void GetDownpourDenseTableProto( + ::paddle::distributed::TableParameter* dense_table_proto) { + dense_table_proto->set_table_id(0); + dense_table_proto->set_table_class("CommonDenseTable"); + dense_table_proto->set_shard_num(256); + dense_table_proto->set_type(::paddle::distributed::PS_DENSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + dense_table_proto->mutable_accessor(); + ::paddle::distributed::CommonAccessorParameter* common_proto = + dense_table_proto->mutable_common(); + + accessor_proto->set_accessor_class("CommMergeAccessor"); + accessor_proto->set_fea_dim(100); + accessor_proto->set_embedx_dim(1); + + common_proto->set_name("sgd"); + common_proto->set_table_name("MergedDense"); + common_proto->set_trainer_num(1); + common_proto->set_sync(false); + common_proto->add_params("Param"); + common_proto->add_dims(100); + common_proto->add_initializers("fill_constant&1.0"); + common_proto->add_params("LearningRate"); + common_proto->add_dims(1); + common_proto->add_initializers("fill_constant&1.0"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* dense_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(dense_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_dense_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(worker_dense_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_dense_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(server_dense_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1"; +uint32_t port_ = 4214; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_; + +std::shared_ptr worker_ptr_; + +void RunServer() { + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + LOG(INFO) << "RUN set_ps_servers"; + _ps_env.set_ps_servers(&host_sign_list_, 1); + pserver_ptr_ = std::shared_ptr( + paddle::distributed::PSServerFactory::create(server_proto)); + LOG(INFO) << "RUN configure"; + pserver_ptr_->configure(server_proto, _ps_env, 0); + LOG(INFO) << "RUN start"; + pserver_ptr_->start(ip_, port_); + LOG(INFO) << "End start"; +} + +void RunClient(std::map>& + dense_regions) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + LOG(INFO) << "Run set_ps_servers"; + _ps_env.set_ps_servers(&host_sign_list_, servers_); + LOG(INFO) << "Run Create PSClient"; + worker_ptr_ = std::shared_ptr( + paddle::distributed::PSClientFactory::create(worker_proto)); + LOG(INFO) << "Run configure"; + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); +} + +void RunBrpcPushDense() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // Srart Server + std::thread server_thread(RunServer); + sleep(1); + + // Start Client + LOG(INFO) << "Run InitTensorsOnClient"; + framework::Scope client_scope; + platform::CPUPlace place; + InitTensorsOnClient(&client_scope, &place, 100); + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + framework::Variable* var = client_scope.FindVar("x"); + framework::LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + + LOG(INFO) << "Run RunClient"; + RunClient(dense_regions); + + /*-----------------------Test Server Init----------------------------------*/ + LOG(INFO) << "Run pull_dense_param"; + float* temp = new float[tensor->numel()](); + std::vector temp_region; + paddle::distributed::Region temp_reg(temp, tensor->numel()); + temp_region.emplace_back(std::move(temp_reg)); + auto pull_status = + worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0); + pull_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(temp[idx], 1.0); + } + + /*-----------------------Test Push Param----------------------------------*/ + + LOG(INFO) << "Run push_dense_param"; + auto push_status = + worker_ptr_->push_dense_param(regions.data(), regions.size(), 0); + push_status.wait(); + + pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(w[idx], float(idx)); + } + + /*-----------------------Test Push Grad----------------------------------*/ + + paddle::distributed::DownpourBrpcClosure* closure = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + + LOG(INFO) << "Run pull_dense_grad"; + auto push_grad_status = + worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure); + push_grad_status.wait(); + + auto pull_update_status = + worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_update_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(w[idx], float(idx) - 1.0); + } + + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); + server_thread.join(); +} + +TEST(RunBrpcPushDense, Run) { RunBrpcPushDense(); } diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..224b9ba2fc780a217bbe4a007d624d0f7afcedf0 --- /dev/null +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -0,0 +1,285 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { + auto x_var = scope->Var("x"); + x_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; +} + +void GetDownpourSparseTableProto( + ::paddle::distributed::TableParameter* sparse_table_proto) { + sparse_table_proto->set_table_id(0); + sparse_table_proto->set_table_class("CommonSparseTable"); + sparse_table_proto->set_shard_num(256); + sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->mutable_accessor(); + ::paddle::distributed::CommonAccessorParameter* common_proto = + sparse_table_proto->mutable_common(); + + accessor_proto->set_accessor_class("CommMergeAccessor"); + accessor_proto->set_fea_dim(0); + accessor_proto->set_embedx_dim(10); + + common_proto->set_name("sgd"); + common_proto->set_table_name("MergedDense"); + common_proto->set_trainer_num(1); + common_proto->set_sync(false); + common_proto->add_params("Param"); + common_proto->add_dims(10); + common_proto->add_initializers("uniform_random&0&-1.0&1.0"); + common_proto->add_params("LearningRate"); + common_proto->add_dims(1); + common_proto->add_initializers("fill_constant&1.0"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(sparse_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_sparse_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(worker_sparse_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(server_sparse_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1"; +uint32_t port_ = 4209; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_; + +std::shared_ptr worker_ptr_; + +void RunServer() { + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, 1); + pserver_ptr_ = std::shared_ptr( + paddle::distributed::PSServerFactory::create(server_proto)); + pserver_ptr_->configure(server_proto, _ps_env, 0); + pserver_ptr_->start(ip_, port_); +} + +void RunClient(std::map>& + dense_regions) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, servers_); + worker_ptr_ = std::shared_ptr( + paddle::distributed::PSClientFactory::create(worker_proto)); + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); +} + +void RunBrpcPushSparse() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // Srart Server + std::thread server_thread(RunServer); + sleep(1); + + // Start Client + framework::Scope client_scope; + platform::CPUPlace place; + InitTensorsOnClient(&client_scope, &place, 100); + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + framework::Variable* var = client_scope.FindVar("x"); + framework::LoDTensor* tensor = var->GetMutable(); + + RunClient(dense_regions); + std::vector fea_keys(10); + std::vector fea_values(100); + std::vector fea_temp_values(100); + std::vector fea_value_ptr(10); + std::vector fea_temp_value_ptr(10); + + for (size_t idx = 0; idx < fea_keys.size(); ++idx) { + fea_keys[idx] = (uint64_t)idx; + fea_value_ptr[idx] = fea_values.data() + idx * 10; + fea_temp_value_ptr[idx] = fea_temp_values.data() + idx * 10; + } + + /*-----------------------Test Server Init----------------------------------*/ + LOG(INFO) << "Run pull_sparse_param"; + auto pull_status = worker_ptr_->pull_sparse(fea_value_ptr.data(), 0, + fea_keys.data(), fea_keys.size()); + pull_status.wait(); + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + fea_values.data()[idx] *= 2.0; + } + + /*-----------------------Test Push Param----------------------------------*/ + + LOG(INFO) << "Run push_sparse_param"; + paddle::distributed::DownpourBrpcClosure* closure_push_param = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_SPARSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto push_status = worker_ptr_->push_sparse_param( + 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), + closure_push_param); + push_status.wait(); + + auto pull_param_status = worker_ptr_->pull_sparse( + fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size()); + pull_param_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx]); + } + + /*-----------------------Test Push Grad----------------------------------*/ + + paddle::distributed::DownpourBrpcClosure* closure_push_grad = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + + LOG(INFO) << "Run pull_sparse_grad"; + std::vector push_g_vec; + for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * 10); + } + auto push_grad_status = worker_ptr_->push_sparse_raw_gradient( + 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), + closure_push_grad); + push_grad_status.wait(); + + auto pull_update_status = worker_ptr_->pull_sparse( + fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size()); + pull_update_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx] - 1.0); + } + + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); + server_thread.join(); +} + +TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } diff --git a/paddle/fluid/distributed/test/brpc_utils_test.cc b/paddle/fluid/distributed/test/brpc_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce33cbe6ea39713589eb8f201d95bb2d99e5ff0c --- /dev/null +++ b/paddle/fluid/distributed/test/brpc_utils_test.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void CreateVarsOnScope(framework::Scope* scope, platform::Place* place, + const platform::DeviceContext& ctx) { + // var 1 + framework::Variable* var1 = scope->Var("x1"); + auto* tensor1 = var1->GetMutable(); + tensor1->Resize(framework::make_ddim({512, 8, 4, 2})); + framework::LoD lod1; + lod1.push_back(framework::Vector({1, 3, 8})); + tensor1->set_lod(lod1); + tensor1->mutable_data(*place); + math::set_constant(ctx, tensor1, 31.9); + + // var 2 + framework::Variable* var2 = scope->Var("x2"); + auto* tensor2 = var2->GetMutable(); + tensor2->Resize(framework::make_ddim({1000, 64})); + framework::LoD lod2; + lod2.push_back(framework::Vector({1, 1})); + tensor2->set_lod(lod2); + tensor2->mutable_data(*place); + math::set_constant(ctx, tensor2, 100); + + // var 3 + framework::Variable* var3 = scope->Var("x3"); + auto* slr = var3->GetMutable(); + slr->set_height(564); + auto* tensor3 = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + tensor3->Resize(framework::make_ddim({564, 128})); + tensor3->mutable_data(*place); + math::set_constant(ctx, tensor3, 32.7); + for (int i = 0; i < 564; ++i) rows->push_back(i); +} + +void RunMultiVarMsg(platform::Place place) { + framework::Scope scope; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + CreateVarsOnScope(&scope, &place, ctx); + + ::paddle::MultiVariableMessage multi_msg; + std::string message_name("se_de_test"); + std::vector send_var_name = {"x1", "x2", "x3"}; + std::vector recv_var_name = {}; + LOG(INFO) << "begin SerializeToMultiVarMsg"; + + butil::IOBuf io_buf; + distributed::SerializeToMultiVarMsgAndIOBuf(message_name, send_var_name, + recv_var_name, ctx, &scope, + &multi_msg, &io_buf); + EXPECT_GT(multi_msg.ByteSizeLong(), static_cast(0)); + + // deserialize + framework::Scope scope_recv; + LOG(INFO) << "begin DeserializeFromMultiVarMsg"; + distributed::DeserializeFromMultiVarMsgAndIOBuf(multi_msg, &io_buf, ctx, + &scope_recv); + + // check var1 + framework::Variable* var1 = scope_recv.FindVar("x1"); + auto* tensor1 = var1->GetMutable(); + EXPECT_EQ(tensor1->dims(), framework::make_ddim({512, 8, 4, 2})); + // EXPECT_EQ(tensor1->lod(), framework::Vector({1, 3, 8})); + auto* tensor_data1 = const_cast(tensor1->data()); + int tensor_numel1 = 512 * 8 * 4 * 2; + for (int i = 0; i < tensor_numel1; ++i) + EXPECT_FLOAT_EQ(tensor_data1[i], 31.9); + + // check var2 + framework::Variable* var2 = scope_recv.FindVar("x2"); + auto* tensor2 = var2->GetMutable(); + EXPECT_EQ(tensor2->dims(), framework::make_ddim({1000, 64})); + // EXPECT_EQ(tensor2->lod(), framework::Vector({1, 1})); + auto* tensor_data2 = const_cast(tensor2->data()); + int tensor_numel2 = 1000 * 64; + for (int i = 0; i < tensor_numel2; ++i) EXPECT_EQ(tensor_data2[i], 100); + + // check var3 + framework::Variable* var3 = scope_recv.FindVar("x3"); + auto* slr = var3->GetMutable(); + EXPECT_EQ(slr->rows().size(), 564); + for (int i = 0; i < 564; ++i) { + EXPECT_EQ(slr->rows()[i], i); + } + + auto* tensor3 = slr->mutable_value(); + EXPECT_EQ(tensor3->dims(), framework::make_ddim({564, 128})); + auto* tensor_data3 = const_cast(tensor3->data()); + int tensor_numel3 = 564 * 128; + for (int i = 0; i < tensor_numel3; ++i) + EXPECT_FLOAT_EQ(tensor_data3[i], 32.7); +} + +TEST(MultiVarMsgCPU, Run) { + platform::CPUPlace place; + RunMultiVarMsg(place); +} + +// #ifdef PADDLE_WITH_CUDA +// TEST(MultiVarMsgGPU, Run) { +// platform::CUDAPlace place; +// RunMultiVarMsg(place); +// } +// #endif \ No newline at end of file diff --git a/paddle/fluid/distributed/test/dense_table_test.cc b/paddle/fluid/distributed/test/dense_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..75f9df168961fa337d6f405575aeb078a4c8ee6b --- /dev/null +++ b/paddle/fluid/distributed/test/dense_table_test.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/table/common_dense_table.h" +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/sparse_geo_table.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +// CommonDenseTable + Adam +TEST(CommonDenseTable, Adam) { + int fea_dim = 10; + int trainers = 2; + float beta1 = 0.9; + float beta2 = 0.999; + float epsilon = 1.0e-8; + + TableParameter table_config; + table_config.set_table_class("CommonDenseTable"); + FsClientParameter fs_config; + Table *table = new CommonDenseTable(); + TableAccessorParameter *accessor_config = table_config.mutable_accessor(); + accessor_config->set_accessor_class("CommMergeAccessor"); + CommonAccessorParameter *common_config = table_config.mutable_common(); + // set adam optimize config + common_config->set_name("adam"); + common_config->set_table_name("adam_test_table"); + common_config->set_trainer_num(trainers); + common_config->add_params("Param"); + common_config->add_dims(fea_dim); + common_config->add_initializers("gaussian_random&0&0.0&1.0"); + common_config->add_params("LearningRate"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); + common_config->add_params("Moment1"); + common_config->add_dims(fea_dim); + common_config->add_initializers("fill_constant&0.0"); + common_config->add_params("Moment2"); + common_config->add_dims(fea_dim); + common_config->add_initializers("fill_constant&0.0"); + common_config->add_params("Beta1Pow"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); + common_config->add_params("Beta2Pow"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); + auto ret = table->initialize(table_config, fs_config); + ASSERT_EQ(ret, 0); + + // pull parameters for create and check + std::vector init_values; + init_values.resize(fea_dim); + table->pull_dense(init_values.data(), fea_dim); + + // push gradient + std::vector> trainer_gradient_values; + trainer_gradient_values.resize(trainers); + float start = 10.0; + for (int i = 0; i < trainers; i++) { + for (int k = 0; k < fea_dim; k++) { + trainer_gradient_values[i].push_back(start); + start += 0.1; + } + } + + // for adam + for (int i = 0; i < trainers; i++) { + auto &push_values = trainer_gradient_values[i]; + table->push_dense(push_values.data(), push_values.size()); + } + + std::vector pull_values; + pull_values.resize(fea_dim); + table->pull_dense(pull_values.data(), fea_dim); + + std::vector beta1_pow, beta2_pow, lr, mom1, mom2, param; + beta1_pow.push_back(beta1); + beta2_pow.push_back(beta2); + lr.push_back(1.0); + for (int i = 0; i < fea_dim; i++) { + mom1.push_back(0.0); + mom2.push_back(0.0); + param.push_back(init_values[i]); + } + + for (int i = 0; i < trainers; i++) { + auto lr_ = lr[0] * sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]); + for (int j = 0; j < fea_dim; j++) { + mom1[j] = beta1 * mom1[j] + (1 - beta1) * trainer_gradient_values[i][j]; + mom2[j] = beta2 * mom2[j] + + (1 - beta2) * trainer_gradient_values[i][j] * + trainer_gradient_values[i][j]; + param[j] = + param[j] - + lr_ * (mom1[j] / (sqrt(mom2[j]) + epsilon * sqrt(1 - beta2_pow[0]))); + } + beta1_pow[0] *= beta1; + beta2_pow[0] *= beta2; + } + for (int j = 0; j < fea_dim; j++) { + ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-6); + } +} + +// CommonDenseTable + Adam +TEST(CommonDenseTable, SGD) { + int fea_dim = 10; + int trainers = 2; + + TableParameter table_config; + table_config.set_table_class("CommonDenseTable"); + FsClientParameter fs_config; + Table *table = new CommonDenseTable(); + TableAccessorParameter *accessor_config = table_config.mutable_accessor(); + accessor_config->set_accessor_class("CommMergeAccessor"); + CommonAccessorParameter *common_config = table_config.mutable_common(); + common_config->set_name("sgd"); + common_config->set_table_name("sgd_test_table"); + common_config->set_trainer_num(trainers); + common_config->add_params("Param"); + common_config->add_dims(fea_dim); + common_config->add_initializers("gaussian_random&0&0.0&1.0"); + common_config->add_params("LearningRate"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); + auto ret = table->initialize(table_config, fs_config); + ASSERT_EQ(ret, 0); + + // pull parameters for create and check + std::vector init_values; + init_values.resize(fea_dim); + table->pull_dense(init_values.data(), fea_dim); + + std::vector total_gradients; + total_gradients.resize(fea_dim); + memset(total_gradients.data(), 0, sizeof(float) * total_gradients.size()); + // push gradient + std::vector> trainer_gradient_values; + trainer_gradient_values.resize(trainers); + float start = 10.0; + for (int i = 0; i < trainers; i++) { + for (int k = 0; k < fea_dim; k++) { + trainer_gradient_values[i].push_back(start); + total_gradients[k] += start; + start += 0.1; + } + } + + std::shared_ptr<::ThreadPool> pool_ = + std::make_shared<::ThreadPool>(trainers); + std::vector> task_status; + for (int i = 0; i < trainers; i++) { + auto &push_values = trainer_gradient_values[i]; + auto task = [table, &push_values] { + table->push_dense(push_values.data(), push_values.size()); + }; + task_status.push_back(pool_->enqueue(std::move(task))); + } + for (auto &status : task_status) { + status.wait(); + } + + std::vector pull_values; + pull_values.resize(fea_dim); + table->pull_dense(pull_values.data(), fea_dim); + for (int j = 0; j < fea_dim; j++) { + auto update_val = init_values[j] - 1.0 * total_gradients[j]; + ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5); + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/geo_table_test.cc b/paddle/fluid/distributed/test/geo_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ec1e87dcb693827cbf1aa2024706b8fcd0b7dc9 --- /dev/null +++ b/paddle/fluid/distributed/test/geo_table_test.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/table/common_dense_table.h" +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/sparse_geo_table.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +// SparseGeoTable + SSUM +TEST(SparseGeoTable, SSUM) { + int emb_dim = 10; + int trainers = 2; + + TableParameter table_config; + table_config.set_table_class("SparseGeoTable"); + FsClientParameter fs_config; + Table *table = new SparseGeoTable(); + TableAccessorParameter *accessor_config = table_config.mutable_accessor(); + accessor_config->set_accessor_class("CommMergeAccessor"); + CommonAccessorParameter *common_config = table_config.mutable_common(); + common_config->set_name("sum"); + common_config->set_table_name("ssum_test_table"); + common_config->set_trainer_num(trainers); + common_config->add_params("Param"); + common_config->add_dims(emb_dim); + common_config->add_initializers("fill_constant&1.0"); + + auto ret = table->initialize(table_config, fs_config); + ASSERT_EQ(ret, 0); + + // test push_sparse_param, and create params + std::vector init_keys = {0, 1, 2, 3, 4}; + std::vector init_values; + for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { + init_values.push_back(0.0); + } + table->push_sparse_param(init_keys.data(), init_values.data(), + init_keys.size()); + std::vector pull_values(init_values.size()); + table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size()); + for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { + ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-6); + } + + std::vector> trainer_keys; + std::vector> trainer_values; + trainer_keys.resize(trainers); + trainer_values.resize(trainers); + float start = 0.0; + for (int i = 0; i < trainers; i++) { + trainer_keys[i] = init_keys; + for (size_t j = 0; j < trainer_keys[i].size(); j++) { + auto id = trainer_keys[i][j]; + for (int k = 0; k < emb_dim; k++) { + trainer_values[i].push_back(start); + pull_values[id * emb_dim + k] += start; + start += 0.1; + } + } + } + + std::shared_ptr<::ThreadPool> pool_ = + std::make_shared<::ThreadPool>(trainers); + std::vector> task_status; + for (int i = 0; i < trainers; i++) { + auto &push_keys = trainer_keys[i]; + auto &push_values = trainer_values[i]; + auto task = [table, &push_keys, &push_values] { + table->push_sparse(push_keys.data(), push_values.data(), + push_keys.size()); + }; + task_status.push_back(pool_->enqueue(std::move(task))); + } + for (auto &status : task_status) { + status.wait(); + } + + std::vector> geo_pull_ids; + std::vector> geo_pull_values; + geo_pull_ids.resize(trainers); + geo_pull_values.resize(trainers); + for (int i = 0; i < trainers; i++) { + table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]); + ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); + for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { + auto id = geo_pull_ids[i][j]; + for (int k = 0; k < emb_dim; k++) { + ASSERT_TRUE(abs(geo_pull_values[i][j * emb_dim + k] - + pull_values[id * emb_dim + k]) < 1e-5); + } + } + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/sparse_table_test.cc b/paddle/fluid/distributed/test/sparse_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6db95c5fac211b94db726ee77c9122a8824c2351 --- /dev/null +++ b/paddle/fluid/distributed/test/sparse_table_test.cc @@ -0,0 +1,213 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/table/common_dense_table.h" +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/sparse_geo_table.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +// CommonSparseTable + SSGD +TEST(CommonSparseTable, SGD) { + int emb_dim = 10; + int trainers = 2; + + TableParameter table_config; + table_config.set_table_class("CommonSparseTable"); + FsClientParameter fs_config; + Table *table = new CommonSparseTable(); + TableAccessorParameter *accessor_config = table_config.mutable_accessor(); + accessor_config->set_accessor_class("CommMergeAccessor"); + CommonAccessorParameter *common_config = table_config.mutable_common(); + common_config->set_name("sgd"); + common_config->set_table_name("sgd_test_table"); + common_config->set_trainer_num(trainers); + common_config->add_params("Param"); + common_config->add_dims(emb_dim); + common_config->add_initializers("uniform_random&0&-1.0&1.0"); // param + common_config->add_params("LearningRate"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); // learning_rate + auto ret = table->initialize(table_config, fs_config); + ASSERT_EQ(ret, 0); + + // pull parameters for create and check + std::vector init_keys = {0, 1, 2, 3, 4}; + std::vector init_values; + init_values.resize(init_keys.size() * emb_dim); + table->pull_sparse(init_values.data(), init_keys.data(), init_keys.size()); + + // for check + std::vector total_gradients; + total_gradients.resize(init_keys.size() * emb_dim); + memset(total_gradients.data(), 0, sizeof(float) * total_gradients.size()); + + // push gradient + std::vector> trainer_keys; + std::vector> trainer_gradient_values; + trainer_keys.resize(trainers); + trainer_gradient_values.resize(trainers); + float start = 0.0; + for (int i = 0; i < trainers; i++) { + trainer_keys[i] = init_keys; + for (size_t j = 0; j < trainer_keys[i].size(); j++) { + auto id = trainer_keys[i][j]; + for (int k = 0; k < emb_dim; k++) { + trainer_gradient_values[i].push_back(start); + total_gradients[id * emb_dim + k] += start; + start += 0.1; + } + } + } + + std::shared_ptr<::ThreadPool> pool_ = + std::make_shared<::ThreadPool>(trainers); + std::vector> task_status; + for (int i = 0; i < trainers; i++) { + auto &push_keys = trainer_keys[i]; + auto &push_values = trainer_gradient_values[i]; + auto task = [table, &push_keys, &push_values] { + table->push_sparse(push_keys.data(), push_values.data(), + push_keys.size()); + }; + task_status.push_back(pool_->enqueue(std::move(task))); + } + for (auto &status : task_status) { + status.wait(); + } + + std::vector pull_values; + pull_values.resize(init_keys.size() * emb_dim); + table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size()); + for (size_t i = 0; i < init_values.size(); ++i) { + auto update_val = init_values[i] - 1.0 * total_gradients[i]; + ASSERT_TRUE(abs(update_val - pull_values[i]) < 1e-5); + } +} + +// CommonSparseTable + Adam +TEST(CommonSparseTable, Adam) { + int emb_dim = 10; + int trainers = 2; + float beta1 = 0.9; + float beta2 = 0.999; + float epsilon = 1.0e-8; + + TableParameter table_config; + table_config.set_table_class("CommonSparseTable"); + FsClientParameter fs_config; + Table *table = new CommonSparseTable(); + TableAccessorParameter *accessor_config = table_config.mutable_accessor(); + accessor_config->set_accessor_class("CommMergeAccessor"); + CommonAccessorParameter *common_config = table_config.mutable_common(); + common_config->set_name("adam"); + common_config->set_table_name("adam_test_table"); + common_config->set_trainer_num(trainers); + common_config->add_params("Param"); + common_config->add_dims(emb_dim); + common_config->add_initializers("uniform_random&0&-1.0&1.0"); + common_config->add_params("LearningRate"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); + common_config->add_params("Moment1"); + common_config->add_dims(emb_dim); + common_config->add_initializers("fill_constant&0.0"); + common_config->add_params("Moment2"); + common_config->add_dims(emb_dim); + common_config->add_initializers("fill_constant&0.0"); + common_config->add_params("Beta1Pow"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); + common_config->add_params("Beta2Pow"); + common_config->add_dims(1); + common_config->add_initializers("fill_constant&1.0"); + auto ret = table->initialize(table_config, fs_config); + ASSERT_EQ(ret, 0); + + // pull parameters for create and check + std::vector init_keys = {0, 1, 2, 3, 4}; + std::vector init_values; + init_values.resize(init_keys.size() * emb_dim); + table->pull_sparse(init_values.data(), init_keys.data(), init_keys.size()); + + // push gradient + std::vector> trainer_keys; + std::vector> trainer_gradient_values; + trainer_keys.resize(trainers); + trainer_gradient_values.resize(trainers); + float start = 0.0; + for (int i = 0; i < trainers; i++) { + trainer_keys[i] = init_keys; + for (size_t j = 0; j < trainer_keys[i].size(); j++) { + for (int k = 0; k < emb_dim; k++) { + trainer_gradient_values[i].push_back(start); + start += 0.1; + } + } + } + + for (int i = 0; i < trainers; i++) { + auto &push_keys = trainer_keys[i]; + auto &push_values = trainer_gradient_values[i]; + table->push_sparse(push_keys.data(), push_values.data(), push_keys.size()); + } + + std::vector pull_values; + pull_values.resize(init_keys.size() * emb_dim); + table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size()); + + for (size_t idx = 0; idx < init_keys.size(); idx += emb_dim) { + std::vector beta1_pow, beta2_pow, lr, mom1, mom2, param; + beta1_pow.push_back(beta1); + beta2_pow.push_back(beta2); + lr.push_back(1.0); + for (int i = 0; i < emb_dim; i++) { + mom1.push_back(0.0); + mom2.push_back(0.0); + param.push_back(init_values[idx + i]); + } + for (int i = 0; i < trainers; i++) { + auto lr_ = lr[0] * sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]); + for (int j = 0; j < emb_dim; j++) { + mom1[j] = + beta1 * mom1[j] + (1 - beta1) * trainer_gradient_values[i][idx + j]; + mom2[j] = beta2 * mom2[j] + + (1 - beta2) * trainer_gradient_values[i][idx + j] * + trainer_gradient_values[i][idx + j]; + param[j] = param[j] - + lr_ * (mom1[j] / + (sqrt(mom2[j]) + epsilon * sqrt(1 - beta2_pow[0]))); + } + beta1_pow[0] *= beta1; + beta2_pow[0] *= beta2; + } + for (int i = 0; i < emb_dim; i++) { + ASSERT_TRUE(abs(param[i] - pull_values[idx + i]) < 1e-5); + } + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/table_test.cc b/paddle/fluid/distributed/test/table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..98d52c268d77be179bf68349a2b8db702b124416 --- /dev/null +++ b/paddle/fluid/distributed/test/table_test.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/table/common_dense_table.h" +#include "paddle/fluid/distributed/table/common_sparse_table.h" +#include "paddle/fluid/distributed/table/sparse_geo_table.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +TEST(Table, Initialize) { + TableParameter table_config; + table_config.set_table_class("SparseGeoTable"); + FsClientParameter fs_config; + // case 1. no accessor + Table *table = new SparseGeoTable(); + auto ret = table->initialize(table_config, fs_config); + ASSERT_EQ(ret, -1); +} +} // namespace distributed +} // // namespace paddle