未验证 提交 7e9902aa 编写于 作者: Y Yinggang Wang 提交者: GitHub

Add warning when no param update (#4896)

* style(Optim): add warning when no param update

* style(Optim): add TODO
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 fb164b58
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import warnings
from typing import Dict, Callable, Union, Any, Iterator
from types import GeneratorType
......@@ -71,10 +72,20 @@ class Optimizer(object):
raise NotImplementedError()
def zero_grad(self, set_to_none: bool = False):
all_grad_is_none = True
for param_group in self._param_groups:
for param in param_group.parameters:
if set_to_none:
param.grad = None
else:
param.grad.fill_(0)
# param.grad.zeros_()
if param.grad is not None:
all_grad_is_none = False
if set_to_none:
param.grad = None
else:
param.grad.fill_(0)
# param.grad.zeros_()
if all_grad_is_none:
# TODO: delete this after implementing Tensor.data
warnings.warn(
"\nParameters in optimizer do not have gradient.\n"
"Please check `loss.backward()` is called or not,\n"
"or try to declare optimizer after calling `module.to()`"
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册