diff --git a/mnist.py b/mnist.py index ebfad01ec7f13a45538b8d3afc1b4e7bde46a83c..401b766e9d74a1ad4890db45eba6efcf7732ae41 100644 --- a/mnist.py +++ b/mnist.py @@ -21,7 +21,7 @@ from paddle import fluid from paddle.fluid.optimizer import MomentumOptimizer from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear -from model import Model, shape_hints +from model import Model, shape_hints, CrossEntropy class SimpleImgConvPool(fluid.dygraph.Layer): @@ -132,7 +132,7 @@ if __name__ == '__main__': sgd = MomentumOptimizer(learning_rate=1e-3, momentum=0.9, parameter_list=model.parameters()) # sgd = SGDOptimizer(learning_rate=1e-3) - model.prepare(sgd, 'cross_entropy') + model.prepare(sgd, CrossEntropy()) for e in range(2): for idx, batch in enumerate(train_loader()): diff --git a/model.py b/model.py index 0b3fed861948f603553d4c5bf08a4aa490a05566..c4c460239dfb242fa6f38b0a3fc129c8c0ae4822 100644 --- a/model.py +++ b/model.py @@ -27,11 +27,7 @@ from paddle.fluid.executor import global_scope from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.dygraph.base import to_variable -__all__ = ['Model', 'shape_hints'] - -LOSS_DTYPE_MAP = { - 'cross_entropy': 'int64' -} +__all__ = ['shape_hints', 'Model', 'Loss', 'CrossEntropy'] def to_list(value): @@ -71,6 +67,45 @@ def shape_hints(**hints): return wrapper +class Loss(object): + def __init__(self, average=True): + super(Loss, self).__init__() + self.average = average + + def infer_shape(self, outputs): + return [o.shape for o in outputs] + + def infer_dtype(self, outputs): + return [o.dtype for o in outputs] + + def forward(self, outputs, labels): + raise NotImplementedError() + + def __call__(self, outputs, labels): + labels = to_list(labels) + if in_dygraph_mode(): + labels = [to_variable(l) for l in labels] + losses = to_list(self.forward(to_list(outputs), labels)) + if not self.average: + return losses + return [fluid.layers.reduce_mean(l) for l in losses] + + +class CrossEntropy(Loss): + def __init__(self): + super(CrossEntropy, self).__init__() + + def infer_shape(self, outputs): + return [o.shape[:-1] + (1, ) for o in outputs] + + def infer_dtype(self, outputs): + return ['int64' for _ in outputs] + + def forward(self, outputs, labels): + return [fluid.layers.cross_entropy(o, l) for o, l in zip( + outputs, labels)] + + class StaticGraphAdapter(object): def __init__(self, model): super(StaticGraphAdapter, self).__init__() @@ -103,13 +138,13 @@ class StaticGraphAdapter(object): self.model.mode = value def train(self, inputs, labels, device='CPU', device_ids=None): - assert self.model._optimizer and self.model._loss_functions, \ + assert self.model._optimizer and self.model._loss_function, \ "model not ready, please call `model.prepare()` first" self.mode = 'train' return self._run(inputs, labels, device, device_ids) def eval(self, inputs, labels, device='CPU', device_ids=None): - assert self.model._loss_functions, \ + assert self.model._loss_function, \ "model not ready, please call `model.prepare()` first" self.mode = 'eval' return self._run(inputs, labels, device, device_ids) @@ -249,22 +284,10 @@ class StaticGraphAdapter(object): losses = [] with fluid.program_guard(prog, self._startup_prog): outputs = to_list(self.model.forward(*inputs)) - losses = [] - label_vars = [] if self.mode != 'test': - loss_weights = self.model._loss_weights - if loss_weights is None: - loss_weights = [1. for _ in self.model._loss_functions] - for o, l, w in zip(outputs, self.model._loss_functions, - loss_weights): - if l is None: - continue - label_var = self._infer_label_var(o, l) - label_vars.append(label_var) - loss_fn = getattr(fluid.layers, l) - loss = loss_fn(o, label_var) - losses.append(fluid.layers.reduce_mean(loss) * w) + label_vars = self._infer_label_vars(outputs) self._label_vars[self.mode] = label_vars + losses = self.model._loss_function(outputs, label_vars) if self.mode == 'train': self._loss_endpoint = fluid.layers.sum(losses) self.model._optimizer.minimize(self._loss_endpoint) @@ -288,18 +311,14 @@ class StaticGraphAdapter(object): input_vars.append(fluid.data(name, shape, ndarray.dtype)) return input_vars - # TODO wrap loss in callable classes - # - same call signaure - # - infer_shape method? or same shape as y_pred (e.g., one hot) - # - split multiple dtype loss functions (e.g., soft label) - def _infer_label_var(self, output, loss): - name = output.name + '.label' - shape = output.shape - # XXX could get ugly very quickly - if loss == 'cross_entropy': - shape = shape[:-1] + (1, ) - dtype = LOSS_DTYPE_MAP.get(loss, output.dtype) - return fluid.data(name, shape, dtype) + def _infer_label_vars(self, outputs): + shapes = self.model._loss_function.infer_shape(outputs) + dtypes = self.model._loss_function.infer_dtype(outputs) + label_vars = [] + for idx, (shape, dtype) in enumerate(zip(shapes, dtypes)): + name = '__label{}'.format(idx) + label_vars.append(fluid.data(name, shape, dtype)) + return label_vars def _compile_and_initialize(self, prog, device='CPU', device_ids=None): if device.lower() == 'cpu': @@ -351,14 +370,14 @@ class DynamicGraphAdapter(object): # TODO multi device in dygraph mode not implemented at present time def train(self, inputs, labels, device='CPU', device_ids=None): - assert self.model._optimizer and self.model._loss_functions, \ + assert self.model._optimizer and self.model._loss_function, \ "model not ready, please call `model.prepare()` first" super(Model, self.model).train() self.mode = 'train' inputs = to_list(inputs) labels = to_list(labels) outputs = self.model.forward(*[to_variable(x) for x in inputs]) - losses = self._loss(outputs, labels) + losses = self.model._loss_function(outputs, labels) final_loss = fluid.layers.sum(losses) final_loss.backward() self.model._optimizer.minimize(final_loss) @@ -367,14 +386,14 @@ class DynamicGraphAdapter(object): [to_numpy(l) for l in losses] def eval(self, inputs, labels, device='CPU', device_ids=None): - assert self.model._loss_functions, \ + assert self.model._loss_function, \ "model not ready, please call `model.prepare()` first" super(Model, self.model).eval() self.mode = 'eval' inputs = to_list(inputs) labels = to_list(labels) outputs = self.model.forward(*[to_variable(x) for x in inputs]) - losses = self._loss(outputs, labels) + losses = self.model._loss_function(outputs, labels) return [to_numpy(o) for o in to_list(outputs)], \ [to_numpy(l) for l in losses] @@ -404,26 +423,12 @@ class DynamicGraphAdapter(object): return self.model._optimizer.set_dict(optim) - def _loss(self, pred, labels): - losses = [] - loss_weights = self.model._loss_weights - if loss_weights is None: - loss_weights = [1. for _ in self.model._loss_functions] - for o, l, w, t in zip(to_list(pred), self.model._loss_functions, - loss_weights, labels): - if l is None: - continue - loss_fn = getattr(fluid.layers, l) - loss = loss_fn(o, to_variable(t)) - losses.append(fluid.layers.reduce_mean(loss) * w) - return losses - class Model(fluid.dygraph.Layer): def __init__(self): super(Model, self).__init__(self.__class__.__name__) self.mode = 'train' - self._loss_functions = [] + self._loss_function = None self._loss_weights = None self._optimizer = None if in_dygraph_mode(): @@ -446,11 +451,11 @@ class Model(fluid.dygraph.Layer): def load(self, *args, **kwargs): return self._adapter.load(*args, **kwargs) - def prepare(self, optimizer, loss_functions, loss_weights=None): + def prepare(self, optimizer, loss_function): self._optimizer = optimizer - self._loss_functions = to_list(loss_functions) - if loss_weights is not None: - self._loss_weights = to_list(loss_weights) + assert isinstance(loss_function, Loss), \ + "'loss_function' must be sub classes of 'Loss'" + self._loss_function = loss_function def parameters(self, *args, **kwargs): return self._adapter.parameters(*args, **kwargs) diff --git a/resnet.py b/resnet.py index 2c2816f9d79cc225e46de3436b65b59c724547f5..be5c1f0287677fb690da7c091e70de30f2a90cb6 100644 --- a/resnet.py +++ b/resnet.py @@ -28,7 +28,7 @@ import paddle.fluid as fluid from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear -from model import Model +from model import Model, CrossEntropy def center_crop_resize(img): @@ -358,7 +358,7 @@ def main(): with guard: model = ResNet() sgd = make_optimizer(parameter_list=model.parameters()) - model.prepare(sgd, 'cross_entropy') + model.prepare(sgd, CrossEntropy()) for e in range(epoch): print("======== train epoch {} ========".format(e))