From 4281f380757827042cde8f4cf95abd4d5ff21d0a Mon Sep 17 00:00:00 2001 From: ZPaC Date: Tue, 28 Jul 2020 16:48:28 +0800 Subject: [PATCH] Delete hard code in pull kernel. --- .../kernel_compiler/cpu/ps/pull_kernel.h | 5 +-- .../ccsrc/backend/session/ascend_session.cc | 4 +-- .../ccsrc/backend/session/cpu_session.cc | 5 +-- .../ccsrc/backend/session/gpu_session.cc | 4 +-- .../ccsrc/backend/session/session_basic.cc | 10 +----- .../ccsrc/backend/session/session_basic.h | 3 +- .../parallel/ps/optimizer_info_builder.cc | 28 +++++++++------ .../frontend/parallel/ps/parameter_server.h | 12 +------ mindspore/ccsrc/frontend/parallel/ps/worker.h | 34 ++++++++++++++++--- 9 files changed, 57 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h index 3f0cf67d6..350b503d8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h @@ -33,8 +33,9 @@ class PullKernel : public CPUKernel { ~PullKernel() override = default; bool Launch(const std::vector &inputs, const std::vector &, const std::vector &) { - // If the paramter is embedding table, don't Pull from PServer. - if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) { + bool init_in_server = mindspore::parallel::ps::Worker::GetInstance().GetParamInitInServer(param_name_); + // If init_in_server, forward kernel should run in server too. + if (!init_in_server) { parallel::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); } return true; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 2ea5711fa..cf953ef4a 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -325,9 +325,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector need_sync_outputs; diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index d608f69ed..c22bf5ff7 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -237,9 +237,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector shape_init_in_server = {1}; for (size_t i = 0; i < inputs.size(); ++i) { auto tensor = inputs[i]; MS_EXCEPTION_IF_NULL(tensor); @@ -1233,16 +1232,9 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { auto pk_node = input_node->cast(); - bool init_in_server = false; - if (tensor->shape_c() == shape_init_in_server) { - MS_LOG(INFO) << "The param need to be initialized in server " << pk_node->fullname_with_scope(); - init_in_server = true; - } - mindspore::parallel::ps::Worker::GetInstance().InitPSParamAndOptim( - pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server); + mindspore::parallel::ps::Worker::GetInstance().InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor); } } - ps_init_ = true; } #endif } // namespace session diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 427c49bdb..5f3a0cab9 100755 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -52,7 +52,7 @@ using OpRunInfoPtr = std::shared_ptr; class SessionBasic { public: - SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), ps_init_(false) { + SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { #ifdef ENABLE_DEBUGGER debugger_ = nullptr; #endif @@ -146,7 +146,6 @@ class SessionBasic { CallBackFunc summary_callback_; static GraphId graph_sum_; uint32_t device_id_; - bool ps_init_; #ifdef ENABLE_DEBUGGER std::shared_ptr debugger_; #endif diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index 23ad87c41..60f8d10f1 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -81,13 +81,21 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, AddressPtr m = std::make_shared(); m->addr = new float[weight->size()]; m->size = weight->size() * sizeof(float); + int ret = memset_s(m->addr, m->size, 0x00, m->size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } AddressPtr v = std::make_shared(); v->addr = new float[weight->size()]; v->size = weight->size() * sizeof(float); + ret = memset_s(v->addr, v->size, 0x00, v->size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } void *data_ptr = values.data(); void *copy_data_ptr = new float[values.size()]; - auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); + ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } @@ -120,10 +128,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies()); AddressPtr grad = std::make_shared(); grad->addr = new float[total_grad_size * worker_num]; - auto ret2 = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5], - lens[6] * sizeof(float)); - if (ret2 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + ret = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5], + lens[6] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } grad->size = lens[6] * sizeof(float); @@ -132,10 +140,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies()); AddressPtr indices = std::make_shared(); indices->addr = new float[total_indice_size * worker_num]; - auto ret3 = memcpy_s(indices->addr, lens[7] * sizeof(float), - reinterpret_cast(epsilon->addr) + lens[5] + lens[6], lens[7] * sizeof(float)); - if (ret3 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")"; + ret = memcpy_s(indices->addr, lens[7] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5] + lens[6], + lens[7] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } indices->size = lens[7] * sizeof(int); @@ -160,7 +168,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, linear->addr = new float[weight->size()]; auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")"; } linear->size = weight->size() * sizeof(float); diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 092e907da..674490db3 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -208,11 +208,6 @@ void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re size_t pos = 0; for (size_t i = 0; i < key_num; i++) { Key key = req_data.keys[i]; - if (init_weights_[key]) { - continue; - } else { - init_weights_[key] = true; - } size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); @@ -261,11 +256,6 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { std::unique_lock lock(ps_->mutex()); const Key &key = req_data.keys[0]; - if (init_weights_[key]) { - return; - } else { - init_weights_[key] = true; - } std::shared_ptr>>> shapes = std::make_shared>>>(); std::shared_ptr> input_shape = std::make_shared>(); @@ -418,7 +408,7 @@ const CNodePtr ParameterServer::GetCNode(const std::string &name) const { template void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { MS_LOG(INFO) << "Initializing weight for key " << key; - if (weights_.count(key) == 0) { + if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { weights_[key] = weight; tokens_[key] = 0; is_embedding_[key] = false; diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index ec220d5ef..4908534cb 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -24,6 +24,7 @@ #include #include "ps/ps.h" #include "utils/log_adapter.h" +#include "ir/tensor.h" #include "frontend/parallel/ps/util.h" #include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/worker_proxy.h" @@ -43,12 +44,13 @@ class Worker { void Push(const std::vector &keys, std::vector addrs, const std::vector &sizes); void Pull(const size_t key, void *dev_addr, const size_t size); size_t SetParamKey(const std::string ¶m_name); + void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); + bool GetParamInitInServer(const std::string ¶m_name); void SetKeyOptimId(size_t key, const std::string &optimizer_name); void SetOptimInputShapes(size_t key, const std::vector &shape); void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const std::vector &sizes); - void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size, - bool init_in_server = false); + void InitPSParamAndOptim(const std::string ¶m_name, tensor::TensorPtr tensor); void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); void Finalize(); @@ -74,6 +76,7 @@ class Worker { std::map init_keys_; std::map key_to_optimId_; std::map>> key_to_optim_shapes_; + std::map param_to_init_in_server_; }; template @@ -208,6 +211,20 @@ size_t Worker::SetParamKey(const std::string ¶m_name) { return key; } +template +void Worker::SetParamInitInServer(const std::string ¶m_name, bool init_in_server) { + MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server; + param_to_init_in_server_[param_name] = init_in_server; +} + +template +bool Worker::GetParamInitInServer(const std::string ¶m_name) { + if (param_to_init_in_server_.count(param_name) == 0) { + return false; + } + return param_to_init_in_server_[param_name]; +} + template size_t Worker::GetParamKey(const std::string ¶m_name) { size_t key = kInvalidKey; @@ -253,13 +270,22 @@ void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vecto template // Initialize parameters and optimizer kernels of Parameter Server. -void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size, - bool init_in_server) { +void Worker::InitPSParamAndOptim(const std::string ¶m_name, tensor::TensorPtr tensor) { + void *param_data = tensor->data_c(); + size_t param_size = LongToSize(tensor->data().nbytes()); + std::vector param_shape = tensor->shape_c(); + size_t param_key = GetParamKey(param_name); if (param_key == kInvalidKey) { MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; return; } + bool init_in_server = false; + std::vector shape_init_in_server = {1}; + if (param_shape == shape_init_in_server) { + init_in_server = true; + } + SetParamInitInServer(param_name, init_in_server); bool init = IsKeyInit(param_key); if (!init) { MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name -- GitLab