未验证 提交 a33d563c 编写于 作者: P parap1uie-s 提交者: GitHub

Allow to specify train_bs and eval_bs separately in hapi.fit() (#48032)

* Fix hAPI bug of not compatible with LayerHook

https://github.com/PaddlePaddle/Paddle/issues/47000

* Fix hAPI bug of not compatible with LayerHook

* Allow to specify train_bs and eval_bs separately in hapi.fit()

* Update model.py

* Update Model.py

* Update test_model.py

* update model.py
上级 982d5ff7
......@@ -1713,7 +1713,7 @@ class Model:
evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.io.Dataset or paddle.io.Dataloader
is recomended. Default: None.
batch_size (int, optional): The batch size of train_data and eval_data. When
batch_size (int|list, optional): The batch size of train_data and eval_data. When
train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored. Default: 1.
epochs (int, optional): The number of epochs to train the model. Default: 1.
......@@ -1836,10 +1836,20 @@ class Model:
"""
assert train_data is not None, "train_data must be given!"
if isinstance(batch_size, (tuple, list)) and all(
[isinstance(x, int) for x in batch_size]
):
assert (
len(batch_size) == 2
), "batch_size length error, expected train_batch_size and eval_batch_size."
train_batch_size, eval_batch_size = batch_size
elif isinstance(batch_size, int):
train_batch_size, eval_batch_size = batch_size, batch_size
if isinstance(train_data, Dataset):
train_sampler = DistributedBatchSampler(
train_data,
batch_size=batch_size,
batch_size=train_batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
......@@ -1855,7 +1865,7 @@ class Model:
if eval_data is not None and isinstance(eval_data, Dataset):
eval_sampler = DistributedBatchSampler(
eval_data, batch_size=batch_size
eval_data, batch_size=eval_batch_size
)
eval_loader = DataLoader(
eval_data,
......
......@@ -312,6 +312,8 @@ class TestModel(unittest.TestCase):
self.val_dataset, batch_size=64, num_iters=num_iters
)
model.fit(self.train_dataset, batch_size=(64, 64), shuffle=False)
train_sampler = DistributedBatchSampler(
self.train_dataset,
batch_size=64,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册