diff --git a/mindspore/ccsrc/parallel/ps/parameter_server.h b/mindspore/ccsrc/parallel/ps/parameter_server.h new file mode 100755 index 0000000000000000000000000000000000000000..4d3aa413060238f1621ee5450c2d01584d89052d --- /dev/null +++ b/mindspore/ccsrc/parallel/ps/parameter_server.h @@ -0,0 +1,559 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "session/session_basic.h" +#include "session/kernel_graph.h" +#include "session/anf_runtime_algorithm.h" +#include "session/session_factory.h" +#include "parallel/ps/common.h" +#include "parallel/ps/optimizer_info.h" +#include "parallel/ps/optimizer_info_builder.h" +#include "parallel/ps/util.h" +#include "device/cpu/kernel_select_cpu.h" +#include "utils/context/ms_context.h" +#include "kernel/kernel.h" +#include "kernel/ps/pserver_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" +#include "kernel/ps/sparse_apply_adam_ps_kernel.h" +#include "kernel/ps/sparse_apply_ftrl_ps_kernel.h" +#include "kernel/ps/apply_momentum_ps_kernel.h" +#include "kernel/ps/embedding_look_up_ps_kernel.h" + +namespace mindspore { +namespace parallel { +namespace ps { +using mindspore::kernel::ps::PServerKernel; +template +class ParameterServer { + public: + static ParameterServer &GetInstance() { + static ParameterServer instance; + return instance; + } + + void Run(const FuncGraphPtr &func_graph); + + private: + ParameterServer() + : pserver_num_(0), + worker_num_(0), + rank_id_(0), + grad_accum_count_(0), + ps_(new ::ps::KVServer(0)), + handler_(nullptr), + func_graph_(nullptr), + kernel_graph_(nullptr), + sess_(nullptr), + thread_(nullptr) {} + ~ParameterServer() = default; + ParameterServer(const ParameterServer &) = delete; + ParameterServer &operator=(const ParameterServer &) = delete; + + struct ServerHandler { + explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); + void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data); + void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleInitWeights(const ::ps::KVPairs &req_data); + void HandleInitWeightToOptimId(const ::ps::KVPairs &req_data); + void HandleInitInputsShape(const ::ps::KVPairs &req_data); + void HandleInitEmbeddings(const ::ps::KVPairs &req_data); + void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + ParameterServer *ps_; + }; + + bool Init(const FuncGraphPtr &func_graph); + void InitOptimInfoBuilders(); + void InitWeightKeyToOptims(const Key &key, const int &optim_id); + void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); + void InitWeight(const Key &key, const WeightPtr &weight); + void InitGrad(const Key &key, const GradPtr &grad); + void InitEmbeddingTable(const Key &key, + const std::shared_ptr>>> &shapes); + void UpdateWeights(); + void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); + WeightPtr weight(const Key &key); + void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); + int SumOfShapes(const std::vector &shapes) const; + size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); + bool ReadyForUpdateWeights(); + bool ReadyForAccumGrads(); + void ResetGradAccumCount(); + + size_t pserver_num_; + size_t worker_num_; + size_t rank_id_; + size_t grad_accum_count_; + std::unique_ptr<::ps::KVServer> ps_; + std::unique_ptr handler_; + FuncGraphPtr func_graph_; + std::shared_ptr kernel_graph_; + std::shared_ptr sess_; + + std::unordered_map> optimizers_; + std::unordered_map optim_inputs_shape_; + std::unordered_map> optim_infos_; + std::unordered_map> optim_info_builders_; + std::unordered_map weight_key_to_optims_; + std::unordered_map weights_; + std::unordered_map grads_; + std::unordered_map grads_accum_counter_; + // std::unordered_map embeddings_; + std::unordered_map> embedding_lookup_ops_; + std::unordered_map embedding_row_lens_; + + T learning_rate_; + T momentum_; + + std::mutex mutex_; + std::condition_variable apply_grads_cv_; + std::condition_variable accum_grads_cv_; + + std::unique_ptr thread_; + + friend struct ServerHandler; +}; + +class FuncGraph; +template +void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVServer *server) { + ::ps::KVPairs res; + if (req_meta.cmd == kInitWeightsCmd) { + MS_LOG(ERROR) << "handle init weights cmd" << std::endl; + HandleInitWeights(req_data); + } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { + MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; + HandleInitWeightToOptimId(req_data); + } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { + MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; + HandleInitInputsShape(req_data); + } else if (req_meta.cmd == kInitEmbeddingsCmd) { + MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; + HandleInitEmbeddings(req_data); + } else if (req_meta.cmd == kEmbeddingLookupCmd) { + MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; + HandleEmbeddingLookup(req_meta, req_data, &res); + } else if (req_meta.push) { + MS_LOG(ERROR) << "handle push req cmd" << std::endl; + HandlePushReq(req_meta, req_data); + } else { + MS_LOG(ERROR) << "handle pull req cmd" << std::endl; + HandlePullReq(req_meta, req_data, &res); + } + server->Response(req_meta, res); +} + +template +void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data) { + ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); +} + +template +void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + res->keys = req_data.keys; + ::ps::Key key = req_data.keys[0]; + res->vals = *(ps_->weight(key)); +} + +template +void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs &req_data) { + size_t key_num = req_data.keys.size(); + T *data_ptr = req_data.vals.data(); + size_t pos = 0; + for (size_t i = 0; i < key_num; i++) { + Key key = req_data.keys[i]; + size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; + + WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); + weight_ptr->CopyFrom(data_ptr + pos, data_len); + ps_->InitWeight(key, weight_ptr); + + GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); + ps_->InitGrad(key, grad_ptr); + pos += data_len; + } +} + +template +void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs &req_data) { + size_t key_num = req_data.keys.size(); + for (size_t i = 0; i < key_num; i++) { + Key key = req_data.keys[i]; + T val = req_data.vals[i]; + ps_->InitWeightKeyToOptims(key, val); + } +} + +template +void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs &req_data) { + ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); +} + +template +void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs &req_data) { + std::shared_ptr>>> shapes = + std::make_shared>>>(); + std::shared_ptr> input_shape = std::make_shared>(); + std::shared_ptr> indices_shape = std::make_shared>(); + std::shared_ptr> output_shape = std::make_shared>(); + shapes->push_back(input_shape); + shapes->push_back(indices_shape); + shapes->push_back(output_shape); + + const Key &key = req_data.keys[0]; + const Lengths &lens = req_data.lens; + size_t index = 0; + for (int i = 0; i < lens[0]; i++) { + input_shape->push_back(static_cast(req_data.vals[index++])); + } + for (int j = 0; j < lens[1]; j++) { + indices_shape->push_back(static_cast(req_data.vals[index++])); + } + for (int k = 0; k < lens[2]; k++) { + output_shape->push_back(static_cast(req_data.vals[index++])); + } + ps_->InitEmbeddingTable(key, shapes); +} + +template +void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + const Key &key = req_data.keys[0]; + ps_->DoEmbeddingLookup(key, req_data.vals, res); + for (size_t i = 0; i < req_data.vals.size(); i++) { + res->keys->push_back(req_data.vals[i]); + } +} + +template +bool ParameterServer::Init(const FuncGraphPtr &func_graph) { + const char *server_num = getenv(kEnvPServerNum); + const char *worker_num = getenv(kEnvWorkerNum); + if (server_num != nullptr) { + pserver_num_ = *server_num - '0'; + } + if (worker_num != nullptr) { + worker_num_ = *worker_num - '0'; + } + func_graph_ = func_graph; + rank_id_ = ::ps::MyRank(); + handler_.reset(new ServerHandler(this)); + + InitOptimInfoBuilders(); + + ps_->set_request_handle(*handler_); + thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); + return true; +} + +template +void ParameterServer::InitOptimInfoBuilders() { + std::shared_ptr momentum_info_builder = std::make_shared(); + std::shared_ptr sparse_adam_info_builder = std::make_shared(); + std::shared_ptr sparse_ftrl_info_builder = std::make_shared(); + optim_info_builders_[kApplyMomentum] = momentum_info_builder; + optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; + optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; +} + +template +void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_id) { + if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") { + return; + } + weight_key_to_optims_[key] = Util::optimizer_name(optim_id); +} + +template +void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { + InputsShapePtr inputs_shape = std::make_shared(); + int val_idx = 0; + const Key &key = keys[0]; + + if (optim_inputs_shape_.count(key) == 0) { + optim_inputs_shape_[key] = inputs_shape; + } + for (size_t i = 0; i < keys.size(); i++) { + auto shape = std::make_shared>(); + inputs_shape->push_back(shape); + + int len = lengths[i]; + for (int j = 0; j < len; j++) { + shape->push_back(values[val_idx++]); + } + } + if (weight_key_to_optims_.count(key) > 0) { + const std::string &optim_name = weight_key_to_optims_[key]; + if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) { + if (optim_name == kSparseAdam) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } else if (optim_name == kApplyMomentum) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } else if (optim_name == kSparseFtrl) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } + } + } +} + +template +void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { + if (weights_.count(key) == 0) { + weights_[key] = weight; + } +} + +template +void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { + if (grads_.count(key) == 0) { + grads_[key] = grad; + grads_accum_counter_[key] = 0; + } +} + +template +void ParameterServer::InitEmbeddingTable( + const Key &key, const std::shared_ptr>>> &shapes) { + // Init embedding lookup kernel + std::shared_ptr lookup = std::make_shared(rank_id_, pserver_num_); + lookup->InitKernel(shapes); + embedding_lookup_ops_[key] = lookup; + + // Init embedding weight + const std::vector &input_shapes = lookup->input_sizes(); + size_t total_dims = 1; + for (auto shape : input_shapes) { + total_dims *= shape; + } + WeightPtr embedding = std::make_shared(total_dims, 0.01); + weights_[key] = embedding; + + grads_accum_counter_[key] = 0; +} + +template +void ParameterServer::UpdateWeights() { + while (true) { + std::unique_lock lock(mutex_); + apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); + + for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { + Key key = iter->first; + WeightPtr weight_ptr = iter->second; + + std::shared_ptr optimizer = nullptr; + if (weight_key_to_optims_.count(key) > 0) { + const std::string &optim_name = weight_key_to_optims_[key]; + optimizer = optimizers_[optim_name]; + } + MS_EXCEPTION_IF_NULL(optimizer); + + std::shared_ptr optim_info = optim_infos_[key]; + if (optim_info == nullptr) { + continue; + } + const WeightPtr &weight = weights_[key]; + optim_info->UpdateWeight(weight); + const std::vector &inputs = optim_info->inputs(); + const std::vector &workspaces = optim_info->workspaces(); + const std::vector &outputs = optim_info->outputs(); + + optimizer->Execute(inputs, workspaces, outputs); + optim_info->Reset(); + } + ResetGradAccumCount(); + accum_grads_cv_.notify_all(); + } +} + +template +void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { + std::unique_lock lock(mutex_); + accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); + + const Key &key = keys[0]; + std::shared_ptr optim_info = optim_infos_[key]; + + // Create or update the optimizer info + if (optim_info == nullptr) { + const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; + std::shared_ptr pserver_kernel = optimizers_[weight_key_to_optims_[key]]; + if (pserver_kernel == nullptr) { + MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; + } + MS_EXCEPTION_IF_NULL(pserver_kernel); + OptimizerInfo *optim = + builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_); + optim_info.reset(optim); + optim_infos_[key] = optim_info; + } else { + optim_info->Update(values, lengths); + } + MS_EXCEPTION_IF_NULL(optim_info); + + optim_info->Accumulate(values, lengths); + + grads_accum_counter_[key] += 1; + if (grads_accum_counter_[key] == worker_num_) { + grad_accum_count_++; + } + if (ReadyForUpdateWeights()) { + apply_grads_cv_.notify_one(); + } +} + +template +WeightPtr ParameterServer::weight(const Key &key) { + std::unique_lock lock(mutex_); + + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid weight key " << key; + return nullptr; + } + WeightPtr weight_ptr = weights_[key]; + WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); + copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); + return copy_weight_ptr; +} + +template +void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res) { + std::unique_lock lock(mutex_); + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding table key " << key; + return; + } + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; + return; + } + WeightPtr table_ptr = weights_[key]; + std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; + + // Update shapes of lookup operator + std::shared_ptr>>> shapes = + std::make_shared>>>(); + std::shared_ptr> indices_shape = std::make_shared>(); + indices_shape->emplace_back(lookup_ids.size()); + shapes->push_back(indices_shape); + table_lookup_op->ReInit(shapes); + + const std::vector output_shapes = table_lookup_op->output_sizes(); + std::vector inputs; + AddressPtr embedding_table = std::make_shared(); + AddressPtr indices = std::make_shared(); + inputs.push_back(embedding_table); + inputs.push_back(indices); + embedding_table->addr = table_ptr->data(); + embedding_table->size = table_ptr->size() * sizeof(T); + indices->addr = lookup_ids.data(); + indices->size = lookup_ids.size() * sizeof(T); + + std::vector workspaces; + std::vector outputs; + AddressPtr output = std::make_shared(); + std::shared_ptr addr = std::make_shared(output_shapes[0] / sizeof(T), 0); + + output->addr = addr->data(); + output->size = output_shapes[0]; + outputs.push_back(output); + + table_lookup_op->Execute(inputs, workspaces, outputs); + res->vals = *addr; + res->lens.push_back(res.vals.size()); +} + +template +int ParameterServer::SumOfShapes(const std::vector &shapes) const { + int sum = 1; + for (auto shape : shapes) { + sum *= shape; + } + return sum; +} + +template +size_t ParameterServer::PreComputeCapacity(const Keys &keys, const Lengths &lens) { + size_t capacity = 0; + for (size_t i = 0; i < keys.size(); i++) { + Key key = keys[i]; + if (embedding_row_lens_.count(key) > 0) { + capacity += embedding_row_lens_[key] * lens[i]; + } else { + MS_LOG(ERROR) << "Invalid embedding lookup id " << key; + } + } + return capacity; +} + +template +inline bool ParameterServer::ReadyForUpdateWeights() { + return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); +} + +template +inline bool ParameterServer::ReadyForAccumGrads() { + return grad_accum_count_ < weights_.size(); +} + +template +inline void ParameterServer::ResetGradAccumCount() { + grad_accum_count_ = 0; + for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { + grads_accum_counter_[iter->first] = 0; + } +} + +template +void ParameterServer::Run(const FuncGraphPtr &func_graph) { + ::ps::Start(0); + if (!::ps::IsServer()) { + std::cout << "This is not ther Server" << std::endl; + return; + } + Init(func_graph); + thread_->join(); +} +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_