diff --git a/oneflow/python/nn/optimizer/optimizer.py b/oneflow/python/nn/optimizer/optimizer.py index d7049632499984564e2645c4f50e4d3af89abe0d..35983448af8e7a0e5adb0759579ff99bab31b490 100644 --- a/oneflow/python/nn/optimizer/optimizer.py +++ b/oneflow/python/nn/optimizer/optimizer.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import warnings from typing import Dict, Callable, Union, Any, Iterator from types import GeneratorType @@ -71,10 +72,20 @@ class Optimizer(object): raise NotImplementedError() def zero_grad(self, set_to_none: bool = False): + all_grad_is_none = True for param_group in self._param_groups: for param in param_group.parameters: - if set_to_none: - param.grad = None - else: - param.grad.fill_(0) - # param.grad.zeros_() + if param.grad is not None: + all_grad_is_none = False + if set_to_none: + param.grad = None + else: + param.grad.fill_(0) + # param.grad.zeros_() + if all_grad_is_none: + # TODO: delete this after implementing Tensor.data + warnings.warn( + "\nParameters in optimizer do not have gradient.\n" + "Please check `loss.backward()` is called or not,\n" + "or try to declare optimizer after calling `module.to()`" + )