提交 e13d9c74 编写于 作者: Y Yu Yang

Expose Parameter to train event handler

* User can get/set parameter in event now.
* Add update equation
上级 176d44ef
......@@ -39,11 +39,24 @@ def main():
array = pool.get_parameter(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
trainer = paddle_v2.trainer.SGDTrainer(
update_equation=paddle_v2.optimizer.Adam(
learning_rate=1e-4,
model_average=ModelAverage(average_window=0.5),
regularization=L2Regularization(rate=0.5)))
def nag(v, g, vel_t_1):
"""
NAG Optimizer. A optimizer which Paddle CPP is not implemented.
https://arxiv.org/pdf/1212.0901v2.pdf eq.6 eq.7
:param v: parameter value
:param g: parameter gradient
:param vel_t_1: t-1 velocity
:return:
"""
mu = 0.09
e = 0.0001
vel_t = mu * vel_t_1 - e * g
v[:] = v + (mu**2) * vel_t - (1 + mu) * e * g
vel_t_1[:] = vel_t
trainer = paddle_v2.trainer.SGDTrainer(update_equation=nag)
trainer.train(train_data_reader=train_reader,
topology=model_config,
......
import collections
from paddle.proto.ModelConfig_pb2 import ModelConfig
import paddle.v2.parameters
import paddle.v2.optimizer
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
from . import parameters as v2_parameters
import numpy
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
......@@ -20,10 +21,11 @@ class CompleteTrainOneBatch(BaseEvent):
Event On One Batch Training Complete.
"""
def __init__(self, pass_id, batch_id, cost):
def __init__(self, pass_id, batch_id, cost, parameters):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
self.paramters = parameters
def default_event_handler(event):
......@@ -40,6 +42,102 @@ class ITrainer(object):
raise NotImplementedError()
class LazyParameterPool(v2_parameters.IParameterPool):
"""
:type __gradient_machine__: api.GradientMachine
"""
def get_parameter(self, name, flag=v2_parameters.ParameterFlag.READ_WRITE):
param = filter(lambda x: x.getName() == name,
self.__gradient_machine__.getParameters())
if len(param) == 0:
raise ValueError("Cannot found parameter with name %s" % name)
elif len(param) > 1:
raise RuntimeError("Unexpected branch")
else:
conf = param[0].getConfig().toProto()
param = param[0].getBuf(api.PARAMETER_VALUE)
assert isinstance(param, api.Vector)
assert isinstance(conf, ParameterConfig)
shape = map(int, conf.dims)
if api.isUsingGpu():
arr = param.copyToNumpyArray().reshape(shape)
if flag & v2_parameters.ParameterFlag.WRITE_ONLY:
self.need_copy = True
self.arrays[name] = arr
else:
arr = param.toNumpyArrayInplace().reshape(shape)
return arr
def get_names(self):
return [
param.getName()
for param in self.__gradient_machine__.getParameters()
]
def __init__(self, gradient_machine):
self.__gradient_machine__ = gradient_machine
self.need_copy = False
self.arrays = dict()
class CustomizeUpdateEquation(object):
def __init__(self, callback):
self.__callback__ = callback
if self.__callback__.func_code.co_argcount < 2:
raise ValueError(
"The update equation at least should contain 2 arguments, "
"first is value, second is gradient")
self.local_params_count = self.__callback__.func_code.co_argcount - 2
self.local_params = dict()
def enable_types(self):
return [api.PARAMETER_VALUE, api.PARAMETER_GRADIENT]
def init(self, gradient_machine):
assert isinstance(gradient_machine, api.GradientMachine)
for param in gradient_machine.getParameters():
conf = param.getConfig().toProto()
shape = map(int, conf.dims)
self.local_params[conf.name] = []
for _ in xrange(self.local_params_count):
self.local_params[conf.name].append(
numpy.zeros(
shape=shape, dtype='float32'))
def create_local_updater(self):
return self
def startPass(self):
pass
def finishPass(self):
pass
def startBatch(self, batch_size):
return api.PASS_TRAIN
def finishBatch(self, cost):
pass
def update(self, param):
conf = param.getConfig().toProto()
shape = map(int, conf.dims)
if not api.isUsingGpu():
v = param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace().reshape(
shape)
g = param.getBuf(api.PARAMETER_GRADIENT).toNumpyArrayInplace(
).reshape(shape)
args = [v, g]
for arg in self.local_params[conf.name]:
args.append(arg)
self.__callback__(*args)
else:
raise NotImplementedError()
class SGDTrainer(ITrainer):
def __init__(self, update_equation):
"""
......@@ -47,8 +145,8 @@ class SGDTrainer(ITrainer):
:param update_equation: Maybe we should give a DSL for update equation?
"""
if not isinstance(update_equation, paddle.v2.optimizer.Optimizer):
raise ValueError()
if callable(update_equation):
update_equation = CustomizeUpdateEquation(update_equation)
self.__optimizer__ = update_equation
......@@ -87,7 +185,6 @@ class SGDTrainer(ITrainer):
__copy_parameter_from_pool__(gm, parameters)
updater = self.__optimizer__.create_local_updater()
assert isinstance(updater, api.ParameterUpdater)
updater.init(gm)
gm.start()
......@@ -115,10 +212,16 @@ class SGDTrainer(ITrainer):
cost_vec = cost_vec.copyToNumpyMat()
cost = cost_vec.sum() / len(data_batch)
updater.finishBatch(cost)
pool = LazyParameterPool(gradient_machine=gm)
event_handler(
CompleteTrainOneBatch(
pass_id=pass_id, batch_id=batch_id, cost=cost))
pass_id=pass_id,
batch_id=batch_id,
cost=cost,
parameters=pool))
if pool.need_copy:
__copy_parameter_from_lazy_pool__(gm, pool)
updater.finishPass()
gm.finish()
......@@ -153,20 +256,30 @@ def __generator_to_batch__(generator, batch_size):
yield ret_val
def __copy_parameter_from_lazy_pool__(gm, pool):
assert isinstance(pool, LazyParameterPool)
for each_param_name in pool.arrays.keys():
param = filter(lambda x: x.getName() == each_param_name,
gm.getParameters())
assert len(param) == 1
param = param[0]
param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(pool.arrays[
each_param_name].flatten().astype('float32'))
def __copy_parameter_from_pool__(gm, pool):
"""
:param gm:
:type gm: api.GradientMachine
:param pool:
:type pool: paddle.v2.parameters.IParameterPool
:type pool: v2_parameters.IParameterPool
:return:
"""
assert isinstance(pool, paddle.v2.parameters.IParameterPool)
assert isinstance(pool, v2_parameters.IParameterPool)
for each_param in gm.getParameters():
name = each_param.getName()
param = pool.get_parameter(name,
paddle.v2.parameters.ParameterFlag.READ_ONLY)
param = pool.get_parameter(name, v2_parameters.ParameterFlag.READ_ONLY)
each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(param.flatten(
).astype('float32'))
......@@ -190,7 +303,7 @@ def __check_train_args__(train_data_reader, topology, parameters,
if not isinstance(topology, ModelConfig):
raise ValueError('topology should be a model config')
if not isinstance(parameters, paddle.v2.parameters.IParameterPool):
if not isinstance(parameters, v2_parameters.IParameterPool):
raise ValueError('parameters should be a parameter pool')
if not callable(event_handler):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册