提交 0f753ee2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4592 fix optimizer tuple inpus issue

Merge pull request !4592 from wangqiuliang/resolve-optimizer-tuple-inputs-issue
...@@ -223,6 +223,15 @@ class Cell: ...@@ -223,6 +223,15 @@ class Cell:
else: else:
object.__delattr__(self, name) object.__delattr__(self, name)
def cast_inputs(self, inputs, dst_type):
res = list()
for item in inputs:
if isinstance(item, tuple):
res.append(self.cast_inputs(item, dst_type))
else:
res.append(cast(item, dst_type))
return tuple(res)
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
if context.get_context("mode") == context.GRAPH_MODE: if context.get_context("mode") == context.GRAPH_MODE:
if kwargs: if kwargs:
...@@ -250,14 +259,10 @@ class Cell: ...@@ -250,14 +259,10 @@ class Cell:
cast_inputs = list() cast_inputs = list()
if hasattr(self, "_mindspore_flags"): if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'): if self._mindspore_flags.get('fp16'):
for item in inputs: cast_inputs = self.cast_inputs(inputs, mstype.float16)
cast_inputs.append(cast(item, mstype.float16))
if self._mindspore_flags.get('fp32'): if self._mindspore_flags.get('fp32'):
for item in inputs: cast_inputs = self.cast_inputs(inputs, mstype.float32)
cast_inputs.append(cast(item, mstype.float32)) if not cast_inputs:
if cast_inputs:
cast_inputs = tuple(cast_inputs)
else:
cast_inputs = inputs cast_inputs = inputs
if self.enable_hook: if self.enable_hook:
output = self._hook_construct(*cast_inputs, **kwargs) output = self._hook_construct(*cast_inputs, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册