提交 52022c80 编写于 作者: Z ZPaC

Enable to train in parameter server mode

上级 4936fe48
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS}")
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(glog
VER 0.4.0
......
......@@ -123,3 +123,7 @@ endif()
if(ENABLE_DEBUGGER)
add_compile_definitions(ENABLE_DEBUGGER)
endif()
if(ENABLE_TESTCASES)
add_compile_definitions(ENABLE_TESTCASES)
endif()
\ No newline at end of file
......@@ -26,14 +26,6 @@ if (ENABLE_CPU)
"cpu/*.cc"
)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc"
"cpu/ps/pull_kernel.cc"
"cpu/ps/embedding_look_up_ps_kernel.cc"
"cpu/ps/embedding_look_up_proxy_kernel.cc"
"cpu/ps/apply_momentum_ps_kernel.cc"
"cpu/ps/sparse_apply_adam_ps_kernel.cc"
"cpu/ps/sparse_apply_ftrl_ps_kernel.cc")
if (NOT ENABLE_MPI)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc")
......@@ -41,6 +33,17 @@ if (ENABLE_CPU)
endif ()
endif ()
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows" OR ENABLE_GE)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/pserver_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/pull_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_adam_ps_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/sparse_apply_ftrl_ps_kernel.cc")
endif()
if (ENABLE_GPU)
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"gpu/*.cu"
......
......@@ -46,7 +46,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
protected:
void LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
float **output_addr);
void CheckParam(const CNodePtr &kernel_node);
......
......@@ -53,15 +53,15 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &i
size_t output_size = outputs[0]->size;
size_t size = input_size / sizeof(float);
::ps::SArray<float> lookup_ids(size, 0);
::ps::SArray<int> lookup_ids(size, 0);
::ps::SArray<int> lengths{size};
::ps::SArray<float> lookup_result;
::ps::SArray<float> lookup_result(output_size / sizeof(float), 0);
auto ret = memcpy_s(lookup_ids.data(), input_size, indices_addr, input_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
}
parallel::ps::Worker<float>::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, lookup_result,
parallel::ps::Worker<float>::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result,
parallel::ps::kEmbeddingLookupCmd);
auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size);
......
......@@ -50,7 +50,7 @@ void EmbeddingLookUpPSKernel::InitKernel(
split_num_ = pserver_num_;
// input shape should be sharded after computing offset_;
Shard(input_shape_, axis_);
Shard(&input_shape_, axis_);
size_t output_size =
std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies<size_t>());
......
......@@ -34,5 +34,13 @@ MS_REG_CPU_KERNEL_T(Push,
MS_REG_CPU_KERNEL_T(
Push, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
PushKernel, float);
MS_REG_CPU_KERNEL_T(Push,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeUInt64),
PushKernel, float);
} // namespace kernel
} // namespace mindspore
......@@ -43,7 +43,7 @@ class PushKernel : public CPUKernel {
sizes.push_back(SizeToInt(input->size) / sizeof(T));
}
parallel::ps::Worker<T>::GetInstance().Push(keys, addrs, sizes);
memcpy(outputs[0]->addr, &key_, sizeof(size_t));
memcpy_s(outputs[0]->addr, sizeof(size_t), &key_, sizeof(size_t));
return true;
}
......
......@@ -75,7 +75,7 @@ void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr<std::vector<std::shar
void SparseApplyAdamPSKernel::ReInit(const std::vector<AddressPtr> &inputs) {
const auto &indices_addr = inputs[10];
indices_size_ = indices_addr->size;
indices_size_ = indices_addr->size / sizeof(int);
workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float);
workspace_size_list_[1] = indices_size_ * sizeof(int);
}
......
......@@ -64,7 +64,7 @@ void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr<std::vector<std::shar
void SparseApplyFtrlPSKernel::ReInit(const std::vector<AddressPtr> &inputs) {
const auto &indices_addr = inputs[4];
indices_size_ = indices_addr->size;
indices_size_ = indices_addr->size / sizeof(int);
workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float);
workspace_size_list_[1] = indices_size_ * sizeof(int);
}
......
......@@ -71,7 +71,6 @@ bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
AbstractBasePtrList abstract_list;
AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node);
AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node);
AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node);
abstract_list.push_back(cnode->abstract());
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
......
......@@ -353,6 +353,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
RootGraphExecutorValidate(NOT_NULL(root_graph));
// adjust kernel
AdjustKernel(root_graph);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Assign parameter keys.
AssignParamKey(root_graph);
#endif
// assign stream
AssignStream(NOT_NULL(root_graph));
// insert profiling point
......@@ -511,6 +515,12 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
}
// load input data from user input
LoadInputData(kernel_graph, inputs);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
#endif
// convert inputs to model
predictmodel::StepConvertWeight(inputs);
{
......
......@@ -25,9 +25,15 @@
#include "predict/predict.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/pass/replace_node_by_proxy.h"
#ifdef ENABLE_DEBUGGER
#include "debug/debugger/debugger.h"
#endif
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/util.h"
#endif
namespace mindspore {
namespace session {
......@@ -49,12 +55,29 @@ ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf,
return new_parameter;
}
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
std::string pass_name = "replace_node_by_proxy";
pass_name.append(std::to_string(graph_sum_));
pm->AddPass(std::make_shared<opt::ReplaceNodeByProxy>(pass_name));
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
auto graph_id = graph_sum_;
auto graph = ConstructKernelGraph(lst, outputs);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Set kernel info";
SetKernelInfo(graph.get());
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
AssignParamKey(graph);
if (parallel::ps::Util::IsRoleOfWorker()) {
Optimize(graph);
}
#endif
predictmodel::StepConvertGraph(graph);
MS_LOG(INFO) << "Build kernel";
BuildKernel(graph.get());
......@@ -66,6 +89,12 @@ GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
#endif
MS_LOG(INFO) << "Bind input output address";
std::vector<tensor::TensorPtr> need_sync_outputs;
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs);
......
......@@ -37,6 +37,7 @@ class CPUSession : public SessionBasic {
protected:
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override;
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
private:
void SetKernelInfo(const KernelGraph *kernel_graph);
......
......@@ -167,6 +167,10 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
Optimize(graph);
// Select kernel build info
SelectKernel(graph);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Assign parameter keys.
AssignParamKey(graph);
#endif
// Convert kernel Graph to model
predictmodel::StepConvertGraph(graph);
// Start gpu kernel runtime
......@@ -204,6 +208,10 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
auto &kernel_graph = graphs_[graph_id];
// Load input data from user input
LoadInputData(kernel_graph, inputs);
// Initialize parameter server
if (!ps_init_) {
InitPSParamAndOptim(kernel_graph, inputs);
}
MS_EXCEPTION_IF_NULL(kernel_graph);
// Convert inputs to model
predictmodel::StepConvertWeight(inputs);
......
......@@ -35,6 +35,11 @@
#include "ir/dtype.h"
#include "ir/anf.h"
#include "ir/func_graph_cloner.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/worker.h"
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/util.h"
#endif
namespace mindspore {
namespace session {
......@@ -1097,5 +1102,92 @@ KernelGraphPtr SessionBasic::NewKernelGraph() {
graphs_[graph_sum_++] = graph;
return graph;
}
AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) {
MS_EXCEPTION_IF_NULL(push_node);
for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>()) {
for (auto input : node->cast<CNodePtr>()->inputs()) {
if (push_node == AnfAlgo::VisitKernel(input, 0).first) {
if (AnfAlgo::GetCNodeName(node) != kPullOpName) {
MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
}
return node;
}
}
}
}
return nullptr;
}
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
if (!parallel::ps::Util::IsRoleOfWorker()) {
MS_LOG(INFO) << "Not parameter server mode.";
return;
}
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>()) {
// Assign key for forward kernel EmbeddingLookup.
// The key will be assigned to embedding table ande Push kernel as well.
if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
size_t embedding_table_idx = 0;
auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
size_t key = parallel::ps::Worker<float>::GetInstance().SetParamKey(embedding_table->fullname_with_scope());
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
} else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
auto pull_node = FindPullNode(node, node_list);
if (!pull_node) {
MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node.";
}
// Second input of Pull node is the trainable parameter.
size_t parameter_index = 1;
auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
size_t key = parallel::ps::Worker<float>::GetInstance().SetParamKey(parameter_node->fullname_with_scope());
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
parallel::ps::Worker<float>::GetInstance().SetKeyOptimId(key, optimizer_name);
}
}
}
}
void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) {
if (!parallel::ps::Util::IsRoleOfWorker()) {
return;
}
std::vector<tensor::TensorPtr> inputs(inputs_const);
size_t input_ctrl_size = 1;
MS_EXCEPTION_IF_NULL(kernel_graph);
if (kernel_graph->input_ctrl_tensors()) {
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
}
auto input_nodes = kernel_graph->inputs();
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i];
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()));
}
}
ps_init_ = true;
}
#endif
} // namespace session
} // namespace mindspore
......@@ -51,7 +51,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class SessionBasic {
public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), ps_init_(false) {
#ifdef ENABLE_DEBUGGER
debugger_ = nullptr;
#endif
......@@ -104,6 +104,8 @@ class SessionBasic {
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
virtual void SetActive(GraphId, GraphId) {}
virtual void GetSummaryNodes(KernelGraph *graph);
void AssignParamKey(const KernelGraphPtr &kernel_graph);
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
#ifdef ENABLE_DEBUGGER
// set debugger
......@@ -140,6 +142,7 @@ class SessionBasic {
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter);
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
......@@ -148,6 +151,7 @@ class SessionBasic {
CallBackFunc summary_callback_;
static GraphId graph_sum_;
uint32_t device_id_;
bool ps_init_;
#ifdef ENABLE_DEBUGGER
std::shared_ptr<Debugger> debugger_;
#endif
......
file(GLOB_RECURSE _PARALLEL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/util.cc" "ps/scheduler.cc" "ps/optimizer_info.cc" "ps/optimizer_info_builder.cc")
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows" OR ENABLE_GE)
list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/optimizer_info_builder.cc")
list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/optimizer_info.cc")
list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/scheduler.cc")
list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/util.cc")
endif()
if (ENABLE_DUMP_PROTO)
list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
endif ()
......
......@@ -118,11 +118,13 @@ const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; }
const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; }
size_t MomentumOptimInfo::grad_index() { return 1; }
SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v,
const AddressPtr &beta1_power, const AddressPtr &beta2_power,
const AddressPtr &learning_rate, const AddressPtr &beta1,
const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad,
const AddressPtr &indices, size_t grads_offset, size_t indices_offset) {
const AddressPtr &indices) {
inputs_.push_back(weight);
inputs_.push_back(m);
inputs_.push_back(v);
......@@ -134,8 +136,8 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address
inputs_.push_back(epsilon);
inputs_.push_back(grad);
inputs_.push_back(indices);
grads_offset_ = grads_offset;
indices_offset_ = indices_offset;
grads_offset_ = 0;
indices_offset_ = 0;
}
void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) {
......@@ -159,15 +161,14 @@ size_t SparseAdamOptimInfo::grad_index() { return 6; }
size_t SparseAdamOptimInfo::indices_index() { return 7; }
SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear,
const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset,
size_t indices_offset) {
const AddressPtr &grad, const AddressPtr &indices) {
inputs_.push_back(weight);
inputs_.push_back(accum);
inputs_.push_back(linear);
inputs_.push_back(grad);
inputs_.push_back(indices);
grads_offset_ = grads_offset;
indices_offset_ = indices_offset;
grads_offset_ = 0;
indices_offset_ = 0;
}
const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; }
......
......@@ -81,6 +81,7 @@ class MomentumOptimInfo : public DenseOptimInfo {
const AddressPtr &gradient();
const AddressPtr &indices();
size_t grad_index() override;
};
class SparseAdamOptimInfo : public SparseOptimInfo {
......@@ -88,7 +89,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo {
SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power,
const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1,
const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad,
const AddressPtr &indices, size_t grads_offset, size_t indices_offset);
const AddressPtr &indices);
~SparseAdamOptimInfo() override = default;
void Update(const Values &values, const Lengths &lens) override;
......@@ -102,7 +103,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo {
class SparseFtrlOptimInfo : public SparseOptimInfo {
public:
SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear,
const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, size_t indices_offset);
const AddressPtr &grad, const AddressPtr &indices);
~SparseFtrlOptimInfo() override = default;
const AddressPtr &gradient();
......
......@@ -48,20 +48,25 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co
size_t worker_num) {
AddressPtr weight_addr = std::make_shared<kernel::Address>();
weight_addr->addr = weight->data();
weight_addr->size = weight->size();
weight_addr->size = weight->size() * sizeof(float);
void *data_ptr = values.data();
void *copy_data_ptr = new float[values.size()];
auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
AddressPtr accumulate = std::make_shared<kernel::Address>();
accumulate->addr = new float[weight->size()];
accumulate->size = weight->size();
accumulate->size = weight->size() * sizeof(float);
AddressPtr learning_rate = std::make_shared<kernel::Address>();
learning_rate->addr = data_ptr;
learning_rate->size = lens[0];
learning_rate->addr = copy_data_ptr;
learning_rate->size = lens[0] * sizeof(float);
AddressPtr gradient = std::make_shared<kernel::Address>();
gradient->addr = reinterpret_cast<float *>(learning_rate->addr) + lens[0];
gradient->size = lens[1];
gradient->size = lens[1] * sizeof(float);
AddressPtr momentum = std::make_shared<kernel::Address>();
momentum->addr = reinterpret_cast<float *>(gradient->addr) + lens[1];
momentum->size = lens[2];
momentum->size = lens[2] * sizeof(float);
return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum);
}
......@@ -131,10 +136,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
if (ret3 != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")";
}
indices->size = lens[7] * sizeof(float);
indices->size = lens[7] * sizeof(int);
return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
grad, indices, total_grad_size, total_indice_size);
grad, indices);
}
OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
......@@ -175,9 +180,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
if (ret2 != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")";
}
indices->size = lens[1] * sizeof(float);
indices->size = lens[1] * sizeof(int);
return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, total_grad_size, total_indice_size);
return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices);
}
} // namespace ps
} // namespace parallel
......
......@@ -19,7 +19,7 @@
#include <vector>
#include <memory>
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/ps/pserver_kernel.h"
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
#include "frontend/parallel/ps/optimizer_info.h"
namespace mindspore {
......
......@@ -40,12 +40,12 @@
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/context/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/ps/pserver_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/ps/sparse_apply_adam_ps_kernel.h"
#include "backend/kernel_compiler/ps/sparse_apply_ftrl_ps_kernel.h"
#include "backend/kernel_compiler/ps/apply_momentum_ps_kernel.h"
#include "backend/kernel_compiler/ps/embedding_look_up_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
namespace mindspore {
namespace parallel {
......@@ -118,7 +118,7 @@ class ParameterServer {
std::shared_ptr<session::KernelGraph> kernel_graph_;
std::shared_ptr<session::SessionBasic> sess_;
std::unordered_map<std::string, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
......@@ -249,10 +249,10 @@ template <typename T>
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
const Key &key = req_data.keys[0];
ps_->DoEmbeddingLookup(key, req_data.vals, res);
for (size_t i = 0; i < req_data.vals.size(); i++) {
res->keys->push_back(req_data.vals[i]);
res->keys.push_back(req_data.vals[i]);
}
ps_->DoEmbeddingLookup(key, req_data.vals, res);
}
template <typename T>
......@@ -288,7 +288,7 @@ void ParameterServer<T>::InitOptimInfoBuilders() {
template <typename T>
void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_id) {
if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") {
if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") {
return;
}
weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
......@@ -314,22 +314,22 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
}
if (weight_key_to_optims_.count(key) > 0) {
const std::string &optim_name = weight_key_to_optims_[key];
if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) {
if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) {
if (optim_name == kSparseAdam) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizers_[optim_name] = optimizer;
optimizers_[key] = optimizer;
} else if (optim_name == kApplyMomentum) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizers_[optim_name] = optimizer;
optimizers_[key] = optimizer;
} else if (optim_name == kSparseFtrl) {
std::shared_ptr<PServerKernel> optimizer =
std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_);
optimizer->InitKernel(optim_inputs_shape_[key]);
optimizers_[optim_name] = optimizer;
optimizers_[key] = optimizer;
}
}
}
......@@ -382,8 +382,7 @@ void ParameterServer<T>::UpdateWeights() {
std::shared_ptr<PServerKernel> optimizer = nullptr;
if (weight_key_to_optims_.count(key) > 0) {
const std::string &optim_name = weight_key_to_optims_[key];
optimizer = optimizers_[optim_name];
optimizer = optimizers_[key];
}
MS_EXCEPTION_IF_NULL(optimizer);
......@@ -391,8 +390,6 @@ void ParameterServer<T>::UpdateWeights() {
if (optim_info == nullptr) {
continue;
}
const WeightPtr &weight = weights_[key];
optim_info->UpdateWeight(weight);
const std::vector<kernel::AddressPtr> &inputs = optim_info->inputs();
const std::vector<kernel::AddressPtr> &workspaces = optim_info->workspaces();
const std::vector<kernel::AddressPtr> &outputs = optim_info->outputs();
......@@ -416,7 +413,7 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
// Create or update the optimizer info
if (optim_info == nullptr) {
const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[weight_key_to_optims_[key]];
std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
if (pserver_kernel == nullptr) {
MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
}
......@@ -427,10 +424,8 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
optim_infos_[key] = optim_info;
} else {
optim_info->Update(values, lengths);
optim_info->Accumulate(values, lengths);
}
MS_EXCEPTION_IF_NULL(optim_info);
optim_info->Accumulate(values, lengths);
grads_accum_counter_[key] += 1;
if (grads_accum_counter_[key] == worker_num_) {
......@@ -499,7 +494,7 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids,
table_lookup_op->Execute(inputs, workspaces, outputs);
res->vals = *addr;
res->lens.push_back(res.vals.size());
res->lens.push_back(res->vals.size());
}
template <typename T>
......
......@@ -48,7 +48,7 @@ class Worker {
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const std::vector<int> &sizes);
void InitPSParamAndOptim(const std::string &param_name, void *param_data, size_t param_size);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
private:
......@@ -98,7 +98,8 @@ void Worker<T>::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> add
::ps::SArray<T> total_buffer(total_size, 0);
size_t offset = 0;
for (size_t i = 0; i < sizes.size(); i++) {
memcpy(total_buffer.data() + offset / sizeof(T), addrs[i], sizes[i] * sizeof(T));
memcpy_s(total_buffer.data() + offset / sizeof(T), sizes[i] * sizeof(T), reinterpret_cast<void *>(addrs[i]),
sizes[i] * sizeof(T));
offset += sizes[i] * sizeof(T);
}
kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes));
......@@ -108,13 +109,13 @@ template <typename T>
void Worker<T>::Pull(const size_t key, void *dev_addr, const size_t size) {
::ps::SArray<T> variables(size / sizeof(T), 0);
kv_worker_->Wait(kv_worker_->ZPull({key}, &variables));
memcpy(dev_addr, variables.data(), size);
memcpy_s(dev_addr, size, variables.data(), size);
}
template <typename T>
void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd) {
kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, &lookup_result, cmd);
kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd);
}
template <typename T>
......
......@@ -22,6 +22,7 @@
#include <utility>
#include <memory>
#include <vector>
#include <unordered_set>
#include "ps/ps.h"
#include "frontend/parallel/ps/util.h"
......@@ -34,24 +35,23 @@ class WorkerProxy : public ::ps::KVWorker<T> {
using Worker = ::ps::KVWorker<T>;
using Callback = std::function<void()>;
using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>;
using Slicer =
std::function<void(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>;
using Slicer = std::function<void(int ts, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
SlicedKVs *sliced)>;
using ::ps::SimpleApp::obj_;
explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) {
using _1 = std::placeholders::_1;
using _2 = std::placeholders::_2;
using _3 = std::placeholders::_3;
using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;
using std::placeholders::_4;
lookup_customer_ = std::unique_ptr<::ps::Customer>(
new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1)));
lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3);
init_embedding_slicer_ = std::bind(&WorkerProxy<T>::EmbeddingTableInitSlicer, this, _1, _2, _3);
push_slicer_ = std::bind(&WorkerProxy<T>::PushSlicer, this, _1, _2, _3);
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3);
lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3, _4);
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4);
}
~WorkerProxy() override = default;
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd = 0, const Callback &cb = nullptr,
int priority = 0);
int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
......@@ -61,15 +61,11 @@ class WorkerProxy : public ::ps::KVWorker<T> {
private:
template <typename C>
int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, C *vals, int cmd,
int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, C *vals, int cmd,
const Callback &cb);
void LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
void LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
void EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
void PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
void BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
void BroadcastSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
void ProcessLookupResult(const ::ps::Message &msg);
void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs<T> &kvs,
......@@ -80,10 +76,9 @@ class WorkerProxy : public ::ps::KVWorker<T> {
std::unordered_map<int, std::vector<::ps::KVPairs<T>>> lookup_results_;
std::mutex mutex_;
Slicer lookup_slicer_;
Slicer init_embedding_slicer_;
Slicer push_slicer_;
Slicer broadcast_slicer_;
std::unordered_map<int, Callback> lookup_callbacks_;
std::unordered_map<int, int> expected_result_count_;
};
template <typename T>
......@@ -108,17 +103,21 @@ void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_c
}
template <typename T>
void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd, const Callback &cb,
int priority) {
int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb);
::ps::KVPairs<T> kvs;
kvs.keys = keys;
kvs.vals = lookup_ids;
kvs.lens = lens;
kvs.lens = lookup_ids;
kvs.priority = priority;
Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_);
expected_result_count_[ts] = 0;
Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_);
int server_num = ::ps::NumServers();
int expect_rt_count = expected_result_count_[ts];
lookup_customer_->AddResponse(ts, server_num - expect_rt_count);
lookup_customer_->WaitRequest(ts);
expected_result_count_.erase(ts);
}
template <typename T>
......@@ -130,7 +129,7 @@ int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, cons
kvs.vals = vals;
kvs.lens = lens;
kvs.priority = priority;
Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_);
Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_);
return ts;
}
......@@ -143,13 +142,13 @@ void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::S
kvs.vals = vals;
kvs.lens = lens;
kvs.priority = priority;
Send(obj_, ts, true, false, cmd, kvs, push_slicer_);
Send(obj_, ts, true, false, cmd, kvs, broadcast_slicer_);
obj_->WaitRequest(ts);
}
template <typename T>
template <typename C>
int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
C *lookup_result, int cmd, const Callback &cb) {
int ts = lookup_customer_->NewRequest(::ps::kServerGroup);
const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable {
......@@ -158,18 +157,28 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
mutex_.unlock();
size_t total_len = 0;
const auto &s = kvs[0];
for (size_t i = 0; i < s.lens.size(); i++) {
total_len += s.lens[i];
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int>>> id_addr_map;
for (const auto &s : kvs) {
int offset = 0;
int len = s.vals.size() / s.keys.size();
for (size_t i = 0; i < s.keys.size(); i++) {
const Key &key = s.keys[i];
T *addr = s.vals.data() + offset;
offset += len;
total_len += len;
id_addr_map[key] = std::make_shared<std::pair<T *, int>>(std::make_pair(addr, len));
}
}
lookup_result->resize(total_len, 0);
T *result_addr = lookup_result->data();
for (const auto &s : kvs) {
size_t offset = 0;
for (size_t i = 0; i < s.vals.size(); i++) {
result_addr[offset++] += s.vals[i];
T *result_addr = lookup_result->data();
int offset = 0;
for (size_t i = 0; i < lookup_ids.size(); i++) {
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
auto ret = memcpy_s(result_addr + offset, pair->second, pair->first, pair->second);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
offset += pair->second;
}
mutex_.lock();
......@@ -182,31 +191,30 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
}
template <typename T>
void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
void WorkerProxy<T>::LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
int *data = send.lens.data();
size_t size = send.lens.size();
std::vector<int> lookup_ids(data, data + size);
std::sort(lookup_ids.begin(), lookup_ids.end());
int *lookup_ids = send.lens.data();
size_t id_size = send.lens.size();
const Key &key = send.keys[0];
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
sliced->resize(ranges.size());
size_t index = 0;
for (size_t i = 0; i < ranges.size(); i++) {
const ::ps::Range &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
std::unordered_set<int> unique_ids;
auto &kvs = sliced->at(i).second;
auto lookup_id = static_cast<uint64_t>(lookup_ids[index]);
while (lookup_id >= begin && lookup_id <= end) {
kvs.vals.push_back(lookup_id);
if (++index >= lookup_ids.size()) {
break;
for (size_t j = 0; j < id_size; j++) {
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
if (lookup_id >= begin && lookup_id <= end) {
unique_ids.insert(lookup_id);
}
lookup_id = static_cast<uint64_t>(lookup_ids[index]);
}
for (const auto &lookup_id : unique_ids) {
kvs.vals.push_back(lookup_id);
}
kvs.keys.push_back(key);
kvs.lens.push_back(kvs.vals.size());
......@@ -215,35 +223,13 @@ void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vec
sliced->at(i).first = false;
} else {
sliced->at(i).first = true;
expected_result_count_[timestamp] += 1;
}
}
}
template <typename T>
void WorkerProxy<T>::EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
const Key &key = send.keys[0];
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
sliced->resize(ranges.size());
for (size_t i = 0; i < ranges.size(); i++) {
sliced->at(i).first = true;
sliced->at(i).second = send;
}
}
template <typename T>
void WorkerProxy<T>::PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
auto server_num = ::ps::Postoffice::Get()->num_servers();
sliced->resize(server_num);
for (int i = 0; i < server_num; i++) {
sliced->at(i).first = true;
sliced->at(i).second = send;
}
}
template <typename T>
void WorkerProxy<T>::BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
void WorkerProxy<T>::BroadcastSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
auto server_num = ::ps::Postoffice::Get()->num_servers();
sliced->resize(server_num);
......@@ -268,7 +254,7 @@ void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
lookup_results_[ts].push_back(kvs);
mutex_.unlock();
}
if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) {
if (lookup_customer_->NumResponse(ts) == expected_result_count_[ts] - 1) {
const auto &cb = lookup_callbacks_[ts];
cb();
lookup_callbacks_.erase(ts);
......@@ -279,7 +265,7 @@ template <typename T>
void WorkerProxy<T>::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd,
const ::ps::KVPairs<T> &kvs, const Slicer &slicer) {
SlicedKVs sliced;
slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced);
slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced);
for (size_t i = 0; i < sliced.size(); i++) {
const auto &s = sliced[i];
......
......@@ -146,6 +146,12 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite)
else()
target_link_libraries(_c_dataengine PRIVATE _c_mindrecord)
if (NOT ENABLE_GE)
target_link_libraries(_c_dataengine PRIVATE mindspore::pslite mindspore::protobuf ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
if (${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
endif()
endif()
endif()
if (USE_GLOG)
......
......@@ -40,6 +40,11 @@
#include "vm/transform.h"
#include "parse/python_adapter.h"
#include "frontend/optimizer/py_pass_manager.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/parameter_server.h"
#include "frontend/parallel/ps/scheduler.h"
#include "frontend/parallel/ps/worker.h"
#endif
namespace mindspore {
namespace pipeline {
......@@ -374,6 +379,25 @@ bool ExecuteAction(const ResourcePtr &res) {
return true;
}
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
bool StartPSWorkerAction(const ResourcePtr &res) {
parallel::ps::Worker<float>::GetInstance().Run();
return true;
}
bool StartPSServerAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
auto &ps = parallel::ps::ParameterServer<float>::GetInstance();
ps.Run(func_graph);
return true;
}
bool StartPSSchedulerAction(const ResourcePtr &res) {
parallel::ps::Scheduler::GetInstance().Run();
return true;
}
#endif
// The parallel primitive related valuenode might be partitioned so that its value changes by device,
// that will result in a syncronization error due to different executing order.
// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
......@@ -481,7 +505,11 @@ std::vector<ActionItem> VmPipeline() {
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
actions.emplace_back(std::make_pair("validate", ValidateAction));
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
if (parallel::ps::Util::IsRoleOfWorker()) {
actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
}
#endif
// compile the ANF graph
actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
......@@ -490,5 +518,21 @@ std::vector<ActionItem> VmPipeline() {
return actions;
}
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
std::vector<ActionItem> PServerPipeline() {
auto actions = CommonPipeline();
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
return actions;
}
std::vector<ActionItem> PSchedulerPipeline() {
std::vector<ActionItem> actions;
actions.emplace_back(std::make_pair("scheduler", StartPSSchedulerAction));
return actions;
}
#endif
} // namespace pipeline
} // namespace mindspore
......@@ -38,9 +38,14 @@ bool VmOptimizeAction(const ResourcePtr &res);
bool PynativeOptimizeAction(const ResourcePtr &res);
bool TaskEmitAction(const ResourcePtr &res);
bool ExecuteAction(const ResourcePtr &res);
bool StartPSWorkerAction(const ResourcePtr &res);
bool StartPSServerAction(const ResourcePtr &res);
bool StartPSSchedulerAction(const ResourcePtr &res);
std::vector<ActionItem> GePipeline();
std::vector<ActionItem> VmPipeline();
std::vector<ActionItem> PServerPipeline();
std::vector<ActionItem> PSchedulerPipeline();
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec, bool clear = false);
FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
......
......@@ -41,6 +41,11 @@
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/util.h"
#endif
#if (ENABLE_GE || ENABLE_D)
#include "pipeline/jit/pipeline_ge.h"
#include "transform/graph_ir/convert.h"
......@@ -420,6 +425,26 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
use_vm = ChangeExportGeirUseVmFlag(use_vm, phase_s);
std::string backend = MsContext::GetInstance()->backend_policy();
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
if (mindspore::parallel::ps::Util::IsParamServerMode()) {
mindspore::parallel::ps::Util::SetInternalEnvVar();
}
if (parallel::ps::Util::IsRoleOfPServer()) {
resource->results()[kBackend] = compile::CreateBackend();
p_actions = PServerPipeline();
} else if (parallel::ps::Util::IsRoleOfScheduler()) {
p_actions = PSchedulerPipeline();
} else if (use_vm && backend != "ge") {
// Create backend and session
auto backend_ptr = compile::CreateBackend();
// Connect session to debugger
backend_ptr->SetDebugger();
resource->results()[kBackend] = backend_ptr;
p_actions = VmPipeline();
} else {
p_actions = GePipeline();
}
#else
if (use_vm && backend != "ge") {
// Create backend and session
auto backend_ptr = compile::CreateBackend();
......@@ -430,6 +455,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
} else {
p_actions = GePipeline();
}
#endif
std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(p_actions, phase_s));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册