Created by: wanghuancoder
PR types
Others
PR changes
Others
Describe
一、目的 检查用户定义的占位Var,应该都有相应的feed data,如果没有则报错。
二、背景 对于如下demo,由于没有feed Var y的数据,Paddle会崩溃。
x = fluid.data(name='x', shape=[-1, 13], dtype='float32')
y = fluid.data(name='y', shape=[-1, 1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_loss)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
main_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
exe.run(startup_program)
a = np.random.random((32, 13)).astype("float32")
b = np.random.random((32, 1)).astype("float32")
avg_loss_value, = exe.run(main_program,
#feed={ "x" : a, "y" : b}, #没有feed y
feed={ "x" : a},
fetch_list=[avg_loss])
崩溃信息如下,因为没有检查feed data,导致Paddle后知后觉。进而报错位置不准确,影响用户排查问题。
三、解决方法
在executor.py run时,prune剪枝逻辑后,启动Executor、PE、inference Run之前。检查所有通过fluid.data()定义的Var,是否都有相应的feed data。
修改后报错信息如下:
有2个问题:
问题1:
如果有一个占位Var z。是个孤点,如果没有feed data,原则上也应该报错。
处理思路:
1、如果是孤点,prune会删除掉的,因此报错检查放在prune后面。
2、可能存在大量用户,定义了孤点、没有打开prune。因此,报错检查暂时在prune开启情况下检查。后续,应该择机全面检查。
问题2:
怎么找到占位Var?
处理思路:
1、通过代码阅读,通过fluid.data()定义的Var,有很多特定属性:persistable(False)、type(LOD_TENSOR)、need_check_feed(True)、_stop_gradient(True)、is_data(True)、belong_to_optimizer(False)。通过全文检索,满足这样属性的只有占位Var。
2、通过fluid.layers.data()定义的占位Var,是没有need_check_feed(True)属性的,因为目前主推fluid.data(),不建议用户使用fluid.layers.data()了,所以,没有对fluid.layers.data()的占位符进行检查。
3、这里按理说只需要persistable和_stop_gradient就够了(如下图),但为了防止“错杀”尽可能将判别属性写全。
四、测试