提交 66ae0a8c 编写于 作者: Y yangyaming

Enhence shrink_rnn_memory_op.

上级 25af35d8
......@@ -46,8 +46,19 @@ class ShrinkRNNMemoryOp : public ArrayOp {
auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
// should consider multiple levels
size_t height = dst_num_rows;
auto lod_level = lod_rank_table.level();
if (x_tensor.lod().size() > lod_level &&
x_tensor.lod()[lod_level].size() < dst_num_rows) {
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(
x_tensor.lod(), 0, dst_num_rows + 1, lod_level);
height = lod_offset.second.second;
}
if (dst_num_rows != 0) {
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows));
out_tensor.ShareDataWith(x_tensor.Slice(0, height));
}
}
};
......
......@@ -26,13 +26,13 @@ class TestShrinkRNNMemory(unittest.TestCase):
cpu = core.CPUPlace()
tensor = core.LoDTensor()
tensor.set_lod([[0, 2, 5, 6]])
tensor_np = numpy.random.random(size=(3, 100)).astype('float32')
tensor_np = numpy.random.random(size=(6, 100)).astype('float32')
tensor.set(tensor_np, cpu)
exe = Executor(cpu)
outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3])
self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0]))
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1]))
self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2]))
self.assertTrue(numpy.allclose(tensor_np[0:6], outs[0]))
self.assertTrue(numpy.allclose(tensor_np[0:5], outs[1]))
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[2]))
mem3_mean = layers.mean(x=mem3)
append_backward(loss=mem3_mean)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册