提交 4281f380 编写于 作者: Z ZPaC

Delete hard code in pull kernel.

上级 e0a2d2f9
......@@ -33,8 +33,9 @@ class PullKernel : public CPUKernel {
~PullKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
// If the paramter is embedding table, don't Pull from PServer.
if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) {
bool init_in_server = mindspore::parallel::ps::Worker<float>::GetInstance().GetParamInitInServer(param_name_);
// If init_in_server, forward kernel should run in server too.
if (!init_in_server) {
parallel::ps::Worker<T>::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size);
}
return true;
......
......@@ -325,9 +325,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
LoadInputData(kernel_graph, inputs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
InitPSParamAndOptim(kernel_graph, inputs);
#endif
{
py::gil_scoped_release release;
......
......@@ -89,10 +89,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
InitPSParamAndOptim(kernel_graph, inputs);
#endif
MS_LOG(INFO) << "Bind input output address";
std::vector<tensor::TensorPtr> need_sync_outputs;
......
......@@ -237,9 +237,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
LoadInputData(kernel_graph, inputs);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
InitPSParamAndOptim(kernel_graph, inputs);
#endif
MS_EXCEPTION_IF_NULL(kernel_graph);
{
......
......@@ -1225,7 +1225,6 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::vector<int> shape_init_in_server = {1};
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor);
......@@ -1233,16 +1232,9 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
bool init_in_server = false;
if (tensor->shape_c() == shape_init_in_server) {
MS_LOG(INFO) << "The param need to be initialized in server " << pk_node->fullname_with_scope();
init_in_server = true;
}
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server);
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor);
}
}
ps_init_ = true;
}
#endif
} // namespace session
......
......@@ -52,7 +52,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic {
public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), ps_init_(false) {
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
#ifdef ENABLE_DEBUGGER
debugger_ = nullptr;
#endif
......@@ -146,7 +146,6 @@ class SessionBasic {
CallBackFunc summary_callback_;
static GraphId graph_sum_;
uint32_t device_id_;
bool ps_init_;
#ifdef ENABLE_DEBUGGER
std::shared_ptr<Debugger> debugger_;
#endif
......
......@@ -81,13 +81,21 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
AddressPtr m = std::make_shared<kernel::Address>();
m->addr = new float[weight->size()];
m->size = weight->size() * sizeof(float);
int ret = memset_s(m->addr, m->size, 0x00, m->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
AddressPtr v = std::make_shared<kernel::Address>();
v->addr = new float[weight->size()];
v->size = weight->size() * sizeof(float);
ret = memset_s(v->addr, v->size, 0x00, v->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
void *data_ptr = values.data();
void *copy_data_ptr = new float[values.size()];
auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float));
ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
......@@ -120,10 +128,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies<size_t>());
AddressPtr grad = std::make_shared<kernel::Address>();
grad->addr = new float[total_grad_size * worker_num];
auto ret2 = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast<float *>(epsilon->addr) + lens[5],
lens[6] * sizeof(float));
if (ret2 != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")";
ret = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast<float *>(epsilon->addr) + lens[5],
lens[6] * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
grad->size = lens[6] * sizeof(float);
......@@ -132,10 +140,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies<size_t>());
AddressPtr indices = std::make_shared<kernel::Address>();
indices->addr = new float[total_indice_size * worker_num];
auto ret3 = memcpy_s(indices->addr, lens[7] * sizeof(float),
reinterpret_cast<float *>(epsilon->addr) + lens[5] + lens[6], lens[7] * sizeof(float));
if (ret3 != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")";
ret = memcpy_s(indices->addr, lens[7] * sizeof(float), reinterpret_cast<float *>(epsilon->addr) + lens[5] + lens[6],
lens[7] * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
indices->size = lens[7] * sizeof(int);
......@@ -160,7 +168,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
linear->addr = new float[weight->size()];
auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
}
linear->size = weight->size() * sizeof(float);
......
......@@ -208,11 +208,6 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re
size_t pos = 0;
for (size_t i = 0; i < key_num; i++) {
Key key = req_data.keys[i];
if (init_weights_[key]) {
continue;
} else {
init_weights_[key] = true;
}
size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i];
WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>();
......@@ -261,11 +256,6 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
std::unique_lock<std::mutex> lock(ps_->mutex());
const Key &key = req_data.keys[0];
if (init_weights_[key]) {
return;
} else {
init_weights_[key] = true;
}
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
......@@ -418,7 +408,7 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const {
template <typename T>
void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
MS_LOG(INFO) << "Initializing weight for key " << key;
if (weights_.count(key) == 0) {
if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
weights_[key] = weight;
tokens_[key] = 0;
is_embedding_[key] = false;
......
......@@ -24,6 +24,7 @@
#include <map>
#include "ps/ps.h"
#include "utils/log_adapter.h"
#include "ir/tensor.h"
#include "frontend/parallel/ps/util.h"
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/worker_proxy.h"
......@@ -43,12 +44,13 @@ class Worker {
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const std::vector<int> &sizes);
void Pull(const size_t key, void *dev_addr, const size_t size);
size_t SetParamKey(const std::string &param_name);
void SetParamInitInServer(const std::string &param_name, bool init_in_server);
bool GetParamInitInServer(const std::string &param_name);
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
void SetOptimInputShapes(size_t key, const std::vector<int> &shape);
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes);
void InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size,
bool init_in_server = false);
void InitPSParamAndOptim(const std::string &param_name, tensor::TensorPtr tensor);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
void Finalize();
......@@ -74,6 +76,7 @@ class Worker {
std::map<size_t, bool> init_keys_;
std::map<size_t, int> key_to_optimId_;
std::map<size_t, std::vector<std::vector<int>>> key_to_optim_shapes_;
std::map<std::string, bool> param_to_init_in_server_;
};
template <typename T>
......@@ -208,6 +211,20 @@ size_t Worker<T>::SetParamKey(const std::string &param_name) {
return key;
}
template <typename T>
void Worker<T>::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
param_to_init_in_server_[param_name] = init_in_server;
}
template <typename T>
bool Worker<T>::GetParamInitInServer(const std::string &param_name) {
if (param_to_init_in_server_.count(param_name) == 0) {
return false;
}
return param_to_init_in_server_[param_name];
}
template <typename T>
size_t Worker<T>::GetParamKey(const std::string &param_name) {
size_t key = kInvalidKey;
......@@ -253,13 +270,22 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
template <typename T>
// Initialize parameters and optimizer kernels of Parameter Server.
void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size,
bool init_in_server) {
void Worker<T>::InitPSParamAndOptim(const std::string &param_name, tensor::TensorPtr tensor) {
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
std::vector<int> param_shape = tensor->shape_c();
size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) {
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned.";
return;
}
bool init_in_server = false;
std::vector<int> shape_init_in_server = {1};
if (param_shape == shape_init_in_server) {
init_in_server = true;
}
SetParamInitInServer(param_name, init_in_server);
bool init = IsKeyInit(param_key);
if (!init) {
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册