提交 994975bd 编写于 作者: L LielinJiang

remove nouse code

上级 c3ba953b
...@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear ...@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input from model import Model, CrossEntropy, Input
from metrics import Accuracy from metrics import Accuracy
from distributed import prepare_context, all_gather, Env, get_nranks, get_local_rank, DistributedBatchSampler, to_numpy from distributed import prepare_context, Env, get_nranks, DistributedBatchSampler
from paddle.fluid.io import BatchSampler, DataLoader, MnistDataset from paddle.fluid.io import BatchSampler, DataLoader, MnistDataset
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
...@@ -112,7 +112,8 @@ class CustromMnistDataset(MnistDataset): ...@@ -112,7 +112,8 @@ class CustromMnistDataset(MnistDataset):
label_filename=None, label_filename=None,
mode='train', mode='train',
download=True): download=True):
super(CustromMnistDataset, self).__init__(image_filename, label_filename, mode, download) super(CustromMnistDataset, self).__init__(image_filename,
label_filename, mode, download)
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -135,7 +136,6 @@ def main(): ...@@ -135,7 +136,6 @@ def main():
os.mkdir('mnist_checkpoints') os.mkdir('mnist_checkpoints')
with guard: with guard:
train_dataset = CustromMnistDataset(mode='train') train_dataset = CustromMnistDataset(mode='train')
val_dataset = CustromMnistDataset(mode='test') val_dataset = CustromMnistDataset(mode='test')
......
...@@ -350,6 +350,7 @@ class StaticGraphAdapter(object): ...@@ -350,6 +350,7 @@ class StaticGraphAdapter(object):
if self.mode != 'train' and self.model._test_dataloader is not None \ if self.mode != 'train' and self.model._test_dataloader is not None \
and self._nranks > 1: and self._nranks > 1:
total_size = len(self.model._test_dataloader.dataset) total_size = len(self.model._test_dataloader.dataset)
# TODO: fixme if have better way to get batch size
samples = state[0].shape[0] samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode, 0) current_count = self._merge_count.get(self.mode, 0)
if current_count + samples > total_size: if current_count + samples > total_size:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册