提交 37a67272 编写于 作者: Q qingqing01

Merge branch 'master' of https://github.com/PaddlePaddle/hapi into cyclegan

...@@ -242,6 +242,12 @@ class ProgBarLogger(Callback): ...@@ -242,6 +242,12 @@ class ProgBarLogger(Callback):
samples = logs.get('batch_size', 1) samples = logs.get('batch_size', 1)
self.evaled_samples += samples self.evaled_samples += samples
if self.eval_step % self.log_freq == 0 and self.verbose and ParallelEnv(
).local_rank == 0:
# if steps is not None, last step will update in on_epoch_end
if self.eval_steps and self.eval_step < self.eval_steps:
self._updates(logs, 'eval')
def on_eval_end(self, logs=None): def on_eval_end(self, logs=None):
logs = logs or {} logs = logs or {}
if self.verbose and ParallelEnv().local_rank == 0: if self.verbose and ParallelEnv().local_rank == 0:
......
...@@ -25,7 +25,6 @@ from paddle.fluid.layers import collective ...@@ -25,7 +25,6 @@ from paddle.fluid.layers import collective
from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy
from paddle.fluid.io import BatchSampler from paddle.fluid.io import BatchSampler
_parallel_context_initialized = False _parallel_context_initialized = False
...@@ -67,7 +66,8 @@ class DistributedBatchSampler(BatchSampler): ...@@ -67,7 +66,8 @@ class DistributedBatchSampler(BatchSampler):
self.nranks = ParallelEnv().nranks self.nranks = ParallelEnv().nranks
self.local_rank = ParallelEnv().local_rank self.local_rank = ParallelEnv().local_rank
self.epoch = 0 self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) self.num_samples = int(
math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks self.total_size = self.num_samples * self.nranks
def __iter__(self): def __iter__(self):
...@@ -78,9 +78,28 @@ class DistributedBatchSampler(BatchSampler): ...@@ -78,9 +78,28 @@ class DistributedBatchSampler(BatchSampler):
if self.shuffle: if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices) np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1 self.epoch += 1
# subsample # subsample
indices = indices[self.local_rank * self.num_samples: def _get_indices_by_batch_size(indices):
(self.local_rank + 1) * self.num_samples] subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
_sample_iter = iter(indices) _sample_iter = iter(indices)
...@@ -103,7 +122,8 @@ class DistributedBatchSampler(BatchSampler): ...@@ -103,7 +122,8 @@ class DistributedBatchSampler(BatchSampler):
def _all_gather(x, nranks, ring_id=0, use_calc_stream=True): def _all_gather(x, nranks, ring_id=0, use_calc_stream=True):
return collective._c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream) return collective._c_allgather(
x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream)
def wait_server_ready(endpoints): def wait_server_ready(endpoints):
...@@ -114,8 +134,7 @@ def wait_server_ready(endpoints): ...@@ -114,8 +134,7 @@ def wait_server_ready(endpoints):
for ep in endpoints: for ep in endpoints:
ip_port = ep.split(":") ip_port = ep.split(":")
with contextlib.closing( with contextlib.closing(
socket.socket(socket.AF_INET, socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
socket.SOCK_STREAM)) as sock:
sock.settimeout(2) sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1]))) result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0: if result != 0:
...@@ -127,8 +146,8 @@ def wait_server_ready(endpoints): ...@@ -127,8 +146,8 @@ def wait_server_ready(endpoints):
break break
def init_communicator(program, rank, nranks, wait_port, def init_communicator(program, rank, nranks, wait_port, current_endpoint,
current_endpoint, endpoints): endpoints):
if nranks < 2: if nranks < 2:
return return
other_endpoints = endpoints[:] other_endpoints = endpoints[:]
...@@ -166,7 +185,7 @@ def prepare_distributed_context(place=None): ...@@ -166,7 +185,7 @@ def prepare_distributed_context(place=None):
if place is None: if place is None:
place = fluid.CUDAPlace(ParallelEnv().dev_id) if ParallelEnv().nranks > 1 \ place = fluid.CUDAPlace(ParallelEnv().dev_id) if ParallelEnv().nranks > 1 \
else fluid.CUDAPlace(0) else fluid.CUDAPlace(0)
strategy = ParallelStrategy() strategy = ParallelStrategy()
strategy.nranks = ParallelEnv().nranks strategy.nranks = ParallelEnv().nranks
strategy.local_rank = ParallelEnv().local_rank strategy.local_rank = ParallelEnv().local_rank
...@@ -178,11 +197,14 @@ def prepare_distributed_context(place=None): ...@@ -178,11 +197,14 @@ def prepare_distributed_context(place=None):
global _parallel_context_initialized global _parallel_context_initialized
if not _parallel_context_initialized and isinstance(place, fluid.CUDAPlace): if not _parallel_context_initialized and isinstance(place,
fluid.CUDAPlace):
def _init_context(): def _init_context():
communicator_prog = fluid.Program() communicator_prog = fluid.Program()
init_communicator(communicator_prog, strategy.local_rank, strategy.nranks, init_communicator(communicator_prog, strategy.local_rank,
True, strategy.current_endpoint, strategy.trainer_endpoints) strategy.nranks, True, strategy.current_endpoint,
strategy.trainer_endpoints)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(communicator_prog) exe.run(communicator_prog)
...@@ -197,4 +219,4 @@ def prepare_distributed_context(place=None): ...@@ -197,4 +219,4 @@ def prepare_distributed_context(place=None):
assert ("Only support CUDAPlace for now.") assert ("Only support CUDAPlace for now.")
_parallel_context_initialized = True _parallel_context_initialized = True
return strategy return strategy
\ No newline at end of file
...@@ -20,6 +20,7 @@ import pickle ...@@ -20,6 +20,7 @@ import pickle
import numpy as np import numpy as np
import six import six
import warnings import warnings
import tqdm
from collections import Iterable from collections import Iterable
from paddle import fluid from paddle import fluid
...@@ -28,6 +29,7 @@ from paddle.fluid.executor import global_scope ...@@ -28,6 +29,7 @@ from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.layers.utils import flatten
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.io import DataLoader, Dataset from paddle.fluid.io import DataLoader, Dataset
...@@ -413,13 +415,7 @@ class StaticGraphAdapter(object): ...@@ -413,13 +415,7 @@ class StaticGraphAdapter(object):
losses = [] losses = []
metrics = [] metrics = []
with fluid.program_guard(prog, self._startup_prog): with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict): ins = self.model._inputs
ins = [
self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self'
]
else:
ins = self.model._inputs
lbls = self.model._labels if self.model._labels else [] lbls = self.model._labels if self.model._labels else []
inputs = [k.forward() for k in to_list(ins)] inputs = [k.forward() for k in to_list(ins)]
labels = [k.forward() for k in to_list(lbls)] labels = [k.forward() for k in to_list(lbls)]
...@@ -587,10 +583,8 @@ class DynamicGraphAdapter(object): ...@@ -587,10 +583,8 @@ class DynamicGraphAdapter(object):
samples = outputs[0].shape[0] samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples >= total_size: if current_count + samples >= total_size:
outputs = [ outputs = [o[:total_size - current_count] for o in outputs]
o[:total_size - metric.count[0]] for o in outputs labels = [l[:total_size - current_count] for l in labels]
]
labels = [l[:total_size - metric.count[0]] for l in labels]
self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + self._merge_count[self.mode +
'_batch'] = total_size - current_count '_batch'] = total_size - current_count
...@@ -612,8 +606,9 @@ class DynamicGraphAdapter(object): ...@@ -612,8 +606,9 @@ class DynamicGraphAdapter(object):
self.mode = 'test' self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)] inputs = [to_variable(x) for x in to_list(inputs)]
outputs = self.model.forward(*inputs) outputs = self.model.forward(*inputs)
if self._nranks > 2: if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace):
outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)] outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)]
return [to_numpy(o) for o in to_list(outputs)] return [to_numpy(o) for o in to_list(outputs)]
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
...@@ -696,7 +691,6 @@ class Model(fluid.dygraph.Layer): ...@@ -696,7 +691,6 @@ class Model(fluid.dygraph.Layer):
self._loss_weights = None self._loss_weights = None
self._optimizer = None self._optimizer = None
self._device = None self._device = None
self._device_ids = None
self._optimizer = None self._optimizer = None
self._test_dataloader = None self._test_dataloader = None
...@@ -794,8 +788,7 @@ class Model(fluid.dygraph.Layer): ...@@ -794,8 +788,7 @@ class Model(fluid.dygraph.Layer):
metrics=None, metrics=None,
inputs=None, inputs=None,
labels=None, labels=None,
device=None, device=None):
device_ids=None):
""" """
FIXME: add comments FIXME: add comments
Args: Args:
...@@ -818,17 +811,6 @@ class Model(fluid.dygraph.Layer): ...@@ -818,17 +811,6 @@ class Model(fluid.dygraph.Layer):
device (str|None): specify device type, 'CPU' or 'GPU'. device (str|None): specify device type, 'CPU' or 'GPU'.
If None, automatically select device according to If None, automatically select device according to
installation package version. installation package version.
device_ids (list[int]|None): specify device index. If None,
the available device will be obtained from the environment
variable when the model is executed: If the GPU is used, the
currently available device ID is obtained from the environment
variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when the
model is executed; CPU, when the model is executed,
the currently available CPU number is obtained from the
environment variable CPU_NUM. For example, export CPU_NUM=4,
if the environment variable is not set, the executor will add
the variable to the environment variable and set its value to 1.
The default is None.
""" """
if isinstance(device, fluid.CUDAPlace) or \ if isinstance(device, fluid.CUDAPlace) or \
...@@ -878,8 +860,10 @@ class Model(fluid.dygraph.Layer): ...@@ -878,8 +860,10 @@ class Model(fluid.dygraph.Layer):
metric.__class__.__name__) metric.__class__.__name__)
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
self._inputs = inputs self._inputs = to_list(inputs) if not isinstance(inputs, dict) else [
self._labels = labels inputs[n] for n in extract_args(self.forward) if n != 'self'
]
self._labels = to_list(labels)
if not in_dygraph_mode(): if not in_dygraph_mode():
self._adapter.prepare() self._adapter.prepare()
...@@ -916,7 +900,7 @@ class Model(fluid.dygraph.Layer): ...@@ -916,7 +900,7 @@ class Model(fluid.dygraph.Layer):
eval_freq (int): The frequency, in number of epochs, an evalutation eval_freq (int): The frequency, in number of epochs, an evalutation
is performed. is performed.
log_freq (int): The frequency, in number of steps, the training logs log_freq (int): The frequency, in number of steps, the training logs
is printed. are printed.
save_dir(str|None): The directory to save checkpoint during training. save_dir(str|None): The directory to save checkpoint during training.
If None, will not save checkpoint. If None, will not save checkpoint.
save_freq (int): The frequency, in number of epochs, to save checkpoint. save_freq (int): The frequency, in number of epochs, to save checkpoint.
...@@ -989,71 +973,256 @@ class Model(fluid.dygraph.Layer): ...@@ -989,71 +973,256 @@ class Model(fluid.dygraph.Layer):
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), ) metrics=self._metrics_name(), )
def _run_one_epoch(data_loader, callbacks, mode):
size = len(data_loader) if hasattr(data_loader,
'__len__') else None
logs = {
'steps': size,
'metrics_name': metrics_name,
}
for step, data in enumerate(data_loader):
if not fluid.in_dygraph_mode():
data = data[0]
batch_size = data[0].shape()[0]
else:
batch_size = data[0].shape[0]
cbks.on_batch_begin(mode, step, logs)
if mode == 'train':
outs = self.train(*data)
else:
outs = self.eval(*data)
# losses
loss = outs[0] if self._metrics else outs
metrics = [[l[0] for l in loss]]
# metrics
for metric in self._metrics:
res = metric.accumulate()
metrics.extend(to_list(res))
assert len(metrics_name) == len(metrics)
for k, v in zip(metrics_name, metrics):
logs[k] = v
logs['step'] = step
if mode == 'train' or self._adapter._merge_count.get(
mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * ParallelEnv().nranks
else:
logs['batch_size'] = self._adapter._merge_count[mode +
'_batch']
cbks.on_batch_end(mode, step, logs)
self._reset_metrics()
return logs
cbks.on_begin('train') cbks.on_begin('train')
for epoch in range(epochs): for epoch in range(epochs):
cbks.on_epoch_begin(epoch)
# FIXME: adapt to DataLoader # FIXME: adapt to DataLoader
loader = train_loader loader = train_loader
if not isinstance(train_loader, Iterable): if not isinstance(train_loader, Iterable):
loader = train_loader() loader = train_loader()
logs = _run_one_epoch(loader, cbks, 'train') logs = self._run_one_epoch(
cbks.on_epoch_end(epoch, logs) loader, cbks, 'train', metrics_name, epoch=epoch)
if do_eval and epoch % eval_freq == 0: if do_eval and epoch % eval_freq == 0:
cbks.on_begin('eval', logs)
# FIXME: adapt to DataLoader # FIXME: adapt to DataLoader
loader = eval_loader loader = eval_loader
if not isinstance(eval_loader, Iterable): if not isinstance(eval_loader, Iterable):
loader = eval_loader() loader = eval_loader()
logs = _run_one_epoch(loader, cbks, 'eval')
eval_steps = len(loader) if hasattr(loader,
'__len__') else None
cbks.on_begin('eval', {
'steps': eval_steps,
'metrics_name': metrics_name
})
logs = self._run_one_epoch(loader, cbks, 'eval', metrics_name)
cbks.on_end('eval', logs) cbks.on_end('eval', logs)
cbks.on_end('train', logs) cbks.on_end('train', logs)
self._test_dataloader = None
def evaluate(
self,
eval_data,
batch_size=1,
log_freq=10,
verbose=2,
num_workers=0,
callbacks=None, ):
"""
FIXME: add more comments and usage
Args:
eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation. An instance of paddle.fluid.io.Dataset or
paddle.fluid.io.Dataloader is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
log_freq (int): The frequency, in number of steps, the eval logs
are printed.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch.
num_workers (int): The number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted.
"""
if fluid.in_dygraph_mode():
feed_list = None
else:
feed_list = [x.forward() for x in self._inputs + self._labels]
if eval_data is not None and isinstance(eval_data, Dataset):
eval_sampler = DistributedBatchSampler(
eval_data, batch_size=batch_size)
eval_loader = DataLoader(
eval_data,
batch_sampler=eval_sampler,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
else:
eval_loader = eval_data
self._test_dataloader = eval_loader
metrics_name = self._metrics_name()
cbks = config_callbacks(
callbacks,
model=self,
log_freq=log_freq,
verbose=verbose,
metrics=self._metrics_name(), )
loader = eval_loader
if not isinstance(eval_loader, Iterable):
loader = eval_loader()
eval_steps = len(loader) if hasattr(loader, '__len__') else None
cbks.on_begin('eval',
{'steps': eval_steps,
'metrics_name': metrics_name})
logs = self._run_one_epoch(loader, cbks, 'eval', metrics_name)
cbks.on_end('eval', logs)
self._test_dataloader = None
eval_result = {}
for k in self._metrics_name():
eval_result[k] = logs[k]
return eval_result
def predict(self, test_data, batch_size=1, num_workers=0):
"""
FIXME: add more comments and usage
Args:
test_data (Dataset|DataLoader): An iterable data loader is used for
predict. An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader
is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
"""
if fluid.in_dygraph_mode():
feed_list = None
else:
feed_list = [x.forward() for x in self._inputs + self._labels]
if test_data is not None and isinstance(test_data, Dataset):
test_sampler = DistributedBatchSampler(
test_data, batch_size=batch_size)
test_loader = DataLoader(
test_data,
batch_sampler=test_sampler,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
else:
test_loader = test_data
self._test_dataloader = test_loader
loader = test_loader
if not isinstance(test_loader, Iterable):
loader = test_loader()
outputs = None
for data in tqdm.tqdm(loader):
if not fluid.in_dygraph_mode():
data = data[0]
outs = self.test(*data)
if outputs is None:
outputs = outs
else:
outputs = [
np.vstack([x, outs[i]]) for i, x in enumerate(outputs)
]
self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \
and isinstance(test_loader, DataLoader):
outputs = [o[:len(test_loader.dataset)] for o in outputs]
return outputs
def set_eval_data(self, eval_data):
"""
Args:
eval_data (Dataset|DataLoader|None): An iterable data loader is used for
eval. An instance of paddle.fluid.io.Dataset or
paddle.fluid.io.Dataloader is recomended.
"""
assert isinstance(
eval_data,
DataLoader), "eval_data must be a instance of Dataloader!"
self._test_dataloader = eval_data
def _run_one_epoch(self,
data_loader,
callbacks,
mode,
metrics_name,
epoch=None):
size = len(data_loader) if hasattr(data_loader, '__len__') else None
logs = {
'steps': size,
'metrics_name': metrics_name,
}
if mode == 'train':
assert epoch is not None, 'when mode is train, epoch must be given'
callbacks.on_epoch_begin(epoch)
for step, data in enumerate(data_loader):
# data might come from different types of data_loader and have
# different format, as following:
# 1. DataLoader in static graph:
# [[input1, input2, ..., label1, lable2, ...]]
# 2. DataLoader in dygraph
# [input1, input2, ..., label1, lable2, ...]
# 3. custumed iterator yield concated inputs and labels:
# [input1, input2, ..., label1, lable2, ...]
# 4. custumed iterator yield seperated inputs and labels:
# ([input1, input2, ...], [label1, lable2, ...])
# To handle all of these, flatten (nested) list to list.
data = flatten(data)
# LoDTensor.shape is callable, where LoDTensor comes from
# DataLoader in static graph
batch_size = data[0].shape()[0] if callable(data[
0].shape) else data[0].shape[0]
callbacks.on_batch_begin(mode, step, logs)
if mode == 'train':
outs = self.train(data[:len(self._inputs)],
data[len(self._inputs):])
else:
outs = self.eval(data[:len(self._inputs)],
data[len(self._inputs):])
# losses
loss = outs[0] if self._metrics else outs
metrics = [[l[0] for l in loss]]
# metrics
for metric in self._metrics:
res = metric.accumulate()
metrics.extend(to_list(res))
assert len(metrics_name) == len(metrics)
for k, v in zip(metrics_name, metrics):
logs[k] = v
logs['step'] = step
if mode == 'train' or self._adapter._merge_count.get(
mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * ParallelEnv().nranks
else:
logs['batch_size'] = self._adapter._merge_count[mode +
'_batch']
callbacks.on_batch_end(mode, step, logs)
self._reset_metrics()
if mode == 'train':
assert epoch is not None, 'when mode is train, epoch must be given'
callbacks.on_epoch_end(epoch)
return logs
def _reset_metrics(self): def _reset_metrics(self):
for metric in self._metrics: for metric in self._metrics:
......
...@@ -107,7 +107,7 @@ class ProgressBar(object): ...@@ -107,7 +107,7 @@ class ProgressBar(object):
eta = time_per_unit * (self._num - current_num) eta = time_per_unit * (self._num - current_num)
if eta > 3600: if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
60, eta % 60) 60, eta % 60)
elif eta > 60: elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60) eta_format = '%d:%02d' % (eta // 60, eta % 60)
else: else:
...@@ -148,7 +148,7 @@ class ProgressBar(object): ...@@ -148,7 +148,7 @@ class ProgressBar(object):
else: else:
info += ' %.4e' % v info += ' %.4e' % v
elif isinstance(v, np.ndarray) and \ elif isinstance(v, np.ndarray) and \
isinstance(v.size, 1) and \ v.size == 1 and \
isinstance(v.dtype, (np.float32, np.float64)): isinstance(v.dtype, (np.float32, np.float64)):
if abs(v[0]) > 1e-3: if abs(v[0]) > 1e-3:
info += ' %.4f' % v[0] info += ' %.4f' % v[0]
......
...@@ -139,6 +139,26 @@ class MyCrossEntropy(Loss): ...@@ -139,6 +139,26 @@ class MyCrossEntropy(Loss):
return [loss1, loss2] return [loss1, loss2]
class TestMnistDataset(MnistDataset):
def __init__(self):
super(TestMnistDataset, self).__init__(mode='test')
def __getitem__(self, idx):
return self.images[idx],
def __len__(self):
return len(self.images)
def get_predict_accuracy(pred, gt):
pred = np.argmax(pred, -1)
gt = np.array(gt)
correct = pred[:, np.newaxis] == gt
return np.sum(correct) / correct.shape[0]
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
def fit(self, dynamic, is_mlp=False): def fit(self, dynamic, is_mlp=False):
device = set_device('gpu') device = set_device('gpu')
...@@ -152,6 +172,7 @@ class TestModel(unittest.TestCase): ...@@ -152,6 +172,7 @@ class TestModel(unittest.TestCase):
train_dataset = MnistDataset(mode='train') train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test') val_dataset = MnistDataset(mode='test')
test_dataset = TestMnistDataset()
model = MNIST() if not is_mlp else MLP() model = MNIST() if not is_mlp else MLP()
optim = fluid.optimizer.Momentum( optim = fluid.optimizer.Momentum(
...@@ -159,12 +180,23 @@ class TestModel(unittest.TestCase): ...@@ -159,12 +180,23 @@ class TestModel(unittest.TestCase):
loss = CrossEntropy() if not is_mlp else MyCrossEntropy() loss = CrossEntropy() if not is_mlp else MyCrossEntropy()
model.prepare(optim, loss, Accuracy(), inputs, labels, device=device) model.prepare(optim, loss, Accuracy(), inputs, labels, device=device)
cbk = ProgBarLogger(50) cbk = ProgBarLogger(50)
model.fit(train_dataset, model.fit(train_dataset,
val_dataset, val_dataset,
epochs=2, epochs=2,
batch_size=batch_size, batch_size=batch_size,
callbacks=cbk) callbacks=cbk)
eval_result = model.evaluate(val_dataset, batch_size=batch_size)
output = model.predict(test_dataset, batch_size=batch_size)
np.testing.assert_equal(output[0].shape[0], len(test_dataset))
acc = get_predict_accuracy(output[0], val_dataset.labels)
np.testing.assert_allclose(acc, eval_result['acc'])
def test_fit_static(self): def test_fit_static(self):
self.fit(False) self.fit(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册