提交 03b7e92f 编写于 作者: Q qingqing01

Put inputs and labels setting in prepare

上级 44291186
...@@ -76,8 +76,8 @@ class SimpleImgConvPool(fluid.dygraph.Layer): ...@@ -76,8 +76,8 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
class MNIST(Model): class MNIST(Model):
def __init__(self, inputs=None, targets=None): def __init__(self):
super(MNIST, self).__init__(inputs, targets) super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConvPool( self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 20, 5, 2, 2, act="relu") 1, 20, 5, 2, 2, act="relu")
...@@ -140,15 +140,15 @@ def main(): ...@@ -140,15 +140,15 @@ def main():
device_ids = list(range(FLAGS.num_devices)) device_ids = list(range(FLAGS.num_devices))
with guard: with guard:
inputs = [Input([None, 1, 28, 28], 'float32', name='image')] model = MNIST()
labels = [Input([None, 1], 'int64', name='label')]
model = MNIST(inputs, labels)
#model = MNIST()
optim = Momentum( optim = Momentum(
learning_rate=FLAGS.lr, learning_rate=FLAGS.lr,
momentum=.9, momentum=.9,
parameter_list=model.parameters()) parameter_list=model.parameters())
model.prepare(optim, CrossEntropy()) inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
#inputs = {'inputs': Input([None, 1, 28, 28], 'float32', name='image')}
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim, CrossEntropy(), inputs, labels)
if FLAGS.resume is not None: if FLAGS.resume is not None:
model.load(FLAGS.resume) model.load(FLAGS.resume)
...@@ -199,7 +199,7 @@ if __name__ == '__main__': ...@@ -199,7 +199,7 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"-b", "--batch_size", default=128, type=int, help="batch size") "-b", "--batch_size", default=128, type=int, help="batch size")
parser.add_argument( parser.add_argument(
"-n", "--num_devices", default=4, type=int, help="number of devices") "-n", "--num_devices", default=1, type=int, help="number of devices")
parser.add_argument( parser.add_argument(
"-r", "-r",
"--resume", "--resume",
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=None,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__('SimpleConv')
self._conv2d = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(Model):
def __init__(self, inputs=None):
super(MNIST, self).__init__(inputs)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
800,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.flatten(x, axis=1)
x = self._fc(x)
loss = fluid.layers.cross_entropy(x, label)
loss = fluid.layers.mean(loss)
self.set_loss(loss)
return x, loss
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0]
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(100.0 * correct_k / batch_size)
return res
def main():
@contextlib.contextmanager
def null_guard():
yield
guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard()
if not os.path.exists('mnist_checkpoints'):
os.mkdir('mnist_checkpoints')
train_loader = fluid.io.xmap_readers(
lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
np.array([x[1] for x in b]).reshape(-1, 1)],
paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4),
batch_size=FLAGS.batch_size, drop_last=True), 1, 1)
val_loader = fluid.io.xmap_readers(
lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
np.array([x[1] for x in b]).reshape(-1, 1)],
paddle.batch(paddle.dataset.mnist.test(),
batch_size=FLAGS.batch_size, drop_last=True), 1, 1)
device_ids = list(range(FLAGS.num_devices))
add_loss = True
with guard:
inputs = [
Input(
[None, 1, 28, 28], 'float32', name='image'),
Input(
[None, 1], 'int64', name='label'),
]
model = MNIST(inputs)
optim = Momentum(
learning_rate=FLAGS.lr,
momentum=.9,
parameter_list=model.parameters())
model.prepare(optim)
if FLAGS.resume is not None:
model.load(FLAGS.resume)
for e in range(FLAGS.epoch):
train_loss = 0.0
train_acc = 0.0
val_loss = 0.0
val_acc = 0.0
print("======== train epoch {} ========".format(e))
for idx, batch in enumerate(train_loader()):
outputs, losses = model.train(
batch, device='gpu', device_ids=device_ids)
acc = accuracy(outputs[0], batch[1])[0]
train_loss += np.sum(losses)
train_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, train_loss / (idx + 1), train_acc / (idx + 1)))
print("======== eval epoch {} ========".format(e))
for idx, batch in enumerate(val_loader()):
outputs, losses = model.eval(
batch, device='gpu', device_ids=device_ids)
acc = accuracy(outputs[0], batch[1])[0]
val_loss += np.sum(losses)
val_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, val_loss / (idx + 1), val_acc / (idx + 1)))
model.save('mnist_checkpoints/{:02d}'.format(e))
if __name__ == '__main__':
parser = argparse.ArgumentParser("CNN training on MNIST")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=100, type=int, help="number of epoch")
parser.add_argument(
'--lr',
'--learning-rate',
default=1e-3,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=128, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=4, type=int, help="number of devices")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
main()
...@@ -30,16 +30,6 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -30,16 +30,6 @@ from paddle.fluid.dygraph.base import to_variable
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input'] __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
class Input(fluid.dygraph.Layer):
def __init__(self, shape=None, dtype=None, name=None):
self.shape = shape
self.dtype = dtype
self.name = name
def forward(self):
return fluid.data(self.name, shape=self.shape, dtype=self.dtype)
def to_list(value): def to_list(value):
if value is None: if value is None:
return value return value
...@@ -56,6 +46,23 @@ def to_numpy(var): ...@@ -56,6 +46,23 @@ def to_numpy(var):
return np.array(t) return np.array(t)
def extract_args(func):
if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0]
else:
return inspect.getargspec(func)[0]
class Input(fluid.dygraph.Layer):
def __init__(self, shape=None, dtype=None, name=None):
self.shape = shape
self.dtype = dtype
self.name = name
def forward(self):
return fluid.data(self.name, shape=self.shape, dtype=self.dtype)
class Loss(object): class Loss(object):
def __init__(self, average=True): def __init__(self, average=True):
super(Loss, self).__init__() super(Loss, self).__init__()
...@@ -231,25 +238,28 @@ class StaticGraphAdapter(object): ...@@ -231,25 +238,28 @@ class StaticGraphAdapter(object):
t.set(ndarray, place) t.set(ndarray, place)
def _run(self, inputs, labels=None, device='CPU', device_ids=None): def _run(self, inputs, labels=None, device='CPU', device_ids=None):
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
assert len(inputs) == len(self.model._inputs), "number of inputs" \
+ " does not match number of arguments of `forward` method"
if self._progs.get(self.mode, None) is None: if self._progs.get(self.mode, None) is None:
if self.model._inputs is None: if isinstance(self.model._inputs, dict):
raise ValueError("The inputs of Model must be not None.") ins = [self.model._inputs[n] \
self._input_vars = [ for n in extract_args(self.model.forward) if n != 'self']
k.forward() for k in to_list(self.model._inputs) else:
] ins = self.model._inputs
self._make_program(self._input_vars) self._input_vars[self.mode] = [k.forward() for k in to_list(ins)]
self._make_program(self._input_vars[self.mode])
compiled_prog = self._compile_and_initialize(self._progs[self.mode], compiled_prog = self._compile_and_initialize(self._progs[self.mode],
device, device_ids) device, device_ids)
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
assert len(inputs) == len(self._input_vars[self.mode]), "number of inputs" \
+ " does not match number of arguments of `forward` method"
feed = {} feed = {}
input_names = [v.name for v in self._input_vars] input_names = [v.name for v in self._input_vars[self.mode]]
for idx, n in enumerate(input_names): for idx, n in enumerate(input_names):
# train and test may take different arguments # train and test may take different arguments
if inputs[idx] is not None: if inputs[idx] is not None:
...@@ -261,7 +271,7 @@ class StaticGraphAdapter(object): ...@@ -261,7 +271,7 @@ class StaticGraphAdapter(object):
endpoints = self._endpoints[self.mode] endpoints = self._endpoints[self.mode]
fetch_list = endpoints['output'] fetch_list = endpoints['output']
if 'loss' in endpoints: if 'loss' in endpoints:
fetch_list += endpoints['loss'] fetch_list = endpoints['output'] + endpoints['loss']
num_output = len(endpoints['output']) num_output = len(endpoints['output'])
out = self._executor.run(compiled_prog, out = self._executor.run(compiled_prog,
feed=feed, feed=feed,
...@@ -296,19 +306,10 @@ class StaticGraphAdapter(object): ...@@ -296,19 +306,10 @@ class StaticGraphAdapter(object):
} }
def _get_loss(self, outputs): def _get_loss(self, outputs):
if self.model._loss_function and self.model._loss: assert self.model._loss_function
raise ValueError( label_vars = [k.forward() for k in to_list(self.model._labels)]
"Do not set loss by model.set_loss() and " self._label_vars[self.mode] = label_vars
"loss_function in model.prepare() at the same time.") losses = self.model._loss_function(outputs, label_vars)
if self.model._loss_function is not None:
if self.model._labels is None:
raise ValueError("The labels of Model must be not None.")
label_vars = [k.forward() for k in to_list(self.model._labels)]
self._label_vars[self.mode] = label_vars
losses = self.model._loss_function(outputs, label_vars)
else:
assert self.model._loss
losses = to_list(self.model._loss)
return losses return losses
def _compile_and_initialize(self, prog, device='CPU', device_ids=None): def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
...@@ -415,14 +416,8 @@ class DynamicGraphAdapter(object): ...@@ -415,14 +416,8 @@ class DynamicGraphAdapter(object):
return [to_numpy(o) for o in to_list(outputs)] return [to_numpy(o) for o in to_list(outputs)]
def _get_loss(self, outputs, labels): def _get_loss(self, outputs, labels):
if self.model._loss_function and self.model._loss: assert self.model._loss_function
raise ValueError( return self.model._loss_function(outputs, labels)
"Do not set loss by model.set_loss() and "
"loss_function in model.prepare() at the same time.")
if self.model._loss_function is not None:
return self.model._loss_function(outputs, labels)
else:
return to_list(self.model._loss)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return super(Model, self.model).parameters(*args, **kwargs) return super(Model, self.model).parameters(*args, **kwargs)
...@@ -447,23 +442,13 @@ class DynamicGraphAdapter(object): ...@@ -447,23 +442,13 @@ class DynamicGraphAdapter(object):
class Model(fluid.dygraph.Layer): class Model(fluid.dygraph.Layer):
""" """
FIXME: add more comments and usage FIXME: add more comments and usage
Args:
inputs (Input|list of Input|None): inputs, entry points of network,
could be a Input layer of lits of Input layers, or None.
For static graph, inputs must be set. For dynamic graph, it could
be None.
labels (Input|list of Input|None): labels, entry points of network,
could be a Input layer of lits of Input layers, or None.
For static graph, if set loss_function in Model.prepare(), it
must be set. Otherwise, it could be None.
""" """
def __init__(self, inputs=None, labels=None): def __init__(self):
super(Model, self).__init__(self.__class__.__name__) super(Model, self).__init__(self.__class__.__name__)
self.mode = 'train' self.mode = 'train'
self._inputs = to_list(inputs) self._inputs = None
self._labels = to_list(labels) self._labels = None
self._loss_function = None self._loss_function = None
self._loss_weights = None self._loss_weights = None
self._loss = None self._loss = None
...@@ -488,22 +473,33 @@ class Model(fluid.dygraph.Layer): ...@@ -488,22 +473,33 @@ class Model(fluid.dygraph.Layer):
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
return self._adapter.load(*args, **kwargs) return self._adapter.load(*args, **kwargs)
def prepare(self, optimizer, loss_function=None): def prepare(self, optimizer, loss_function=None, inputs=None, labels=None):
"""
FIXME: add comments
Args:
inputs (Input|list|dict|None): inputs, entry points of network,
could be a Input layer, or lits of Input layers, or dict (name: ), or None.
For static graph, inputs must be set. For dynamic graph, it could
be None.
labels (Input|list|dict|None): labels, entry points of network,
could be a Input layer or lits of Input layers, or None.
For static graph, if set loss_function in Model.prepare(), it
must be set. Otherwise, it could be None.
"""
self._optimizer = optimizer self._optimizer = optimizer
if loss_function: if loss_function:
if not isinstance(loss_function, Loss): if not isinstance(loss_function, Loss):
raise TypeError( raise TypeError(
"'loss_function' must be sub classes of 'Loss'") "'loss_function' must be sub classes of 'Loss'")
self._loss_function = loss_function self._loss_function = loss_function
if not in_dygraph_mode():
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict in static graph mode")
if loss_function and not isinstance(labels, (list, Input)):
raise TypeError("'labels' must be list in static graph mode")
self._inputs = inputs
self._labels = labels
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs) return self._adapter.parameters(*args, **kwargs)
def set_loss(self, loss):
if loss and self._loss_function:
raise ValueError(
"Do not set loss by model.set_loss() and "
"loss_function in model.prepare() at the same time.")
if not isinstance(loss, (Variable, fluid.core.VarBase)):
raise TypeError("loss type should be a Variable or VarBase.")
self._loss = loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册