提交 8943cdbc 编写于 作者: Z ZPaC

Delete parameter name hard code.

上级 26de81a2
...@@ -1194,6 +1194,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, ...@@ -1194,6 +1194,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
} }
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
std::vector<int> shape_init_in_server = {1};
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs[i]; auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
...@@ -1201,8 +1202,13 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, ...@@ -1201,8 +1202,13 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>(); auto pk_node = input_node->cast<ParameterPtr>();
bool init_in_server = false;
if (tensor->shape_c() == shape_init_in_server) {
MS_LOG(INFO) << "The param need to be initialized in server " << pk_node->fullname_with_scope();
init_in_server = true;
}
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim( mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes())); pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server);
} }
} }
ps_init_ = true; ps_init_ = true;
......
...@@ -530,6 +530,10 @@ inline bool ParameterServer<T>::ReadyForUpdateWeights() { ...@@ -530,6 +530,10 @@ inline bool ParameterServer<T>::ReadyForUpdateWeights() {
template <typename T> template <typename T>
inline bool ParameterServer<T>::ReadyForAccumGrads() { inline bool ParameterServer<T>::ReadyForAccumGrads() {
if (weights_.empty()) {
MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
"kInitWeightsCmd command. 2.The Server failed to initialize weights.";
}
return grad_accum_count_ < weights_.size(); return grad_accum_count_ < weights_.size();
} }
......
...@@ -47,7 +47,8 @@ class Worker { ...@@ -47,7 +47,8 @@ class Worker {
void SetOptimInputShapes(size_t key, const std::vector<int> &shape); void SetOptimInputShapes(size_t key, const std::vector<int> &shape);
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes); void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes);
void InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size); void InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size,
bool init_in_server = false);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd); const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
void Finalize(); void Finalize();
...@@ -237,7 +238,8 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto ...@@ -237,7 +238,8 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
template <typename T> template <typename T>
// Initialize parameters and optimizer kernels of Parameter Server. // Initialize parameters and optimizer kernels of Parameter Server.
void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size) { void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size,
bool init_in_server) {
size_t param_key = GetParamKey(param_name); size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) { if (param_key == kInvalidKey) {
MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned.";
...@@ -245,9 +247,9 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_d ...@@ -245,9 +247,9 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_d
} }
bool init = IsKeyInit(param_key); bool init = IsKeyInit(param_key);
if (!init) { if (!init) {
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name; MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
// No need to push embedding table data to Parameter Server. << ", whether init in server: " << init_in_server;
if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) { if (!init_in_server) {
InitPSParamData({param_key}, param_data, param_size); InitPSParamData({param_key}, param_data, param_size);
} }
InitPSOptimId(param_key); InitPSOptimId(param_key);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册