提交 739ff181 编写于 作者: Y Yu Yang

V2.testing complete

上级 37d54cb7
...@@ -21,28 +21,29 @@ def main(): ...@@ -21,28 +21,29 @@ def main():
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
trainer = paddle.trainer.SGD(topology=cost,
parameters=parameters,
update_equation=adam_optimizer)
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 1000 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % ( result = trainer.test(reader=paddle.reader.batched(
event.pass_id, event.batch_id, event.cost, event.metrics) paddle.dataset.mnist.test_creator(), batch_size=256))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
else: else:
pass pass
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.reader.shuffle(paddle.dataset.mnist.train_creator(), paddle.reader.shuffle(
buf_size=8192), batch_size=32), paddle.dataset.mnist.train_creator(), buf_size=8192),
topology=cost, batch_size=32),
parameters=parameters, event_handler=event_handler)
event_handler=event_handler,
data_types=[ # data_types will be removed, It should be in
# network topology
('pixel', images.type),
('label', label.type)],
reader_dict={'pixel': 0, 'label': 1}
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -11,7 +11,10 @@ There are: ...@@ -11,7 +11,10 @@ There are:
TODO(yuyang18): Complete it! TODO(yuyang18): Complete it!
""" """
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass']
__all__ = [
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
]
class WithMetric(object): class WithMetric(object):
...@@ -30,6 +33,11 @@ class WithMetric(object): ...@@ -30,6 +33,11 @@ class WithMetric(object):
return retv return retv
class TestResult(WithMetric):
def __init__(self, evaluator):
super(TestResult, self).__init__(evaluator)
class BeginPass(object): class BeginPass(object):
""" """
Event On One Pass Training Start. Event On One Pass Training Start.
......
...@@ -23,6 +23,13 @@ def default_event_handler(event): ...@@ -23,6 +23,13 @@ def default_event_handler(event):
pass pass
def __bfs_travel_topology__(callback, *topologies):
for each_layer in topologies:
callback(each_layer)
__bfs_travel_topology__(callback,
*each_layer.__parent_layers__.values())
class ITrainer(object): class ITrainer(object):
""" """
The interface of Trainer. The only exposed method is `train`. The interface of Trainer. The only exposed method is `train`.
...@@ -43,26 +50,49 @@ class ITrainer(object): ...@@ -43,26 +50,49 @@ class ITrainer(object):
class SGD(ITrainer): class SGD(ITrainer):
def __init__(self, update_equation): def __init__(self, topology, parameters, update_equation):
""" """
Simple SGD Trainer. Simple SGD Trainer.
:param update_equation: The optimizer object. :param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer :type update_equation: v2_optimizer.Optimizer
""" """
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
if not isinstance(update_equation, v2_optimizer.Optimizer): if not isinstance(update_equation, v2_optimizer.Optimizer):
raise ValueError("update equation parameter must be " raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer") "paddle.v2.optimizer.Optimizer")
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
self.__topology__ = topology
self.__parameters__ = parameters
self.__topology_in_proto__ = v2_layer.parse_network(topology)
data_types = dict()
def __travel__(l):
if hasattr(l, 'type'):
data_types[l.name] = l.type
if not isinstance(topology, collections.Sequence):
topology = [topology]
__bfs_travel_topology__(__travel__, *topology)
self.__data_types__ = [
(iname, data_types[iname])
for iname in self.__topology_in_proto__.input_layer_names
]
if not isinstance(self.__topology_in_proto__, ModelConfig):
raise TypeError('topology should be a model config')
def train(self, gm = api.GradientMachine.createFromConfigProto(
reader, self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
topology, self.__optimizer__.enable_types())
parameters, assert isinstance(gm, api.GradientMachine)
num_passes=1, parameters.append_gradient_machine(gm)
event_handler=None, self.__gradient_machine__ = gm
data_types=None, self.__gradient_machine__.randParameters()
reader_dict=None):
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
""" """
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
...@@ -79,26 +109,21 @@ class SGD(ITrainer): ...@@ -79,26 +109,21 @@ class SGD(ITrainer):
if event_handler is None: if event_handler is None:
event_handler = default_event_handler event_handler = default_event_handler
topology = v2_layer.parse_network(topology) if reader_dict is None:
reader_dict = self.default_reader_dict()
__check_train_args__(**locals()) __check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto(
topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
gm.randParameters()
updater = self.__optimizer__.create_local_updater() updater = self.__optimizer__.create_local_updater()
updater.init(gm) updater.init(self.__gradient_machine__)
gm.start() self.__gradient_machine__.start()
batch_evaluator = gm.makeEvaluator() batch_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(batch_evaluator, api.Evaluator) assert isinstance(batch_evaluator, api.Evaluator)
pass_evaluator = gm.makeEvaluator() pass_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(pass_evaluator, api.Evaluator) assert isinstance(pass_evaluator, api.Evaluator)
out_args = api.Arguments.createArguments(0) out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(data_types, reader_dict) feeder = DataFeeder(self.__data_types__, reader_dict)
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id)) event_handler(v2_event.BeginPass(pass_id))
...@@ -106,16 +131,18 @@ class SGD(ITrainer): ...@@ -106,16 +131,18 @@ class SGD(ITrainer):
updater.startPass() updater.startPass()
for batch_id, data_batch in enumerate(reader()): for batch_id, data_batch in enumerate(reader()):
pass_type = updater.startBatch(len(data_batch)) pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type) self.__gradient_machine__.forwardBackward(
feeder(data_batch), out_args, pass_type)
batch_evaluator.start() batch_evaluator.start()
event_handler( event_handler(
v2_event.BeginIteration( v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id)) pass_id=pass_id, batch_id=batch_id))
pass_type = updater.startBatch(len(data_batch)) pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type) self.__gradient_machine__.forwardBackward(
gm.eval(pass_evaluator) feeder(data_batch), out_args, pass_type)
gm.eval(batch_evaluator) self.__gradient_machine__.eval(pass_evaluator)
for each_param in gm.getParameters(): self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getParameters():
updater.update(each_param) updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch. # Get cost. We use numpy to calculate total cost for this batch.
cost_vec = out_args.getSlotValue(0) cost_vec = out_args.getSlotValue(0)
...@@ -133,10 +160,32 @@ class SGD(ITrainer): ...@@ -133,10 +160,32 @@ class SGD(ITrainer):
updater.finishPass() updater.finishPass()
pass_evaluator.finish() pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
gm.finish() self.__gradient_machine__.finish()
def default_reader_dict(self):
reader_dict = dict()
for i, tp in enumerate(self.__data_types__):
reader_dict[tp[0]] = i
return reader_dict
def test(self, reader, reader_dict=None):
if reader_dict is None:
reader_dict = self.default_reader_dict()
feeder = DataFeeder(self.__data_types__, reader_dict)
evaluator = self.__gradient_machine__.makeEvaluator()
out_args = api.Arguments.createArguments(0)
evaluator.start()
for data_batch in reader():
self.__gradient_machine__.forward(
feeder(data_batch), out_args, api.PASS_TEST)
self.__gradient_machine__.eval(evaluator)
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): evaluator.finish()
return v2_event.TestResult(evaluator=evaluator)
def __check_train_args__(reader, event_handler, **kwargs):
""" """
Check train function's argument types Check train function's argument types
""" """
...@@ -144,11 +193,5 @@ def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): ...@@ -144,11 +193,5 @@ def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
raise TypeError('train_data_reader should be a function, ' raise TypeError('train_data_reader should be a function, '
'which can return a iterator') 'which can return a iterator')
if not isinstance(topology, ModelConfig):
raise TypeError('topology should be a model config')
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be a parameter pool')
if not callable(event_handler): if not callable(event_handler):
raise TypeError('event handler should be a function') raise TypeError('event handler should be a function')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册