From a33d563c5e96624615f6c485f8b74d69a316205c Mon Sep 17 00:00:00 2001 From: parap1uie-s Date: Fri, 18 Nov 2022 11:48:55 +0800 Subject: [PATCH] Allow to specify train_bs and eval_bs separately in hapi.fit() (#48032) * Fix hAPI bug of not compatible with LayerHook https://github.com/PaddlePaddle/Paddle/issues/47000 * Fix hAPI bug of not compatible with LayerHook * Allow to specify train_bs and eval_bs separately in hapi.fit() * Update model.py * Update Model.py * Update test_model.py * update model.py --- python/paddle/hapi/model.py | 16 +++++++++++++--- python/paddle/tests/test_model.py | 2 ++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index b7813932d8..e64aa47e2d 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -1713,7 +1713,7 @@ class Model: evaluation at the end of epoch. If None, will not do evaluation. An instance of paddle.io.Dataset or paddle.io.Dataloader is recomended. Default: None. - batch_size (int, optional): The batch size of train_data and eval_data. When + batch_size (int|list, optional): The batch size of train_data and eval_data. When train_data and eval_data are both the instance of Dataloader, this parameter will be ignored. Default: 1. epochs (int, optional): The number of epochs to train the model. Default: 1. @@ -1836,10 +1836,20 @@ class Model: """ assert train_data is not None, "train_data must be given!" + if isinstance(batch_size, (tuple, list)) and all( + [isinstance(x, int) for x in batch_size] + ): + assert ( + len(batch_size) == 2 + ), "batch_size length error, expected train_batch_size and eval_batch_size." + train_batch_size, eval_batch_size = batch_size + elif isinstance(batch_size, int): + train_batch_size, eval_batch_size = batch_size, batch_size + if isinstance(train_data, Dataset): train_sampler = DistributedBatchSampler( train_data, - batch_size=batch_size, + batch_size=train_batch_size, shuffle=shuffle, drop_last=drop_last, ) @@ -1855,7 +1865,7 @@ class Model: if eval_data is not None and isinstance(eval_data, Dataset): eval_sampler = DistributedBatchSampler( - eval_data, batch_size=batch_size + eval_data, batch_size=eval_batch_size ) eval_loader = DataLoader( eval_data, diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 76a41a56ca..c20761b7cd 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -312,6 +312,8 @@ class TestModel(unittest.TestCase): self.val_dataset, batch_size=64, num_iters=num_iters ) + model.fit(self.train_dataset, batch_size=(64, 64), shuffle=False) + train_sampler = DistributedBatchSampler( self.train_dataset, batch_size=64, -- GitLab