From 6b391617012645c326552c97d4c6022228ce5160 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Mon, 27 Apr 2020 21:25:28 +0800 Subject: [PATCH] only cast when level is O2 --- mindspore/ccsrc/operator/composite/unpack_call.h | 2 -- mindspore/nn/wrap/cell_wrapper.py | 10 +++++++--- mindspore/train/model.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h index 2f39615c1..8c055a938 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.h +++ b/mindspore/ccsrc/operator/composite/unpack_call.h @@ -35,7 +35,6 @@ namespace mindspore { // namespace to support composite operators definition namespace prim { - // Expand the tuple and dict parameters generated when parsing the function call, // and generate positional parameters and key-value pairs for function. class UnpackCall : public MetaFuncGraph { @@ -47,7 +46,6 @@ class UnpackCall : public MetaFuncGraph { friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } }; using UnpackCallPtr = std::shared_ptr; - } // namespace prim } // namespace mindspore diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index de0007c2e..60718ec2b 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -304,15 +304,19 @@ class WithEvalCell(Cell): >>> eval_net = nn.WithEvalCell(net, loss_fn) """ - def __init__(self, network, loss_fn): + def __init__(self, network, loss_fn, add_cast_fp32=False): super(WithEvalCell, self).__init__(auto_prefix=False) self._network = network self._loss_fn = loss_fn + self.add_cast_fp32 = add_cast_fp32 + def construct(self, data, label): outputs = self._network(data) - label = _mp_cast_helper(mstype.float32, label) - loss = self._loss_fn(F.cast(outputs, mstype.float32), label) + if self.add_cast_fp32: + label = _mp_cast_helper(mstype.float32, label) + outputs = F.cast(outputs, mstype.float32) + loss = self._loss_fn(outputs, label) return loss, outputs, label diff --git a/mindspore/train/model.py b/mindspore/train/model.py index fa4a37817..788c740f6 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -162,7 +162,7 @@ class Model: else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") - self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") self._eval_indexes = [0, 1, 2] def _build_predict_network(self): -- GitLab