diff --git a/mnist.py b/mnist.py index 430b0a6d4a5615fb2a5887343c9b7cd2623662ad..1c1eeb79c8eaf1cde7d7e119dffadc2321c96ca0 100644 --- a/mnist.py +++ b/mnist.py @@ -76,7 +76,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer): class MNIST(Model): - def __init__(self, inputs, targets=None): + def __init__(self, inputs=None, targets=None): super(MNIST, self).__init__(inputs, targets) self._simple_img_conv_pool_1 = SimpleImgConvPool( 1, 20, 5, 2, 2, act="relu") @@ -87,12 +87,13 @@ class MNIST(Model): pool_2_shape = 50 * 4 * 4 SIZE = 10 scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 - self._fc = Linear(800, - 10, - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.NormalInitializer( - loc=0.0, scale=scale)), - act="softmax") + self._fc = Linear( + 800, + 10, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale)), + act="softmax") def forward(self, inputs): x = self._simple_img_conv_pool_1(inputs) @@ -139,11 +140,14 @@ def main(): device_ids = list(range(FLAGS.num_devices)) with guard: - inputs=[Input([None, 1, 28, 28], 'float32', name='image')] - labels=[Input([None, 1], 'int64', name='label')] + inputs = [Input([None, 1, 28, 28], 'float32', name='image')] + labels = [Input([None, 1], 'int64', name='label')] model = MNIST(inputs, labels) - optim = Momentum(learning_rate=FLAGS.lr, momentum=.9, - parameter_list=model.parameters()) + #model = MNIST() + optim = Momentum( + learning_rate=FLAGS.lr, + momentum=.9, + parameter_list=model.parameters()) model.prepare(optim, CrossEntropy()) if FLAGS.resume is not None: model.load(FLAGS.resume) @@ -155,8 +159,8 @@ def main(): val_acc = 0.0 print("======== train epoch {} ========".format(e)) for idx, batch in enumerate(train_loader()): - outputs, losses = model.train(batch[0], batch[1], device='gpu', - device_ids=device_ids) + outputs, losses = model.train( + batch[0], batch[1], device='gpu', device_ids=device_ids) acc = accuracy(outputs[0], batch[1])[0] train_loss += np.sum(losses) @@ -167,8 +171,8 @@ def main(): print("======== eval epoch {} ========".format(e)) for idx, batch in enumerate(val_loader()): - outputs, losses = model.eval(batch[0], batch[1], device='gpu', - device_ids=device_ids) + outputs, losses = model.eval( + batch[0], batch[1], device='gpu', device_ids=device_ids) acc = accuracy(outputs[0], batch[1])[0] val_loss += np.sum(losses) @@ -186,14 +190,21 @@ if __name__ == '__main__': parser.add_argument( "-e", "--epoch", default=100, type=int, help="number of epoch") parser.add_argument( - '--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', + '--lr', + '--learning-rate', + default=1e-3, + type=float, + metavar='LR', help='initial learning rate') parser.add_argument( "-b", "--batch_size", default=128, type=int, help="batch size") parser.add_argument( "-n", "--num_devices", default=4, type=int, help="number of devices") parser.add_argument( - "-r", "--resume", default=None, type=str, + "-r", + "--resume", + default=None, + type=str, help="checkpoint path to resume") FLAGS = parser.parse_args() main() diff --git a/mnist2.py b/mnist2.py index c0bc4a34441a6b8579cf1f0c86b8f7dc513cb857..b30aaeed79fcd8aa9047e46bb39fceac3c361bee 100644 --- a/mnist2.py +++ b/mnist2.py @@ -76,7 +76,7 @@ class SimpleImgConvPool(fluid.dygraph.Layer): class MNIST(Model): - def __init__(self, inputs): + def __init__(self, inputs=None): super(MNIST, self).__init__(inputs) self._simple_img_conv_pool_1 = SimpleImgConvPool( 1, 20, 5, 2, 2, act="relu") @@ -146,8 +146,9 @@ def main(): with guard: inputs = [ Input( - [None, 1, 28, 28], 'float32', name='image'), Input( - [None, 1], 'int64', name='label') + [None, 1, 28, 28], 'float32', name='image'), + Input( + [None, 1], 'int64', name='label'), ] model = MNIST(inputs) optim = Momentum( diff --git a/model.py b/model.py index 4d28d6ec0e6de425bae58820a3ae671804afa4cf..c932757fbe72974d79c8f99bce9576be70a02983 100644 --- a/model.py +++ b/model.py @@ -41,6 +41,8 @@ class Input(fluid.dygraph.Layer): def to_list(value): + if value is None: + return value if isinstance(value, (list, tuple)): return value return [value] @@ -443,11 +445,25 @@ class DynamicGraphAdapter(object): class Model(fluid.dygraph.Layer): + """ + FIXME: add more comments and usage + + Args: + inputs (Input|list of Input|None): inputs, entry points of network, + could be a Input layer of lits of Input layers, or None. + For static graph, inputs must be set. For dynamic graph, it could + be None. + labels (Input|list of Input|None): labels, entry points of network, + could be a Input layer of lits of Input layers, or None. + For static graph, if set loss_function in Model.prepare(), it + must be set. Otherwise, it could be None. + """ + def __init__(self, inputs=None, labels=None): super(Model, self).__init__(self.__class__.__name__) self.mode = 'train' - self._inputs = inputs - self._labels = labels + self._inputs = to_list(inputs) + self._labels = to_list(labels) self._loss_function = None self._loss_weights = None self._loss = None