diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h index 2f39615c1a66b70c7721799a0fcf7cf39313bbff..8c055a938649c506c6f4c0decaa8ff9c88691ad6 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.h +++ b/mindspore/ccsrc/operator/composite/unpack_call.h @@ -35,7 +35,6 @@ namespace mindspore { // namespace to support composite operators definition namespace prim { - // Expand the tuple and dict parameters generated when parsing the function call, // and generate positional parameters and key-value pairs for function. class UnpackCall : public MetaFuncGraph { @@ -47,7 +46,6 @@ class UnpackCall : public MetaFuncGraph { friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } }; using UnpackCallPtr = std::shared_ptr; - } // namespace prim } // namespace mindspore diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index de0007c2ebe16fcb8be74869356493b0b45c07e2..60718ec2b112ecde40f8fa6306643d9bbd1b5818 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -304,15 +304,19 @@ class WithEvalCell(Cell): >>> eval_net = nn.WithEvalCell(net, loss_fn) """ - def __init__(self, network, loss_fn): + def __init__(self, network, loss_fn, add_cast_fp32=False): super(WithEvalCell, self).__init__(auto_prefix=False) self._network = network self._loss_fn = loss_fn + self.add_cast_fp32 = add_cast_fp32 + def construct(self, data, label): outputs = self._network(data) - label = _mp_cast_helper(mstype.float32, label) - loss = self._loss_fn(F.cast(outputs, mstype.float32), label) + if self.add_cast_fp32: + label = _mp_cast_helper(mstype.float32, label) + outputs = F.cast(outputs, mstype.float32) + loss = self._loss_fn(outputs, label) return loss, outputs, label diff --git a/mindspore/train/model.py b/mindspore/train/model.py index fa4a37817119edae32e88abe6f59e2b06d2e25c0..788c740f6077b54ee277195a924572d580147717 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -162,7 +162,7 @@ class Model: else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") - self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") self._eval_indexes = [0, 1, 2] def _build_predict_network(self):