diff --git a/paddle/fluid/operators/distributed/collective_server_test.cc b/paddle/fluid/operators/distributed/collective_server_test.cc index 5009058422b81d3187f7792bf7bf56db1d03f4d6..90f2f9fd65bf1b8c1edda6a2ebe0ce5288ddcb5d 100644 --- a/paddle/fluid/operators/distributed/collective_server_test.cc +++ b/paddle/fluid/operators/distributed/collective_server_test.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/distributed/collective_client.h" #include "paddle/fluid/operators/distributed/collective_server.h" @@ -57,7 +58,7 @@ std::unique_ptr GenerateVars(platform::Place place) { auto* tensor = slr->mutable_value(); auto* rows = slr->mutable_rows(); - tensor->Resize(framework::make_ddim({20000, 1024})); + tensor->Resize(framework::make_ddim({3, 1024})); tensor->mutable_data(place); paddle::operators::math::set_constant(ctx, tensor, 32.7); @@ -80,6 +81,20 @@ void Gather(const std::vector& vars, std::vector dst; client->Gather(vars, &dst, *dev_ctx, scope); std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]); + dev_ctx->Wait(); + + ASSERT_EQ(dst[0]->value().dims(), framework::make_ddim({3, 1024})); + ASSERT_EQ(dst[0]->height(), 20000); + ASSERT_EQ(dst[0]->rows().size(), static_cast(3)); + for (int i = 0; i < 3; i++) { + ASSERT_EQ(dst[0]->rows()[i], i); + } + + std::vector vec; + TensorToVector(dst[0]->value(), *dev_ctx, &vec); + for (size_t i = 0; i < 3 * 1024; i++) { + ASSERT_FLOAT_EQ(vec[i], 32.7); + } } TEST(CollectiveServer, GPU) {