提交 739ff181 编写于 作者: Y Yu Yang

V2.testing complete

上级 37d54cb7
......@@ -21,28 +21,29 @@ def main():
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
trainer = paddle.trainer.SGD(topology=cost,
parameters=parameters,
update_equation=adam_optimizer)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if event.batch_id % 1000 == 0:
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test_creator(), batch_size=256))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
else:
pass
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(paddle.dataset.mnist.train_creator(),
buf_size=8192), batch_size=32),
topology=cost,
parameters=parameters,
event_handler=event_handler,
data_types=[ # data_types will be removed, It should be in
# network topology
('pixel', images.type),
('label', label.type)],
reader_dict={'pixel': 0, 'label': 1}
)
paddle.reader.shuffle(
paddle.dataset.mnist.train_creator(), buf_size=8192),
batch_size=32),
event_handler=event_handler)
if __name__ == '__main__':
......
......@@ -11,7 +11,10 @@ There are:
TODO(yuyang18): Complete it!
"""
import py_paddle.swig_paddle as api
__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass']
__all__ = [
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
]
class WithMetric(object):
......@@ -30,6 +33,11 @@ class WithMetric(object):
return retv
class TestResult(WithMetric):
def __init__(self, evaluator):
super(TestResult, self).__init__(evaluator)
class BeginPass(object):
"""
Event On One Pass Training Start.
......
......@@ -23,6 +23,13 @@ def default_event_handler(event):
pass
def __bfs_travel_topology__(callback, *topologies):
for each_layer in topologies:
callback(each_layer)
__bfs_travel_topology__(callback,
*each_layer.__parent_layers__.values())
class ITrainer(object):
"""
The interface of Trainer. The only exposed method is `train`.
......@@ -43,26 +50,49 @@ class ITrainer(object):
class SGD(ITrainer):
def __init__(self, update_equation):
def __init__(self, topology, parameters, update_equation):
"""
Simple SGD Trainer.
:param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer
"""
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
if not isinstance(update_equation, v2_optimizer.Optimizer):
raise ValueError("update equation parameter must be "
raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer")
self.__optimizer__ = update_equation
self.__topology__ = topology
self.__parameters__ = parameters
self.__topology_in_proto__ = v2_layer.parse_network(topology)
data_types = dict()
def __travel__(l):
if hasattr(l, 'type'):
data_types[l.name] = l.type
if not isinstance(topology, collections.Sequence):
topology = [topology]
__bfs_travel_topology__(__travel__, *topology)
self.__data_types__ = [
(iname, data_types[iname])
for iname in self.__topology_in_proto__.input_layer_names
]
if not isinstance(self.__topology_in_proto__, ModelConfig):
raise TypeError('topology should be a model config')
def train(self,
reader,
topology,
parameters,
num_passes=1,
event_handler=None,
data_types=None,
reader_dict=None):
gm = api.GradientMachine.createFromConfigProto(
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters()
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
"""
Training method. Will train num_passes of input data.
......@@ -79,26 +109,21 @@ class SGD(ITrainer):
if event_handler is None:
event_handler = default_event_handler
topology = v2_layer.parse_network(topology)
if reader_dict is None:
reader_dict = self.default_reader_dict()
__check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto(
topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
gm.randParameters()
updater = self.__optimizer__.create_local_updater()
updater.init(gm)
updater.init(self.__gradient_machine__)
gm.start()
batch_evaluator = gm.makeEvaluator()
self.__gradient_machine__.start()
batch_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(batch_evaluator, api.Evaluator)
pass_evaluator = gm.makeEvaluator()
pass_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(pass_evaluator, api.Evaluator)
out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(data_types, reader_dict)
feeder = DataFeeder(self.__data_types__, reader_dict)
for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id))
......@@ -106,16 +131,18 @@ class SGD(ITrainer):
updater.startPass()
for batch_id, data_batch in enumerate(reader()):
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
self.__gradient_machine__.forwardBackward(
feeder(data_batch), out_args, pass_type)
batch_evaluator.start()
event_handler(
v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id))
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
gm.eval(pass_evaluator)
gm.eval(batch_evaluator)
for each_param in gm.getParameters():
self.__gradient_machine__.forwardBackward(
feeder(data_batch), out_args, pass_type)
self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getParameters():
updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch.
cost_vec = out_args.getSlotValue(0)
......@@ -133,10 +160,32 @@ class SGD(ITrainer):
updater.finishPass()
pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
gm.finish()
self.__gradient_machine__.finish()
def default_reader_dict(self):
reader_dict = dict()
for i, tp in enumerate(self.__data_types__):
reader_dict[tp[0]] = i
return reader_dict
def test(self, reader, reader_dict=None):
if reader_dict is None:
reader_dict = self.default_reader_dict()
feeder = DataFeeder(self.__data_types__, reader_dict)
evaluator = self.__gradient_machine__.makeEvaluator()
out_args = api.Arguments.createArguments(0)
evaluator.start()
for data_batch in reader():
self.__gradient_machine__.forward(
feeder(data_batch), out_args, api.PASS_TEST)
self.__gradient_machine__.eval(evaluator)
evaluator.finish()
return v2_event.TestResult(evaluator=evaluator)
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
def __check_train_args__(reader, event_handler, **kwargs):
"""
Check train function's argument types
"""
......@@ -144,11 +193,5 @@ def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
raise TypeError('train_data_reader should be a function, '
'which can return a iterator')
if not isinstance(topology, ModelConfig):
raise TypeError('topology should be a model config')
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be a parameter pool')
if not callable(event_handler):
raise TypeError('event handler should be a function')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册