未验证 提交 f4dec5cd 编写于 作者: G gongweibao 提交者: GitHub

Check collective server's data. (#15449)

上级 58727e8e
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/distributed/collective_client.h" #include "paddle/fluid/operators/distributed/collective_client.h"
#include "paddle/fluid/operators/distributed/collective_server.h" #include "paddle/fluid/operators/distributed/collective_server.h"
...@@ -57,7 +58,7 @@ std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) { ...@@ -57,7 +58,7 @@ std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) {
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows(); auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({20000, 1024})); tensor->Resize(framework::make_ddim({3, 1024}));
tensor->mutable_data<float>(place); tensor->mutable_data<float>(place);
paddle::operators::math::set_constant(ctx, tensor, 32.7); paddle::operators::math::set_constant(ctx, tensor, 32.7);
...@@ -80,6 +81,20 @@ void Gather(const std::vector<distributed::RemoteVar>& vars, ...@@ -80,6 +81,20 @@ void Gather(const std::vector<distributed::RemoteVar>& vars,
std::vector<const framework::SelectedRows*> dst; std::vector<const framework::SelectedRows*> dst;
client->Gather(vars, &dst, *dev_ctx, scope); client->Gather(vars, &dst, *dev_ctx, scope);
std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]); std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]);
dev_ctx->Wait();
ASSERT_EQ(dst[0]->value().dims(), framework::make_ddim({3, 1024}));
ASSERT_EQ(dst[0]->height(), 20000);
ASSERT_EQ(dst[0]->rows().size(), static_cast<size_t>(3));
for (int i = 0; i < 3; i++) {
ASSERT_EQ(dst[0]->rows()[i], i);
}
std::vector<float> vec;
TensorToVector(dst[0]->value(), *dev_ctx, &vec);
for (size_t i = 0; i < 3 * 1024; i++) {
ASSERT_FLOAT_EQ(vec[i], 32.7);
}
} }
TEST(CollectiveServer, GPU) { TEST(CollectiveServer, GPU) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册