diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 12ddf1a0f8ee47cce9f084f0056329270a28b0ed..8547501e1b33a0585c16ede7b166d0ef9c03ad58 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -395,9 +395,10 @@ class StateDictHook(object): for key in state_dict: param = state_dict[key] with paddle.fluid.dygraph.guard(): - param_applied = paddle.cast(param, self._save_dtype) - param_applied.name = param.name - state_dict[key] = param_applied + if paddle.is_floating_point(param): + param_applied = paddle.cast(param, self._save_dtype) + param_applied.name = param.name + state_dict[key] = param_applied @dygraph_only