Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • Issue
  • #15567

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 1月 29, 2019 by saxon_zh@saxon_zhGuest

while_op里的梯度回传问题

Created by: Ashleychen

def train_program(self):
    image_lod = fluid.layers.data(name = 'image_lod', dtype = 'float32',
        shape = [784], lod_level = 1)
    label_lod = fluid.layers.data(name = 'label_lod', dtype = 'int',
        shape = [1], lod_level = 1)
    last_label_lod = fluid.layers.data(name = 'last_label_lod', dtype = 'float32',
        shape = [1], lod_level = 1)
    image_lod_rank_table = fluid.layers.control_flow.lod_rank_table(image_lod)
    label_lod_rank_table = fluid.layers.control_flow.lod_rank_table(label_lod)
    last_label_lod_rank_table = fluid.layers.control_flow.lod_rank_table(last_label_lod)
    image_array = lod_tensor_to_array(x = image_lod, table = image_lod_rank_table)
    label_array = lod_tensor_to_array(x = label_lod, table = label_lod_rank_table)
    last_label_array = lod_tensor_to_array(x = last_label_lod, table = last_label_lod_rank_table)
    loss_array = fluid.layers.create_array('float32')
    array_len = fluid.layers.fill_constant(
        shape = [1], dtype = 'int64', value = self.list_size)
    counter = fluid.layers.zeros(shape = [1], dtype = 'int64')
    cond = fluid.layers.less_than(x = counter, y = array_len)
    while_op = fluid.layers.While(cond = cond)
    with while_op.block():
        current_image = fluid.layers.array_read(array = image_array, i = counter)
        current_label = fluid.layers.array_read(array = label_array, i = counter)
        current_label_reshape = fluid.layers.reshape(x = current_label, shape = [-1, 1])
        current_last_label = fluid.layers.array_read(array = last_label_array, i = counter)
        current_last_label_reshape = fluid.layers.reshape(
            x = current_last_label, shape = [-1, 1])
        image_fc = fluid.layers.fc(input = current_image, size = 20)
        last_label_fc = fluid.layers.fc(input = current_last_label_reshape, size = 20)
        loss_input = fluid.layers.elementwise_add(
            x = image_fc, y = last_label_fc)
        current_loss = fluid.layers.softmax_with_cross_entropy(
            logits = loss_input, label = current_label_reshape)
        current_loss_val = fluid.layers.reduce_sum(current_loss, dim = 0)
        #loss_array = fluid.layers.array_write(current_loss_val, i = counter)
        fluid.layers.array_write(current_loss_val, array = loss_array, i = counter)
        fluid.layers.increment(x = counter, value = 1, in_place = True)
        fluid.layers.less_than(x = counter, y = array_len, cond = cond)
    loss_lod = array_to_lod_tensor(x = loss_array, table = image_lod_rank_table)
    print('%s' % loss_lod)
    loss = fluid.layers.reduce_sum(loss_lod, dim = 0)
    print('%s' % loss)
    return loss
def ntm_main():
    # 该模型运行在单个CPU上
    use_cuda = False # set to True if training with GPU
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    batch_size = 128
    list_size = 50
    train_generator = OmniglotGenerator(
        data_file='data/train.npz',
        nb_classes=5,
        nb_samples_per_class=10,
        batchsize=batch_size,
        max_iter=None,
        xp=np)
    test_generator = OmniglotGenerator(
        data_file='data/test.npz',
        nb_classes=5,
        nb_samples_per_class=10,
        batchsize=batch_size,
        max_iter=10,
        xp=np)
    ntm = Ntm(nb_class = 5, nb_reads = 4, input_size = 28 * 28, cell_size = 200,
        memory_shape = (128, 40), gamma = 0.95, batch_size = 128, list_size = 50)
    loss = ntm.train_program()
    optimizer = fluid.optimizer.Adam(learning_rate=0.001)
    optimizer.minimize(loss)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    main_program = fluid.default_main_program()
    test_program = fluid.default_main_program().clone(for_test=True)
    # create lod-tensor
    for i, (images, labels) in train_generator:
        images_list = []
        last_label_list = []
        label_list = []
        reshape_image_list = []
        reshape_last_label_list = []
        reshape_label_list = []
        for sample_idx in xrange(len(images)):
            # get list of arrys 128 * 784
            images_items = np.split(images[sample_idx], batch_size)
            images_list.append(images_items)
            label_list.append(list(labels[sample_idx]))
        last_label_list.append([0.0] * batch_size)
        for sample_idx in xrange(len(images) - 1):
            last_label_list.append(list(labels[sample_idx]))
        for batch_idx in xrange(batch_size):
            for sample_idx in xrange(len(images)):
                reshape_image_list.append(list(images_list[sample_idx][batch_idx]))
                reshape_last_label_list.append(
                    last_label_list[sample_idx][batch_idx])
                reshape_label_list.append(
                    label_list[sample_idx][batch_idx])
        image_lod = fluid.create_lod_tensor(
            np.array(reshape_image_list),
            [[50] * 128], place)
        label_lod = fluid.create_lod_tensor(
            np.array(reshape_label_list, dtype = 'int'),
            [[50] * 128], place)
        last_label_lod = fluid.create_lod_tensor(
            np.array(reshape_last_label_list, dtype = 'float32'),
            [[50] * 128], place)
        exe.run(
            main_program,
            feed = {
                'image_lod': image_lod,
                'label_lod': label_lod,
                'last_label_lod': last_label_lod},
            fetch_list = [loss])

上面两段是我的训练program的代码,主要是feed进去3个lod tensor,首先我将这3个lod tensor转成array,然后在while_op里面遍历每一个时间步的数据,然后搭model layers,之后得到每一个时间步的loss,将每一个时间步的loss用array write写到array里面。最后跳出while循环,将loss array转成lod tensor,并且用reduce sum求出所有时间步的loss总和返回给optimizer。 在运行的时候报错:

Traceback (most recent call last):
  File "train_omniglot.py", line 213, in <module>
    ntm_main()
  File "train_omniglot.py", line 40, in ntm_main
    optimizer.minimize(loss)
  File "/home/ol/anaconda2/lib/python2.7/site-packages/paddle/fluid/optimizer.py", line 259, in minimize
    [error_clip_callback])
  File "/home/ol/anaconda2/lib/python2.7/site-packages/paddle/fluid/backward.py", line 590, in append_backward
    _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
  File "/home/ol/anaconda2/lib/python2.7/site-packages/paddle/fluid/backward.py", line 412, in _append_backward_vars_
    _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
  File "/home/ol/anaconda2/lib/python2.7/site-packages/paddle/fluid/backward.py", line 426, in _append_backward_vars_
    op_desc.infer_shape(block.desc)
paddle.fluid.core.EnforceNotMet: Input(Out@GRAD) shouldn't be null. at [/paddle/paddle/fluid/operators/reshape_op.cc:314]
PaddlePaddle Call Stacks:
0       0x7f43d73dee36p paddle::platform::EnforceNotMet::EnforceNotMet(std::__exception_ptr::exception_ptr, char const*, int) + 486
1       0x7f43d8370e41p paddle::operators::Reshape2GradOp::InferShape(paddle::framework::InferShapeContext*) const + 913
2       0x7f43d7494f86p paddle::framework::OpDesc::InferShape(paddle::framework::BlockDesc const&) const + 886
3       0x7f43d7441275p void pybind11::cpp_function::initialize<pybind11::cpp_function::initialize<void, paddle::framework::OpDesc, paddle::framework::BlockDesc const&, pybind11::name, pybind11::is_method, pybind11::sibling>(void (paddle::framework::OpDesc::*)(paddle::framework::BlockDesc const&) const, pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::OpDesc const*, paddle::framework::BlockDesc const&)#1}, void, paddle::framework::OpDesc const*, paddle::framework::BlockDesc const&, pybind11::name, pybind11::is_method, pybind11::sibling>(pybind11::cpp_function::initialize<void, paddle::framework::OpDesc, paddle::framework::BlockDesc const&, pybind11::name, pybind11::is_method, pybind11::sibling>(void (paddle::framework::OpDesc::*)(paddle::framework::BlockDesc const&) const, pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::OpDesc const*, paddle::framework::BlockDesc const&)#1}&&, void (*)(paddle::framework::OpDesc const*, paddle::framework::BlockDesc const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) + 213
4       0x7f43d73f5544p pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 2596
5       0x7f449e40feecp PyEval_EvalFrameEx + 33468
6       0x7f449e4114e9p PyEval_EvalCodeEx + 2025
7       0x7f449e40e482p PyEval_EvalFrameEx + 26706
8       0x7f449e4114e9p PyEval_EvalCodeEx + 2025
9       0x7f449e40e482p PyEval_EvalFrameEx + 26706
10      0x7f449e4114e9p PyEval_EvalCodeEx + 2025
11      0x7f449e40e482p PyEval_EvalFrameEx + 26706
12      0x7f449e4114e9p PyEval_EvalCodeEx + 2025
13      0x7f449e40e482p PyEval_EvalFrameEx + 26706
14      0x7f449e40fdacp PyEval_EvalFrameEx + 33148
15      0x7f449e4114e9p PyEval_EvalCodeEx + 2025
16      0x7f449e41170ap PyEval_EvalCode + 26
17      0x7f449e42a93dp
18      0x7f449e42bab8p PyRun_FileExFlags + 120
19      0x7f449e42ccd8p PyRun_SimpleFileExFlags + 232
20      0x7f449e43ed3cp Py_Main + 2988
21      0x7f449d677bd5p __libc_start_main + 245
22      0x7f449e70b87fp

不知道是什么原因,是说我传回去的loss是空的?

指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle#15567
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7