Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !19070

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Enhance the error message when GrapOpMaker is null. !19070

  • Report abuse
!19070 已合并 8月 08, 2019 由 saxon_zh@saxon_zh 创建
#<User:0x00007f0ef64674e8>
  • 概览 4
  • 提交 3
  • 变更 1

Created by: Xreki

If user uses a op which does not have grad_op, and forgets to set stop_gradient to True, the train process will crash.

The error message is:

Traceback (most recent call last):
  File "./train_gpu_paddle.py", line 471, in <module>
    train(n_token, cutoffs)
  File "./train_gpu_paddle.py", line 254, in train
    decr_ratio=FLAGS.decr_ratio)
  File "/home/dingsiyu/project/python/transformer-xl-paddlepaddle/optimization.py", line 151, in optimization
    _, param_grads = optimizer.minimize(loss)
  File "<decorator-gen-20>", line 2, in minimize
  File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/wrapped_decorator.py", line 25, in __impl__
    return wrapped_func(*args, **kwargs)
  File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/dygraph/base.py", line 88, in __impl__
    return func(*args, **kwargs)
  File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/optimizer.py", line 593, in minimize
    no_grad_set=no_grad_set)
  File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/optimizer.py", line 493, in backward
    no_grad_set, callbacks)
  File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/backward.py", line 570, in append_backward
    input_grad_names_set=input_grad_names_set)
  File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/backward.py", line 310, in _append_backward_ops_
    op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
paddle.fluid.core_avx.EnforceNotMet: grad_op_maker_ should not be null
Operator GradOpMaker has not been registered. at [/paddle/paddle/fluid/framework/op_info.h:69]
PaddlePaddle Call Stacks: 
0       0x7ff346440948p void paddle::platform::EnforceNotMet::Init<std::string>(std::string, char const*, int) + 360
1       0x7ff346440c97p paddle::platform::EnforceNotMet::EnforceNotMet(std::string const&, char const*, int) + 87
2       0x7ff346441c5cp paddle::framework::OpInfo::GradOpMaker() const + 108
3       0x7ff34643935ep
4       0x7ff346472de6p
5       0x7ff381de7744p _PyCFunction_FastCallDict + 340
6       0x7ff381e7593cp
7       0x7ff381e99a7ap _PyEval_EvalFrameDefault + 762

There is no useful information for users to fix. After this PR, error message is:

Traceback (most recent call last):
  File "./train_gpu_paddle.py", line 471, in <module>
    train(n_token, cutoffs)
  File "./train_gpu_paddle.py", line 254, in train
    decr_ratio=FLAGS.decr_ratio)
  File "/work/Paddle/build_paddle/test/users/transformer-xl-paddlepaddle-lyq/optimization.py", line 151, in optimization
    _, param_grads = optimizer.minimize(loss)
  File "</usr/local/lib/python2.7/dist-packages/decorator.pyc:decorator-gen-20>", line 2, in minimize
  File "/usr/local/lib/python2.7/dist-packages/paddle/fluid/wrapped_decorator.py", line 25, in __impl__
    return wrapped_func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/paddle/fluid/dygraph/base.py", line 86, in __impl__
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/paddle/fluid/optimizer.py", line 594, in minimize
    no_grad_set=no_grad_set)
  File "/usr/local/lib/python2.7/dist-packages/paddle/fluid/optimizer.py", line 493, in backward
    no_grad_set, callbacks)
  File "/usr/local/lib/python2.7/dist-packages/paddle/fluid/backward.py", line 699, in append_backward
    input_grad_names_set=input_grad_names_set)
  File "/usr/local/lib/python2.7/dist-packages/paddle/fluid/backward.py", line 432, in _append_backward_ops_
    op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
paddle.fluid.core_avx.EnforceNotMet: grad_op_maker_ should not be null 
Operator range's GradOpMaker has not been registered.                                                                                                                                        
Please check whether range_op has grad_op.
If not, please set stop_gradient to True for its input and output variables using var.step_gradient=True. at [/paddle/paddle/fluid/framework/op_info.h:75]
PaddlePaddle Call Stacks: 
0       0x7f6ffa355188p void paddle::platform::EnforceNotMet::Init<std::string>(std::string, char const*, int) + 360
1       0x7f6ffa3554d7p paddle::platform::EnforceNotMet::EnforceNotMet(std::string const&, char const*, int) + 87 
2       0x7f6ffa356ad9p paddle::framework::OpInfo::GradOpMaker() const + 233
3       0x7f6ffa34db9ep
4       0x7f6ffa3881e6p
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!19070
Source branch: github/fork/Xreki/core_enhance_error_message
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7