提交 61551b85 编写于 作者: Z ZPaC

incremental feature for ps

上级 1625a27a
......@@ -62,6 +62,8 @@ void SparseApplyAdamPSKernel::InitKernel(
*/
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float));
}
......
......@@ -52,6 +52,8 @@ void SparseApplyFtrlPSKernel::InitKernel(
lr_power_ = -0.5;
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
}
void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
......
......@@ -72,13 +72,10 @@ using Values = ::ps::SArray<float>;
using ValuesPtr = std::shared_ptr<Values>;
using Weight = ::ps::SArray<float>;
using Grad = ::ps::SArray<float>;
using LookupIds = ::ps::SArray<float>;
using LookupIds = ::ps::SArray<Key>;
using Lengths = ::ps::SArray<int>;
using WeightPtr = std::shared_ptr<Weight>;
using GradPtr = std::shared_ptr<Grad>;
// using EmbeddingTable = std::unordered_map<int, WeightPtr>;
// using EmbeddingTable = ::ps::SArray<float>;
// using EmbeddingTablePtr = std::shared_ptr<EmbeddingTable>;
using InputsShape = std::vector<std::shared_ptr<std::vector<size_t>>>;
using InputsShapePtr = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>>;
} // namespace ps
......
......@@ -57,6 +57,8 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
}
}
void DenseOptimInfo::Reset() { memset_s(gradient()->addr, gradient()->size, 0x00, gradient()->size); }
void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
// Append grad data to the end
float *accum_grad_data = reinterpret_cast<float *>(gradient()->addr);
......
......@@ -58,6 +58,7 @@ class DenseOptimInfo : public OptimizerInfo {
~DenseOptimInfo() override = default;
void Accumulate(const Values &values, const Lengths &lens) override;
void Reset() override;
};
class SparseOptimInfo : public OptimizerInfo {
......
......@@ -58,6 +58,7 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co
AddressPtr accumulate = std::make_shared<kernel::Address>();
accumulate->addr = new float[weight->size()];
accumulate->size = weight->size() * sizeof(float);
memset_s(accumulate->addr, accumulate->size, 0x00, accumulate->size);
AddressPtr learning_rate = std::make_shared<kernel::Address>();
learning_rate->addr = copy_data_ptr;
learning_rate->size = lens[0] * sizeof(float);
......
......@@ -30,7 +30,6 @@
#include <random>
#include "ir/func_graph.h"
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/session_factory.h"
#include "frontend/parallel/ps/common.h"
......@@ -70,24 +69,32 @@ class ParameterServer {
ps_(new ::ps::KVServer<T>(0)),
handler_(nullptr),
func_graph_(nullptr),
kernel_graph_(nullptr),
sess_(nullptr),
thread_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
ParameterServer &operator=(const ParameterServer &) = delete;
struct ServerHandler {
class ServerHandler {
public:
explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
void Init();
void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVServer<T> *server);
void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data);
private:
void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleInitWeights(const ::ps::KVPairs<T> &req_data);
void HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data);
void HandleInitInputsShape(const ::ps::KVPairs<T> &req_data);
void HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data);
void HandleInitWeights(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res);
void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
ParameterServer *ps_;
typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res);
std::unordered_map<int, RequestHandler> handlers_;
};
bool Init(const FuncGraphPtr &func_graph);
......@@ -103,7 +110,6 @@ class ParameterServer {
WeightPtr weight(const Key &key);
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
int SumOfShapes(const std::vector<int> &shapes) const;
size_t PreComputeCapacity(const Keys &keys, const Lengths &lens);
bool ReadyForUpdateWeights();
bool ReadyForAccumGrads();
void ResetGradAccumCount();
......@@ -115,7 +121,6 @@ class ParameterServer {
std::unique_ptr<::ps::KVServer<T>> ps_;
std::unique_ptr<ServerHandler> handler_;
FuncGraphPtr func_graph_;
std::shared_ptr<session::KernelGraph> kernel_graph_;
std::shared_ptr<session::SessionBasic> sess_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
......@@ -126,12 +131,7 @@ class ParameterServer {
std::unordered_map<Key, WeightPtr> weights_;
std::unordered_map<Key, WeightPtr> grads_;
std::unordered_map<Key, size_t> grads_accum_counter_;
// std::unordered_map<Key, EmbeddingTablePtr> embeddings_;
std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
std::unordered_map<Key, size_t> embedding_row_lens_;
T learning_rate_;
T momentum_;
std::mutex mutex_;
std::condition_variable apply_grads_cv_;
......@@ -139,7 +139,7 @@ class ParameterServer {
std::unique_ptr<std::thread> thread_;
friend struct ServerHandler;
friend class ServerHandler;
};
class FuncGraph;
......@@ -147,33 +147,29 @@ template <typename T>
void ParameterServer<T>::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVServer<T> *server) {
::ps::KVPairs<T> res;
if (req_meta.cmd == kInitWeightsCmd) {
MS_LOG(ERROR) << "handle init weights cmd" << std::endl;
HandleInitWeights(req_data);
} else if (req_meta.cmd == kInitWeightToOptimIdCmd) {
MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl;
HandleInitWeightToOptimId(req_data);
} else if (req_meta.cmd == kInitOptimInputsShapeCmd) {
MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl;
HandleInitInputsShape(req_data);
} else if (req_meta.cmd == kInitEmbeddingsCmd) {
MS_LOG(ERROR) << "handle init embedding cmd" << std::endl;
HandleInitEmbeddings(req_data);
} else if (req_meta.cmd == kEmbeddingLookupCmd) {
MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl;
HandleEmbeddingLookup(req_meta, req_data, &res);
if (handlers_.count(req_meta.cmd) > 0) {
auto &handler_ptr = handlers_[req_meta.cmd];
(this->*handler_ptr)(req_meta, req_data, &res);
} else if (req_meta.push) {
MS_LOG(ERROR) << "handle push req cmd" << std::endl;
HandlePushReq(req_meta, req_data);
HandlePushReq(req_meta, req_data, &res);
} else {
MS_LOG(ERROR) << "handle pull req cmd" << std::endl;
HandlePullReq(req_meta, req_data, &res);
}
server->Response(req_meta, res);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data) {
void ParameterServer<T>::ServerHandler::Init() {
handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights;
handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens);
}
......@@ -186,7 +182,8 @@ void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_me
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVPairs<T> &req_data) {
void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
size_t key_num = req_data.keys.size();
T *data_ptr = req_data.vals.data();
size_t pos = 0;
......@@ -205,7 +202,9 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVPairs<T>
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs<T> &req_data) {
void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data,
::ps::KVPairs<T> *res) {
size_t key_num = req_data.keys.size();
for (size_t i = 0; i < key_num; i++) {
Key key = req_data.keys[i];
......@@ -215,12 +214,14 @@ void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KV
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs<T> &req_data) {
void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens);
}
template <typename T>
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs<T> &req_data) {
void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
......@@ -249,10 +250,10 @@ template <typename T>
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
const Key &key = req_data.keys[0];
for (size_t i = 0; i < req_data.vals.size(); i++) {
res->keys.push_back(req_data.vals[i]);
for (size_t i = 0; i < req_data.keys.size(); i++) {
res->keys.push_back(req_data.keys[i]);
}
ps_->DoEmbeddingLookup(key, req_data.vals, res);
ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
}
template <typename T>
......@@ -268,6 +269,7 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
func_graph_ = func_graph;
rank_id_ = ::ps::MyRank();
handler_.reset(new ServerHandler(this));
handler_->Init();
InitOptimInfoBuilders();
......@@ -364,7 +366,13 @@ void ParameterServer<T>::InitEmbeddingTable(
for (auto shape : input_shapes) {
total_dims *= shape;
}
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0.01);
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
std::default_random_engine engine;
std::normal_distribution<float> random(0, 0.01);
for (size_t i = 0; i < total_dims; i++) {
(*embedding)[i] = random(engine);
}
weights_[key] = embedding;
grads_accum_counter_[key] = 0;
......@@ -480,8 +488,13 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids,
inputs.push_back(indices);
embedding_table->addr = table_ptr->data();
embedding_table->size = table_ptr->size() * sizeof(T);
indices->addr = lookup_ids.data();
indices->size = lookup_ids.size() * sizeof(T);
std::unique_ptr<int[]> tmp_ids(new int[lookup_ids.size()]);
for (size_t i = 0; i < lookup_ids.size(); i++) {
tmp_ids[i] = static_cast<int>(lookup_ids[i]);
}
indices->addr = tmp_ids.get();
indices->size = lookup_ids.size() * sizeof(int);
std::vector<kernel::AddressPtr> workspaces;
std::vector<kernel::AddressPtr> outputs;
......@@ -506,20 +519,6 @@ int ParameterServer<T>::SumOfShapes(const std::vector<int> &shapes) const {
return sum;
}
template <typename T>
size_t ParameterServer<T>::PreComputeCapacity(const Keys &keys, const Lengths &lens) {
size_t capacity = 0;
for (size_t i = 0; i < keys.size(); i++) {
Key key = keys[i];
if (embedding_row_lens_.count(key) > 0) {
capacity += embedding_row_lens_[key] * lens[i];
} else {
MS_LOG(ERROR) << "Invalid embedding lookup id " << key;
}
}
return capacity;
}
template <typename T>
inline bool ParameterServer<T>::ReadyForUpdateWeights() {
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
......
......@@ -155,9 +155,9 @@ void Worker<T>::InitPSOptimInputShapes(const size_t key) {
}
}
}
MS_LOG(ERROR) << "keys:" << keys;
MS_LOG(ERROR) << "shape_len:" << shape_len;
MS_LOG(ERROR) << "all_shape:" << all_shape;
MS_LOG(INFO) << "keys:" << keys;
MS_LOG(INFO) << "shape_len:" << shape_len;
MS_LOG(INFO) << "all_shape:" << all_shape;
if (!init_keys_[key]) {
init_keys_[key] = true;
}
......@@ -191,7 +191,7 @@ size_t Worker<T>::GetParamKey(const std::string &param_name) {
size_t key = kInvalidKey;
if (param_to_key_.find(param_name) != param_to_key_.end()) {
key = param_to_key_[param_name];
MS_LOG(ERROR) << "Get key of parameter " << param_name << " key is " << key;
MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key;
}
return key;
}
......@@ -251,6 +251,10 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, void *param_d
template <typename T>
void Worker<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
bool has_init = IsKeyInit(key);
if (has_init) {
return;
}
kv_worker_->AddEmbeddingTable(key, row_count);
}
......
......@@ -156,30 +156,8 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
auto &kvs = lookup_results_[ts];
mutex_.unlock();
size_t total_len = 0;
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int>>> id_addr_map;
for (const auto &s : kvs) {
int offset = 0;
int len = s.vals.size() / s.keys.size();
for (size_t i = 0; i < s.keys.size(); i++) {
const Key &key = s.keys[i];
T *addr = s.vals.data() + offset;
offset += len;
total_len += len;
id_addr_map[key] = std::make_shared<std::pair<T *, int>>(std::make_pair(addr, len));
}
}
T *result_addr = lookup_result->data();
int offset = 0;
for (size_t i = 0; i < lookup_ids.size(); i++) {
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
auto ret = memcpy_s(result_addr + offset, pair->second, pair->first, pair->second);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
offset += pair->second;
}
auto &s = kvs[0];
*lookup_result = s.vals;
mutex_.lock();
lookup_results_.erase(ts);
......@@ -201,25 +179,16 @@ void WorkerProxy<T>::LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send,
sliced->resize(ranges.size());
for (size_t i = 0; i < ranges.size(); i++) {
const ::ps::Range &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
std::unordered_set<int> unique_ids;
auto &kvs = sliced->at(i).second;
kvs.keys.push_back(key);
kvs.vals.push_back(0.0f);
for (size_t j = 0; j < id_size; j++) {
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
if (lookup_id >= begin && lookup_id <= end) {
unique_ids.insert(lookup_id);
}
kvs.keys.push_back(lookup_ids[j]);
kvs.vals.push_back(0.0f);
}
for (const auto &lookup_id : unique_ids) {
kvs.vals.push_back(lookup_id);
}
kvs.keys.push_back(key);
kvs.lens.push_back(kvs.vals.size());
if (kvs.vals.size() == 0) {
if (kvs.keys.size() <= 1) {
sliced->at(i).first = false;
} else {
sliced->at(i).first = true;
......
......@@ -318,7 +318,6 @@ const std::set<std::string> kOptOperatorSet = {
kApplyProximalAdagradOpName,
kApplyProximalGradientDescentOpName,
kApplyRMSPropOpName,
kPushOpName,
kPullOpName,
};
......
......@@ -61,6 +61,7 @@ class Parameter:
self._is_init = False
self._sliced = False
self.is_param_ps = False
self.init_in_server = False
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
......@@ -71,8 +72,9 @@ class Parameter:
def __parameter__(self):
"""For parse check."""
def set_param_ps(self):
def set_param_ps(self, init_in_server=False):
self.is_param_ps = True
self.init_in_server = init_in_server
@property
def name(self):
......@@ -251,9 +253,15 @@ class Parameter:
raise ValueError("The length of layout must be larger than 3! layout is {}."
.format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1]))
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)):
self.default_input = self.init_mode.to_tensor(0, [1])
else:
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
else:
self.default_input = self.init_mode.to_tensor()
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)):
self.default_input = self.init_mode.to_tensor(0, [1])
else:
self.default_input = self.init_mode.to_tensor()
self.init_mode = None
if set_sliced:
......
......@@ -113,6 +113,8 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error.
"""
def wrapper(*args, **kargs):
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
return func(*args, **kargs)
group = None
if "group" in kargs.keys():
group = kargs.get("group")
......
......@@ -831,7 +831,7 @@ class Cell:
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self.enable_hook = True
def set_param_ps(self, recurse=True):
def set_param_ps(self, recurse=True, init_in_server=False):
"""
Set whether the trainable parameter is updated by parameter server.
......@@ -843,7 +843,7 @@ class Cell:
"""
params = self.trainable_params(recurse)
for param in params:
param.set_param_ps()
param.set_param_ps(init_in_server)
class GraphKernel(Cell):
"""
......
......@@ -85,7 +85,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
return gradient
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "IndexedSlices",
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "IndexedSlices",
"Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2, ps_parameter):
......@@ -108,7 +108,7 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2
return success
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor",
@_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2, ps_parameter):
......@@ -276,7 +276,7 @@ class Adam(Optimizer):
self.beta2 = Tensor(beta2, mstype.float32)
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
self.eps = eps
self.eps = Tensor(eps, mstype.float32)
self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
......
......@@ -32,7 +32,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment,
if ps_parameter:
op_shape = P.Shape()
_ps_pull = P.Pull()
_ps_push = P.Push("Momentum", [])
_ps_push = P.Push("ApplyMomentum", [])
shapes = (op_shape(learning_rate), op_shape(gradient), op_shape(momentum))
success = F.depend(success, _ps_pull(_ps_push((learning_rate, gradient, momentum), shapes), weight))
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册