提交 9dc23eeb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3602 Delete hard code in pull node

Merge pull request !3602 from ZPaC/r0.6-delete-hard-code-in-pull-node
......@@ -33,8 +33,9 @@ class PullKernel : public CPUKernel {
~PullKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
// If the paramter is embedding table, don't Pull from PServer.
if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) {
bool init_in_server = mindspore::parallel::ps::Worker<float>::GetInstance().GetParamInitInServer(param_name_);
// If init_in_server, forward kernel should run in server too.
if (!init_in_server) {
parallel::ps::Worker<T>::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size);
}
return true;
......
......@@ -517,9 +517,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
LoadInputData(kernel_graph, inputs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
InitPSParamAndOptim(kernel_graph, inputs);
#endif
// convert inputs to model
predictmodel::StepConvertWeight(inputs);
......
......@@ -91,10 +91,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
InitPSParamAndOptim(kernel_graph, inputs);
#endif
MS_LOG(INFO) << "Bind input output address";
std::vector<tensor::TensorPtr> need_sync_outputs;
......
......@@ -233,9 +233,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
LoadInputData(kernel_graph, inputs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
InitPSParamAndOptim(kernel_graph, inputs);
#endif
MS_EXCEPTION_IF_NULL(kernel_graph);
// Convert inputs to model
......
......@@ -1196,7 +1196,6 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::vector<int> shape_init_in_server = {1};
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor);
......@@ -1204,16 +1203,9 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
bool init_in_server = false;
if (tensor->shape_c() == shape_init_in_server) {
MS_LOG(INFO) << "The parameter needs to be initialized in server " << pk_node->fullname_with_scope();
init_in_server = true;
}
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server);
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor);
}
}
ps_init_ = true;
}
#endif
} // namespace session
......
......@@ -51,7 +51,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic {
public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), ps_init_(false) {
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
#ifdef ENABLE_DEBUGGER
debugger_ = nullptr;
#endif
......@@ -152,7 +152,6 @@ class SessionBasic {
CallBackFunc summary_callback_;
static GraphId graph_sum_;
uint32_t device_id_;
bool ps_init_;
#ifdef ENABLE_DEBUGGER
std::shared_ptr<Debugger> debugger_;
#endif
......
......@@ -24,7 +24,6 @@
#include <memory>
#include <vector>
#include <mutex>
#include <list>
#include <condition_variable>
#include <thread>
#include <cmath>
......@@ -209,11 +208,6 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re
size_t pos = 0;
for (size_t i = 0; i < key_num; i++) {
Key key = req_data.keys[i];
if (init_weights_[key]) {
continue;
} else {
init_weights_[key] = true;
}
size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i];
WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>();
......@@ -262,11 +256,6 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
const Key &key = req_data.keys[0];
if (init_weights_[key]) {
return;
} else {
init_weights_[key] = true;
}
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
......@@ -419,7 +408,7 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const {
template <typename T>
void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
MS_LOG(INFO) << "Initializing weight for key " << key;
if (weights_.count(key) == 0) {
if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
weights_[key] = weight;
tokens_[key] = 0;
is_embedding_[key] = false;
......
......@@ -24,6 +24,7 @@
#include <map>
#include "ps/ps.h"
#include "utils/log_adapter.h"
#include "ir/tensor.h"
#include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/worker_proxy.h"
......@@ -43,12 +44,13 @@ class Worker {
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const std::vector<int> &sizes);
void Pull(const size_t key, void *dev_addr, const size_t size);
size_t SetParamKey(const std::string &param_name);
void SetParamInitInServer(const std::string &param_name, bool init_in_server);
bool GetParamInitInServer(const std::string &param_name);
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
void SetOptimInputShapes(size_t key, const std::vector<int> &shape);
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes);
void InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size,
bool init_in_server = false);
void InitPSParamAndOptim(const std::string &param_name, tensor::TensorPtr tensor);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
void Finalize();
......@@ -74,6 +76,7 @@ class Worker {
std::map<size_t, bool> init_keys_;
std::map<size_t, int> key_to_optimId_;
std::map<size_t, std::vector<std::vector<int>>> key_to_optim_shapes_;
std::map<std::string, bool> param_to_init_in_server_;
};
template <typename T>
......@@ -208,6 +211,20 @@ size_t Worker<T>::SetParamKey(const std::string &param_name) {
return key;
}
template <typename T>
void Worker<T>::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
param_to_init_in_server_[param_name] = init_in_server;
}
template <typename T>
bool Worker<T>::GetParamInitInServer(const std::string &param_name) {
if (param_to_init_in_server_.count(param_name) == 0) {
return false;
}
return param_to_init_in_server_[param_name];
}
template <typename T>
size_t Worker<T>::GetParamKey(const std::string &param_name) {
size_t key = kInvalidKey;
......@@ -253,13 +270,22 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
template <typename T>
// Initialize parameters and optimizer kernels of Parameter Server.
void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size,
bool init_in_server) {
void Worker<T>::InitPSParamAndOptim(const std::string &param_name, tensor::TensorPtr tensor) {
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
std::vector<int> param_shape = tensor->shape_c();
size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) {
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned.";
return;
}
bool init_in_server = false;
std::vector<int> shape_init_in_server = {1};
if (param_shape == shape_init_in_server) {
init_in_server = true;
}
SetParamInitInServer(param_name, init_in_server);
bool init = IsKeyInit(param_key);
if (!init) {
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册