diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index b7813932d86f662fa19da0d802c537392f5d2f35..e64aa47e2d1f7c87795a48686ae0821f7217a9f4 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -1713,7 +1713,7 @@ class Model: evaluation at the end of epoch. If None, will not do evaluation. An instance of paddle.io.Dataset or paddle.io.Dataloader is recomended. Default: None. - batch_size (int, optional): The batch size of train_data and eval_data. When + batch_size (int|list, optional): The batch size of train_data and eval_data. When train_data and eval_data are both the instance of Dataloader, this parameter will be ignored. Default: 1. epochs (int, optional): The number of epochs to train the model. Default: 1. @@ -1836,10 +1836,20 @@ class Model: """ assert train_data is not None, "train_data must be given!" + if isinstance(batch_size, (tuple, list)) and all( + [isinstance(x, int) for x in batch_size] + ): + assert ( + len(batch_size) == 2 + ), "batch_size length error, expected train_batch_size and eval_batch_size." + train_batch_size, eval_batch_size = batch_size + elif isinstance(batch_size, int): + train_batch_size, eval_batch_size = batch_size, batch_size + if isinstance(train_data, Dataset): train_sampler = DistributedBatchSampler( train_data, - batch_size=batch_size, + batch_size=train_batch_size, shuffle=shuffle, drop_last=drop_last, ) @@ -1855,7 +1865,7 @@ class Model: if eval_data is not None and isinstance(eval_data, Dataset): eval_sampler = DistributedBatchSampler( - eval_data, batch_size=batch_size + eval_data, batch_size=eval_batch_size ) eval_loader = DataLoader( eval_data, diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 76a41a56caf94d1397900eb938cfdb05a65d7bcd..c20761b7cd2a43fc3dc7e40fbd08e61952a5cae1 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -312,6 +312,8 @@ class TestModel(unittest.TestCase): self.val_dataset, batch_size=64, num_iters=num_iters ) + model.fit(self.train_dataset, batch_size=(64, 64), shuffle=False) + train_sampler = DistributedBatchSampler( self.train_dataset, batch_size=64,