提交 8e0d1d8b 编写于 作者: Y Yu Yang

Complete update equation.

上级 e13d9c74
......@@ -49,7 +49,7 @@ def main():
:return:
"""
mu = 0.09
e = 0.0001
e = 0.00001
vel_t = mu * vel_t_1 - e * g
......
......@@ -25,7 +25,7 @@ class CompleteTrainOneBatch(BaseEvent):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
self.paramters = parameters
self.parameters = parameters
def default_event_handler(event):
......@@ -44,6 +44,17 @@ class ITrainer(object):
class LazyParameterPool(v2_parameters.IParameterPool):
"""
Lazy Parameter Pool stores a reference to GradientMachine. User could invoke
`get_parameter` if needed, but the operation is lazy. It means the parameter
will only fetched from GPU or Parameter Server if `get_parameter` is
invoked. Also, set flag = writable will make a extra host2device copy after
reading/modifying parameter.
This class is not exposed to User. User should treat this class as a normal
IParameterPool.
See IParameterPool for usage documentation.
:type __gradient_machine__: api.GradientMachine
"""
......@@ -130,12 +141,22 @@ class CustomizeUpdateEquation(object):
shape)
g = param.getBuf(api.PARAMETER_GRADIENT).toNumpyArrayInplace(
).reshape(shape)
args = [v, g]
for arg in self.local_params[conf.name]:
args.append(arg)
self.__callback__(*args)
else:
raise NotImplementedError()
v = param.getBuf(api.PARAMETER_VALUE).copyToNumpyArray().reshape(
shape)
g = param.getBuf(api.PARAMETER_GRADIENT).copyToNumpyArray().reshape(
shape)
args = [v, g]
for arg in self.local_params[conf.name]:
args.append(arg)
self.__callback__(*args)
if api.isUsingGpu():
param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(v.flatten(
).astype('float32'))
# discard gradient changed.
class SGDTrainer(ITrainer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册