提交 e1c8f248 编写于 作者: W Wei Luning

Fix the output is not tuple, when eval

上级 c9fba7f0
......@@ -14,15 +14,23 @@
# ============================================================================
"""Cell_wrapper."""
import copy
import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ...ops import composite as C, functional as F, operations as P
from ...common import Tensor, dtype as mstype
from ..cell import Cell
from ...common import Tensor
from ...common import dtype as mstype
from ...common.initializer import initializer
from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C
from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite.base import _mp_cast_helper
from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell
from .grad_reducer import DistributedGradReducer
......@@ -310,8 +318,8 @@ class WithEvalCell(Cell):
def construct(self, data, label):
outputs = self._network(data)
loss = self._loss_fn(outputs, label)
label = _mp_cast_helper(mstype.float32, label)
loss = self._loss_fn(F.cast(outputs, mstype.float32), label)
return loss, outputs, label
......
......@@ -24,7 +24,7 @@ from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
from ..nn.metrics import Loss
from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ..common import dtype as mstype
......@@ -130,7 +130,7 @@ class Model:
self._loss_fn,
level=self._amp_level)
elif self._loss_fn:
network = WithLossCell(network, self._loss_fn)
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None
return network
......@@ -150,10 +150,7 @@ class Model:
else:
if self._loss_fn is None:
raise ValueError("loss_fn can not be None.")
if self._optimizer:
self._eval_network = self._train_network.network
else:
self._eval_network = WithEvalCell(self._network, self._loss_fn)
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn)
self._eval_indexes = [0, 1, 2]
def _clear_metrics(self):
......@@ -263,7 +260,7 @@ class Model:
dataset_helper = DatasetHelper(train_dataset)
# remove later to deal with loop sink
if need_wrap:
self._train_network = DataWrapper(self._train_network, *(dataset_helper.types_shapes()),
self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()),
train_dataset.__ME_INITED__)
cb_params.train_network = self._train_network
self._train_network.set_train()
......@@ -429,7 +426,7 @@ class Model:
# remove later to deal with loop sink
if need_wrap:
self._eval_network = DataWrapper(self._eval_network, *(dataset_helper.types_shapes()),
self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()),
valid_dataset.__ME_INITED__)
self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval'
......
......@@ -14,12 +14,15 @@
# ============================================================================
""" auto mixed precision """
import numpy as np
import pytest
from mindspore import amp
from mindspore import nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.context as context
from mindspore.model_zoo.resnet import resnet50
from mindspore.train import Model
from ....dataset_mock import MindData
def setup_module(module):
......@@ -85,3 +88,52 @@ def test_amp_o0_loss():
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_network = amp.build_train_network(net, optimizer, loss)
output = train_network(inputs, label)
class MindDataSet(MindData):
def __init__(self, dataset_types, dataset_shapes):
super(MindDataSet, self).__init__(size=2, batch_size=32,
np_types=dataset_types,
output_shapes=dataset_shapes,
input_indexs=(0, 1))
def __next__(self):
if self._size < self._iter_num:
raise StopIteration
self._iter_num += 1
next = []
for shape, type in zip(self._output_shapes, self._np_types):
next.append(Tensor(np.ones(shape).astype(type)))
return tuple(next)
def test_compile_model_train_O0():
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))
dataset = MindDataSet(dataset_types, dataset_shapes)
net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O0")
model.train(2, dataset, dataset_sink_mode=False)
with pytest.raises(ValueError):
# not actual run, the metrics step will fail, check if compile ok.
model.eval(dataset)
def test_compile_model_train_O2():
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))
dataset = MindDataSet(dataset_types, dataset_shapes)
net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
model.train(2, dataset, dataset_sink_mode=False)
with pytest.raises(ValueError):
# not actual run, the metrics step will fail, check if compile ok.
model.eval(dataset)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册