From 7324639e551a06ddcc1fe6ee8729ee4ab98ec283 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Sat, 8 Aug 2020 14:51:43 +0800 Subject: [PATCH] Fix ps traing precision error. --- mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc | 9 +++++++++ mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h | 1 + 2 files changed, 10 insertions(+) diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index 5f25b79c2..5801b241e 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -126,6 +126,15 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr inputs_.push_back(momentum); } +void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) { + size_t lr_offset = 0; + float *lr = values.data() + lr_offset; + auto ret = memcpy_s(inputs_[2]->addr, sizeof(float), lr, sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } +} + const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h index dc567e023..f59d8ad6c 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -82,6 +82,7 @@ class MomentumOptimInfo : public DenseOptimInfo { const AddressPtr &gradient, const AddressPtr &momentum); ~MomentumOptimInfo() override = default; + void Update(const Values &values, const Lengths &lens) override; const AddressPtr &gradient(); const AddressPtr &indices(); size_t grad_index() override; -- GitLab