提交 f947c153 编写于 作者: Y yangyaming

Consider multiple levels of LoD.

上级 66ae0a8c
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/operators/array_operator.h" #include "paddle/operators/array_operator.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
...@@ -49,12 +50,24 @@ class ShrinkRNNMemoryOp : public ArrayOp { ...@@ -49,12 +50,24 @@ class ShrinkRNNMemoryOp : public ArrayOp {
// should consider multiple levels // should consider multiple levels
size_t height = dst_num_rows; size_t height = dst_num_rows;
auto lod_level = lod_rank_table.level(); auto lod_level = rank_table.level();
if (x_tensor.lod().size() > lod_level && if (x_tensor.lod().size() > lod_level &&
x_tensor.lod()[lod_level].size() < dst_num_rows) { x_tensor.lod()[lod_level].size() > static_cast<size_t>(dst_num_rows)) {
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset( auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(
x_tensor.lod(), 0, dst_num_rows + 1, lod_level); x_tensor.lod(), 0, dst_num_rows, lod_level);
height = lod_offset.second.second; height = lod_offset.second.second;
auto out_lod = out_tensor.mutable_lod();
auto x_lod = x_tensor.lod();
out_lod->reserve(lod_level + lod_offset.first.size());
for (size_t i = 0; i < lod_level; ++i) {
out_lod->emplace_back(x_lod.at(i));
}
framework::LoD remain;
framework::AppendLoD(&remain, lod_offset.first);
for (size_t j = 0; j < remain.size(); ++j) {
out_lod->emplace_back(remain[j]);
}
} }
if (dst_num_rows != 0) { if (dst_num_rows != 0) {
......
...@@ -29,7 +29,9 @@ class TestShrinkRNNMemory(unittest.TestCase): ...@@ -29,7 +29,9 @@ class TestShrinkRNNMemory(unittest.TestCase):
tensor_np = numpy.random.random(size=(6, 100)).astype('float32') tensor_np = numpy.random.random(size=(6, 100)).astype('float32')
tensor.set(tensor_np, cpu) tensor.set(tensor_np, cpu)
exe = Executor(cpu) exe = Executor(cpu)
outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]) outs = exe.run(feed={'x': tensor},
fetch_list=[mem1, mem2, mem3],
return_numpy=False)
self.assertTrue(numpy.allclose(tensor_np[0:6], outs[0])) self.assertTrue(numpy.allclose(tensor_np[0:6], outs[0]))
self.assertTrue(numpy.allclose(tensor_np[0:5], outs[1])) self.assertTrue(numpy.allclose(tensor_np[0:5], outs[1]))
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[2])) self.assertTrue(numpy.allclose(tensor_np[0:2], outs[2]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册