提交 f353d34b 编写于 作者: L LielinJiang

multiple gputs

上级 90134440
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import math
import socket
import contextlib
from contextlib import closing
from six import string_types
import numpy as np
from collections import OrderedDict
from paddle import fluid
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.layers import collective
from paddle.fluid.dygraph import to_variable, no_grad, layers
from paddle.fluid.framework import Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.parallel import Env, DataParallel, ParallelStrategy
from paddle.fluid.layers.collective import _c_allreduce, _c_allgather, _c_broadcast, _c_sync_comm_stream, _c_sync_calc_stream
from paddle.fluid.io import BatchSampler, DataLoader
class DistributedBatchSampler(BatchSampler):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
data_source: this could be a `fluid.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
shuffle(bool): whther to shuffle indices order before genrate
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
"""
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
self.dataset = dataset
self.sample_iter = None
assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer"
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value"
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
self.drop_last = drop_last
self.nranks = get_nranks()
self.local_rank = get_local_rank()
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
def __iter__(self):
_sample_iter = self.sample_iter
if _sample_iter is None:
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1
# subsample
indices = indices[self.local_rank * self.num_samples: (self.local_rank + 1) * self.num_samples]
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
@contextlib.contextmanager
def null_guard():
yield
def to_numpy(var):
assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
if isinstance(var, fluid.core.VarBase):
return var.numpy()
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
def all_gather(input):
place = fluid.CUDAPlace(Env().dev_id) \
if Env().nranks > 1 else fluid.CUDAPlace(0)
guard = null_guard() if fluid.in_dygraph_mode() else fluid.dygraph.guard(place)
with guard:
input = to_variable(input)
output = _all_gather(input, Env().nranks)
return to_numpy(output)
def _all_reduce(x, out=None, reduce_type="sum", sync_mode=True):
out = _c_allreduce(x, out, reduce_type)
if sync_mode:
return _c_sync_calc_stream(out)
def _all_gather(x, nranks, ring_id=0, use_calc_stream=True):
return _c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream)
def _bradcast(x, root=0, ring_id=0, use_calc_stream=True):
return _c_broadcast(x, root, ring_id, use_calc_stream)
def _sync_comm_stream(x, ring_id):
return _c_sync_comm_stream(x, ring_id)
def barrier():
pass
def get_local_rank():
return Env().local_rank
def get_nranks():
return Env().nranks
def wait_server_ready(endpoints):
assert not isinstance(endpoints, string_types)
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with closing(
socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(
not_ready_endpoints) + "\n")
sys.stderr.flush()
time.sleep(3)
else:
break
def initCommunicator(program, rank, nranks, wait_port,
current_endpoint, endpoints):
if nranks < 2:
return
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=nameGen.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
})
def prepare_context(place):
strategy = ParallelStrategy()
strategy.nranks = Env().nranks
strategy.local_rank = Env().local_rank
strategy.trainer_endpoints = Env().trainer_endpoints
strategy.current_endpoint = Env().current_endpoint
if strategy.nranks < 2:
return
if isinstance(place, core.CUDAPlace):
communicator_prog = framework.Program()
initCommunicator(communicator_prog, strategy.local_rank, strategy.nranks, True,
strategy.current_endpoint, strategy.trainer_endpoints)
exe = fluid.Executor(place)
exe.run(communicator_prog)
else:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert ("Only support CUDAPlace for now.")
return strategy
class DistributedDataParallel(DataParallel):
def __init__(self, layers, strategy=None):
if strategy is None:
strategy = ParallelStrategy()
strategy.nranks = Env().nranks
strategy.local_rank = Env().local_rank
strategy.trainer_endpoints = Env().trainer_endpoints
strategy.current_endpoint = Env().current_endpoint
super(DistributedDataParallel, self).__init__(layers, strategy)
@no_grad
def apply_collective_grads(self):
"""
AllReduce the Parameters' gradient.
"""
if not self._is_data_parallel_mode():
return
grad_var_set = set()
grad_vars = []
for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and param._grad_ivar():
g_var = param._grad_ivar()
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)
# FIXME(zcd): the type of the var should be LoDTensor, i.e
# the gradients should be dense, otherwise, the following
# logic should be updated.
# 128 MB as a group
mega_bytes = 128 * 1024 * 1024
group_idx = 0
memory_counter = 0
grad_var_groups = OrderedDict()
dtype = grad_vars[0].dtype
for g_var in grad_vars:
# Note: the dtype of the same group should be the same.
bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype)
if memory_counter < mega_bytes and dtype == g_var.dtype:
memory_counter += bytes
else:
memory_counter = bytes
group_idx += 1
grad_var_groups.setdefault(group_idx, []).append(g_var)
coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups)
for coalesced_grad, _, _ in coalesced_grads_and_vars:
collective._c_allreduce(coalesced_grad, coalesced_grad, use_calc_stream=True)
self._split_tensors(coalesced_grads_and_vars)
from __future__ import division
from __future__ import print_function
import os
import sys
sys.path.append('../')
import argparse
import contextlib
import time
import numpy as np
import paddle.fluid as fluid
from model import CrossEntropy, Input
from nets import ResNet
from distributed import prepare_context, all_gather, Env, get_local_rank, get_nranks, DistributedBatchSampler
from utils import ImageNetDataset
from metrics import Accuracy
from models.resnet import resnet50
from paddle.fluid.io import BatchSampler, DataLoader
def run(model, loader, mode='train'):
total_loss = 0
total_time = 0.0 #AverageMeter()
local_rank = get_local_rank()
start = time.time()
start_time = time.time()
for idx, batch in enumerate(loader()):
if not fluid.in_dygraph_mode():
batch = batch[0]
losses, metrics = getattr(model, mode)(
batch[0], batch[1])
if idx > 1: # skip first two steps
total_time += time.time() - start
total_loss += np.sum(losses)
if idx % 10 == 0 and local_rank == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}% top5: {:0.3f}% time: {:0.3f} samples: {}".format(
idx, total_loss / (idx + 1), metrics[0][0] * 100, metrics[0][1] * 100, total_time / max(1, (idx - 1)), model._metrics[0].count[0]))
start = time.time()
eval_time = time.time() - start_time
for metric in model._metrics:
res = metric.accumulate()
if local_rank == 0:
print("[EVAL END]: top1: {:0.3f}%, top5: {:0.3f} total samples: {} total time: {:.3f}".format(res[0] * 100, res[1] * 100, model._metrics[0].count[0], eval_time))
metric.reset()
def main():
@contextlib.contextmanager
def null_guard():
yield
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if get_nranks() > 1 else fluid.CUDAPlace(0)
guard = fluid.dygraph.guard(place) if FLAGS.dynamic else null_guard()
if get_nranks() > 1:
prepare_context(place)
if get_nranks() > 1 and not os.path.exists('resnet_checkpoints'):
os.mkdir('resnet_checkpoints')
with guard:
# model = ResNet()
model = resnet50(pretrained=True)
inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
if fluid.in_dygraph_mode():
feed_list = None
else:
feed_list = [x.forward() for x in inputs + labels]
val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val')
if get_nranks() > 1:
distributed_sampler = DistributedBatchSampler(val_dataset, batch_size=FLAGS.batch_size)
val_loader = DataLoader(val_dataset, batch_sampler=distributed_sampler, places=place,
feed_list=feed_list, num_workers=4, return_list=True)
else:
val_loader = DataLoader(val_dataset, batch_size=FLAGS.batch_size, places=place,
feed_list=feed_list, num_workers=4, return_list=True)
model.prepare(None, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels, val_dataset)
# model.save('resnet_checkpoints/{:03d}'.format(000))
if FLAGS.resume is not None:
model.load(FLAGS.resume)
run(model, val_loader, mode='eval')
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument('data', metavar='DIR', help='path to dataset '
'(should have subdirectories named "train" and "val"')
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=90, type=int, help="number of epoch")
parser.add_argument(
'--lr', '--learning-rate', default=0.1, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=4, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=1, type=int, help="number of devices")
parser.add_argument(
"-r", "--resume", default=None, type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import sys
sys.path.append('../')
import time
import math
import numpy as np
import paddle.fluid as fluid
from model import CrossEntropy, Input
from utils import AverageMeter, accuracy, ImageNetDataset
from distributed import prepare_context, all_gather, Env, get_nranks, get_local_rank, DistributedBatchSampler
from models import resnet50
from metrics import Accuracy
from paddle.fluid.io import BatchSampler, DataLoader
def make_optimizer(parameter_list=None):
total_images = 1281167
base_lr = FLAGS.lr
momentum = 0.9
weight_decay = 1e-4
step_per_epoch = int(math.floor(float(total_images) / FLAGS.batch_size))
boundaries = [step_per_epoch * e for e in [30, 60, 90]]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=5 * step_per_epoch,
start_lr=0.,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=momentum,
regularization=fluid.regularizer.L2Decay(weight_decay),
parameter_list=parameter_list)
return optimizer
def run(model, loader, mode='train'):
total_loss = 0
total_time = 0.0
local_rank = get_local_rank()
start = time.time()
start_time = time.time()
for idx, batch in enumerate(loader()):
if not fluid.in_dygraph_mode():
batch = batch[0]
losses, metrics = getattr(model, mode)(
batch[0], batch[1])
if idx > 1: # skip first two steps
total_time += time.time() - start
total_loss += np.sum(losses)
if idx % 10 == 0 and local_rank == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}% top5: {:0.3f}% time: {:0.3f} samples: {}".format(
idx, total_loss / (idx + 1), metrics[0][0] * 100, metrics[0][1] * 100, total_time / max(1, (idx - 1)), model._metrics[0].count[0]))
start = time.time()
eval_time = time.time() - start_time
for metric in model._metrics:
res = metric.accumulate()
if local_rank == 0 and mode == 'eval':
print("[EVAL END]: top1: {:0.3f}%, top5: {:0.3f} total samples: {} total time: {:.3f}".format(res[0] * 100, res[1] * 100, model._metrics[0].count[0], eval_time))
metric.reset()
def main():
@contextlib.contextmanager
def null_guard():
yield
epoch = FLAGS.epoch
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if fluid.dygraph.parallel.Env().nranks > 1 else fluid.CUDAPlace(0)
guard = fluid.dygraph.guard(place) if FLAGS.dynamic else null_guard()
if fluid.dygraph.parallel.Env().nranks > 1:
prepare_context(place)
if not os.path.exists('resnet_checkpoints'):
os.mkdir('resnet_checkpoints')
with guard:
model = resnet50()
optim = make_optimizer(parameter_list=model.parameters())
inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
if fluid.in_dygraph_mode():
feed_list = None
else:
feed_list = [x.forward() for x in inputs + labels]
train_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='train')
val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val')
if get_nranks() > 1:
train_sampler = DistributedBatchSampler(train_dataset, batch_size=FLAGS.batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, places=place,
feed_list=feed_list, num_workers=0, return_list=True)
val_sampler = DistributedBatchSampler(val_dataset, batch_size=FLAGS.batch_size)
val_loader = DataLoader(val_dataset, batch_sampler=val_sampler, places=place,
feed_list=feed_list, num_workers=0, return_list=True)
else:
train_loader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, places=place,
feed_list=feed_list, num_workers=0, return_list=True)
val_loader = DataLoader(val_dataset, batch_size=FLAGS.batch_size, places=place,
feed_list=feed_list, num_workers=0, return_list=True)
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels, val_dataset)
if FLAGS.resume is not None:
model.load(FLAGS.resume)
for e in range(epoch):
if get_local_rank() == 0:
print("======== train epoch {} ========".format(e))
run(model, train_loader)
model.save('resnet_checkpoints/{:02d}'.format(e))
if get_local_rank() == 0:
print("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval')
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument('data', metavar='DIR', help='path to dataset '
'(should have subdirectories named "train" and "val"')
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=120, type=int, help="number of epoch")
parser.add_argument(
'--lr', '--learning-rate', default=0.1, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=256, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=1, type=int, help="number of devices")
parser.add_argument(
"-r", "--resume", default=None, type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
import os
import cv2
import math
import random
import numpy as np
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0]
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(100.0 * correct_k / batch_size)
return res
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i: i + c, j: j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i: i + h, j: j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
return img[:, ::-1, :]
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
def image_folder(path):
valid_ext = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.webp')
classes = [d for d in os.listdir(path) if
os.path.isdir(os.path.join(path, d))]
classes.sort()
class_map = {cls: idx for idx, cls in enumerate(classes)}
samples = []
for dir in sorted(class_map.keys()):
d = os.path.join(path, dir)
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
p = os.path.join(root, fname)
if os.path.splitext(p)[1].lower() in valid_ext:
samples.append((p, [class_map[dir]]))
return samples
class ImageNetDataset:
def __init__(self, path, mode='train'):
self.samples = image_folder(path)
self.mode = mode
if self.mode == 'train':
self.transform = compose([cv2.imread, random_crop_resize, random_flip,
normalize_permute])
else:
self.transform = compose([cv2.imread, center_crop_resize, normalize_permute])
def __getitem__(self, idx):
return self.transform(self.samples[idx])
def __len__(self):
return len(self.samples)
\ No newline at end of file
......@@ -137,6 +137,9 @@ class StaticGraphAdapter(object):
self._progs = {}
self._compiled_progs = {}
self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank
@property
def mode(self):
return self.model.mode
......@@ -353,6 +356,12 @@ class StaticGraphAdapter(object):
metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
metrics = []
for metric, state in zip(self.model._metrics, metric_states):
# cut off padding size
if self.model._dataset is not None and self._nranks > 1:
total_size = len(self.model._dataset)
samples = state[0].shape[0]
if metric.count[0] + samples > total_size:
state = [s[:total_size - metric.count[0], ...] for s in state]
metrics.append(metric.update(*state))
return (losses, metrics) if len(metrics) > 0 else losses
......@@ -393,12 +402,12 @@ class StaticGraphAdapter(object):
lbls = self.model._labels if self.model._labels else []
inputs = [k.forward() for k in to_list(ins)]
labels = [k.forward() for k in to_list(lbls)]
self._label_vars[mode] = labels
outputs = to_list(self.model.forward(*inputs))
if mode != 'test':
if self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
if self._nranks > 1:
......@@ -410,19 +419,24 @@ class StaticGraphAdapter(object):
self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer, strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint)
if self.mode != 'train':
if self._nranks > 1 and mode != 'train' and self.model._dataset is not None:
outputs = [distributed._all_gather(o, self._nranks) for o in outputs]
if self.mode != 'test':
label_vars = [distributed._all_gather(l, self._nranks) for l in label_vars]
if mode != 'test':
labels = [distributed._all_gather(l, self._nranks) for l in labels]
if mode != 'test':
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode != 'train': # clone again to put it in test mode
prog = prog.clone(for_test=True)
self._input_vars[mode] = inputs
self._label_vars[mode] = labels
self._progs[mode] = prog
self._endpoints[mode] = {"output": outputs, "loss": losses, "metric": metrics}
def _compile_and_initialize(self, prog, mode):
compiled_prog = self._compiled_progs.get(mode, None)
if compiled_prog is not None:
......@@ -457,10 +471,6 @@ class StaticGraphAdapter(object):
startup_prog = self._startup_prog._prune(uninitialized)
self._executor.run(startup_prog)
if self.mode == 'train' and self._lazy_load_optimizer:
self._load_optimizer(self._lazy_load_optimizer)
self._lazy_load_optimizer = None
if self._nranks < 2:
compiled_prog = fluid.CompiledProgram(prog)
else:
......@@ -518,7 +528,7 @@ class DynamicGraphAdapter(object):
self.model.clear_gradients()
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, to_list(labels))
metric_outs = metric.add_metric_op(to_list(outputs), to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return ([to_numpy(l) for l in losses], metrics) \
......@@ -539,7 +549,15 @@ class DynamicGraphAdapter(object):
labels = [distributed._all_gather(l, self._nranks) for l in labels]
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, labels)
# cut off padding value.
if self.model._dataset is not None and self._nranks > 1:
total_size = len(self.model._dataset)
samples = outputs[0].shape[0]
if metric.count[0] + samples > total_size:
outputs = [o[:total_size - metric.count[0]] for o in outputs]
labels = [l[:total_size - metric.count[0]] for l in labels]
metric_outs = metric.add_metric_op(to_list(outputs), labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
......@@ -637,6 +655,8 @@ class Model(fluid.dygraph.Layer):
self._device = None
self._device_ids = None
self._optimizer = None
self._dataset = None
self._distributed_sampler = None
if in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self)
else:
......@@ -664,6 +684,7 @@ class Model(fluid.dygraph.Layer):
metrics=None,
inputs=None,
labels=None,
dataset=None,
device=None,
device_ids=None):
"""
......@@ -722,6 +743,7 @@ class Model(fluid.dygraph.Layer):
self._inputs = inputs
self._labels = labels
self._device = device
self._dataset = dataset
if device is None:
self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU'
self._device_ids = device_ids
......
from .resnet import *
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import shutil
import requests
import tqdm
import hashlib
import time
import logging
logger = logging.getLogger(__name__)
__all__ = [
'get_weights_path'
]
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
DOWNLOAD_RETRY_LIMIT = 3
def get_weights_path(url):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
"""
path, _ = get_path(url, WEIGHTS_HOME)
return path
def map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
"""
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)
exist_flag = False
if osp.exists(fullpath) and check_exist:
exist_flag = True
logger.info("Found {}".format(fullpath))
else:
if int(os.getenv("PADDLE_TRAINER_ID", "0")) == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
return fullpath, exist_flag
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from model import Model
from .download import get_weights_path
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
model_urls = {
'resnet50': 'https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams'
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BasicBlock(fluid.dygraph.Layer):
expansion = 1
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(x)
# return fluid.layers.relu(x)
class ResNet(Model):
def __init__(self, Block, depth=50, num_classes=1000):
super(ResNet, self).__init__()
layer_config = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
assert depth in layer_config.keys(), \
"supported depth are {} but input layer is {}".format(
layer_config.keys(), depth)
layers = layer_config[depth]
num_in = [64, 256, 512, 1024]
num_out = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool = Pool2D(
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
self.layers = []
for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False
for b in range(num_blocks):
block = Block(
num_channels=num_in[idx] if b == 0 else num_out[idx] * 4,
num_filters=num_out[idx],
stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut)
blocks.append(block)
shortcut = True
layer = self.add_sublayer(
"layer_{}".format(idx),
Sequential(*blocks))
self.layers.append(layer)
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.fc_input_dim = num_out[-1] * 4 * 1 * 1
self.fc = Linear(self.fc_input_dim,
num_classes,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained):
model = ResNet(Block, depth)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(arch)
weight_path = get_weights_path(model_urls[arch])
assert weight_path.endswith('.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def resnet50(pretrained=False):
return _resnet('resnet50', BottleneckBlock, 50, pretrained)
def resnet101(pretrained=False):
return _resnet('resnet101', BottleneckBlock, 101, pretrained)
def resnet152(pretrained=False):
return _resnet('resnet152', BottleneckBlock, 152, pretrained)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册