提交 91f13e48 编写于 作者: J jacquesqiao 提交者: GitHub

Merge pull request #1465 from reyoung/feature/tester

Paddle.V2.Trainer.test method complete.
...@@ -20,26 +20,29 @@ def main(): ...@@ -20,26 +20,29 @@ def main():
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=adam_optimizer)
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 1000 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % ( result = trainer.test(reader=paddle.reader.batched(
event.pass_id, event.batch_id, event.cost, event.metrics) paddle.dataset.mnist.test(), batch_size=256))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
else: else:
pass pass
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192), paddle.dataset.mnist.train(), buf_size=8192),
batch_size=32), batch_size=32),
cost=cost, event_handler=event_handler)
parameters=parameters,
event_handler=event_handler,
reader_dict={images.name: 0,
label.name: 1})
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -9,9 +9,9 @@ __all__ = ['train', 'test'] ...@@ -9,9 +9,9 @@ __all__ = ['train', 'test']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688' TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz' TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz' TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
......
...@@ -11,7 +11,10 @@ There are: ...@@ -11,7 +11,10 @@ There are:
TODO(yuyang18): Complete it! TODO(yuyang18): Complete it!
""" """
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass']
__all__ = [
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
]
class WithMetric(object): class WithMetric(object):
...@@ -30,6 +33,11 @@ class WithMetric(object): ...@@ -30,6 +33,11 @@ class WithMetric(object):
return retv return retv
class TestResult(WithMetric):
def __init__(self, evaluator):
super(TestResult, self).__init__(evaluator)
class BeginPass(object): class BeginPass(object):
""" """
Event On One Pass Training Start. Event On One Pass Training Start.
......
...@@ -21,6 +21,14 @@ import layer as v2_layer ...@@ -21,6 +21,14 @@ import layer as v2_layer
__all__ = ['Topology'] __all__ = ['Topology']
def __bfs_travel__(callback, *layers):
for each_layer in layers:
__break__ = callback(each_layer)
if __break__:
return
__bfs_travel__(callback, *each_layer.__parent_layers__.values())
class Topology(object): class Topology(object):
""" """
Topology is used to store the information about all layers Topology is used to store the information about all layers
...@@ -46,21 +54,17 @@ class Topology(object): ...@@ -46,21 +54,17 @@ class Topology(object):
:param name: :param name:
:return: :return:
""" """
result_layer = [] result_layer = [None]
def find_layer_by_name(layer, layer_name): def __impl__(l):
if len(result_layer) == 1: if l.name == name:
return result_layer[0] = l
elif layer.name == layer_name: return True # break
result_layer.append(layer) return False
else:
for parent_layer in layer.__parent_layers__.values():
find_layer_by_name(parent_layer, layer_name)
for layer in self.layers: __bfs_travel__(__impl__, *self.layers)
find_layer_by_name(layer, name) if result_layer[0] is None:
raise ValueError("No such layer %s" % name)
assert len(result_layer) == 1
return result_layer[0] return result_layer[0]
def data_layers(self): def data_layers(self):
...@@ -68,17 +72,13 @@ class Topology(object): ...@@ -68,17 +72,13 @@ class Topology(object):
get all data layer get all data layer
:return: :return:
""" """
data_layers = set() data_layers = dict()
def find_data_layer(layer):
if isinstance(layer, v2_layer.DataLayerV2):
data_layers.add(layer)
for parent_layer in layer.__parent_layers__.values():
find_data_layer(parent_layer)
for layer in self.layers: def __impl__(l):
find_data_layer(layer) if isinstance(l, v2_layer.DataLayerV2):
data_layers[l.name] = l
__bfs_travel__(__impl__, *self.layers)
return data_layers return data_layers
def data_type(self): def data_type(self):
...@@ -86,8 +86,9 @@ class Topology(object): ...@@ -86,8 +86,9 @@ class Topology(object):
get data_type from proto, such as: get data_type from proto, such as:
[('image', dense_vector(768)), ('label', integer_value(10))] [('image', dense_vector(768)), ('label', integer_value(10))]
""" """
return [(data_layer.name, data_layer.type) data_layers = self.data_layers()
for data_layer in self.data_layers()] return [(nm, data_layers[nm].type)
for nm in self.proto().input_layer_names]
def __check_layer_type__(layer): def __check_layer_type__(layer):
......
...@@ -42,25 +42,35 @@ class ITrainer(object): ...@@ -42,25 +42,35 @@ class ITrainer(object):
class SGD(ITrainer): class SGD(ITrainer):
def __init__(self, update_equation): def __init__(self, cost, parameters, update_equation):
""" """
Simple SGD Trainer. Simple SGD Trainer.
:param update_equation: The optimizer object. :param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer :type update_equation: v2_optimizer.Optimizer
""" """
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
if not isinstance(update_equation, v2_optimizer.Optimizer): if not isinstance(update_equation, v2_optimizer.Optimizer):
raise ValueError("update equation parameter must be " raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer") "paddle.v2.optimizer.Optimizer")
topology = Topology(cost)
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
self.__topology__ = topology
self.__parameters__ = parameters
self.__topology_in_proto__ = topology.proto()
self.__data_types__ = topology.data_type()
gm = api.GradientMachine.createFromConfigProto(
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters()
def train(self, def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
reader,
cost,
parameters,
num_passes=1,
event_handler=None,
reader_dict=None):
""" """
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
...@@ -76,27 +86,22 @@ class SGD(ITrainer): ...@@ -76,27 +86,22 @@ class SGD(ITrainer):
if event_handler is None: if event_handler is None:
event_handler = default_event_handler event_handler = default_event_handler
topology = Topology(cost) if reader_dict is None:
reader_dict = self.default_reader_dict()
__check_train_args__(**locals()) __check_train_args__(**locals())
gm = api.GradientMachine.createFromConfigProto(
topology.proto(), api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
gm.randParameters()
updater = self.__optimizer__.create_local_updater() updater = self.__optimizer__.create_local_updater()
updater.init(gm) updater.init(self.__gradient_machine__)
gm.start() self.__gradient_machine__.start()
batch_evaluator = gm.makeEvaluator() batch_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(batch_evaluator, api.Evaluator) assert isinstance(batch_evaluator, api.Evaluator)
pass_evaluator = gm.makeEvaluator() pass_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(pass_evaluator, api.Evaluator) assert isinstance(pass_evaluator, api.Evaluator)
out_args = api.Arguments.createArguments(0) out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(topology.data_type(), reader_dict) feeder = DataFeeder(self.__data_types__, reader_dict)
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id)) event_handler(v2_event.BeginPass(pass_id))
...@@ -104,16 +109,18 @@ class SGD(ITrainer): ...@@ -104,16 +109,18 @@ class SGD(ITrainer):
updater.startPass() updater.startPass()
for batch_id, data_batch in enumerate(reader()): for batch_id, data_batch in enumerate(reader()):
pass_type = updater.startBatch(len(data_batch)) pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type) self.__gradient_machine__.forwardBackward(
feeder(data_batch), out_args, pass_type)
batch_evaluator.start() batch_evaluator.start()
event_handler( event_handler(
v2_event.BeginIteration( v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id)) pass_id=pass_id, batch_id=batch_id))
pass_type = updater.startBatch(len(data_batch)) pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type) self.__gradient_machine__.forwardBackward(
gm.eval(pass_evaluator) feeder(data_batch), out_args, pass_type)
gm.eval(batch_evaluator) self.__gradient_machine__.eval(pass_evaluator)
for each_param in gm.getParameters(): self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getParameters():
updater.update(each_param) updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch. # Get cost. We use numpy to calculate total cost for this batch.
cost_vec = out_args.getSlotValue(0) cost_vec = out_args.getSlotValue(0)
...@@ -131,22 +138,37 @@ class SGD(ITrainer): ...@@ -131,22 +138,37 @@ class SGD(ITrainer):
updater.finishPass() updater.finishPass()
pass_evaluator.finish() pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
gm.finish() self.__gradient_machine__.finish()
def default_reader_dict(self):
reader_dict = dict()
for i, tp in enumerate(self.__data_types__):
reader_dict[tp[0]] = i
return reader_dict
def test(self, reader, reader_dict=None):
if reader_dict is None:
reader_dict = self.default_reader_dict()
feeder = DataFeeder(self.__data_types__, reader_dict)
evaluator = self.__gradient_machine__.makeEvaluator()
out_args = api.Arguments.createArguments(0)
evaluator.start()
for data_batch in reader():
self.__gradient_machine__.forward(
feeder(data_batch), out_args, api.PASS_TEST)
self.__gradient_machine__.eval(evaluator)
evaluator.finish()
return v2_event.TestResult(evaluator=evaluator)
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): def __check_train_args__(reader, event_handler, **kwargs):
""" """
Check train function's argument types Check train function's argument types
""" """
if not callable(reader) or not isinstance(reader(), collections.Iterator): if not callable(reader) or not isinstance(reader(), collections.Iterator):
raise TypeError('train_data_reader should be a function, ' raise TypeError('train_data_reader should be a function, '
'which can return a iterator') 'which can return a iterator')
if not isinstance(topology, Topology):
raise TypeError('topology should be a model config')
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be a parameter pool')
if not callable(event_handler): if not callable(event_handler):
raise TypeError('event handler should be a function') raise TypeError('event handler should be a function')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册