From 91f63cd40128dc74e1fe37e0ffaa072af22c10bb Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 29 Jul 2018 19:59:12 +0800 Subject: [PATCH] fix split_ids_op and add unit test --- paddle/fluid/operators/split_ids_op.h | 12 ++-- .../tests/unittests/test_split_ids_op.py | 60 +++++++++++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h index d263426e07..86a3eaa5c4 100644 --- a/paddle/fluid/operators/split_ids_op.h +++ b/paddle/fluid/operators/split_ids_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" @@ -68,9 +69,11 @@ class SplitIdsOpKernel : public framework::OpKernel { auto outs = ctx.MultiOutput("Out"); const size_t shard_num = outs.size(); // get rows for outputs - for (auto &id : ids_rows) { - size_t shard_id = static_cast(id) % shard_num; - outs[shard_id]->mutable_rows()->push_back(id); + std::map id_to_index; + for (size_t i = 0; i < ids_rows.size(); ++i) { + id_to_index[ids_rows[i]] = i; + size_t shard_id = static_cast(ids_rows[i]) % shard_num; + outs[shard_id]->mutable_rows()->push_back(ids_rows[i]); } int64_t row_width = ids_dims[1]; @@ -80,7 +83,8 @@ class SplitIdsOpKernel : public framework::OpKernel { {static_cast(out->rows().size()), row_width}); T *output = out->mutable_value()->mutable_data(ddim, place); for (int64_t i = 0; i < ddim[0]; ++i) { - memcpy(output + i * row_width, ids + out->rows()[i] * row_width, + memcpy(output + i * row_width, + ids + id_to_index[out->rows()[i]] * row_width, row_width * sizeof(T)); } } diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py index e9f0a06a56..adf3345f1d 100644 --- a/python/paddle/fluid/tests/unittests/test_split_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -15,6 +15,8 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator class TestSplitIdsOp(OpTest): @@ -31,5 +33,63 @@ class TestSplitIdsOp(OpTest): self.check_output() +class TestSpliteIds(unittest.TestCase): + def get_places(self): + places = [core.CPUPlace()] + return places + + def test_check_output(self): + for place in self.get_places(): + self.check_with_place(place) + + def check_with_place(self, place): + scope = core.Scope() + rows = [0, 5, 7, 4, 9] + height = 20 + row_numel = 2 + + # initialize input variable X + x = scope.var('X').get_selected_rows() + x.set_rows(rows) + x.set_height(height) + np_array = np.ones((len(rows), row_numel)).astype("float32") + for i in range(len(rows)): + for j in range(row_numel): + np_array[i, j] = rows[i] + j + x_tensor = x.get_tensor() + x_tensor.set(np_array, place) + + outs_name = ["out%d" % i for i in xrange(3)] + outs = [ + scope.var(var_name).get_selected_rows() for var_name in outs_name + ] + + # expected output selected rows + expected_out0_rows = [0, 9] + expected_out1_rows = [7, 4] + expected_out2_rows = [5] + + op = Operator("split_ids", Ids="X", Out=outs_name) + + op.run(scope, place) + + self.assertEqual(outs[0].rows(), expected_out0_rows) + self.assertEqual(outs[1].rows(), expected_out1_rows) + self.assertEqual(outs[2].rows(), expected_out2_rows) + + self.assertAlmostEqual(0.0, np.array(outs[0].get_tensor())[0, 0]) + self.assertAlmostEqual(1.0, np.array(outs[0].get_tensor())[0, 1]) + self.assertAlmostEqual(9.0, np.array(outs[0].get_tensor())[1, 0]) + self.assertAlmostEqual(10.0, np.array(outs[0].get_tensor())[1, 1]) + + self.assertAlmostEqual(7.0, np.array(outs[1].get_tensor())[0, 0]) + self.assertAlmostEqual(8.0, np.array(outs[1].get_tensor())[0, 1]) + self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 0]) + self.assertAlmostEqual(5.0, np.array(outs[1].get_tensor())[1, 1]) + + self.assertAlmostEqual(5.0, np.array(outs[2].get_tensor())[0, 0]) + self.assertAlmostEqual(6.0, np.array(outs[2].get_tensor())[0, 1]) + + if __name__ == '__main__': unittest.main() -- GitLab