未验证 提交 d99cb2e1 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add C++ DeviceMesh (#44951)

* [Auto Parallel] Add ProcessMesh, DeviceMesh and DistributedMapper

* [Auto Parallel] Remove unecessary codes

* [Auto Parallel] Comment out unnecessary cmake statements
上级 839d8bb3
......@@ -47,3 +47,4 @@ add_subdirectory(ps)
add_subdirectory(test)
add_subdirectory(index_dataset)
add_subdirectory(fleet_executor)
add_subdirectory(auto_parallel)
cc_library(
device_mesh
SRCS device_mesh.cc
DEPS auto_parallel_proto)
cc_test(
device_mesh_test
SRCS device_mesh_test.cc
DEPS device_mesh)
# cc_library(
# process_mesh
# SRCS process_mesh.cc
# DEPS auto_parallel_proto)
# cc_test(
# process_mesh_test
# SRCS process_mesh_test.cc
# DEPS process_mesh)
# cc_library(
# dist_attr
# SRCS dist_attr.cc
# DEPS process_mesh auto_parallel_proto proto_desc)
# cc_test(
# dist_attr_test
# SRCS dist_attr_test.cc
# DEPS dist_attr)
# cc_library(
# dist_mapper
# SRCS dist_mapper.cc
# DEPS device_mesh auto_parallel_proto)
# cc_test(
# dist_mapper_test
# SRCS dist_mapper_test.cc
# DEPS dist_mapper)
proto_library(auto_parallel_proto SRCS auto_parallel.proto)
# cc_library(auto_parallel DEPS process_mesh device_mesh dist_attr dist_mapper
# auto_parallel_proto)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless optional by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle.distributed.auto_parallel;
// This proto describes the capability of one device such as the computation and memory.
message DeviceCapabilityProto {
optional double single_precision_flops = 1;
optional double double_precision_flops = 2;
optional double memory_size_in_bytes = 3;
optional double clock_rate_in_ghz = 4;
}
// This proto represents a device.
message DeviceProto {
// The global id of this device within the cluster.
optional int64 global_id = 1;
// The local id of this device within the machine.
optional int64 local_id = 2;
// The id of the machine own this device.
optional int64 machine_id = 3;
// The id of the machine has this device.
optional string type = 4;
// The capability of this device.
optional DeviceCapabilityProto capability = 5;
}
// This proto describes the capability of the link between two devices.
message LinkCapabilityProto {
optional int64 bandwidth = 1; // Bytes/s
optional int64 latency = 2;
}
message LinkProto {
// The global id of the source device.
optional int64 source_id = 1;
// The global id of the source device.
optional int64 target_id = 2;
// Represent the link type.
optional string type = 3;
// The capability of this link.
optional LinkCapabilityProto capability = 4;
}
// DeviceMesh is used to organize devices and like n-dimension array.
message DeviceMeshProto {
// The global id of this mesh.
optional string name = 1;
// The size of each dimension.
repeated int64 shape = 2;
// These device ids are stored by a row-major way.
// There are no duplicate device ids within one device mesh.
repeated int64 device_ids = 3;
// The name of each dimension.
repeated string dim_names = 4;
// The devices of this mesh.
repeated DeviceProto devices = 5;
// The links are between devices.
repeated LinkProto links = 6;
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <iterator>
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
std::string DeviceCapability::to_string() const {
std::string str;
str += "{sflops: " + to_string_with_precision(single_precision_flops) + ", ";
str += "dflops: " + to_string_with_precision(double_precision_flops) + ", ";
str += "memory: " + to_string_with_precision(memory_size_in_bytes) + ", ";
str += "rate: " + to_string_with_precision(clock_rate_in_ghz) + "}";
return str;
}
DeviceCapability DeviceCapability::from_proto(
const DeviceCapabilityProto &proto) {
DeviceCapability capability;
capability.single_precision_flops = proto.single_precision_flops();
capability.double_precision_flops = proto.double_precision_flops();
capability.memory_size_in_bytes = proto.memory_size_in_bytes();
capability.clock_rate_in_ghz = proto.clock_rate_in_ghz();
return capability;
}
DeviceCapabilityProto DeviceCapability::to_proto() const {
DeviceCapabilityProto proto;
proto.set_single_precision_flops(single_precision_flops);
proto.set_double_precision_flops(double_precision_flops);
proto.set_memory_size_in_bytes(memory_size_in_bytes);
proto.set_clock_rate_in_ghz(clock_rate_in_ghz);
return proto;
}
std::string Device::to_string() const {
std::string str = "{global_id: " + std::to_string(global_id_) + ", ";
str += "local_id: " + std::to_string(local_id_) + ", ";
str += "machine_id: " + std::to_string(machine_id_) + ", ";
str += "type: " + type_ + ", ";
str += "capability: " + capability_.to_string() + "}";
return str;
}
Device Device::from_proto(const DeviceProto &proto) {
Device device;
device.global_id_ = proto.global_id();
device.local_id_ = proto.local_id();
device.machine_id_ = proto.machine_id();
device.type_ = proto.type();
device.capability_ = DeviceCapability::from_proto(proto.capability());
return device;
}
DeviceProto Device::to_proto() const {
DeviceProto proto;
proto.set_global_id(global_id_);
proto.set_local_id(local_id_);
proto.set_machine_id(machine_id_);
proto.set_type(type_);
proto.mutable_capability()->CopyFrom(capability_.to_proto());
return proto;
}
bool operator==(const Device &lhs, const Device &rhs) {
if (lhs.global_id() != rhs.global_id()) {
return false;
}
if (lhs.local_id() != rhs.local_id()) {
return false;
}
if (lhs.machine_id() != rhs.machine_id()) {
return false;
}
if (lhs.type() != rhs.type()) {
return false;
}
return true;
}
std::string LinkCapability::to_string() const {
std::string str;
str += "{bandwidth: " + to_string_with_precision(bandwidth) + ",";
str += "latency: " + to_string_with_precision(latency) + "}";
return str;
}
LinkCapability LinkCapability::from_proto(const LinkCapabilityProto &proto) {
LinkCapability capability;
capability.bandwidth = proto.bandwidth();
capability.latency = proto.latency();
return capability;
}
LinkCapabilityProto LinkCapability::to_proto() const {
LinkCapabilityProto proto;
proto.set_bandwidth(bandwidth);
proto.set_latency(latency);
return proto;
}
std::string Link::to_string() const {
std::string str = "{source_id:" + std::to_string(source_id_) + ",";
str += "target_id:" + std::to_string(target_id_) + ",";
str += "type:" + type_ + ",";
str += "capability:" + capability_.to_string() + "}";
return str;
}
Link Link::from_proto(const LinkProto &proto) {
Link link;
link.source_id_ = proto.source_id();
link.target_id_ = proto.target_id();
link.type_ = proto.type();
link.capability_ = LinkCapability::from_proto(proto.capability());
return link;
}
LinkProto Link::to_proto() const {
LinkProto proto;
proto.set_source_id(source_id_);
proto.set_target_id(target_id_);
proto.set_type(type_);
proto.mutable_capability()->CopyFrom(capability_.to_proto());
return proto;
}
bool operator==(const Link &lhs, const Link &rhs) {
if (lhs.source_id() != rhs.source_id()) {
return false;
}
if (lhs.target_id() != rhs.target_id()) {
return false;
}
if (lhs.type() != rhs.type()) {
return false;
}
return true;
}
bool Machine::contains(int64_t device_id) const {
if (devices_.count(device_id) == 1) {
return true;
} else {
return false;
}
}
void Machine::add_device(const Device &device) {
if (id() == -1) {
set_id(device.machine_id());
} else {
PADDLE_ENFORCE_EQ(device.machine_id(),
id(),
platform::errors::InvalidArgument(
"The machine id [%d] of the device should be equal "
"to this machine id [%d].",
device.machine_id(),
id_));
}
devices_[device.global_id()] = &device;
}
void Machine::add_link(const Link &link) {
PADDLE_ENFORCE_EQ(contains(link.source_id()),
true,
platform::errors::InvalidArgument(
"The source device id of the added link [%s] "
"cannot be found in the device_ids. Please add the "
"source device before adding this link",
std::to_string(link.source_id())));
links_[link.source_id()][link.target_id()] = &link;
}
std::string Machine::to_string() const {
std::string str = "{devices: [";
for (const auto &device : devices_) {
str += device.second->to_string() + ", ";
}
str.replace(str.size() - 2, 2, "], ");
str += "links: [";
for (const auto &item : links_) {
str += "{";
str += "source_id: " + std::to_string(item.first) + ", neighbors: [";
for (const auto &link : item.second) {
str += link.second->to_string() + ", ";
}
str.replace(str.size() - 2, 2, "]}, ");
}
str.replace(str.size() - 4, 4, "]}");
return str;
}
DeviceMesh::DeviceMesh(const std::string &name,
const std::vector<int64_t> &shape,
const std::vector<int64_t> &device_ids,
const std::vector<std::string> &dim_names) {
name_ = name;
shape_ = shape;
int64_t size = this->size();
PADDLE_ENFORCE_EQ(size,
device_ids.size(),
platform::errors::InvalidArgument(
"The size %d of this device mesh must be "
"equal to the size %d of its device ids.",
size,
device_ids.size()));
PADDLE_ENFORCE_EQ(
has_duplicates(device_ids),
false,
platform::errors::InvalidArgument("The device ids [%s] must be unique.",
str_join(device_ids)));
device_ids_ = device_ids;
PADDLE_ENFORCE_EQ(
shape_.size(),
dim_names.size(),
platform::errors::InvalidArgument(
"The size %d of mesh shape must be equal to the size %d "
"of the dimension names.",
shape_.size(),
dim_names.size()));
PADDLE_ENFORCE_EQ(has_duplicates(dim_names),
false,
platform::errors::InvalidArgument(
"The names [%s] of each dimension must be unique.",
str_join(dim_names)));
dim_names_ = dim_names;
}
int64_t DeviceMesh::size() const {
if (shape_.empty()) return 0;
int64_t size = 1;
for (const int64_t dim_size : shape_) size *= dim_size;
return size;
}
bool DeviceMesh::contains(int64_t device_id) const {
auto result =
std::find(std::begin(device_ids_), std::end(device_ids_), device_id);
if (result != std::end(device_ids_)) {
return true;
} else {
return false;
}
}
void DeviceMesh::add_device(const Device &device) {
PADDLE_ENFORCE_EQ(
contains(device.global_id()),
true,
platform::errors::InvalidArgument(
"The added device id [%s] cannot be found in the device_ids.",
std::to_string(device.global_id())));
// Operator [] will create a new object if it cannot find one.
// So we add the default constructor for Device and Machine
// to make sure the new object can be created.
devices_[device.global_id()] = device;
machines_[device.machine_id()].add_device(devices_[device.global_id()]);
}
void DeviceMesh::add_link(const Link &link) {
PADDLE_ENFORCE_EQ(
contains(link.source_id()),
true,
platform::errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.",
std::to_string(link.source_id())));
PADDLE_ENFORCE_EQ(
contains(link.target_id()),
true,
platform::errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.",
std::to_string(link.target_id())));
// Operator [] will create a new object if it cannot find one.
// So we add the default constructor for Device and Machine
// to make sure the new object can be created.
links_[link.source_id()][link.target_id()] = link;
const Device &source_device = devices_[link.source_id()];
machines_[source_device.machine_id()].add_link(
links_[link.source_id()][link.target_id()]);
}
std::string DeviceMesh::to_string() const {
std::string mesh_str = "{name: " + name_ + ", ";
mesh_str += "shape: [" + str_join(shape_) + "], ";
mesh_str += "device_ids: [" + str_join(device_ids_) + "], ";
mesh_str += "dim_names: [" + str_join(dim_names_) + "], ";
mesh_str += "\ndevices: [\n";
for (const auto &device : devices_) {
mesh_str += " " + device.second.to_string() + ",\n";
}
mesh_str.replace(mesh_str.size() - 2, 2, "],");
mesh_str += "\nlinks: [\n";
for (const auto &item : links_) {
mesh_str += " {";
mesh_str += "source_id: " + std::to_string(item.first) + ", neighbors: [";
for (const auto &link : item.second) {
mesh_str += link.second.to_string() + ", ";
}
mesh_str.replace(mesh_str.size() - 2, 2, "]},\n");
}
mesh_str.replace(mesh_str.size() - 4, 4, "]}");
return mesh_str;
}
DeviceMesh DeviceMesh::from_proto(const DeviceMeshProto &proto) {
DeviceMesh mesh;
mesh.name_ = proto.name();
mesh.shape_.resize(proto.shape_size());
for (int64_t i = 0; i < proto.shape_size(); ++i) {
mesh.shape_[i] = proto.shape(i);
}
mesh.device_ids_.resize(proto.device_ids_size());
for (int64_t i = 0; i < proto.device_ids_size(); ++i) {
mesh.device_ids_[i] = proto.device_ids(i);
}
mesh.dim_names_.resize(proto.dim_names_size());
for (int64_t i = 0; i < proto.dim_names_size(); ++i) {
mesh.dim_names_[i] = proto.dim_names(i);
}
for (int64_t i = 0; i < proto.devices_size(); ++i) {
mesh.add_device(Device::from_proto(proto.devices(i)));
}
for (int64_t i = 0; i < proto.links_size(); ++i) {
mesh.add_link(Link::from_proto(proto.links(i)));
}
return mesh;
}
DeviceMeshProto DeviceMesh::to_proto() const {
DeviceMeshProto proto;
proto.set_name(name_);
for (const auto &i : shape_) {
proto.add_shape(i);
}
for (const auto &i : device_ids_) {
proto.add_device_ids(i);
}
for (const auto &i : dim_names_) {
proto.add_dim_names(i);
}
for (const auto &device : devices_) {
proto.mutable_devices()->Add()->CopyFrom(device.second.to_proto());
}
for (const auto &neighbors : links_) {
for (const auto &link : neighbors.second) {
proto.mutable_links()->Add()->CopyFrom(link.second.to_proto());
}
}
return proto;
}
bool operator==(const DeviceMesh &lhs, const DeviceMesh &rhs) {
// Use the unique name to do the fast comparison
if (lhs.name() != rhs.name()) {
return false;
}
return true;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
struct DeviceCapability {
double single_precision_flops = 0.0;
double double_precision_flops = 0.0;
double memory_size_in_bytes = 0.0;
double clock_rate_in_ghz = 0.0;
// DeviceCapability from_string(const std::string& str);
std::string to_string() const;
static DeviceCapability from_proto(const DeviceCapabilityProto& proto);
DeviceCapabilityProto to_proto() const;
};
inline std::ostream& operator<<(std::ostream& os, const DeviceCapability& obj) {
os << obj.to_string();
return os;
}
class Device {
public:
Device() = default;
Device(int64_t global_id,
int64_t local_id,
int64_t machine_id,
const std::string& type)
: global_id_(global_id),
local_id_(local_id),
machine_id_(machine_id),
type_(type) {}
int64_t global_id() const { return global_id_; }
int64_t local_id() const { return local_id_; }
int64_t machine_id() const { return machine_id_; }
const std::string& type() const { return type_; }
const DeviceCapability& capability() const { return capability_; }
void set_capability(const DeviceCapability& capability) {
capability_ = capability;
}
// Device from_string(const std::string& mesh_str);
std::string to_string() const;
static Device from_proto(const DeviceProto& proto);
DeviceProto to_proto() const;
private:
int64_t global_id_;
int64_t local_id_;
int64_t machine_id_;
std::string type_;
DeviceCapability capability_;
};
inline std::ostream& operator<<(std::ostream& os, const Device& obj) {
os << obj.to_string();
return os;
}
bool operator==(const Device& lhs, const Device& rhs);
inline bool operator!=(const Device& lhs, const Device& rhs) {
return !operator==(lhs, rhs);
}
struct LinkCapability {
double bandwidth = 0.0; // Bytes/s
double latency = 0.0;
// LinkCapability from_string(const std::string& str);
std::string to_string() const;
static LinkCapability from_proto(const LinkCapabilityProto& proto);
LinkCapabilityProto to_proto() const;
};
inline std::ostream& operator<<(std::ostream& os, const LinkCapability& obj) {
os << obj.to_string();
return os;
}
class Link {
public:
Link() = default;
Link(int64_t source_id, int64_t target_id, const std::string& type)
: source_id_(source_id), target_id_(target_id), type_(type) {}
int64_t source_id() const { return source_id_; }
int64_t target_id() const { return target_id_; }
const std::string& type() const { return type_; }
const LinkCapability& capability() const { return capability_; }
void set_capability(const LinkCapability& capability) {
capability_ = capability;
}
// Link from_string(const std::string& str);
std::string to_string() const;
static Link from_proto(const LinkProto& proto);
LinkProto to_proto() const;
private:
int64_t source_id_;
int64_t target_id_;
std::string type_;
LinkCapability capability_;
};
inline std::ostream& operator<<(std::ostream& os, const Link& obj) {
os << obj.to_string();
return os;
}
bool operator==(const Link& lhs, const Link& rhs);
inline bool operator!=(const Link& lhs, const Link& rhs) {
return !operator==(lhs, rhs);
}
class Machine {
public:
Machine() = default;
explicit Machine(int64_t id) : id_(id) {}
int64_t id() const { return id_; }
void set_id(int64_t id) { id_ = id; }
bool contains(int64_t device_id) const;
void add_device(const Device& device);
void add_link(const Link& link);
// Machine from_string(const std::string& str);
std::string to_string() const;
private:
int64_t id_ = -1;
std::unordered_map<int64_t, const Device*> devices_;
std::unordered_map<int64_t, std::unordered_map<int64_t, const Link*>> links_;
};
class DeviceMesh {
public:
DeviceMesh() = default;
DeviceMesh(const std::string& name,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& device_ids,
const std::vector<std::string>& dim_names);
const std::string& name() const { return name_; }
void set_name(const std::string& name) { name_ = name; }
const std::vector<int64_t>& shape() const { return shape_; }
const std::vector<int64_t>& device_ids() const { return device_ids_; }
const std::vector<std::string>& dim_names() const { return dim_names_; }
std::string device_type() const {
if (empty()) return std::string();
return std::begin(devices_)->second.type();
}
const std::unordered_map<int64_t, Device>& devices() const {
return devices_;
}
const std::unordered_map<int64_t, std::unordered_map<int64_t, Link>>& links()
const {
return links_;
}
const Device& device(int64_t global_id) const {
return devices_.at(global_id);
}
const Link& link(int64_t source_id, int64_t target_id) const {
return links_.at(source_id).at(target_id);
}
int64_t size() const;
int64_t ndim() const { return shape_.size(); }
int64_t dim_size(int64_t dim) const {
int64_t cdim = canonical_dim(dim, shape_.size());
return shape_[cdim];
}
int64_t dim_size(const std::string& dim_name) const {
for (std::size_t i = 0; i < dim_names_.size(); ++i) {
if (dim_names_[i] == dim_name) {
return shape_[i];
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot find the dimension of %s in this device mesh.", dim_name));
}
bool empty() const { return (shape_.empty() || device_ids_.empty()); }
bool contains(int64_t device_id) const;
void add_device(const Device& device);
void add_link(const Link& link);
// DeviceMesh from_string(const std::string& mesh_str);
std::string to_string() const;
static DeviceMesh from_proto(const DeviceMeshProto& proto);
DeviceMeshProto to_proto() const;
private:
std::string name_;
std::vector<int64_t> shape_;
std::vector<int64_t> device_ids_;
std::vector<std::string> dim_names_;
std::unordered_map<int64_t, Device> devices_;
std::unordered_map<int64_t, std::unordered_map<int64_t, Link>> links_;
std::unordered_map<int64_t, Machine> machines_;
};
inline std::ostream& operator<<(std::ostream& os, const DeviceMesh& obj) {
os << obj.to_string();
return os;
}
bool operator==(const DeviceMesh& lhs, const DeviceMesh& rhs);
inline bool operator!=(const DeviceMesh& lhs, const DeviceMesh& rhs) {
return !operator==(lhs, rhs);
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include <iostream>
#include <sstream>
#include "gtest/gtest.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
TEST(DeviceMesh, Ctor) {
std::vector<int64_t> shape = {2, 3};
std::vector<int64_t> device_ids = {0, 1, 2, 3, 4, 5};
std::vector<std::string> dim_names = {"x", "y"};
std::string device_type = "GPU";
int64_t size = shape[0] * shape[1];
DeviceMesh device_mesh("mesh", shape, device_ids, dim_names);
for (int64_t i = 0; i < shape[0]; ++i) {
for (int64_t j = 0; j < shape[1]; ++j) {
int64_t global_id = i * shape[1] + j;
int64_t local_id = j;
int64_t machine_id = i;
device_mesh.add_device(
Device(global_id, local_id, machine_id, device_type));
}
}
for (int64_t i = 0; i < size; ++i) {
for (int64_t j = 0; j < size; ++j) {
device_mesh.add_link(Link(i, j, "NVL"));
}
}
EXPECT_EQ(device_mesh.name(), "mesh");
EXPECT_EQ(device_mesh.shape(), shape);
EXPECT_EQ(device_mesh.device_ids(), device_ids);
EXPECT_EQ(device_mesh.dim_names()[0], "x");
EXPECT_EQ(device_mesh.dim_names()[1], "y");
EXPECT_EQ(device_mesh.device_type(), device_type);
EXPECT_EQ(device_mesh.size(), size);
EXPECT_EQ(device_mesh.ndim(), static_cast<int64_t>(shape.size()));
EXPECT_EQ(device_mesh.dim_size(0), shape[0]);
EXPECT_EQ(device_mesh.dim_size(-1), shape[1]);
EXPECT_EQ(device_mesh.dim_size("x"), shape[0]);
EXPECT_EQ(device_mesh.dim_size("y"), shape[1]);
EXPECT_EQ(device_mesh.empty(), false);
EXPECT_EQ(device_mesh.contains(0), true);
EXPECT_EQ(device_mesh.contains(6), false);
EXPECT_EQ(device_mesh.device(3).global_id(), 3);
EXPECT_EQ(device_mesh.device(3).local_id(), 0);
EXPECT_EQ(device_mesh.device(3).machine_id(), 1);
EXPECT_EQ(device_mesh.device(3).type(), "GPU");
EXPECT_EQ(device_mesh.link(3, 4).source_id(), 3);
EXPECT_EQ(device_mesh.link(3, 4).target_id(), 4);
EXPECT_EQ(device_mesh.link(3, 4).type(), "NVL");
for (int64_t i = 0; i < shape[0]; ++i) {
for (int64_t j = 0; j < shape[1]; ++j) {
int64_t global_id = i * shape[1] + j;
int64_t local_id = j;
int64_t machine_id = i;
auto device = device_mesh.devices().at(global_id);
EXPECT_EQ(device, Device(global_id, local_id, machine_id, device_type));
}
}
for (int64_t i = 0; i < size; ++i) {
for (int64_t j = 0; j < size; ++j) {
EXPECT_EQ(device_mesh.links().at(i).at(j), Link(i, j, "NVL"));
}
}
std::stringstream sstream;
sstream << device_mesh;
EXPECT_EQ(sstream.str(), device_mesh.to_string());
auto proto = device_mesh.to_proto();
DeviceMesh new_device_mesh = DeviceMesh::from_proto(proto);
EXPECT_EQ(device_mesh, new_device_mesh);
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
// struct Indent {
// Indent(int &level) : level(level) { ++level; }
// ~Indent() { --level; }
// int &level;
// };
// inline std::string str_indent(std::string& str, cur_indent) {
// string spaces(cur_indent, " ");
// return str + std::string(cur_indent, " ");
// }
template <class T>
bool has_duplicates(const std::vector<T>& vec) {
std::unordered_map<T, int> map;
for (const auto& i : vec) {
++map[i];
if (map[i] > 1) return true;
}
return false;
}
inline int64_t canonical_dim(int dim, int ndim) {
PADDLE_ENFORCE_EQ(
dim >= -ndim && dim < ndim,
true,
platform::errors::InvalidArgument(
"Dimension %d is outside of [-%d, %d).", dim, ndim, ndim));
if (dim < 0) {
return dim + ndim;
}
return dim;
}
// Refer to https://stackoverflow.com/a/5289170
template <typename Range, typename Value = typename Range::value_type>
std::string str_join(Range const& elements,
const std::string& delimiter = ",") {
std::ostringstream os;
auto b = std::begin(elements), e = std::end(elements);
if (b != e) {
std::copy(b, prev(e), std::ostream_iterator<Value>(os, delimiter.c_str()));
b = prev(e);
}
if (b != e) {
os << *b;
}
return os.str();
}
// Refer to https://stackoverflow.com/a/46931770
inline std::vector<std::string> str_split(std::string const& input,
const std::string& delimiter = ",") {
size_t pos_start = 0, pos_end, delim_len = delimiter.length();
std::string token;
std::vector<std::string> output;
while ((pos_end = input.find(delimiter, pos_start)) != std::string::npos) {
token = input.substr(pos_start, pos_end - pos_start);
pos_start = pos_end + delim_len;
output.push_back(token);
}
output.push_back(input.substr(pos_start));
return output;
}
// Refer to https://stackoverflow.com/a/29200671/2358969
template <typename T>
std::string to_string_with_precision(const T a_value, const int n = 2) {
std::ostringstream out;
out.precision(n);
out << std::fixed << a_value;
return out.str();
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册