提交 91f63cd4 编写于 作者: Q qiaolongfei

fix split_ids_op and add unit test

上级 02c31458
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -68,9 +69,11 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -68,9 +69,11 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size(); const size_t shard_num = outs.size();
// get rows for outputs // get rows for outputs
for (auto &id : ids_rows) { std::map<int64_t, size_t> id_to_index;
size_t shard_id = static_cast<size_t>(id) % shard_num; for (size_t i = 0; i < ids_rows.size(); ++i) {
outs[shard_id]->mutable_rows()->push_back(id); id_to_index[ids_rows[i]] = i;
size_t shard_id = static_cast<size_t>(ids_rows[i]) % shard_num;
outs[shard_id]->mutable_rows()->push_back(ids_rows[i]);
} }
int64_t row_width = ids_dims[1]; int64_t row_width = ids_dims[1];
...@@ -80,7 +83,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -80,7 +83,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
{static_cast<int64_t>(out->rows().size()), row_width}); {static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place); T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (int64_t i = 0; i < ddim[0]; ++i) { for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width, memcpy(output + i * row_width,
ids + id_to_index[out->rows()[i]] * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestSplitIdsOp(OpTest): class TestSplitIdsOp(OpTest):
...@@ -31,5 +33,63 @@ class TestSplitIdsOp(OpTest): ...@@ -31,5 +33,63 @@ class TestSplitIdsOp(OpTest):
self.check_output() self.check_output()
class TestSpliteIds(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
return places
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
rows = [0, 5, 7, 4, 9]
height = 20
row_numel = 2
# initialize input variable X
x = scope.var('X').get_selected_rows()
x.set_rows(rows)
x.set_height(height)
np_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
for j in range(row_numel):
np_array[i, j] = rows[i] + j
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
outs_name = ["out%d" % i for i in xrange(3)]
outs = [
scope.var(var_name).get_selected_rows() for var_name in outs_name
]
# expected output selected rows
expected_out0_rows = [0, 9]
expected_out1_rows = [7, 4]
expected_out2_rows = [5]
op = Operator("split_ids", Ids="X", Out=outs_name)
op.run(scope, place)
self.assertEqual(outs[0].rows(), expected_out0_rows)
self.assertEqual(outs[1].rows(), expected_out1_rows)
self.assertEqual(outs[2].rows(), expected_out2_rows)
self.assertAlmostEqual(0.0, np.array(outs[0].get_tensor())[0, 0])
self.assertAlmostEqual(1.0, np.array(outs[0].get_tensor())[0, 1])
self.assertAlmostEqual(9.0, np.array(outs[0].get_tensor())[1, 0])
self.assertAlmostEqual(10.0, np.array(outs[0].get_tensor())[1, 1])
self.assertAlmostEqual(7.0, np.array(outs[1].get_tensor())[0, 0])
self.assertAlmostEqual(8.0, np.array(outs[1].get_tensor())[0, 1])
self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 0])
self.assertAlmostEqual(5.0, np.array(outs[1].get_tensor())[1, 1])
self.assertAlmostEqual(5.0, np.array(outs[2].get_tensor())[0, 0])
self.assertAlmostEqual(6.0, np.array(outs[2].get_tensor())[0, 1])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册