diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index f925d70765e1c6700ca5ab4b8cf9369743e43947..2c0394aa0b189780c3fa1bf92e0baf215e033f45 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -25,8 +25,8 @@ def network_config(): def main(): paddle.init(use_gpu=False, trainer_count=1) - model_config = parse_network_config(network_config) - parameters = paddle.parameters.create(model_config) + topology = parse_network_config(network_config) + parameters = paddle.parameters.create(topology) for param_name in parameters.keys(): array = parameters[param_name] array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) @@ -47,7 +47,7 @@ def main(): trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer.train(train_data_reader=train_reader, - topology=model_config, + topology=topology, parameters=parameters, event_handler=event_handler, batch_size=32, # batch size should be refactor in Data reader