提交 0a8ff2ec 编写于 作者: Q Qiao Longfei

add cpu_merge_add_multi_noduplicated_test test=develop

上级 920a9609
...@@ -360,6 +360,69 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { ...@@ -360,6 +360,69 @@ TEST(selected_rows_functor, cpu_merge_add_multi) {
} }
} }
TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
float>
set_const;
int64_t height = 10;
int64_t row_numel = 8;
std::vector<int64_t> rows1{1, 3, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
cpu_place);
set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 2, 4, 6, 8};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
cpu_place);
set_const(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
output->set_height(height);
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float>
merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get());
EXPECT_EQ(output->height(), height);
EXPECT_EQ(output->value().dims(),
paddle::framework::make_ddim({10, row_numel}));
std::vector<int64_t> ret_rows{1, 3, 5, 7, 9, 0, 2, 4, 6, 8};
EXPECT_EQ(output->rows(), ret_rows);
auto* out_data = output->value().data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
float data_value = 0;
if (i < 5) {
data_value = 1.0;
} else {
data_value = 2.0;
}
for (size_t j = 0; j < static_cast<size_t>(row_numel); ++j) {
EXPECT_EQ(out_data[i * row_numel + j], data_value);
}
}
}
TEST(selected_rows_functor, cpu_sum_to) { TEST(selected_rows_functor, cpu_sum_to) {
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place); paddle::platform::CPUDeviceContext ctx(cpu_place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册