未验证 提交 b0cf1fe3 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #12430 from jacquesqiao/add-test-for-split-ids-op

Add test for split ids op
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
......@@ -67,10 +68,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size();
for (auto &out : outs) {
out->mutable_rows()->clear();
}
// get rows for outputs
for (auto &id : ids_rows) {
size_t shard_id = static_cast<size_t>(id) % shard_num;
outs[shard_id]->mutable_rows()->push_back(id);
std::unordered_map<int64_t, size_t> id_to_index;
for (size_t i = 0; i < ids_rows.size(); ++i) {
id_to_index[ids_rows[i]] = i;
size_t shard_id = static_cast<size_t>(ids_rows[i]) % shard_num;
outs[shard_id]->mutable_rows()->push_back(ids_rows[i]);
}
int64_t row_width = ids_dims[1];
......@@ -80,7 +86,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
{static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
memcpy(output + i * row_width,
ids + id_to_index[out->rows()[i]] * row_width,
row_width * sizeof(T));
}
}
......
......@@ -15,6 +15,8 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestSplitIdsOp(OpTest):
......@@ -31,5 +33,55 @@ class TestSplitIdsOp(OpTest):
self.check_output()
class TestSpliteIds(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
return places
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
rows = [0, 5, 7, 4, 9]
height = 20
row_numel = 2
# initialize input variable X
x = scope.var('X').get_selected_rows()
x.set_rows(rows)
x.set_height(height)
np_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
for j in range(row_numel):
np_array[i, j] = rows[i] + j
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
outs_name = ["out%d" % i for i in xrange(3)]
outs = [
scope.var(var_name).get_selected_rows() for var_name in outs_name
]
# expected output selected rows
expected_out_rows = [[0, 9], [7, 4], [5]]
op = Operator("split_ids", Ids="X", Out=outs_name)
for _ in range(3):
op.run(scope, place)
for i in range(len(outs)):
expected_rows = expected_out_rows[i]
self.assertEqual(outs[i].rows(), expected_rows)
for j in range(len(expected_rows)):
row = expected_rows[j]
self.assertAlmostEqual(
float(row), np.array(outs[i].get_tensor())[j, 0])
self.assertAlmostEqual(
float(row + 1), np.array(outs[i].get_tensor())[j, 1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册