未验证 提交 97faf90e 编写于 作者: S shangliang Xu 提交者: GitHub

add num_iters in fit/evalate (#33986)

* add num_iters in fit/evalate, test=develop
上级 6a36977d
......@@ -1520,8 +1520,7 @@ class Model(object):
if not in_dygraph_mode():
self._adapter.prepare()
def fit(
self,
def fit(self,
train_data=None,
eval_data=None,
batch_size=1,
......@@ -1535,7 +1534,8 @@ class Model(object):
shuffle=True,
num_workers=0,
callbacks=None,
accumulate_grad_batches=1, ):
accumulate_grad_batches=1,
num_iters=None):
"""
Trains the model for a fixed number of epochs. If `eval_data` is set,
evaluation will be done at the end of each epoch.
......@@ -1581,6 +1581,9 @@ class Model(object):
accumulate_grad_batches (int): The number of batches to accumulate gradident
during training process before optimizer updates. It can mimic large batch
size. Default: 1.
num_iters (int|None): Integer number. The number of iterations to train
the model. If None, follow `epochs` to train the model, otherwise, train
the model `num_iters` times. Default: None.
Returns:
None
......@@ -1705,6 +1708,11 @@ class Model(object):
self._accumulate = accumulate_grad_batches
steps = self._len_data_loader(train_loader)
self.num_iters = num_iters
if num_iters is not None and isinstance(num_iters, int):
assert num_iters > 0, "num_iters must be greater than 0!"
epochs = (num_iters // steps) + 1
steps = min(num_iters, steps)
cbks = config_callbacks(
callbacks,
model=self,
......@@ -1742,14 +1750,14 @@ class Model(object):
cbks.on_end('train', logs)
self._test_dataloader = None
def evaluate(
self,
eval_data,
batch_size=1,
log_freq=10,
verbose=2,
num_workers=0,
callbacks=None, ):
def evaluate(self,
eval_data,
batch_size=1,
log_freq=10,
verbose=2,
num_workers=0,
callbacks=None,
num_iters=None):
"""
Evaluate the loss and metrics of the model on input dataset.
......@@ -1771,6 +1779,9 @@ class Model(object):
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None.
num_iters (int|None): Integer number. The number of iterations to
evaluate the model. If None, evaluate on whole input dataset,
otherwise, evaluate `num_iters` times. Default: None.
Returns:
dict: Result of metric. The key is the names of Metric,
value is a scalar or numpy.array.
......@@ -1820,6 +1831,11 @@ class Model(object):
metrics=self._metrics_name(), )
eval_steps = self._len_data_loader(eval_loader)
self.num_iters = num_iters
if num_iters is not None and isinstance(num_iters, int):
assert num_iters > 0, "num_iters must be greater than 0!"
eval_steps = min(num_iters, eval_steps)
self.num_iters = eval_steps
cbks.on_begin('eval',
{'steps': eval_steps,
'metrics': self._metrics_name()})
......@@ -2076,6 +2092,10 @@ class Model(object):
logs['batch_size'] = self._adapter._merge_count[mode + '_batch']
callbacks.on_batch_end(mode, step, logs)
if hasattr(self, 'num_iters') and self.num_iters is not None:
self.num_iters -= 1
if self.num_iters == 0:
break
self._reset_metrics()
if mode == 'predict':
......@@ -2091,7 +2111,7 @@ class Model(object):
one input, input_size can be tuple or InputSpec. if model have multiple
input, input_size must be a list which contain every input's shape.
Default: None.
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
dtype (str, optional): if dtype is None, 'float32' will be used, Default: None.
Returns:
Dict: a summary of the network including total params and total trainable params.
......
......@@ -184,6 +184,12 @@ class TestModel(unittest.TestCase):
def test_fit_static_with_rank(self):
self.fit(False, 2, 0)
def test_fit_dynamic_with_num_iters(self):
self.fit(True, num_iters=1)
def test_fit_static_with_num_iters(self):
self.fit(False, num_iters=1)
def test_evaluate_dygraph(self):
self.evaluate(True)
......@@ -199,7 +205,7 @@ class TestModel(unittest.TestCase):
def test_prepare_context(self):
prepare_distributed_context()
def fit(self, dynamic, num_replicas=None, rank=None):
def fit(self, dynamic, num_replicas=None, rank=None, num_iters=None):
fluid.enable_dygraph(self.device) if dynamic else None
seed = 333
paddle.seed(seed)
......@@ -218,6 +224,14 @@ class TestModel(unittest.TestCase):
result = model.evaluate(self.val_dataset, batch_size=64)
np.testing.assert_allclose(result['acc'], self.acc1)
model.fit(self.train_dataset,
batch_size=64,
shuffle=False,
num_iters=num_iters)
result = model.evaluate(
self.val_dataset, batch_size=64, num_iters=num_iters)
train_sampler = DistributedBatchSampler(
self.train_dataset,
batch_size=64,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册