diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index f4d3a78e250ac5555899a6c0151ee8ecbc69a524..b8de4ee05ab128f2a8791b7aa0116c880d3881e8 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -298,10 +298,11 @@ class StaticGraphAdapter(object): def mode(self, value): self.model.mode = value - def train_batch(self, inputs, labels=None): + def train_batch(self, inputs, labels=None, update=True): assert self.model._optimizer, \ "model not ready, please call `model.prepare()` first" self.mode = 'train' + assert update is True, "Model does not support `update == False` in static mode by now." return self._run(inputs, labels) def eval_batch(self, inputs, labels=None):