提交 888c3c49 编写于 作者: Y Yancey1989

do not imprt torch

上级 a08dc90f
......@@ -3,13 +3,9 @@ import os
import numpy as np
import math
import random
import torch
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import Sampler
import torchvision
import pickle
from tqdm import tqdm
......@@ -21,7 +17,6 @@ TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", "0"))
epoch = 0
FINISH_EVENT = "FINISH_EVENT"
#def paddle_data_loader(torch_dataset, indices=None, concurrent=1, queue_size=3072, use_uint8_reader=False):
class PaddleDataLoader(object):
def __init__(self, torch_dataset, indices=None, concurrent=16, queue_size=3072):
self.torch_dataset = torch_dataset
......@@ -97,25 +92,6 @@ def test(valdir, bs, sz, rect_val=False):
return PaddleDataLoader(val_dataset).reader()
def create_validation_set(valdir, batch_size, target_size, rect_val, distributed):
print("create_validation_set", valdir, batch_size, target_size, rect_val, distributed)
if rect_val:
idx_ar_sorted = sort_ar(valdir)
idx_sorted, _ = zip(*idx_ar_sorted)
idx2ar = map_idx2ar(idx_ar_sorted, batch_size)
ar_tfms = [transforms.Resize(int(target_size * 1.14)), CropArTfm(idx2ar, target_size)]
val_dataset = ValDataset(valdir, transform=ar_tfms)
val_sampler = DistValSampler(idx_sorted, batch_size=batch_size, distributed=distributed)
return val_dataset, val_sampler
val_tfms = [transforms.Resize(int(target_size * 1.14)), transforms.CenterCrop(target_size)]
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
val_sampler = DistValSampler(list(range(len(val_dataset))), batch_size=batch_size, distributed=distributed)
return val_dataset, val_sampler
class ValDataset(datasets.ImageFolder):
def __init__(self, root, transform=None, target_transform=None):
super(ValDataset, self).__init__(root, transform, target_transform)
......@@ -134,43 +110,6 @@ class ValDataset(datasets.ImageFolder):
return sample, target
class DistValSampler(Sampler):
# DistValSampler distrbutes batches equally (based on batch size) to every gpu (even if there aren't enough images)
# WARNING: Some baches will contain an empty array to signify there aren't enough images
# Distributed=False - same validation happens on every single gpu
def __init__(self, indices, batch_size, distributed=True):
self.indices = indices
self.batch_size = batch_size
if distributed:
self.world_size = TRAINER_NUMS
self.global_rank = TRAINER_ID
else:
self.global_rank = 0
self.world_size = 1
# expected number of batches per sample. Need this so each distributed gpu validates on same number of batches.
# even if there isn't enough data to go around
self.expected_num_batches = int(math.ceil(len(self.indices) / self.world_size / self.batch_size))
# num_samples = total images / world_size. This is what we distribute to each gpu
self.num_samples = self.expected_num_batches * self.batch_size
def __iter__(self):
offset = self.num_samples * self.global_rank
sampled_indices = self.indices[offset:offset + self.num_samples]
print("DistValSampler: self.world_size: ", self.world_size, " self.global_rank: ", self.global_rank)
for i in range(self.expected_num_batches):
offset = i * self.batch_size
yield sampled_indices[offset:offset + self.batch_size]
def __len__(self):
return self.expected_num_batches
def set_epoch(self, epoch):
return
class CropArTfm(object):
def __init__(self, idx2ar, target_size):
self.idx2ar, self.target_size = idx2ar, target_size
......
......@@ -102,7 +102,6 @@ def linear_lr_decay(lr_values, epochs, bs_values, total_images):
start_lr, end_lr = lr_values[idx]
linear_lr = end_lr - start_lr
steps = last_steps + math.ceil(total_images * 1.0 / bs_values[idx]) * linear_epoch
linear_lr = end_lr = start_lr
with switch.case(global_step < steps):
decayed_lr = start_lr + linear_lr * ((global_step - last_steps) * 1.0/steps)
last_steps = steps
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册