提交 8f1984a8 编写于 作者: W Wei Luning

only cast when level is O2

上级 3dd369ce
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
namespace mindspore { namespace mindspore {
// namespace to support composite operators definition // namespace to support composite operators definition
namespace prim { namespace prim {
// Expand the tuple and dict parameters generated when parsing the function call, // Expand the tuple and dict parameters generated when parsing the function call,
// and generate positional parameters and key-value pairs for function. // and generate positional parameters and key-value pairs for function.
class UnpackCall : public MetaFuncGraph { class UnpackCall : public MetaFuncGraph {
...@@ -47,7 +46,6 @@ class UnpackCall : public MetaFuncGraph { ...@@ -47,7 +46,6 @@ class UnpackCall : public MetaFuncGraph {
friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; }
}; };
using UnpackCallPtr = std::shared_ptr<UnpackCall>; using UnpackCallPtr = std::shared_ptr<UnpackCall>;
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore
......
...@@ -300,6 +300,10 @@ void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) { ...@@ -300,6 +300,10 @@ void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) {
// save the graph to file in protobuf format // save the graph to file in protobuf format
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
if (phase_s.empty()) {
MS_LOG(ERROR) << "`phase` is empty '" << phase_s << "'!";
return;
}
std::string name_prefix = phase_s.substr(0, phase_s.find(".")); std::string name_prefix = phase_s.substr(0, phase_s.find("."));
std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb"; std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb";
std::string filename = GetFilePathName(pb_filename); std::string filename = GetFilePathName(pb_filename);
......
...@@ -304,15 +304,19 @@ class WithEvalCell(Cell): ...@@ -304,15 +304,19 @@ class WithEvalCell(Cell):
>>> eval_net = nn.WithEvalCell(net, loss_fn) >>> eval_net = nn.WithEvalCell(net, loss_fn)
""" """
def __init__(self, network, loss_fn): def __init__(self, network, loss_fn, add_cast_fp32=False):
super(WithEvalCell, self).__init__(auto_prefix=False) super(WithEvalCell, self).__init__(auto_prefix=False)
self._network = network self._network = network
self._loss_fn = loss_fn self._loss_fn = loss_fn
self.add_cast_fp32 = add_cast_fp32
def construct(self, data, label): def construct(self, data, label):
outputs = self._network(data) outputs = self._network(data)
label = _mp_cast_helper(mstype.float32, label) if self.add_cast_fp32:
loss = self._loss_fn(F.cast(outputs, mstype.float32), label) label = _mp_cast_helper(mstype.float32, label)
outputs = F.cast(outputs, mstype.float32)
loss = self._loss_fn(outputs, label)
return loss, outputs, label return loss, outputs, label
......
...@@ -162,7 +162,7 @@ class Model: ...@@ -162,7 +162,7 @@ class Model:
else: else:
if self._loss_fn is None: if self._loss_fn is None:
raise ValueError("loss_fn can not be None.") raise ValueError("loss_fn can not be None.")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
self._eval_indexes = [0, 1, 2] self._eval_indexes = [0, 1, 2]
def _build_predict_network(self): def _build_predict_network(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册