提交 888c3c49 编写于 作者: Y Yancey1989

do not imprt torch

上级 a08dc90f
...@@ -3,13 +3,9 @@ import os ...@@ -3,13 +3,9 @@ import os
import numpy as np import numpy as np
import math import math
import random import random
import torch
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.datasets as datasets import torchvision.datasets as datasets
from torch.utils.data.sampler import Sampler
import torchvision import torchvision
import pickle import pickle
from tqdm import tqdm from tqdm import tqdm
...@@ -21,7 +17,6 @@ TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", "0")) ...@@ -21,7 +17,6 @@ TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", "0"))
epoch = 0 epoch = 0
FINISH_EVENT = "FINISH_EVENT" FINISH_EVENT = "FINISH_EVENT"
#def paddle_data_loader(torch_dataset, indices=None, concurrent=1, queue_size=3072, use_uint8_reader=False):
class PaddleDataLoader(object): class PaddleDataLoader(object):
def __init__(self, torch_dataset, indices=None, concurrent=16, queue_size=3072): def __init__(self, torch_dataset, indices=None, concurrent=16, queue_size=3072):
self.torch_dataset = torch_dataset self.torch_dataset = torch_dataset
...@@ -97,25 +92,6 @@ def test(valdir, bs, sz, rect_val=False): ...@@ -97,25 +92,6 @@ def test(valdir, bs, sz, rect_val=False):
return PaddleDataLoader(val_dataset).reader() return PaddleDataLoader(val_dataset).reader()
def create_validation_set(valdir, batch_size, target_size, rect_val, distributed):
print("create_validation_set", valdir, batch_size, target_size, rect_val, distributed)
if rect_val:
idx_ar_sorted = sort_ar(valdir)
idx_sorted, _ = zip(*idx_ar_sorted)
idx2ar = map_idx2ar(idx_ar_sorted, batch_size)
ar_tfms = [transforms.Resize(int(target_size * 1.14)), CropArTfm(idx2ar, target_size)]
val_dataset = ValDataset(valdir, transform=ar_tfms)
val_sampler = DistValSampler(idx_sorted, batch_size=batch_size, distributed=distributed)
return val_dataset, val_sampler
val_tfms = [transforms.Resize(int(target_size * 1.14)), transforms.CenterCrop(target_size)]
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
val_sampler = DistValSampler(list(range(len(val_dataset))), batch_size=batch_size, distributed=distributed)
return val_dataset, val_sampler
class ValDataset(datasets.ImageFolder): class ValDataset(datasets.ImageFolder):
def __init__(self, root, transform=None, target_transform=None): def __init__(self, root, transform=None, target_transform=None):
super(ValDataset, self).__init__(root, transform, target_transform) super(ValDataset, self).__init__(root, transform, target_transform)
...@@ -134,43 +110,6 @@ class ValDataset(datasets.ImageFolder): ...@@ -134,43 +110,6 @@ class ValDataset(datasets.ImageFolder):
return sample, target return sample, target
class DistValSampler(Sampler):
# DistValSampler distrbutes batches equally (based on batch size) to every gpu (even if there aren't enough images)
# WARNING: Some baches will contain an empty array to signify there aren't enough images
# Distributed=False - same validation happens on every single gpu
def __init__(self, indices, batch_size, distributed=True):
self.indices = indices
self.batch_size = batch_size
if distributed:
self.world_size = TRAINER_NUMS
self.global_rank = TRAINER_ID
else:
self.global_rank = 0
self.world_size = 1
# expected number of batches per sample. Need this so each distributed gpu validates on same number of batches.
# even if there isn't enough data to go around
self.expected_num_batches = int(math.ceil(len(self.indices) / self.world_size / self.batch_size))
# num_samples = total images / world_size. This is what we distribute to each gpu
self.num_samples = self.expected_num_batches * self.batch_size
def __iter__(self):
offset = self.num_samples * self.global_rank
sampled_indices = self.indices[offset:offset + self.num_samples]
print("DistValSampler: self.world_size: ", self.world_size, " self.global_rank: ", self.global_rank)
for i in range(self.expected_num_batches):
offset = i * self.batch_size
yield sampled_indices[offset:offset + self.batch_size]
def __len__(self):
return self.expected_num_batches
def set_epoch(self, epoch):
return
class CropArTfm(object): class CropArTfm(object):
def __init__(self, idx2ar, target_size): def __init__(self, idx2ar, target_size):
self.idx2ar, self.target_size = idx2ar, target_size self.idx2ar, self.target_size = idx2ar, target_size
......
...@@ -102,7 +102,6 @@ def linear_lr_decay(lr_values, epochs, bs_values, total_images): ...@@ -102,7 +102,6 @@ def linear_lr_decay(lr_values, epochs, bs_values, total_images):
start_lr, end_lr = lr_values[idx] start_lr, end_lr = lr_values[idx]
linear_lr = end_lr - start_lr linear_lr = end_lr - start_lr
steps = last_steps + math.ceil(total_images * 1.0 / bs_values[idx]) * linear_epoch steps = last_steps + math.ceil(total_images * 1.0 / bs_values[idx]) * linear_epoch
linear_lr = end_lr = start_lr
with switch.case(global_step < steps): with switch.case(global_step < steps):
decayed_lr = start_lr + linear_lr * ((global_step - last_steps) * 1.0/steps) decayed_lr = start_lr + linear_lr * ((global_step - last_steps) * 1.0/steps)
last_steps = steps last_steps = steps
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册