提交 f6c5b6fd 编写于 作者: Q qiaolongfei

add prefetch for trainer.test

上级 82103508
...@@ -79,6 +79,10 @@ class SGD(object): ...@@ -79,6 +79,10 @@ class SGD(object):
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters() self.__gradient_machine__.randParameters()
parameters.append_gradient_machine(gm) parameters.append_gradient_machine(gm)
self.__parameter_updater__ = None
def use_remote_sparse_updater(self):
return self.__use_sparse_updater__ and not self.__is_local__
def train(self, reader, num_passes=1, event_handler=None, feeding=None): def train(self, reader, num_passes=1, event_handler=None, feeding=None):
""" """
...@@ -103,6 +107,7 @@ class SGD(object): ...@@ -103,6 +107,7 @@ class SGD(object):
else: else:
parameter_updater = self.__optimizer__.create_remote_updater( parameter_updater = self.__optimizer__.create_remote_updater(
num_passes, self.__use_sparse_updater__) num_passes, self.__use_sparse_updater__)
self.__parameter_updater__ = parameter_updater
parameter_updater.init(self.__gradient_machine__) parameter_updater.init(self.__gradient_machine__)
self.__gradient_machine__.start() self.__gradient_machine__.start()
...@@ -122,11 +127,12 @@ class SGD(object): ...@@ -122,11 +127,12 @@ class SGD(object):
v2_event.BeginIteration( v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id)) pass_id=pass_id, batch_id=batch_id))
pass_type = parameter_updater.startBatch(len(data_batch)) pass_type = parameter_updater.startBatch(len(data_batch))
if self.__use_sparse_updater__ and not self.__is_local__: in_args = feeder(data_batch)
self.__gradient_machine__.prefetch(feeder(data_batch)) if self.use_remote_sparse_updater():
self.__gradient_machine__.prefetch(in_args)
parameter_updater.getParametersRemote() parameter_updater.getParametersRemote()
self.__gradient_machine__.forwardBackward( self.__gradient_machine__.forwardBackward(in_args, out_args,
feeder(data_batch), out_args, pass_type) pass_type)
self.__gradient_machine__.eval(pass_evaluator) self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator) self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getNonStaticParameters( for each_param in self.__gradient_machine__.getNonStaticParameters(
...@@ -157,8 +163,11 @@ class SGD(object): ...@@ -157,8 +163,11 @@ class SGD(object):
num_samples = 0.0 num_samples = 0.0
for data_batch in reader(): for data_batch in reader():
num_samples += len(data_batch) num_samples += len(data_batch)
self.__gradient_machine__.forward( in_args = feeder(data_batch)
feeder(data_batch), out_args, api.PASS_TEST) if self.use_remote_sparse_updater():
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
total_cost += out_args.sum() total_cost += out_args.sum()
self.__gradient_machine__.eval(evaluator) self.__gradient_machine__.eval(evaluator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册