提交 37d54cb7 编写于 作者: Y Yu Yang

Merge branch 'feature/EvaluatorToEvent' into feature/clean_mnist_v2

......@@ -18,24 +18,18 @@ def main():
cost = paddle.layer.classification_cost(input=inference, label=label)
parameters = paddle.parameters.create(cost)
for param_name in parameters.keys():
array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
parameters.set(parameter_name=param_name, value=array)
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
para = parameters.get('___fc_2__.w0')
print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
event.pass_id, event.batch_id, event.cost, para.mean())
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
pass
trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(paddle.dataset.mnist.train_creator(),
......
......@@ -3,8 +3,6 @@ All training events.
There are:
* BeginTraining
* EndTraining
* BeginIteration
* EndIteration
* BeginPass
......@@ -12,15 +10,62 @@ There are:
TODO(yuyang18): Complete it!
"""
__all__ = ['EndIteration']
import py_paddle.swig_paddle as api
__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass']
class EndIteration(object):
class WithMetric(object):
def __init__(self, evaluator):
if not isinstance(evaluator, api.Evaluator):
raise TypeError("Evaluator should be api.Evaluator type")
self.__evaluator__ = evaluator
@property
def metrics(self):
names = self.__evaluator__.getNames()
retv = dict()
for each_name in names:
val = self.__evaluator__.getValue(each_name)
retv[each_name] = val
return retv
class BeginPass(object):
"""
Event On One Pass Training Start.
"""
def __init__(self, pass_id):
self.pass_id = pass_id
class EndPass(WithMetric):
"""
Event On One Pass Training Complete.
"""
def __init__(self, pass_id, evaluator):
self.pass_id = pass_id
WithMetric.__init__(self, evaluator)
class BeginIteration(object):
"""
Event On One Batch Training Start.
"""
def __init__(self, pass_id, batch_id):
self.pass_id = pass_id
self.batch_id = batch_id
class EndIteration(WithMetric):
"""
Event On One Batch Training Complete.
"""
def __init__(self, pass_id, batch_id, cost):
def __init__(self, pass_id, batch_id, cost, evaluator):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
WithMetric.__init__(self, evaluator)
......@@ -87,20 +87,34 @@ class SGD(ITrainer):
topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
parameters.append_gradient_machine(gm)
gm.randParameters()
updater = self.__optimizer__.create_local_updater()
updater.init(gm)
gm.start()
batch_evaluator = gm.makeEvaluator()
assert isinstance(batch_evaluator, api.Evaluator)
pass_evaluator = gm.makeEvaluator()
assert isinstance(pass_evaluator, api.Evaluator)
out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(data_types, reader_dict)
for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id))
pass_evaluator.start()
updater.startPass()
for batch_id, data_batch in enumerate(reader()):
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
batch_evaluator.start()
event_handler(
v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id))
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
gm.eval(pass_evaluator)
gm.eval(batch_evaluator)
for each_param in gm.getParameters():
updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch.
......@@ -108,11 +122,17 @@ class SGD(ITrainer):
cost_vec = cost_vec.copyToNumpyMat()
cost = cost_vec.sum() / len(data_batch)
updater.finishBatch(cost)
batch_evaluator.finish()
event_handler(
v2_event.EndIteration(
pass_id=pass_id, batch_id=batch_id, cost=cost))
pass_id=pass_id,
batch_id=batch_id,
cost=cost,
evaluator=batch_evaluator))
updater.finishPass()
pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
gm.finish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册