提交 82103508 编写于 作者: Q qiaolongfei

add getParametersRemote for ParameterUpdater in api

上级 64bfd814
...@@ -859,6 +859,13 @@ public: ...@@ -859,6 +859,13 @@ public:
*/ */
void update(Parameter* param); void update(Parameter* param);
/**
* @breif only get required sparse rows by default.
* @param fullSize: get full matrix parameter if *fullSize* set
* @param apply: get PARAMETER_APPLY on pserver if *apply* set
*/
void getParametersRemote(bool fullSize = false, bool apply = false);
/** /**
* @brief restore the average parameter. * @brief restore the average parameter.
* @note It is only used in AverageOptimizer. Restore will get the current * @note It is only used in AverageOptimizer. Restore will get the current
......
...@@ -72,6 +72,10 @@ void ParameterUpdater::update(Parameter *param) { ...@@ -72,6 +72,10 @@ void ParameterUpdater::update(Parameter *param) {
m->updater->update(paddleParam); m->updater->update(paddleParam);
} }
void ParameterUpdater::getParametersRemote(bool fullSize, bool apply) {
m->updater->getParametersRemote(fullSize, apply);
}
void ParameterUpdater::restore() { m->updater->restore(); } void ParameterUpdater::restore() { m->updater->restore(); }
void ParameterUpdater::apply() { m->updater->apply(); } void ParameterUpdater::apply() { m->updater->apply(); }
......
...@@ -78,10 +78,12 @@ class Topology(object): ...@@ -78,10 +78,12 @@ class Topology(object):
check if any parameter require to use sparse_update check if any parameter require to use sparse_update
:return: :return:
""" """
use_sparse = False
for parameter in self.__model_config__.parameters: for parameter in self.__model_config__.parameters:
if parameter.sparse_update or parameter.sparse_remote_update: if parameter.sparse_update or parameter.sparse_remote_update:
return True use_sparse = True
return False break
return use_sparse
def proto(self): def proto(self):
return self.__model_config__ return self.__model_config__
......
...@@ -65,7 +65,6 @@ class SGD(object): ...@@ -65,7 +65,6 @@ class SGD(object):
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater() self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
# # In local mode, disable sparse_remote_update. # # In local mode, disable sparse_remote_update.
if is_local: if is_local:
self.__use_sparse_updater__ = False
for param in self.__topology_in_proto__.parameters: for param in self.__topology_in_proto__.parameters:
if param.sparse_remote_update: if param.sparse_remote_update:
param.sparse_remote_update = False param.sparse_remote_update = False
...@@ -100,11 +99,11 @@ class SGD(object): ...@@ -100,11 +99,11 @@ class SGD(object):
__check_train_args__(**locals()) __check_train_args__(**locals())
if self.__is_local__: if self.__is_local__:
updater = self.__optimizer__.create_local_updater() parameter_updater = self.__optimizer__.create_local_updater()
else: else:
updater = self.__optimizer__.create_remote_updater( parameter_updater = self.__optimizer__.create_remote_updater(
num_passes, self.__use_sparse_updater__) num_passes, self.__use_sparse_updater__)
updater.init(self.__gradient_machine__) parameter_updater.init(self.__gradient_machine__)
self.__gradient_machine__.start() self.__gradient_machine__.start()
batch_evaluator = self.__gradient_machine__.makeEvaluator() batch_evaluator = self.__gradient_machine__.makeEvaluator()
...@@ -116,26 +115,26 @@ class SGD(object): ...@@ -116,26 +115,26 @@ class SGD(object):
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id)) event_handler(v2_event.BeginPass(pass_id))
pass_evaluator.start() pass_evaluator.start()
updater.startPass() parameter_updater.startPass()
for batch_id, data_batch in enumerate(reader()): for batch_id, data_batch in enumerate(reader()):
batch_evaluator.start() batch_evaluator.start()
event_handler( event_handler(
v2_event.BeginIteration( v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id)) pass_id=pass_id, batch_id=batch_id))
pass_type = updater.startBatch(len(data_batch)) pass_type = parameter_updater.startBatch(len(data_batch))
if self.__use_sparse_updater__: if self.__use_sparse_updater__ and not self.__is_local__:
self.__gradient_machine__.prefetch(feeder(data_batch)) self.__gradient_machine__.prefetch(feeder(data_batch))
updater.getParametersRemote() parameter_updater.getParametersRemote()
self.__gradient_machine__.forwardBackward( self.__gradient_machine__.forwardBackward(
feeder(data_batch), out_args, pass_type) feeder(data_batch), out_args, pass_type)
self.__gradient_machine__.eval(pass_evaluator) self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator) self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getNonStaticParameters( for each_param in self.__gradient_machine__.getNonStaticParameters(
): ):
updater.update(each_param) parameter_updater.update(each_param)
cost_sum = out_args.sum() cost_sum = out_args.sum()
cost = cost_sum / len(data_batch) cost = cost_sum / len(data_batch)
updater.finishBatch(cost) parameter_updater.finishBatch(cost)
batch_evaluator.finish() batch_evaluator.finish()
event_handler( event_handler(
v2_event.EndIteration( v2_event.EndIteration(
...@@ -144,7 +143,7 @@ class SGD(object): ...@@ -144,7 +143,7 @@ class SGD(object):
cost=cost, cost=cost,
evaluator=batch_evaluator)) evaluator=batch_evaluator))
updater.finishPass() parameter_updater.finishPass()
pass_evaluator.finish() pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
self.__gradient_machine__.finish() self.__gradient_machine__.finish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册