提交 8e0d1d8b 编写于 作者: Y Yu Yang

Complete update equation.

上级 e13d9c74
...@@ -49,7 +49,7 @@ def main(): ...@@ -49,7 +49,7 @@ def main():
:return: :return:
""" """
mu = 0.09 mu = 0.09
e = 0.0001 e = 0.00001
vel_t = mu * vel_t_1 - e * g vel_t = mu * vel_t_1 - e * g
......
...@@ -25,7 +25,7 @@ class CompleteTrainOneBatch(BaseEvent): ...@@ -25,7 +25,7 @@ class CompleteTrainOneBatch(BaseEvent):
self.pass_id = pass_id self.pass_id = pass_id
self.batch_id = batch_id self.batch_id = batch_id
self.cost = cost self.cost = cost
self.paramters = parameters self.parameters = parameters
def default_event_handler(event): def default_event_handler(event):
...@@ -44,6 +44,17 @@ class ITrainer(object): ...@@ -44,6 +44,17 @@ class ITrainer(object):
class LazyParameterPool(v2_parameters.IParameterPool): class LazyParameterPool(v2_parameters.IParameterPool):
""" """
Lazy Parameter Pool stores a reference to GradientMachine. User could invoke
`get_parameter` if needed, but the operation is lazy. It means the parameter
will only fetched from GPU or Parameter Server if `get_parameter` is
invoked. Also, set flag = writable will make a extra host2device copy after
reading/modifying parameter.
This class is not exposed to User. User should treat this class as a normal
IParameterPool.
See IParameterPool for usage documentation.
:type __gradient_machine__: api.GradientMachine :type __gradient_machine__: api.GradientMachine
""" """
...@@ -130,12 +141,22 @@ class CustomizeUpdateEquation(object): ...@@ -130,12 +141,22 @@ class CustomizeUpdateEquation(object):
shape) shape)
g = param.getBuf(api.PARAMETER_GRADIENT).toNumpyArrayInplace( g = param.getBuf(api.PARAMETER_GRADIENT).toNumpyArrayInplace(
).reshape(shape) ).reshape(shape)
args = [v, g]
for arg in self.local_params[conf.name]:
args.append(arg)
self.__callback__(*args)
else: else:
raise NotImplementedError() v = param.getBuf(api.PARAMETER_VALUE).copyToNumpyArray().reshape(
shape)
g = param.getBuf(api.PARAMETER_GRADIENT).copyToNumpyArray().reshape(
shape)
args = [v, g]
for arg in self.local_params[conf.name]:
args.append(arg)
self.__callback__(*args)
if api.isUsingGpu():
param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(v.flatten(
).astype('float32'))
# discard gradient changed.
class SGDTrainer(ITrainer): class SGDTrainer(ITrainer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册