提交 f947c153 编写于 作者: Y yangyaming

Consider multiple levels of LoD.

上级 66ae0a8c
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/operators/array_operator.h"
#include "paddle/operators/math/math_function.h"
......@@ -49,12 +50,24 @@ class ShrinkRNNMemoryOp : public ArrayOp {
// should consider multiple levels
size_t height = dst_num_rows;
auto lod_level = lod_rank_table.level();
auto lod_level = rank_table.level();
if (x_tensor.lod().size() > lod_level &&
x_tensor.lod()[lod_level].size() < dst_num_rows) {
x_tensor.lod()[lod_level].size() > static_cast<size_t>(dst_num_rows)) {
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(
x_tensor.lod(), 0, dst_num_rows + 1, lod_level);
x_tensor.lod(), 0, dst_num_rows, lod_level);
height = lod_offset.second.second;
auto out_lod = out_tensor.mutable_lod();
auto x_lod = x_tensor.lod();
out_lod->reserve(lod_level + lod_offset.first.size());
for (size_t i = 0; i < lod_level; ++i) {
out_lod->emplace_back(x_lod.at(i));
}
framework::LoD remain;
framework::AppendLoD(&remain, lod_offset.first);
for (size_t j = 0; j < remain.size(); ++j) {
out_lod->emplace_back(remain[j]);
}
}
if (dst_num_rows != 0) {
......
......@@ -29,7 +29,9 @@ class TestShrinkRNNMemory(unittest.TestCase):
tensor_np = numpy.random.random(size=(6, 100)).astype('float32')
tensor.set(tensor_np, cpu)
exe = Executor(cpu)
outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3])
outs = exe.run(feed={'x': tensor},
fetch_list=[mem1, mem2, mem3],
return_numpy=False)
self.assertTrue(numpy.allclose(tensor_np[0:6], outs[0]))
self.assertTrue(numpy.allclose(tensor_np[0:5], outs[1]))
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[2]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册