提交 994975bd 编写于 作者: L LielinJiang

remove nouse code

上级 c3ba953b
......@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input
from metrics import Accuracy
from distributed import prepare_context, all_gather, Env, get_nranks, get_local_rank, DistributedBatchSampler, to_numpy
from distributed import prepare_context, Env, get_nranks, DistributedBatchSampler
from paddle.fluid.io import BatchSampler, DataLoader, MnistDataset
class SimpleImgConvPool(fluid.dygraph.Layer):
......@@ -112,7 +112,8 @@ class CustromMnistDataset(MnistDataset):
label_filename=None,
mode='train',
download=True):
super(CustromMnistDataset, self).__init__(image_filename, label_filename, mode, download)
super(CustromMnistDataset, self).__init__(image_filename,
label_filename, mode, download)
def __getitem__(self, idx):
......@@ -135,7 +136,6 @@ def main():
os.mkdir('mnist_checkpoints')
with guard:
train_dataset = CustromMnistDataset(mode='train')
val_dataset = CustromMnistDataset(mode='test')
......
......@@ -350,6 +350,7 @@ class StaticGraphAdapter(object):
if self.mode != 'train' and self.model._test_dataloader is not None \
and self._nranks > 1:
total_size = len(self.model._test_dataloader.dataset)
# TODO: fixme if have better way to get batch size
samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode, 0)
if current_count + samples > total_size:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册