未验证 提交 a33d563c 编写于 作者: P parap1uie-s 提交者: GitHub

Allow to specify train_bs and eval_bs separately in hapi.fit() (#48032)

* Fix hAPI bug of not compatible with LayerHook

https://github.com/PaddlePaddle/Paddle/issues/47000

* Fix hAPI bug of not compatible with LayerHook

* Allow to specify train_bs and eval_bs separately in hapi.fit()

* Update model.py

* Update Model.py

* Update test_model.py

* update model.py
上级 982d5ff7
...@@ -1713,7 +1713,7 @@ class Model: ...@@ -1713,7 +1713,7 @@ class Model:
evaluation at the end of epoch. If None, will not do evaluation. evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.io.Dataset or paddle.io.Dataloader An instance of paddle.io.Dataset or paddle.io.Dataloader
is recomended. Default: None. is recomended. Default: None.
batch_size (int, optional): The batch size of train_data and eval_data. When batch_size (int|list, optional): The batch size of train_data and eval_data. When
train_data and eval_data are both the instance of Dataloader, this train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored. Default: 1. parameter will be ignored. Default: 1.
epochs (int, optional): The number of epochs to train the model. Default: 1. epochs (int, optional): The number of epochs to train the model. Default: 1.
...@@ -1836,10 +1836,20 @@ class Model: ...@@ -1836,10 +1836,20 @@ class Model:
""" """
assert train_data is not None, "train_data must be given!" assert train_data is not None, "train_data must be given!"
if isinstance(batch_size, (tuple, list)) and all(
[isinstance(x, int) for x in batch_size]
):
assert (
len(batch_size) == 2
), "batch_size length error, expected train_batch_size and eval_batch_size."
train_batch_size, eval_batch_size = batch_size
elif isinstance(batch_size, int):
train_batch_size, eval_batch_size = batch_size, batch_size
if isinstance(train_data, Dataset): if isinstance(train_data, Dataset):
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
train_data, train_data,
batch_size=batch_size, batch_size=train_batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last, drop_last=drop_last,
) )
...@@ -1855,7 +1865,7 @@ class Model: ...@@ -1855,7 +1865,7 @@ class Model:
if eval_data is not None and isinstance(eval_data, Dataset): if eval_data is not None and isinstance(eval_data, Dataset):
eval_sampler = DistributedBatchSampler( eval_sampler = DistributedBatchSampler(
eval_data, batch_size=batch_size eval_data, batch_size=eval_batch_size
) )
eval_loader = DataLoader( eval_loader = DataLoader(
eval_data, eval_data,
......
...@@ -312,6 +312,8 @@ class TestModel(unittest.TestCase): ...@@ -312,6 +312,8 @@ class TestModel(unittest.TestCase):
self.val_dataset, batch_size=64, num_iters=num_iters self.val_dataset, batch_size=64, num_iters=num_iters
) )
model.fit(self.train_dataset, batch_size=(64, 64), shuffle=False)
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
self.train_dataset, self.train_dataset,
batch_size=64, batch_size=64,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册