diff --git a/mnist.py b/mnist.py index 617fd18586b332c05e293b798d7b7dca13322080..ebfad01ec7f13a45538b8d3afc1b4e7bde46a83c 100644 --- a/mnist.py +++ b/mnist.py @@ -136,9 +136,12 @@ if __name__ == '__main__': for e in range(2): for idx, batch in enumerate(train_loader()): - out = model.train(batch[0], batch[1], device='gpu', - device_ids=[0, 1, 2, 3]) + out, loss = model.train(batch[0], batch[1], device='gpu', + device_ids=[0, 1, 2, 3]) + print("=============== output =========") print(out) + print("=============== loss ===========") + print(loss) if idx > 10: model.save("test.{}".format(e)) break diff --git a/model.py b/model.py index a0f9a2b8c9b92061685d8bdde01712a0ee8ee399..8c7bdbdefedbf2c3e1830498818a58e175e4c84a 100644 --- a/model.py +++ b/model.py @@ -40,6 +40,14 @@ def to_list(value): return [value] +def to_numpy(var): + assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable" + if isinstance(var, fluid.core.VarBase): + return var.numpy() + t = global_scope().find_var(var.name).get_tensor() + return np.array(t) + + def extract_args(func): if hasattr(inspect, 'getfullargspec'): return inspect.getfullargspec(func)[0] @@ -115,15 +123,10 @@ class StaticGraphAdapter(object): def save(self, path): def _save(state, path): - def to_numpy(var): - if not isinstance(var, Variable): - return var - t = global_scope().find_var(var.name).get_tensor() - return np.array(t) - if not state: return - state = {k: to_numpy(v) for k, v in state.items()} + state = {k: to_numpy(v) if isinstance(v, Variable) else v + for k, v in state.items()} with open(path, 'wb') as f: pickle.dump(state, f) @@ -226,18 +229,24 @@ class StaticGraphAdapter(object): for idx, v in enumerate(self._label_vars): feed[v.name] = labels[idx] - outputs = self._executor.run( + endpoints = self._endpoints[self.mode] + fetch_list = endpoints['output'] + endpoints['loss'] + num_output = len(endpoints['output']) + out = self._executor.run( compiled_prog, feed=feed, - fetch_list=self._endpoints[self.mode]) - return outputs + fetch_list=fetch_list) + if self.mode == 'test': + return out[:num_output] + else: + return out[:num_output], out[num_output:] def _make_program(self, inputs): prog = self._main_prog.clone(self.mode != 'train') with fluid.program_guard(prog, self._startup_prog): outputs = to_list(self.model.forward(*inputs)) + losses = [] label_vars = [] if self.mode != 'test': - losses = [] loss_weights = self.model._loss_weights if loss_weights is None: loss_weights = [1. for _ in self.model._loss_functions] @@ -250,13 +259,15 @@ class StaticGraphAdapter(object): loss_fn = getattr(fluid.layers, l) loss = loss_fn(o, label_var) losses.append(fluid.layers.reduce_mean(loss) * w) - outputs = losses if self.mode == 'train': self._label_vars = label_vars self._loss_endpoint = fluid.layers.sum(losses) self.model._optimizer.minimize(self._loss_endpoint) self._progs[self.mode] = prog - self._endpoints[self.mode] = outputs + self._endpoints[self.mode] = { + "output": outputs, + "loss": losses + } def _infer_input_vars(self, inputs): input_vars = [] @@ -346,7 +357,8 @@ class DynamicGraphAdapter(object): final_loss.backward() self.model._optimizer.minimize(final_loss) self.model.clear_gradients() - return losses + return [to_numpy(o) for o in to_list(outputs)], \ + [to_numpy(l) for l in losses] def eval(self, inputs, labels, device='CPU', device_ids=None): assert self.model._loss_functions, \ @@ -356,13 +368,16 @@ class DynamicGraphAdapter(object): inputs = to_list(inputs) labels = to_list(labels) outputs = self.model.forward(*[to_variable(x) for x in inputs]) - return self._loss(outputs, labels) + losses = self._loss(outputs, labels) + return [to_numpy(o) for o in to_list(outputs)], \ + [to_numpy(l) for l in losses] def test(self, inputs, device='CPU', device_ids=None): super(Model, self.model).train() self.mode = 'test' - inputs = to_list(inputs) - return self.model.forward(*[to_variable(x) for x in inputs]) + inputs = [to_variable(x) for x in to_list(inputs)] + outputs = self.model.forward(*inputs) + return [to_numpy(o) for o in to_list(outputs)] def parameters(self, *args, **kwargs): return super(Model, self.model).parameters(*args, **kwargs)