From 20d9220c9145e937da9f0fb9b99323a06b6983bc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 28 Feb 2017 10:27:46 +0800 Subject: [PATCH] change the parameter topology of trainer to cost --- demo/mnist/api_train_v2.py | 2 +- python/paddle/v2/trainer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index a23ddfaca0..8a612cbc66 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -41,7 +41,7 @@ def main(): trainer.train( train_data_reader=train_reader, - topology=cost, + cost=cost, parameters=parameters, event_handler=event_handler, batch_size=32, # batch size should be refactor in Data reader diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 3bf2128e16..be33b91080 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -62,7 +62,7 @@ class SGD(ITrainer): def train(self, train_data_reader, - topology, + cost, parameters, num_passes=1, test_data_reader=None, @@ -73,7 +73,7 @@ class SGD(ITrainer): Training method. Will train num_passes of input data. :param train_data_reader: - :param topology: cost layers, use one or more Layers to represent it. + :param cost: cost layers, to be optimized. :param parameters: The parameter pools. :param num_passes: The total train passes. :param test_data_reader: @@ -86,7 +86,7 @@ class SGD(ITrainer): if event_handler is None: event_handler = default_event_handler - topology = Topology(topology) + topology = Topology(cost) __check_train_args__(**locals()) -- GitLab