提交 4281f380 编写于 作者: Z ZPaC

Delete hard code in pull kernel.

上级 e0a2d2f9
...@@ -33,8 +33,9 @@ class PullKernel : public CPUKernel { ...@@ -33,8 +33,9 @@ class PullKernel : public CPUKernel {
~PullKernel() override = default; ~PullKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) { bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
// If the paramter is embedding table, don't Pull from PServer. bool init_in_server = mindspore::parallel::ps::Worker<float>::GetInstance().GetParamInitInServer(param_name_);
if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) { // If init_in_server, forward kernel should run in server too.
if (!init_in_server) {
parallel::ps::Worker<T>::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); parallel::ps::Worker<T>::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size);
} }
return true; return true;
......
...@@ -325,9 +325,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor:: ...@@ -325,9 +325,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
LoadInputData(kernel_graph, inputs); LoadInputData(kernel_graph, inputs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server // Initialize parameter server
if (!ps_init_) { InitPSParamAndOptim(kernel_graph, inputs);
InitPSParamAndOptim(kernel_graph, inputs);
}
#endif #endif
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
......
...@@ -89,10 +89,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten ...@@ -89,10 +89,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
auto &kernel_graph = graphs_[graph_id]; auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server InitPSParamAndOptim(kernel_graph, inputs);
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
#endif #endif
MS_LOG(INFO) << "Bind input output address"; MS_LOG(INFO) << "Bind input output address";
std::vector<tensor::TensorPtr> need_sync_outputs; std::vector<tensor::TensorPtr> need_sync_outputs;
......
...@@ -237,9 +237,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten ...@@ -237,9 +237,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
LoadInputData(kernel_graph, inputs); LoadInputData(kernel_graph, inputs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server // Initialize parameter server
if (!ps_init_) { InitPSParamAndOptim(kernel_graph, inputs);
InitPSParamAndOptim(kernel_graph, inputs);
}
#endif #endif
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
{ {
......
...@@ -1225,7 +1225,6 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, ...@@ -1225,7 +1225,6 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
} }
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
std::vector<int> shape_init_in_server = {1};
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs[i]; auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
...@@ -1233,16 +1232,9 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, ...@@ -1233,16 +1232,9 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>(); auto pk_node = input_node->cast<ParameterPtr>();
bool init_in_server = false; mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor);
if (tensor->shape_c() == shape_init_in_server) {
MS_LOG(INFO) << "The param need to be initialized in server " << pk_node->fullname_with_scope();
init_in_server = true;
}
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server);
} }
} }
ps_init_ = true;
} }
#endif #endif
} // namespace session } // namespace session
......
...@@ -52,7 +52,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>; ...@@ -52,7 +52,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic { class SessionBasic {
public: public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), ps_init_(false) { SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
debugger_ = nullptr; debugger_ = nullptr;
#endif #endif
...@@ -146,7 +146,6 @@ class SessionBasic { ...@@ -146,7 +146,6 @@ class SessionBasic {
CallBackFunc summary_callback_; CallBackFunc summary_callback_;
static GraphId graph_sum_; static GraphId graph_sum_;
uint32_t device_id_; uint32_t device_id_;
bool ps_init_;
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
std::shared_ptr<Debugger> debugger_; std::shared_ptr<Debugger> debugger_;
#endif #endif
......
...@@ -81,13 +81,21 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ...@@ -81,13 +81,21 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
AddressPtr m = std::make_shared<kernel::Address>(); AddressPtr m = std::make_shared<kernel::Address>();
m->addr = new float[weight->size()]; m->addr = new float[weight->size()];
m->size = weight->size() * sizeof(float); m->size = weight->size() * sizeof(float);
int ret = memset_s(m->addr, m->size, 0x00, m->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
AddressPtr v = std::make_shared<kernel::Address>(); AddressPtr v = std::make_shared<kernel::Address>();
v->addr = new float[weight->size()]; v->addr = new float[weight->size()];
v->size = weight->size() * sizeof(float); v->size = weight->size() * sizeof(float);
ret = memset_s(v->addr, v->size, 0x00, v->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
void *data_ptr = values.data(); void *data_ptr = values.data();
void *copy_data_ptr = new float[values.size()]; void *copy_data_ptr = new float[values.size()];
auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float));
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
} }
...@@ -120,10 +128,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ...@@ -120,10 +128,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies<size_t>()); std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies<size_t>());
AddressPtr grad = std::make_shared<kernel::Address>(); AddressPtr grad = std::make_shared<kernel::Address>();
grad->addr = new float[total_grad_size * worker_num]; grad->addr = new float[total_grad_size * worker_num];
auto ret2 = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast<float *>(epsilon->addr) + lens[5], ret = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast<float *>(epsilon->addr) + lens[5],
lens[6] * sizeof(float)); lens[6] * sizeof(float));
if (ret2 != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
} }
grad->size = lens[6] * sizeof(float); grad->size = lens[6] * sizeof(float);
...@@ -132,10 +140,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ...@@ -132,10 +140,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies<size_t>()); std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies<size_t>());
AddressPtr indices = std::make_shared<kernel::Address>(); AddressPtr indices = std::make_shared<kernel::Address>();
indices->addr = new float[total_indice_size * worker_num]; indices->addr = new float[total_indice_size * worker_num];
auto ret3 = memcpy_s(indices->addr, lens[7] * sizeof(float), ret = memcpy_s(indices->addr, lens[7] * sizeof(float), reinterpret_cast<float *>(epsilon->addr) + lens[5] + lens[6],
reinterpret_cast<float *>(epsilon->addr) + lens[5] + lens[6], lens[7] * sizeof(float)); lens[7] * sizeof(float));
if (ret3 != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")"; MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
} }
indices->size = lens[7] * sizeof(int); indices->size = lens[7] * sizeof(int);
...@@ -160,7 +168,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, ...@@ -160,7 +168,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
linear->addr = new float[weight->size()]; linear->addr = new float[weight->size()];
auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
} }
linear->size = weight->size() * sizeof(float); linear->size = weight->size() * sizeof(float);
......
...@@ -208,11 +208,6 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re ...@@ -208,11 +208,6 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re
size_t pos = 0; size_t pos = 0;
for (size_t i = 0; i < key_num; i++) { for (size_t i = 0; i < key_num; i++) {
Key key = req_data.keys[i]; Key key = req_data.keys[i];
if (init_weights_[key]) {
continue;
} else {
init_weights_[key] = true;
}
size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i];
WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>(); WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>();
...@@ -261,11 +256,6 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta ...@@ -261,11 +256,6 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) { const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex()); std::unique_lock<std::mutex> lock(ps_->mutex());
const Key &key = req_data.keys[0]; const Key &key = req_data.keys[0];
if (init_weights_[key]) {
return;
} else {
init_weights_[key] = true;
}
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>(); std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>(); std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
...@@ -418,7 +408,7 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const { ...@@ -418,7 +408,7 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const {
template <typename T> template <typename T>
void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) { void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
MS_LOG(INFO) << "Initializing weight for key " << key; MS_LOG(INFO) << "Initializing weight for key " << key;
if (weights_.count(key) == 0) { if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
weights_[key] = weight; weights_[key] = weight;
tokens_[key] = 0; tokens_[key] = 0;
is_embedding_[key] = false; is_embedding_[key] = false;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <map> #include <map>
#include "ps/ps.h" #include "ps/ps.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ir/tensor.h"
#include "frontend/parallel/ps/util.h" #include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/worker_proxy.h" #include "frontend/parallel/ps/worker_proxy.h"
...@@ -43,12 +44,13 @@ class Worker { ...@@ -43,12 +44,13 @@ class Worker {
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const std::vector<int> &sizes); void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const std::vector<int> &sizes);
void Pull(const size_t key, void *dev_addr, const size_t size); void Pull(const size_t key, void *dev_addr, const size_t size);
size_t SetParamKey(const std::string &param_name); size_t SetParamKey(const std::string &param_name);
void SetParamInitInServer(const std::string &param_name, bool init_in_server);
bool GetParamInitInServer(const std::string &param_name);
void SetKeyOptimId(size_t key, const std::string &optimizer_name); void SetKeyOptimId(size_t key, const std::string &optimizer_name);
void SetOptimInputShapes(size_t key, const std::vector<int> &shape); void SetOptimInputShapes(size_t key, const std::vector<int> &shape);
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes); void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes);
void InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size, void InitPSParamAndOptim(const std::string &param_name, tensor::TensorPtr tensor);
bool init_in_server = false);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd); const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
void Finalize(); void Finalize();
...@@ -74,6 +76,7 @@ class Worker { ...@@ -74,6 +76,7 @@ class Worker {
std::map<size_t, bool> init_keys_; std::map<size_t, bool> init_keys_;
std::map<size_t, int> key_to_optimId_; std::map<size_t, int> key_to_optimId_;
std::map<size_t, std::vector<std::vector<int>>> key_to_optim_shapes_; std::map<size_t, std::vector<std::vector<int>>> key_to_optim_shapes_;
std::map<std::string, bool> param_to_init_in_server_;
}; };
template <typename T> template <typename T>
...@@ -208,6 +211,20 @@ size_t Worker<T>::SetParamKey(const std::string &param_name) { ...@@ -208,6 +211,20 @@ size_t Worker<T>::SetParamKey(const std::string &param_name) {
return key; return key;
} }
template <typename T>
void Worker<T>::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
param_to_init_in_server_[param_name] = init_in_server;
}
template <typename T>
bool Worker<T>::GetParamInitInServer(const std::string &param_name) {
if (param_to_init_in_server_.count(param_name) == 0) {
return false;
}
return param_to_init_in_server_[param_name];
}
template <typename T> template <typename T>
size_t Worker<T>::GetParamKey(const std::string &param_name) { size_t Worker<T>::GetParamKey(const std::string &param_name) {
size_t key = kInvalidKey; size_t key = kInvalidKey;
...@@ -253,13 +270,22 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto ...@@ -253,13 +270,22 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
template <typename T> template <typename T>
// Initialize parameters and optimizer kernels of Parameter Server. // Initialize parameters and optimizer kernels of Parameter Server.
void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size, void Worker<T>::InitPSParamAndOptim(const std::string &param_name, tensor::TensorPtr tensor) {
bool init_in_server) { void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
std::vector<int> param_shape = tensor->shape_c();
size_t param_key = GetParamKey(param_name); size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) { if (param_key == kInvalidKey) {
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned.";
return; return;
} }
bool init_in_server = false;
std::vector<int> shape_init_in_server = {1};
if (param_shape == shape_init_in_server) {
init_in_server = true;
}
SetParamInitInServer(param_name, init_in_server);
bool init = IsKeyInit(param_key); bool init = IsKeyInit(param_key);
if (!init) { if (!init) {
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册