提交 bf2d8599 编写于 作者: L LielinJiang

add predict

上级 a3e6b21e
......@@ -25,7 +25,6 @@ from paddle.fluid.layers import collective
from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy
from paddle.fluid.io import BatchSampler
_parallel_context_initialized = False
......@@ -67,7 +66,8 @@ class DistributedBatchSampler(BatchSampler):
self.nranks = ParallelEnv().nranks
self.local_rank = ParallelEnv().local_rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.num_samples = int(
math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
def __iter__(self):
......@@ -78,9 +78,28 @@ class DistributedBatchSampler(BatchSampler):
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1
# subsample
indices = indices[self.local_rank * self.num_samples:
(self.local_rank + 1) * self.num_samples]
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
......@@ -103,7 +122,8 @@ class DistributedBatchSampler(BatchSampler):
def _all_gather(x, nranks, ring_id=0, use_calc_stream=True):
return collective._c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream)
return collective._c_allgather(
x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream)
def wait_server_ready(endpoints):
......@@ -114,8 +134,7 @@ def wait_server_ready(endpoints):
for ep in endpoints:
ip_port = ep.split(":")
with contextlib.closing(
socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as sock:
socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
......@@ -127,8 +146,8 @@ def wait_server_ready(endpoints):
break
def init_communicator(program, rank, nranks, wait_port,
current_endpoint, endpoints):
def init_communicator(program, rank, nranks, wait_port, current_endpoint,
endpoints):
if nranks < 2:
return
other_endpoints = endpoints[:]
......@@ -178,11 +197,14 @@ def prepare_distributed_context(place=None):
global _parallel_context_initialized
if not _parallel_context_initialized and isinstance(place, fluid.CUDAPlace):
if not _parallel_context_initialized and isinstance(place,
fluid.CUDAPlace):
def _init_context():
communicator_prog = fluid.Program()
init_communicator(communicator_prog, strategy.local_rank, strategy.nranks,
True, strategy.current_endpoint, strategy.trainer_endpoints)
init_communicator(communicator_prog, strategy.local_rank,
strategy.nranks, True, strategy.current_endpoint,
strategy.trainer_endpoints)
exe = fluid.Executor(place)
exe.run(communicator_prog)
......
......@@ -20,6 +20,7 @@ import pickle
import numpy as np
import six
import warnings
import tqdm
from collections import Iterable
from paddle import fluid
......@@ -587,10 +588,8 @@ class DynamicGraphAdapter(object):
samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples >= total_size:
outputs = [
o[:total_size - metric.count[0]] for o in outputs
]
labels = [l[:total_size - metric.count[0]] for l in labels]
outputs = [o[:total_size - current_count] for o in outputs]
labels = [l[:total_size - current_count] for l in labels]
self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode +
'_batch'] = total_size - current_count
......@@ -612,8 +611,9 @@ class DynamicGraphAdapter(object):
self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)]
outputs = self.model.forward(*inputs)
if self._nranks > 2:
if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace):
outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)]
return [to_numpy(o) for o in to_list(outputs)]
def parameters(self, *args, **kwargs):
......@@ -1012,12 +1012,13 @@ class Model(fluid.dygraph.Layer):
FIXME: add more comments and usage
Args:
eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader
is recomended.
evaluation. An instance of paddle.fluid.io.Dataset or
paddle.fluid.io.Dataloader is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
log_freq (int): The frequency, in number of steps, the training logs
is printed.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
......@@ -1043,10 +1044,8 @@ class Model(fluid.dygraph.Layer):
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
elif eval_data is not None:
eval_loader = eval_data
else:
eval_loader = None
eval_loader = eval_data
self._test_dataloader = eval_loader
metrics_name = self._metrics_name()
......@@ -1069,6 +1068,74 @@ class Model(fluid.dygraph.Layer):
self._test_dataloader = None
eval_result = {}
for k in self._metrics_name():
eval_result[k] = logs[k]
return eval_result
def predict(self, test_data, batch_size=1, num_workers=0, callbacks=None):
"""
FIXME: add more comments and usage
Args:
test_data (Dataset|DataLoader): An iterable data loader is used for
predict. An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader
is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted.
"""
if fluid.in_dygraph_mode():
feed_list = None
else:
feed_list = [x.forward() for x in self._inputs + self._labels]
if test_data is not None and isinstance(test_data, Dataset):
test_sampler = DistributedBatchSampler(
test_data, batch_size=batch_size)
test_loader = DataLoader(
test_data,
batch_sampler=test_sampler,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
else:
test_loader = test_data
self._test_dataloader = test_loader
loader = test_loader
if not isinstance(test_loader, Iterable):
loader = test_loader()
outputs = None
for data in tqdm.tqdm(loader):
if not fluid.in_dygraph_mode():
data = data[0]
outs = self.test(*data)
if outputs is None:
outputs = outs
else:
outputs = [
np.vstack([x, outs[i]]) for i, x in enumerate(outputs)
]
self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \
and isinstance(test_loader, DataLoader):
outputs = [o[:len(test_loader.dataset)] for o in outputs]
return outputs
def set_eval_data(self, eval_data):
"""
Args:
......
......@@ -139,6 +139,26 @@ class MyCrossEntropy(Loss):
return [loss1, loss2]
class TestMnistDataset(MnistDataset):
def __init__(self):
super(TestMnistDataset, self).__init__(mode='test')
def __getitem__(self, idx):
return self.images[idx],
def __len__(self):
return len(self.images)
def get_predict_accuracy(pred, gt):
pred = np.argmax(pred, -1)
gt = np.array(gt)
correct = pred[:, np.newaxis] == gt
return np.sum(correct) / correct.shape[0]
class TestModel(unittest.TestCase):
def fit(self, dynamic, is_mlp=False):
device = set_device('gpu')
......@@ -152,6 +172,7 @@ class TestModel(unittest.TestCase):
train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test')
test_dataset = TestMnistDataset()
model = MNIST() if not is_mlp else MLP()
optim = fluid.optimizer.Momentum(
......@@ -166,7 +187,15 @@ class TestModel(unittest.TestCase):
batch_size=batch_size,
callbacks=cbk)
model.evaluate(val_dataset, batch_size=batch_size)
eval_result = model.evaluate(val_dataset, batch_size=batch_size)
output = model.predict(test_dataset, batch_size=batch_size)
np.testing.assert_equal(output[0].shape[0], len(test_dataset))
acc = get_predict_accuracy(output[0], val_dataset.labels)
np.testing.assert_allclose(acc, eval_result['acc'])
def test_fit_static(self):
self.fit(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册