From e1c8f248e0fc1d86b5877bb4e3ef9229fe2a1909 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Tue, 14 Apr 2020 13:18:33 +0800 Subject: [PATCH] Fix the output is not tuple, when eval --- mindspore/nn/wrap/cell_wrapper.py | 20 ++++++++---- mindspore/train/model.py | 13 +++----- tests/ut/python/train/test_amp.py | 52 +++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 14 deletions(-) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 53a535781..64c382557 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -14,15 +14,23 @@ # ============================================================================ """Cell_wrapper.""" import copy + import numpy as np + +from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, + _get_parallel_mode) from mindspore.train.parallel_utils import ParallelMode -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean -from ...ops import composite as C, functional as F, operations as P -from ...common import Tensor, dtype as mstype -from ..cell import Cell + +from ...common import Tensor +from ...common import dtype as mstype from ...common.initializer import initializer from ...common.parameter import Parameter, ParameterTuple +from ...ops import composite as C +from ...ops import functional as F +from ...ops import operations as P +from ...ops.composite.base import _mp_cast_helper from ...ops.operations.comm_ops import _VirtualDataset +from ..cell import Cell from .grad_reducer import DistributedGradReducer @@ -310,8 +318,8 @@ class WithEvalCell(Cell): def construct(self, data, label): outputs = self._network(data) - loss = self._loss_fn(outputs, label) - + label = _mp_cast_helper(mstype.float32, label) + loss = self._loss_fn(F.cast(outputs, mstype.float32), label) return loss, outputs, label diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 7604b8ac3..46e4f421f 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -24,7 +24,7 @@ from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper from ..nn.metrics import Loss -from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell +from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from .parallel_utils import ParallelMode from ..common import dtype as mstype @@ -130,7 +130,7 @@ class Model: self._loss_fn, level=self._amp_level) elif self._loss_fn: - network = WithLossCell(network, self._loss_fn) + network = nn.WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None return network @@ -150,10 +150,7 @@ class Model: else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") - if self._optimizer: - self._eval_network = self._train_network.network - else: - self._eval_network = WithEvalCell(self._network, self._loss_fn) + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) self._eval_indexes = [0, 1, 2] def _clear_metrics(self): @@ -263,7 +260,7 @@ class Model: dataset_helper = DatasetHelper(train_dataset) # remove later to deal with loop sink if need_wrap: - self._train_network = DataWrapper(self._train_network, *(dataset_helper.types_shapes()), + self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()), train_dataset.__ME_INITED__) cb_params.train_network = self._train_network self._train_network.set_train() @@ -429,7 +426,7 @@ class Model: # remove later to deal with loop sink if need_wrap: - self._eval_network = DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), + self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), valid_dataset.__ME_INITED__) self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' diff --git a/tests/ut/python/train/test_amp.py b/tests/ut/python/train/test_amp.py index 1a26c2177..2afb1e00b 100644 --- a/tests/ut/python/train/test_amp.py +++ b/tests/ut/python/train/test_amp.py @@ -14,12 +14,15 @@ # ============================================================================ """ auto mixed precision """ import numpy as np +import pytest from mindspore import amp from mindspore import nn from mindspore import Tensor from mindspore.common import dtype as mstype import mindspore.context as context from mindspore.model_zoo.resnet import resnet50 +from mindspore.train import Model +from ....dataset_mock import MindData def setup_module(module): @@ -85,3 +88,52 @@ def test_amp_o0_loss(): optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) train_network = amp.build_train_network(net, optimizer, loss) output = train_network(inputs, label) + + +class MindDataSet(MindData): + def __init__(self, dataset_types, dataset_shapes): + super(MindDataSet, self).__init__(size=2, batch_size=32, + np_types=dataset_types, + output_shapes=dataset_shapes, + input_indexs=(0, 1)) + def __next__(self): + if self._size < self._iter_num: + raise StopIteration + self._iter_num += 1 + next = [] + for shape, type in zip(self._output_shapes, self._np_types): + next.append(Tensor(np.ones(shape).astype(type))) + return tuple(next) + + +def test_compile_model_train_O0(): + dataset_types = (np.float32, np.float32) + dataset_shapes = ((16, 16), (16, 16)) + + dataset = MindDataSet(dataset_types, dataset_shapes) + + net = NetNoLoss(16, 16) + loss = nn.MSELoss() + optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + + model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O0") + model.train(2, dataset, dataset_sink_mode=False) + with pytest.raises(ValueError): + # not actual run, the metrics step will fail, check if compile ok. + model.eval(dataset) + +def test_compile_model_train_O2(): + dataset_types = (np.float32, np.float32) + dataset_shapes = ((16, 16), (16, 16)) + + dataset = MindDataSet(dataset_types, dataset_shapes) + + net = NetNoLoss(16, 16) + loss = nn.MSELoss() + optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + + model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") + model.train(2, dataset, dataset_sink_mode=False) + with pytest.raises(ValueError): + # not actual run, the metrics step will fail, check if compile ok. + model.eval(dataset) -- GitLab