From 80cc4f0d87aa2291ce45649709f3880c0d84d962 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Mon, 8 Aug 2022 14:17:35 +0800 Subject: [PATCH] [Auto Parallel] Add the C++ ProcessMesh and DistributedMapper (#44963) --- .../distributed/auto_parallel/CMakeLists.txt | 44 ++---- .../auto_parallel/auto_parallel.proto | 32 ++++ .../distributed/auto_parallel/dist_mapper.cc | 146 ++++++++++++++++++ .../distributed/auto_parallel/dist_mapper.h | 73 +++++++++ .../auto_parallel/dist_mapper_test.cc | 72 +++++++++ .../distributed/auto_parallel/process_mesh.cc | 134 ++++++++++++++++ .../distributed/auto_parallel/process_mesh.h | 94 +++++++++++ .../auto_parallel/process_mesh_test.cc | 53 +++++++ 8 files changed, 620 insertions(+), 28 deletions(-) create mode 100644 paddle/fluid/distributed/auto_parallel/dist_mapper.cc create mode 100644 paddle/fluid/distributed/auto_parallel/dist_mapper.h create mode 100644 paddle/fluid/distributed/auto_parallel/dist_mapper_test.cc create mode 100644 paddle/fluid/distributed/auto_parallel/process_mesh.cc create mode 100644 paddle/fluid/distributed/auto_parallel/process_mesh.h create mode 100644 paddle/fluid/distributed/auto_parallel/process_mesh_test.cc diff --git a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt index 192871c73c..49f4547631 100644 --- a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt @@ -7,34 +7,22 @@ cc_test( SRCS device_mesh_test.cc DEPS device_mesh) -# cc_library( -# process_mesh -# SRCS process_mesh.cc -# DEPS auto_parallel_proto) -# cc_test( -# process_mesh_test -# SRCS process_mesh_test.cc -# DEPS process_mesh) - -# cc_library( -# dist_attr -# SRCS dist_attr.cc -# DEPS process_mesh auto_parallel_proto proto_desc) -# cc_test( -# dist_attr_test -# SRCS dist_attr_test.cc -# DEPS dist_attr) +cc_library( + process_mesh + SRCS process_mesh.cc + DEPS auto_parallel_proto) +cc_test( + process_mesh_test + SRCS process_mesh_test.cc + DEPS process_mesh) -# cc_library( -# dist_mapper -# SRCS dist_mapper.cc -# DEPS device_mesh auto_parallel_proto) -# cc_test( -# dist_mapper_test -# SRCS dist_mapper_test.cc -# DEPS dist_mapper) +cc_library( + dist_mapper + SRCS dist_mapper.cc + DEPS device_mesh auto_parallel_proto) +cc_test( + dist_mapper_test + SRCS dist_mapper_test.cc + DEPS dist_mapper) proto_library(auto_parallel_proto SRCS auto_parallel.proto) - -# cc_library(auto_parallel DEPS process_mesh device_mesh dist_attr dist_mapper -# auto_parallel_proto) diff --git a/paddle/fluid/distributed/auto_parallel/auto_parallel.proto b/paddle/fluid/distributed/auto_parallel/auto_parallel.proto index 5625737c44..1413e80a8a 100644 --- a/paddle/fluid/distributed/auto_parallel/auto_parallel.proto +++ b/paddle/fluid/distributed/auto_parallel/auto_parallel.proto @@ -16,6 +16,20 @@ syntax = "proto2"; package paddle.distributed.auto_parallel; +// ProcessMesh is used to organize processes and like n-dimension array. +message ProcessMeshProto { + // The size of each dimension. + repeated int64 shape = 1; + + // These process ids are stored by a row-major way. + // There are no duplicate process ids within one process mesh. + repeated int64 process_ids = 2; + + // The name of each dimension. + repeated string dim_names = 3; + +} + // This proto describes the capability of one device such as the computation and memory. message DeviceCapabilityProto { optional double single_precision_flops = 1; @@ -86,3 +100,21 @@ message DeviceMeshProto { // The links are between devices. repeated LinkProto links = 6; } + +// Record the mapping between the logical processes and the physical devices. +message DistributedMapperProto { + // The device meshes used by this distributed computation, + // which may be shared by different multiple device meshes. + repeated DeviceMeshProto device_meshes = 1; + + message MapperEntryProto { + optional int64 process_id = 1; + optional string device_mesh_name = 2; + repeated int64 device_ids = 3; + } + + // The mapping from process ids to device ids. + // It is also possible for one process to use multiple devices. + // It is possible for one device shared by multiple processes. + repeated MapperEntryProto process_id_to_device_ids = 2; +} diff --git a/paddle/fluid/distributed/auto_parallel/dist_mapper.cc b/paddle/fluid/distributed/auto_parallel/dist_mapper.cc new file mode 100644 index 0000000000..d099560452 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/dist_mapper.cc @@ -0,0 +1,146 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h" +#include "paddle/fluid/distributed/auto_parallel/utils.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +void DistributedMapper::set_process_id_to_device_ids( + const std::map>>& + process_id_to_device_ids) { + std::vector device_mesh_names; + for (const auto& item : device_meshes_) { + device_mesh_names.push_back(item.first); + } + for (const auto& item : process_id_to_device_ids) { + PADDLE_ENFORCE_GE( + item.first, + 0, + platform::errors::InvalidArgument( + "The process id %d must be greater than or equal to 0.", + item.first)); + std::string device_mesh_name = item.second.first; + const std::vector& device_ids = item.second.second; + PADDLE_ENFORCE_EQ( + device_meshes_.count(device_mesh_name), + 1, + platform::errors::InvalidArgument( + "Cannot find the device mesh %d in device_mesh ids [%s].", + device_mesh_name, + str_join(device_mesh_names))); + PADDLE_ENFORCE_EQ( + has_duplicates(device_ids), + false, + platform::errors::InvalidArgument( + "The mapped device ids [%s] of process_mesh %d must be unique.", + str_join(device_ids), + item.first)); + const DeviceMesh& device_mesh = device_meshes_[device_mesh_name]; + const std::vector cur_device_ids = device_mesh.device_ids(); + for (int64_t device_id : device_ids) { + bool found = + std::find(cur_device_ids.begin(), cur_device_ids.end(), device_id) != + cur_device_ids.end(); + PADDLE_ENFORCE_EQ( + found, + true, + platform::errors::InvalidArgument( + "The device id %d cannot be find in the device mesh [%s].", + device_id, + str_join(cur_device_ids))); + } + } + process_id_to_device_ids_ = process_id_to_device_ids; +} + +DistributedMapper DistributedMapper::from_proto( + const DistributedMapperProto& proto) { + DistributedMapper dist_mapper; + for (int64_t i = 0; i < proto.device_meshes_size(); ++i) { + dist_mapper.device_meshes_[proto.device_meshes(i).name()] = + DeviceMesh::from_proto(proto.device_meshes(i)); + } + for (int64_t i = 0; i < proto.process_id_to_device_ids_size(); ++i) { + int64_t process_id = proto.process_id_to_device_ids(i).process_id(); + std::string device_mesh_name = + proto.process_id_to_device_ids(i).device_mesh_name(); + std::vector device_ids; + int64_t num_devices = proto.process_id_to_device_ids(i).device_ids_size(); + for (int64_t j = 0; j < num_devices; ++j) { + device_ids.push_back(proto.process_id_to_device_ids(i).device_ids(j)); + } + dist_mapper.process_id_to_device_ids_[process_id].first = device_mesh_name; + dist_mapper.process_id_to_device_ids_[process_id].second = device_ids; + } + return dist_mapper; +} + +DistributedMapperProto DistributedMapper::to_proto() const { + DistributedMapperProto proto; + for (const auto& item : device_meshes_) { + proto.mutable_device_meshes()->Add()->CopyFrom(item.second.to_proto()); + } + for (const auto& outer : process_id_to_device_ids_) { + auto proto_item = proto.mutable_process_id_to_device_ids()->Add(); + proto_item->set_process_id(outer.first); + proto_item->set_device_mesh_name(outer.second.first); + for (const auto& inner : outer.second.second) { + proto_item->add_device_ids(inner); + } + } + return proto; +} + +std::string DistributedMapper::to_string() const { + std::string mapper_str = "{device_meshes: ["; + for (const auto& item : device_meshes_) { + mapper_str += item.second.to_string() + ", "; + } + mapper_str.replace(mapper_str.size() - 2, 2, "]"); + + mapper_str += "\nprocess_id_to_device_ids: ["; + for (const auto& item : process_id_to_device_ids_) { + mapper_str += "{"; + mapper_str += + "process_id: " + std::to_string(item.first) + ", device_ids: ["; + for (const auto& device_id : item.second.second) { + mapper_str += + "{" + item.second.first + ", " + std::to_string(device_id) + "}, "; + } + mapper_str.replace(mapper_str.size() - 2, 2, "]"); + mapper_str += "}, "; + } + mapper_str.replace(mapper_str.size() - 2, 2, "]"); + mapper_str += "}"; + return mapper_str; +} + +bool operator==(const DistributedMapper& lhs, const DistributedMapper& rhs) { + if (lhs.device_meshes() != rhs.device_meshes()) { + return false; + } + if (lhs.process_id_to_device_ids() != rhs.process_id_to_device_ids()) { + return false; + } + return true; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/dist_mapper.h b/paddle/fluid/distributed/auto_parallel/dist_mapper.h new file mode 100644 index 0000000000..bd7f9790ad --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/dist_mapper.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include + +#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h" +#include "paddle/fluid/distributed/auto_parallel/device_mesh.h" +#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +class DistributedMapper { + public: + DistributedMapper() = default; + + const std::map& device_meshes() const { + return device_meshes_; + } + + const DeviceMesh& device_mesh(const std::string& name) const { + return device_meshes_.at(name); + } + + void add_device_mesh(const DeviceMesh& device_mesh) { + device_meshes_[device_mesh.name()] = device_mesh; + } + + const std::map>>& + process_id_to_device_ids() const { + return process_id_to_device_ids_; + } + + void set_process_id_to_device_ids( + const std::map>>& + process_id_to_device_ids); + + // DistributedMapper from_string(const std::string& mapper_str); + std::string to_string() const; + + static DistributedMapper from_proto(const DistributedMapperProto& proto); + DistributedMapperProto to_proto() const; + + private: + std::map device_meshes_; + std::map>> + process_id_to_device_ids_; +}; + +bool operator==(const DistributedMapper& lhs, const DistributedMapper& rhs); + +inline std::ostream& operator<<(std::ostream& os, + const DistributedMapper& obj) { + os << obj.to_string(); + return os; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/dist_mapper_test.cc b/paddle/fluid/distributed/auto_parallel/dist_mapper_test.cc new file mode 100644 index 0000000000..d427b9cbb0 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/dist_mapper_test.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h" +#include +#include +#include "gtest/gtest.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +TEST(DistributedMapper, Ctor) { + std::vector shape = {2, 3}; + std::vector device_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + std::string device_type = "GPU"; + int64_t size = shape[0] * shape[1]; + + DeviceMesh device_mesh("device_mesh", shape, device_ids, dim_names); + for (int64_t i = 0; i < shape[0]; ++i) { + for (int64_t j = 0; j < shape[1]; ++j) { + int64_t global_id = i * shape[1] + j; + int64_t local_id = j; + int64_t machine_id = i; + device_mesh.add_device( + Device(global_id, local_id, machine_id, device_type)); + } + } + for (int64_t i = 0; i < size; ++i) { + for (int64_t j = 0; j < size; ++j) { + device_mesh.add_link(Link(i, j, "NVL")); + } + } + + DistributedMapper dist_mapper; + dist_mapper.add_device_mesh(device_mesh); + std::map>> + process_id_to_device_ids; + process_id_to_device_ids[0] = {"device_mesh", {5}}; + process_id_to_device_ids[1] = {"device_mesh", {4}}; + process_id_to_device_ids[2] = {"device_mesh", {3}}; + process_id_to_device_ids[3] = {"device_mesh", {2}}; + process_id_to_device_ids[4] = {"device_mesh", {1}}; + process_id_to_device_ids[5] = {"device_mesh", {0}}; + dist_mapper.set_process_id_to_device_ids(process_id_to_device_ids); + + EXPECT_EQ(dist_mapper.device_meshes().at("device_mesh"), device_mesh); + EXPECT_EQ(dist_mapper.device_mesh("device_mesh"), device_mesh); + EXPECT_EQ(dist_mapper.process_id_to_device_ids(), process_id_to_device_ids); + std::stringstream sstream; + sstream << dist_mapper; + EXPECT_EQ(sstream.str(), dist_mapper.to_string()); + auto proto = dist_mapper.to_proto(); + DistributedMapper new_dist_mapper = DistributedMapper::from_proto(proto); + EXPECT_EQ(dist_mapper, new_dist_mapper); +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/process_mesh.cc b/paddle/fluid/distributed/auto_parallel/process_mesh.cc new file mode 100644 index 0000000000..dda2873768 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/process_mesh.cc @@ -0,0 +1,134 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" +#include "paddle/fluid/distributed/auto_parallel/utils.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +ProcessMesh::ProcessMesh(const std::vector &shape, + const std::vector &process_ids, + const std::vector &dim_names) { + shape_ = shape; + int64_t size = this->size(); + PADDLE_ENFORCE_EQ( + size, + process_ids.size(), + platform::errors::InvalidArgument("The size of this process mesh must be " + "equal to the size of its process ids.", + size, + process_ids.size())); + PADDLE_ENFORCE_EQ( + has_duplicates(process_ids), + false, + platform::errors::InvalidArgument("The process ids [%s] must be unique.", + str_join(process_ids_))); + process_ids_ = process_ids; + + PADDLE_ENFORCE_EQ(shape_.size(), + dim_names.size(), + platform::errors::InvalidArgument( + "The size of mesh shape must be equal to the size " + "of the dimension names.", + shape_.size(), + dim_names_.size())); + PADDLE_ENFORCE_EQ(has_duplicates(dim_names), + false, + platform::errors::InvalidArgument( + "The names [%s] of each dimension must be unique.", + str_join(dim_names))); + dim_names_ = dim_names; +} + +int64_t ProcessMesh::size() const { + if (shape_.empty()) return 0; + int64_t size = 1; + for (const int64_t dim_size : shape_) size *= dim_size; + return size; +} + +bool ProcessMesh::contains(int64_t process_id) const { + auto result = + std::find(std::begin(process_ids_), std::end(process_ids_), process_id); + if (result != std::end(process_ids_)) { + return true; + } else { + return false; + } +} + +std::string ProcessMesh::to_string() const { + std::string mesh_str = "{shape: [" + str_join(shape_) + "], "; + mesh_str += "process_ids: [" + str_join(process_ids_) + "], "; + mesh_str += "dim_names: [" + str_join(dim_names_) + "]}"; + return mesh_str; +} + +ProcessMesh ProcessMesh::from_proto(const ProcessMeshProto &proto) { + ProcessMesh mesh; + + mesh.shape_.resize(proto.shape_size()); + for (int64_t i = 0; i < proto.shape_size(); ++i) { + mesh.shape_[i] = proto.shape(i); + } + + mesh.process_ids_.resize(proto.process_ids_size()); + for (int64_t i = 0; i < proto.process_ids_size(); ++i) { + mesh.process_ids_[i] = proto.process_ids(i); + } + + mesh.dim_names_.resize(proto.dim_names_size()); + for (int64_t i = 0; i < proto.dim_names_size(); ++i) { + mesh.dim_names_[i] = proto.dim_names(i); + } + + return mesh; +} + +ProcessMeshProto ProcessMesh::to_proto() const { + ProcessMeshProto proto; + + for (const auto &i : shape_) { + proto.add_shape(i); + } + + for (const auto &i : process_ids_) { + proto.add_process_ids(i); + } + + for (const auto &i : dim_names_) { + proto.add_dim_names(i); + } + + return proto; +} + +bool operator==(const ProcessMesh &lhs, const ProcessMesh &rhs) { + if (lhs.shape() != rhs.shape()) { + return false; + } + if (lhs.process_ids() != rhs.process_ids()) { + return false; + } + return true; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/process_mesh.h b/paddle/fluid/distributed/auto_parallel/process_mesh.h new file mode 100644 index 0000000000..2652a8f606 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/process_mesh.h @@ -0,0 +1,94 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h" +#include "paddle/fluid/distributed/auto_parallel/device_mesh.h" +#include "paddle/fluid/distributed/auto_parallel/utils.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +class ProcessMesh { + public: + ProcessMesh() = default; + + ProcessMesh(const std::vector& shape, + const std::vector& process_ids, + const std::vector& dim_names); + + const std::vector& shape() const { return shape_; } + + const std::vector& process_ids() const { return process_ids_; } + + const std::vector& dim_names() const { return dim_names_; } + + int64_t size() const; + + int64_t ndim() const { return shape_.size(); } + + int64_t dim_size(int64_t dim) const { + int64_t cdim = canonical_dim(dim, shape_.size()); + return shape_[cdim]; + } + + int64_t dim_size(const std::string& dim_name) const { + for (std::size_t i = 0; i < dim_names_.size(); ++i) { + if (dim_names_[i] == dim_name) { + return shape_[i]; + } + } + PADDLE_THROW(platform::errors::InvalidArgument( + "Cannot find the dimension of %s in this process mesh.", dim_name)); + } + + bool empty() const { return (shape_.empty() || process_ids_.empty()); } + bool contains(int64_t process_id) const; + + // ProcessMesh from_string(const std::string& mesh_str); + std::string to_string() const; + + static ProcessMesh from_proto(const ProcessMeshProto& proto); + ProcessMeshProto to_proto() const; + + private: + std::vector shape_; + std::vector process_ids_; + std::vector dim_names_; +}; + +inline std::ostream& operator<<(std::ostream& os, const ProcessMesh& obj) { + os << obj.to_string(); + return os; +} + +bool operator==(const ProcessMesh& lhs, const ProcessMesh& rhs); + +inline bool operator!=(const ProcessMesh& lhs, const ProcessMesh& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/process_mesh_test.cc b/paddle/fluid/distributed/auto_parallel/process_mesh_test.cc new file mode 100644 index 0000000000..9dbcc5ea2d --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/process_mesh_test.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" +#include +#include +#include "gtest/gtest.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +TEST(ProcessMesh, Ctor) { + std::vector shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + int64_t size = shape[0] * shape[1]; + ProcessMesh process_mesh(shape, process_ids, dim_names); + EXPECT_EQ(process_mesh.shape(), shape); + EXPECT_EQ(process_mesh.process_ids(), process_ids); + EXPECT_EQ(process_mesh.dim_names()[0], "x"); + EXPECT_EQ(process_mesh.dim_names()[1], "y"); + EXPECT_EQ(process_mesh.size(), size); + EXPECT_EQ(process_mesh.ndim(), static_cast(shape.size())); + EXPECT_EQ(process_mesh.dim_size(0), shape[0]); + EXPECT_EQ(process_mesh.dim_size(-1), shape[1]); + EXPECT_EQ(process_mesh.dim_size("x"), shape[0]); + EXPECT_EQ(process_mesh.dim_size("y"), shape[1]); + EXPECT_EQ(process_mesh.empty(), false); + EXPECT_EQ(process_mesh.contains(0), true); + EXPECT_EQ(process_mesh.contains(6), false); + std::stringstream sstream; + sstream << process_mesh; + EXPECT_EQ(sstream.str(), process_mesh.to_string()); + auto proto = process_mesh.to_proto(); + ProcessMesh new_process_mesh = ProcessMesh::from_proto(proto); + EXPECT_EQ(process_mesh, new_process_mesh); +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle -- GitLab