未验证 提交 be1152a4 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

[dist attr 迁移到 phi]Dist attr (#53848)

* merge code from forsish

* polish

* paddle/fluid/pybind/auto_parallel_py.cc

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
上级 4af0f140
proto_library(auto_parallel_proto SRCS auto_parallel.proto)
cc_library(
device_mesh
SRCS device_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library( cc_library(
process_mesh op_dist_attr
SRCS process_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library(
dist_attr
SRCS dist_attr.cc SRCS dist_attr.cc
DEPS process_mesh auto_parallel_proto proto_desc phi_enforce) DEPS dist_attr process_mesh dist_mapper auto_parallel_proto proto_desc
phi_enforce)
cc_library(
dist_mapper
SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper) add_subdirectory(test)
...@@ -26,10 +26,9 @@ namespace paddle { ...@@ -26,10 +26,9 @@ namespace paddle {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
std::vector<std::string> TensorDistAttr::fields_{ using phi::distributed::auto_parallel::str_join;
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
static inline std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) { std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
if (tensor == nullptr) return std::vector<int64_t>(); if (tensor == nullptr) return std::vector<int64_t>();
switch (tensor->GetType()) { switch (tensor->GetType()) {
case framework::proto::VarType::READER: case framework::proto::VarType::READER:
...@@ -43,251 +42,6 @@ static inline std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) { ...@@ -43,251 +42,6 @@ static inline std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
} }
} }
TensorDistAttr::TensorDistAttr(const VarDesc& tensor) {
VLOG(4) << "[TensorDistAttr constructor] tensor name: " << tensor.Name();
std::vector<int64_t> tensor_shape = get_tensor_shape(&tensor);
set_default_dims_mapping(tensor_shape);
set_default_dynamic_dims(tensor_shape);
}
TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
copy_from(dist_attr);
}
TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (this == &dist_attr) return *this;
TensorDistAttr tmp(dist_attr);
std::swap(this->process_mesh_, tmp.process_mesh_);
std::swap(this->dims_mapping_, tmp.dims_mapping_);
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
return *this;
}
void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
}
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
process_mesh_ = process_mesh;
}
void TensorDistAttr::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
dims_mapping_ = dims_mapping;
}
void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
batch_dim_ = batch_dim;
}
void TensorDistAttr::set_dynamic_dims(const std::vector<bool>& dynamic_dims) {
dynamic_dims_ = dynamic_dims;
}
void TensorDistAttr::set_annotated(
const std::map<std::string, bool>& annotated) {
annotated_ = annotated;
}
void TensorDistAttr::set_default_dims_mapping(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
}
}
void TensorDistAttr::set_default_dynamic_dims(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}
}
void TensorDistAttr::mark_annotated(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
annotated_[name] = true;
}
}
bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
VLOG(4) << "[TensorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping >= process_mesh_.ndim()) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dims_mapping] " << str_join(dims_mapping);
if (dims_mapping.size() != tensor_shape.size()) {
return false;
}
std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) {
for (int64_t i : dims_mapping) {
if (i < -1 || i >= process_mesh_.ndim()) {
return false;
}
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
} else {
for (int64_t i : dims_mapping) {
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_batch_dim(
int64_t dim, const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_batch_dim] " << dim;
int64_t ndim = tensor_shape.size();
if (ndim > 0) {
if (dim < 0) {
dim = dim + ndim;
}
if (dim < 0 || dim >= ndim) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dynamic_dims] " << str_join(dynamic_dims);
if (dynamic_dims.size() > 0 && dynamic_dims.size() != tensor_shape.size()) {
return false;
}
return true;
}
bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[TensorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify(const VarDesc* tensor) const {
auto tensor_shape = get_tensor_shape(tensor);
if (!verify_process_mesh(process_mesh_)) {
return false;
}
if (!verify_dims_mapping(dims_mapping_, tensor_shape)) {
return false;
}
if (!verify_batch_dim(batch_dim_, tensor_shape)) {
return false;
}
if (!verify_dynamic_dims(dynamic_dims_, tensor_shape)) {
return false;
}
if (!verify_annotated(annotated_)) {
return false;
}
return true;
}
std::string TensorDistAttr::to_string() const {
std::string dist_str;
dist_str += "{process_mesh: " + process_mesh_.to_string() + ", ";
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
return dist_str;
}
void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dims_mapping_.resize(proto.dims_mapping_size());
for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) {
dims_mapping_[i] = proto.dims_mapping(i);
}
batch_dim_ = proto.batch_dim();
dynamic_dims_.resize(proto.dynamic_dims_size());
for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) {
dynamic_dims_[i] = proto.dynamic_dims(i);
}
}
TensorDistAttrProto TensorDistAttr::to_proto() const {
TensorDistAttrProto proto;
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
for (const auto& i : dims_mapping_) {
proto.add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
}
return proto;
}
std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
platform::errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
return data;
}
void TensorDistAttr::parse_from_string(const std::string& data) {
TensorDistAttrProto proto;
PADDLE_ENFORCE_EQ(proto.ParseFromString(data),
true,
platform::errors::InvalidArgument(
"Failed to parse tensor dist attr from string."));
from_proto(proto);
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
}
if (lhs.dims_mapping() != rhs.dims_mapping()) {
return false;
}
if (lhs.batch_dim() != rhs.batch_dim()) {
return false;
}
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
return true;
}
std::vector<std::string> OperatorDistAttr::fields_{"process_mesh", std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"impl_type", "impl_type",
"impl_idx", "impl_idx",
...@@ -335,7 +89,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) { ...@@ -335,7 +89,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
if (input == nullptr || op->Type() == "create_py_reader") { if (input == nullptr || op->Type() == "create_py_reader") {
input_dist_attrs_[name] = TensorDistAttr(); input_dist_attrs_[name] = TensorDistAttr();
} else { } else {
input_dist_attrs_[name] = TensorDistAttr(*input); input_dist_attrs_[name] = TensorDistAttr(get_tensor_shape(input));
} }
} }
for (std::string name : op->OutputArgumentNames()) { for (std::string name : op->OutputArgumentNames()) {
...@@ -344,7 +98,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) { ...@@ -344,7 +98,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
if (output == nullptr) { if (output == nullptr) {
output_dist_attrs_[name] = TensorDistAttr(); output_dist_attrs_[name] = TensorDistAttr();
} else { } else {
output_dist_attrs_[name] = TensorDistAttr(*output); output_dist_attrs_[name] = TensorDistAttr(get_tensor_shape(output));
} }
} }
op_type_ = op->Type(); op_type_ = op->Type();
...@@ -465,7 +219,8 @@ bool OperatorDistAttr::verify_input_dist_attr(const std::string& name, ...@@ -465,7 +219,8 @@ bool OperatorDistAttr::verify_input_dist_attr(const std::string& name,
const VarDesc* tensor) const { const VarDesc* tensor) const {
VLOG(4) << "[OperatorDistAttr verify_input_dist_attr] " << name << " " VLOG(4) << "[OperatorDistAttr verify_input_dist_attr] " << name << " "
<< dist_attr.to_string(); << dist_attr.to_string();
if (!dist_attr.verify(tensor)) { auto tensor_shape = get_tensor_shape(tensor);
if (!dist_attr.verify(tensor_shape)) {
return false; return false;
} }
if (tensor != nullptr) { if (tensor != nullptr) {
...@@ -484,7 +239,8 @@ bool OperatorDistAttr::verify_output_dist_attr(const std::string& name, ...@@ -484,7 +239,8 @@ bool OperatorDistAttr::verify_output_dist_attr(const std::string& name,
const VarDesc* tensor) const { const VarDesc* tensor) const {
VLOG(4) << "[OperatorDistAttr verify_output_dist_attr] " << name << " " VLOG(4) << "[OperatorDistAttr verify_output_dist_attr] " << name << " "
<< dist_attr.to_string(); << dist_attr.to_string();
if (!dist_attr.verify(tensor)) { auto tensor_shape = get_tensor_shape(tensor);
if (!dist_attr.verify(tensor_shape)) {
return false; return false;
} }
if (tensor != nullptr) { if (tensor != nullptr) {
......
...@@ -21,10 +21,11 @@ limitations under the License. */ ...@@ -21,10 +21,11 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle { namespace paddle {
...@@ -46,97 +47,13 @@ using framework::OpDesc; ...@@ -46,97 +47,13 @@ using framework::OpDesc;
using framework::ProgramDesc; using framework::ProgramDesc;
using framework::VarDesc; using framework::VarDesc;
constexpr const char* kDefault = "default"; using phi::distributed::auto_parallel::OperatorDistAttrProto;
using phi::distributed::auto_parallel::ProcessMesh;
class TensorDistAttr { using phi::distributed::auto_parallel::TensorDistAttr;
public:
TensorDistAttr() = default;
explicit TensorDistAttr(const VarDesc& tensor);
TensorDistAttr(const TensorDistAttr& tensor);
TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
void copy_from(const TensorDistAttr& dist_attr);
const ProcessMesh& process_mesh() const { return process_mesh_; }
void set_process_mesh(const ProcessMesh& process_mesh);
const std::vector<int64_t>& dims_mapping() const { return dims_mapping_; }
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);
int64_t batch_dim() const { return batch_dim_; }
void set_batch_dim(int64_t batch_dim);
const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }
void set_dynamic_dims(const std::vector<bool>& dynamic_dims);
void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);
const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated);
bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1 && annotated_.at(name) == true;
}
void mark_annotated(const std::string& name);
void clear_annotated() { annotated_.clear(); }
bool verify_process_mesh(const ProcessMesh& process_mesh) const; constexpr const char* kDefault = "default";
bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const;
bool verify_batch_dim(int64_t dim,
const std::vector<int64_t>& tensor_shape) const;
bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify(const VarDesc* tensor = nullptr) const;
// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
void from_proto(const TensorDistAttrProto& proto);
TensorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private:
static std::vector<std::string> fields_;
ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_;
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
};
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
os << obj.to_string();
return os;
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs);
inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { std::vector<int64_t> get_tensor_shape(const VarDesc* tensor);
return !operator==(lhs, rhs);
}
class OperatorDistAttr { class OperatorDistAttr {
public: public:
......
...@@ -11,7 +11,7 @@ cc_test( ...@@ -11,7 +11,7 @@ cc_test(
cc_test( cc_test(
dist_attr_test dist_attr_test
SRCS dist_attr_test.cc SRCS dist_attr_test.cc
DEPS dist_attr) DEPS dist_attr proto_desc)
cc_test( cc_test(
dist_mapper_test dist_mapper_test
......
...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -90,4 +90,4 @@ TEST(DeviceMesh, Ctor) { ...@@ -90,4 +90,4 @@ TEST(DeviceMesh, Ctor) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -17,29 +17,37 @@ limitations under the License. */ ...@@ -17,29 +17,37 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
using paddle::framework::BlockDesc;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc;
using paddle::distributed::auto_parallel::get_tensor_shape;
using paddle::distributed::auto_parallel::OperatorDistAttr;
TEST(DistAttr, ctor) { TEST(DistAttr, ctor) {
ProgramDesc program; ProgramDesc program;
auto* global_block = program.MutableBlock(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(framework::proto::VarType::LOD_TENSOR); x->SetType(paddle::framework::proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
x->SetDataType(framework::proto::VarType::FP32); x->SetDataType(paddle::framework::proto::VarType::FP32);
x->SetShape({1000, 784}); x->SetShape({1000, 784});
auto* y = global_block->Var("Y"); auto* y = global_block->Var("Y");
y->SetType(framework::proto::VarType::LOD_TENSOR); y->SetType(paddle::framework::proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0); y->SetLoDLevel(0);
y->SetDataType(framework::proto::VarType::FP32); y->SetDataType(paddle::framework::proto::VarType::FP32);
y->SetShape({784, 100}); y->SetShape({784, 100});
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
...@@ -48,10 +56,15 @@ TEST(DistAttr, ctor) { ...@@ -48,10 +56,15 @@ TEST(DistAttr, ctor) {
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out"); auto* out = global_block->Var("Out");
out->SetType(framework::proto::VarType::LOD_TENSOR); out->SetType(paddle::framework::proto::VarType::LOD_TENSOR);
out->SetShape({1000, 100}); out->SetShape({1000, 100});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
auto get_dist_attr = [](const VarDesc* var_desc) {
auto shape = get_tensor_shape(var_desc);
return TensorDistAttr(shape);
};
std::vector<int64_t> shape = {2, 4}; std::vector<int64_t> shape = {2, 4};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7}; std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7};
std::vector<std::string> dim_names = {"x", "y"}; std::vector<std::string> dim_names = {"x", "y"};
...@@ -62,7 +75,9 @@ TEST(DistAttr, ctor) { ...@@ -62,7 +75,9 @@ TEST(DistAttr, ctor) {
std::vector<std::string> dim_names2 = {"a", "b"}; std::vector<std::string> dim_names2 = {"a", "b"};
ProcessMesh process_mesh2(shape2, process_ids2, dim_names2); ProcessMesh process_mesh2(shape2, process_ids2, dim_names2);
TensorDistAttr x_dist_attr(*x), y_dist_attr(*y), out_dist_attr(*out); auto x_dist_attr = get_dist_attr(x);
auto y_dist_attr = get_dist_attr(y);
auto out_dist_attr = get_dist_attr(out);
x_dist_attr.set_process_mesh(process_mesh); x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1})); x_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1}));
x_dist_attr.set_batch_dim(0); x_dist_attr.set_batch_dim(0);
...@@ -75,7 +90,7 @@ TEST(DistAttr, ctor) { ...@@ -75,7 +90,7 @@ TEST(DistAttr, ctor) {
EXPECT_EQ(x_dist_attr.dynamic_dims(), std::vector<bool>({true, false})); EXPECT_EQ(x_dist_attr.dynamic_dims(), std::vector<bool>({true, false}));
EXPECT_EQ(x_dist_attr.is_annotated("process_mesh"), true); EXPECT_EQ(x_dist_attr.is_annotated("process_mesh"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dims_mapping"), true); EXPECT_EQ(x_dist_attr.is_annotated("dims_mapping"), true);
EXPECT_EQ(x_dist_attr.verify(x), true); EXPECT_EQ(x_dist_attr.verify(get_tensor_shape(x)), true);
x_dist_attr.clear_annotated(); x_dist_attr.clear_annotated();
EXPECT_EQ(x_dist_attr.annotated().empty(), true); EXPECT_EQ(x_dist_attr.annotated().empty(), true);
...@@ -83,7 +98,7 @@ TEST(DistAttr, ctor) { ...@@ -83,7 +98,7 @@ TEST(DistAttr, ctor) {
x_sstream << x_dist_attr; x_sstream << x_dist_attr;
EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string()); EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string());
auto x_proto = x_dist_attr.to_proto(); auto x_proto = x_dist_attr.to_proto();
TensorDistAttr new_x_dist_attr(*x); TensorDistAttr new_x_dist_attr = get_dist_attr(x);
new_x_dist_attr.from_proto(x_proto); new_x_dist_attr.from_proto(x_proto);
EXPECT_EQ(x_dist_attr, new_x_dist_attr); EXPECT_EQ(x_dist_attr, new_x_dist_attr);
...@@ -95,11 +110,11 @@ TEST(DistAttr, ctor) { ...@@ -95,11 +110,11 @@ TEST(DistAttr, ctor) {
x_dist_attr.mark_annotated("dynamic_dims"); x_dist_attr.mark_annotated("dynamic_dims");
EXPECT_EQ(y_dist_attr.process_mesh(), process_mesh); EXPECT_EQ(y_dist_attr.process_mesh(), process_mesh);
EXPECT_EQ(y_dist_attr.dims_mapping(), std::vector<int64_t>({-1, 0})); EXPECT_EQ(y_dist_attr.dims_mapping(), std::vector<int64_t>({-1, 0}));
EXPECT_EQ(y_dist_attr.batch_dim(), 1); EXPECT_EQ(y_dist_attr.batch_dim(), -1);
EXPECT_EQ(y_dist_attr.dynamic_dims(), std::vector<bool>({false, true})); EXPECT_EQ(y_dist_attr.dynamic_dims(), std::vector<bool>({false, true}));
EXPECT_EQ(x_dist_attr.is_annotated("batch_dim"), true); EXPECT_EQ(x_dist_attr.is_annotated("batch_dim"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dynamic_dims"), true); EXPECT_EQ(x_dist_attr.is_annotated("dynamic_dims"), true);
EXPECT_EQ(x_dist_attr.verify(y), true); EXPECT_EQ(x_dist_attr.verify(get_tensor_shape(y)), true);
out_dist_attr.set_process_mesh(process_mesh); out_dist_attr.set_process_mesh(process_mesh);
out_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1})); out_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1}));
...@@ -109,11 +124,11 @@ TEST(DistAttr, ctor) { ...@@ -109,11 +124,11 @@ TEST(DistAttr, ctor) {
EXPECT_EQ(out_dist_attr.dims_mapping(), std::vector<int64_t>({0, 1})); EXPECT_EQ(out_dist_attr.dims_mapping(), std::vector<int64_t>({0, 1}));
EXPECT_EQ(out_dist_attr.batch_dim(), 1); EXPECT_EQ(out_dist_attr.batch_dim(), 1);
EXPECT_EQ(out_dist_attr.dynamic_dims(), std::vector<bool>({false, false})); EXPECT_EQ(out_dist_attr.dynamic_dims(), std::vector<bool>({false, false}));
EXPECT_EQ(out_dist_attr.verify(out), true); EXPECT_EQ(out_dist_attr.verify(get_tensor_shape(out)), true);
OperatorDistAttr mul_dist_attr(*op); OperatorDistAttr mul_dist_attr(*op);
EXPECT_EQ(mul_dist_attr.impl_type(), kDefault); EXPECT_EQ(mul_dist_attr.impl_type(), kDefault);
EXPECT_EQ(mul_dist_attr.impl_idx(), -1); EXPECT_EQ(mul_dist_attr.impl_idx(), 0);
EXPECT_EQ(mul_dist_attr.is_recompute(), false); EXPECT_EQ(mul_dist_attr.is_recompute(), false);
EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), false); EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), false);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), false); EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), false);
...@@ -157,4 +172,4 @@ TEST(DistAttr, ctor) { ...@@ -157,4 +172,4 @@ TEST(DistAttr, ctor) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h" #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include <map> #include <map>
#include <sstream> #include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -69,4 +69,4 @@ TEST(DistributedMapper, Ctor) { ...@@ -69,4 +69,4 @@ TEST(DistributedMapper, Ctor) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -50,4 +50,4 @@ TEST(ProcessMesh, Ctor) { ...@@ -50,4 +50,4 @@ TEST(ProcessMesh, Ctor) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -333,7 +333,7 @@ endif() ...@@ -333,7 +333,7 @@ endif()
cc_library( cc_library(
data_layout_transform data_layout_transform
SRCS data_layout_transform.cc SRCS data_layout_transform.cc
DEPS tensor math_function) DEPS tensor math_function phi_data_layout_transform)
cc_test( cc_test(
data_layout_transform_test data_layout_transform_test
SRCS data_layout_transform_test.cc SRCS data_layout_transform_test.cc
...@@ -348,7 +348,8 @@ cc_library( ...@@ -348,7 +348,8 @@ cc_library(
selected_rows_utils selected_rows_utils
data_device_transform data_device_transform
data_type_transform data_type_transform
data_layout_transform) data_layout_transform
phi_data_transform)
cc_library( cc_library(
attribute attribute
...@@ -541,7 +542,7 @@ cc_library( ...@@ -541,7 +542,7 @@ cc_library(
glog glog
version version
xxhash xxhash
dist_attr op_dist_attr
scalar scalar
op_version_proto op_version_proto
op_version_registry) op_version_registry)
......
...@@ -441,7 +441,8 @@ TensorDistAttr *VarDesc::MutableDistAttr() { ...@@ -441,7 +441,8 @@ TensorDistAttr *VarDesc::MutableDistAttr() {
if (dist_attr_) { if (dist_attr_) {
return dist_attr_.get(); return dist_attr_.get();
} else { } else {
dist_attr_.reset(new TensorDistAttr(*this)); auto shape = paddle::distributed::auto_parallel::get_tensor_shape(this);
dist_attr_.reset(new TensorDistAttr(shape));
return dist_attr_.get(); return dist_attr_.get();
} }
need_updated_ = true; need_updated_ = true;
......
...@@ -28,7 +28,7 @@ limitations under the License. */ ...@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using paddle::distributed::auto_parallel::TensorDistAttr; using phi::distributed::auto_parallel::TensorDistAttr;
// convert between std::vector and protobuf repeated. // convert between std::vector and protobuf repeated.
template <typename T> template <typename T>
......
...@@ -501,6 +501,11 @@ if(WITH_PYTHON) ...@@ -501,6 +501,11 @@ if(WITH_PYTHON)
SRCS ${PYBIND_SRCS} SRCS ${PYBIND_SRCS}
DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
# cc_test do not respect deps, whole archive to link symbols that may need by test
if(WITH_TESTING)
#set_target_properties(${SHARD_LIB_NAME} PROPERTIES LINK_FLAGS "-Wl,--whole-archive")
endif()
# TODO(zhiqiu): some symbols not exported even setting the following # TODO(zhiqiu): some symbols not exported even setting the following
# property. Need to find a better way. # property. Need to find a better way.
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h" #include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -29,19 +29,19 @@ namespace py = pybind11; ...@@ -29,19 +29,19 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using paddle::distributed::auto_parallel::Device;
using paddle::distributed::auto_parallel::DeviceCapability;
using paddle::distributed::auto_parallel::DeviceMesh;
using paddle::distributed::auto_parallel::DistributedMapper;
using paddle::distributed::auto_parallel::kDefault;
using paddle::distributed::auto_parallel::Link;
using paddle::distributed::auto_parallel::LinkCapability;
using paddle::distributed::auto_parallel::Machine;
using paddle::distributed::auto_parallel::OperatorDistAttr; using paddle::distributed::auto_parallel::OperatorDistAttr;
using paddle::distributed::auto_parallel::ProcessMesh;
using paddle::distributed::auto_parallel::TensorDistAttr;
using paddle::framework::OpDesc; using paddle::framework::OpDesc;
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
using phi::distributed::auto_parallel::Device;
using phi::distributed::auto_parallel::DeviceCapability;
using phi::distributed::auto_parallel::DeviceMesh;
using phi::distributed::auto_parallel::DistributedMapper;
using phi::distributed::auto_parallel::kDefault;
using phi::distributed::auto_parallel::Link;
using phi::distributed::auto_parallel::LinkCapability;
using phi::distributed::auto_parallel::Machine;
using phi::distributed::auto_parallel::ProcessMesh;
using phi::distributed::auto_parallel::TensorDistAttr;
static inline const ProcessMesh *get_tensor_process_mesh( static inline const ProcessMesh *get_tensor_process_mesh(
const TensorDistAttr &self) { const TensorDistAttr &self) {
...@@ -227,7 +227,11 @@ void BindAutoParallel(py::module *m) { ...@@ -227,7 +227,11 @@ void BindAutoParallel(py::module *m) {
py::class_<TensorDistAttr>(*m, "TensorDistAttr") py::class_<TensorDistAttr>(*m, "TensorDistAttr")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const VarDesc &>()) .def(py::init([](const VarDesc &var_desc) {
auto shape =
paddle::distributed::auto_parallel::get_tensor_shape(&var_desc);
return std::make_unique<TensorDistAttr>(shape);
}))
.def(py::init<const TensorDistAttr &>()) .def(py::init<const TensorDistAttr &>())
.def_property( .def_property(
"process_mesh", &get_tensor_process_mesh, &set_tensor_process_mesh) "process_mesh", &get_tensor_process_mesh, &set_tensor_process_mesh)
...@@ -246,9 +250,14 @@ void BindAutoParallel(py::module *m) { ...@@ -246,9 +250,14 @@ void BindAutoParallel(py::module *m) {
.def("is_annotated", &TensorDistAttr::is_annotated) .def("is_annotated", &TensorDistAttr::is_annotated)
.def("mark_annotated", &TensorDistAttr::mark_annotated) .def("mark_annotated", &TensorDistAttr::mark_annotated)
.def("clear_annotated", &TensorDistAttr::clear_annotated) .def("clear_annotated", &TensorDistAttr::clear_annotated)
.def("verify", .def(
&TensorDistAttr::verify, "verify",
py::arg("tensor") = static_cast<VarDesc *>(nullptr)) [](TensorDistAttr &self, const VarDesc *tensor) {
auto shape =
paddle::distributed::auto_parallel::get_tensor_shape(tensor);
return self.verify(shape);
},
py::arg("tensor") = static_cast<VarDesc *>(nullptr))
.def("reset", &reset_tensor_dist_attr) .def("reset", &reset_tensor_dist_attr)
.def("serialize_to_string", .def("serialize_to_string",
[](TensorDistAttr &self) { [](TensorDistAttr &self) {
...@@ -369,6 +378,14 @@ void BindAutoParallel(py::module *m) { ...@@ -369,6 +378,14 @@ void BindAutoParallel(py::module *m) {
}, },
py::arg("memo")) py::arg("memo"))
.def("__str__", &OperatorDistAttr::to_string); .def("__str__", &OperatorDistAttr::to_string);
// TODO(liuzhenhai): DistributedMapper is not used for now, but
// dist_mapper_test need the symbols forch DistributedMapper to be linked,
// remove it latter
m->def("touch_dist_mapper", []() {
DistributedMapper mapper;
return mapper.to_string();
});
} }
} // namespace pybind } // namespace pybind
......
...@@ -39,7 +39,9 @@ set(PHI_DEPS ...@@ -39,7 +39,9 @@ set(PHI_DEPS
string_tensor string_tensor
api_scalar api_scalar
api_int_array api_int_array
extended_tensor) extended_tensor
dist_attr
dist_mapper)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels}) set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})
......
add_subdirectory(check) add_subdirectory(check)
add_subdirectory(store) add_subdirectory(store)
add_subdirectory(auto_parallel)
set(COMM_CONTEXT_MANAGER_DEPS tcp_store) set(COMM_CONTEXT_MANAGER_DEPS tcp_store)
......
proto_library(auto_parallel_proto SRCS auto_parallel.proto)
cc_library(
device_mesh
SRCS device_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library(
process_mesh
SRCS process_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library(
dist_attr
SRCS dist_attr.cc
DEPS process_mesh auto_parallel_proto proto_desc phi_enforce)
cc_library(
dist_mapper
SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless optional by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
package paddle.distributed.auto_parallel; package phi.distributed.auto_parallel;
// ProcessMesh is used to organize processes and like n-dimension array. // ProcessMesh is used to organize processes and like n-dimension array.
message ProcessMeshProto { message ProcessMeshProto {
......
...@@ -15,10 +15,10 @@ limitations under the License. */ ...@@ -15,10 +15,10 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -169,7 +169,7 @@ void Machine::add_device(const Device &device) { ...@@ -169,7 +169,7 @@ void Machine::add_device(const Device &device) {
} else { } else {
PADDLE_ENFORCE_EQ(device.machine_id(), PADDLE_ENFORCE_EQ(device.machine_id(),
id(), id(),
platform::errors::InvalidArgument( errors::InvalidArgument(
"The machine id [%d] of the device should be equal " "The machine id [%d] of the device should be equal "
"to this machine id [%d].", "to this machine id [%d].",
device.machine_id(), device.machine_id(),
...@@ -181,7 +181,7 @@ void Machine::add_device(const Device &device) { ...@@ -181,7 +181,7 @@ void Machine::add_device(const Device &device) {
void Machine::add_link(const Link &link) { void Machine::add_link(const Link &link) {
PADDLE_ENFORCE_EQ(contains(link.source_id()), PADDLE_ENFORCE_EQ(contains(link.source_id()),
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The source device id of the added link [%s] " "The source device id of the added link [%s] "
"cannot be found in the device_ids. Please add the " "cannot be found in the device_ids. Please add the "
"source device before adding this link", "source device before adding this link",
...@@ -217,31 +217,31 @@ DeviceMesh::DeviceMesh(const std::string &name, ...@@ -217,31 +217,31 @@ DeviceMesh::DeviceMesh(const std::string &name,
shape_ = shape; shape_ = shape;
int64_t size = this->size(); int64_t size = this->size();
PADDLE_ENFORCE_EQ(size, PADDLE_ENFORCE_EQ(
device_ids.size(), size,
platform::errors::InvalidArgument( device_ids.size(),
"The size %d of this device mesh must be " errors::InvalidArgument("The size %d of this device mesh must be "
"equal to the size %d of its device ids.", "equal to the size %d of its device ids.",
size, size,
device_ids.size())); device_ids.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
has_duplicates(device_ids), has_duplicates(device_ids),
false, false,
platform::errors::InvalidArgument("The device ids [%s] must be unique.", errors::InvalidArgument("The device ids [%s] must be unique.",
str_join(device_ids))); str_join(device_ids)));
device_ids_ = device_ids; device_ids_ = device_ids;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
shape_.size(), shape_.size(),
dim_names.size(), dim_names.size(),
platform::errors::InvalidArgument( errors::InvalidArgument(
"The size %d of mesh shape must be equal to the size %d " "The size %d of mesh shape must be equal to the size %d "
"of the dimension names.", "of the dimension names.",
shape_.size(), shape_.size(),
dim_names.size())); dim_names.size()));
PADDLE_ENFORCE_EQ(has_duplicates(dim_names), PADDLE_ENFORCE_EQ(has_duplicates(dim_names),
false, false,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The names [%s] of each dimension must be unique.", "The names [%s] of each dimension must be unique.",
str_join(dim_names))); str_join(dim_names)));
dim_names_ = dim_names; dim_names_ = dim_names;
...@@ -268,7 +268,7 @@ void DeviceMesh::add_device(const Device &device) { ...@@ -268,7 +268,7 @@ void DeviceMesh::add_device(const Device &device) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
contains(device.global_id()), contains(device.global_id()),
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The added device id [%s] cannot be found in the device_ids.", "The added device id [%s] cannot be found in the device_ids.",
std::to_string(device.global_id()))); std::to_string(device.global_id())));
// Operator [] will create a new object if it cannot find one. // Operator [] will create a new object if it cannot find one.
...@@ -282,15 +282,15 @@ void DeviceMesh::add_link(const Link &link) { ...@@ -282,15 +282,15 @@ void DeviceMesh::add_link(const Link &link) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
contains(link.source_id()), contains(link.source_id()),
true, true,
platform::errors::InvalidArgument("The source id of the added link [%s] " errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.", "cannot be found in the device_ids.",
std::to_string(link.source_id()))); std::to_string(link.source_id())));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
contains(link.target_id()), contains(link.target_id()),
true, true,
platform::errors::InvalidArgument("The source id of the added link [%s] " errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.", "cannot be found in the device_ids.",
std::to_string(link.target_id()))); std::to_string(link.target_id())));
// Operator [] will create a new object if it cannot find one. // Operator [] will create a new object if it cannot find one.
// So we add the default constructor for Device and Machine // So we add the default constructor for Device and Machine
// to make sure the new object can be created. // to make sure the new object can be created.
...@@ -395,4 +395,4 @@ bool operator==(const DeviceMesh &lhs, const DeviceMesh &rhs) { ...@@ -395,4 +395,4 @@ bool operator==(const DeviceMesh &lhs, const DeviceMesh &rhs) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -23,11 +23,11 @@ limitations under the License. */ ...@@ -23,11 +23,11 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
struct DeviceCapability { struct DeviceCapability {
...@@ -259,7 +259,7 @@ class DeviceMesh { ...@@ -259,7 +259,7 @@ class DeviceMesh {
return shape_[i]; return shape_[i];
} }
} }
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(errors::InvalidArgument(
"Cannot find the dimension of %s in this device mesh.", dim_name)); "Cannot find the dimension of %s in this device mesh.", dim_name));
} }
...@@ -298,4 +298,4 @@ inline bool operator!=(const DeviceMesh& lhs, const DeviceMesh& rhs) { ...@@ -298,4 +298,4 @@ inline bool operator!=(const DeviceMesh& lhs, const DeviceMesh& rhs) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include "glog/logging.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
TensorDistAttr::TensorDistAttr(const std::vector<int64_t>& tensor_shape) {
set_default_dims_mapping(tensor_shape);
set_default_dynamic_dims(tensor_shape);
}
TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
copy_from(dist_attr);
}
TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (this == &dist_attr) return *this;
TensorDistAttr tmp(dist_attr);
std::swap(this->process_mesh_, tmp.process_mesh_);
std::swap(this->dims_mapping_, tmp.dims_mapping_);
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
return *this;
}
void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
}
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
process_mesh_ = process_mesh;
}
void TensorDistAttr::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
dims_mapping_ = dims_mapping;
}
void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
batch_dim_ = batch_dim;
}
void TensorDistAttr::set_dynamic_dims(const std::vector<bool>& dynamic_dims) {
dynamic_dims_ = dynamic_dims;
}
void TensorDistAttr::set_annotated(
const std::map<std::string, bool>& annotated) {
annotated_ = annotated;
}
void TensorDistAttr::set_default_dims_mapping(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
}
}
void TensorDistAttr::set_default_dynamic_dims(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}
}
void TensorDistAttr::mark_annotated(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
annotated_[name] = true;
}
}
bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
VLOG(4) << "[TensorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping >= process_mesh_.ndim()) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dims_mapping] " << str_join(dims_mapping);
if (dims_mapping.size() != tensor_shape.size()) {
return false;
}
std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) {
for (int64_t i : dims_mapping) {
if (i < -1 || i >= process_mesh_.ndim()) {
return false;
}
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
} else {
for (int64_t i : dims_mapping) {
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_batch_dim(
int64_t dim, const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_batch_dim] " << dim;
int64_t ndim = tensor_shape.size();
if (ndim > 0) {
if (dim < 0) {
dim = dim + ndim;
}
if (dim < 0 || dim >= ndim) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dynamic_dims] " << str_join(dynamic_dims);
if (dynamic_dims.size() > 0 && dynamic_dims.size() != tensor_shape.size()) {
return false;
}
return true;
}
bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[TensorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
if (!verify_process_mesh(process_mesh_)) {
return false;
}
if (!verify_dims_mapping(dims_mapping_, tensor_shape)) {
return false;
}
if (!verify_batch_dim(batch_dim_, tensor_shape)) {
return false;
}
if (!verify_dynamic_dims(dynamic_dims_, tensor_shape)) {
return false;
}
if (!verify_annotated(annotated_)) {
return false;
}
return true;
}
std::string TensorDistAttr::to_string() const {
std::string dist_str;
dist_str += "{process_mesh: " + process_mesh_.to_string() + ", ";
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
return dist_str;
}
void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dims_mapping_.resize(proto.dims_mapping_size());
for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) {
dims_mapping_[i] = proto.dims_mapping(i);
}
batch_dim_ = proto.batch_dim();
dynamic_dims_.resize(proto.dynamic_dims_size());
for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) {
dynamic_dims_[i] = proto.dynamic_dims(i);
}
}
TensorDistAttrProto TensorDistAttr::to_proto() const {
TensorDistAttrProto proto;
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
for (const auto& i : dims_mapping_) {
proto.add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
}
return proto;
}
std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
return data;
}
void TensorDistAttr::parse_from_string(const std::string& data) {
TensorDistAttrProto proto;
PADDLE_ENFORCE_EQ(
proto.ParseFromString(data),
true,
errors::InvalidArgument(
"Failed to parse tensor dist attr from string: %s.", data));
from_proto(proto);
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
}
if (lhs.dims_mapping() != rhs.dims_mapping()) {
return false;
}
if (lhs.batch_dim() != rhs.batch_dim()) {
return false;
}
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
return true;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
constexpr const char* kDefault = "default";
class TensorDistAttr {
public:
TensorDistAttr() = default;
explicit TensorDistAttr(const std::vector<int64_t>& tensor_shape);
TensorDistAttr(const TensorDistAttr& tensor);
TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
void copy_from(const TensorDistAttr& dist_attr);
const ProcessMesh& process_mesh() const { return process_mesh_; }
void set_process_mesh(const ProcessMesh& process_mesh);
const std::vector<int64_t>& dims_mapping() const { return dims_mapping_; }
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);
int64_t batch_dim() const { return batch_dim_; }
void set_batch_dim(int64_t batch_dim);
const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }
void set_dynamic_dims(const std::vector<bool>& dynamic_dims);
void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);
const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated);
bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1 && annotated_.at(name) == true;
}
void mark_annotated(const std::string& name);
void clear_annotated() { annotated_.clear(); }
bool verify_process_mesh(const ProcessMesh& process_mesh) const;
bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const;
bool verify_batch_dim(int64_t dim,
const std::vector<int64_t>& tensor_shape) const;
bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify(const std::vector<int64_t>& tensor_shape) const;
// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
void from_proto(const TensorDistAttrProto& proto);
TensorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private:
static std::vector<std::string> fields_;
ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_;
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
};
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
os << obj.to_string();
return os;
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs);
inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
return !operator==(lhs, rhs);
}
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h" #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -32,7 +32,7 @@ void DistributedMapper::set_process_id_to_device_ids( ...@@ -32,7 +32,7 @@ void DistributedMapper::set_process_id_to_device_ids(
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
item.first, item.first,
0, 0,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The process id %d must be greater than or equal to 0.", "The process id %d must be greater than or equal to 0.",
item.first)); item.first));
std::string device_mesh_name = item.second.first; std::string device_mesh_name = item.second.first;
...@@ -40,14 +40,14 @@ void DistributedMapper::set_process_id_to_device_ids( ...@@ -40,14 +40,14 @@ void DistributedMapper::set_process_id_to_device_ids(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
device_meshes_.count(device_mesh_name), device_meshes_.count(device_mesh_name),
1, 1,
platform::errors::InvalidArgument( errors::InvalidArgument(
"Cannot find the device mesh %d in device_mesh ids [%s].", "Cannot find the device mesh %d in device_mesh ids [%s].",
device_mesh_name, device_mesh_name,
str_join(device_mesh_names))); str_join(device_mesh_names)));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
has_duplicates(device_ids), has_duplicates(device_ids),
false, false,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The mapped device ids [%s] of process_mesh %d must be unique.", "The mapped device ids [%s] of process_mesh %d must be unique.",
str_join(device_ids), str_join(device_ids),
item.first)); item.first));
...@@ -60,7 +60,7 @@ void DistributedMapper::set_process_id_to_device_ids( ...@@ -60,7 +60,7 @@ void DistributedMapper::set_process_id_to_device_ids(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
found, found,
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The device id %d cannot be find in the device mesh [%s].", "The device id %d cannot be find in the device mesh [%s].",
device_id, device_id,
str_join(cur_device_ids))); str_join(cur_device_ids)));
...@@ -143,4 +143,4 @@ bool operator==(const DistributedMapper& lhs, const DistributedMapper& rhs) { ...@@ -143,4 +143,4 @@ bool operator==(const DistributedMapper& lhs, const DistributedMapper& rhs) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -15,11 +15,11 @@ limitations under the License. */ ...@@ -15,11 +15,11 @@ limitations under the License. */
#include <utility> #include <utility>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -70,4 +70,4 @@ inline std::ostream& operator<<(std::ostream& os, ...@@ -70,4 +70,4 @@ inline std::ostream& operator<<(std::ostream& os,
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -15,10 +15,10 @@ limitations under the License. */ ...@@ -15,10 +15,10 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -30,27 +30,27 @@ ProcessMesh::ProcessMesh(const std::vector<int64_t> &shape, ...@@ -30,27 +30,27 @@ ProcessMesh::ProcessMesh(const std::vector<int64_t> &shape,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
size, size,
process_ids.size(), process_ids.size(),
platform::errors::InvalidArgument("The size of this process mesh must be " errors::InvalidArgument("The size of this process mesh must be "
"equal to the size of its process ids.", "equal to the size of its process ids.",
size, size,
process_ids.size())); process_ids.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
has_duplicates(process_ids), has_duplicates(process_ids),
false, false,
platform::errors::InvalidArgument("The process ids [%s] must be unique.", errors::InvalidArgument("The process ids [%s] must be unique.",
str_join(process_ids_))); str_join(process_ids_)));
process_ids_ = process_ids; process_ids_ = process_ids;
PADDLE_ENFORCE_EQ(shape_.size(), PADDLE_ENFORCE_EQ(shape_.size(),
dim_names.size(), dim_names.size(),
platform::errors::InvalidArgument( errors::InvalidArgument(
"The size of mesh shape must be equal to the size " "The size of mesh shape must be equal to the size "
"of the dimension names.", "of the dimension names.",
shape_.size(), shape_.size(),
dim_names_.size())); dim_names_.size()));
PADDLE_ENFORCE_EQ(has_duplicates(dim_names), PADDLE_ENFORCE_EQ(has_duplicates(dim_names),
false, false,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The names [%s] of each dimension must be unique.", "The names [%s] of each dimension must be unique.",
str_join(dim_names))); str_join(dim_names)));
dim_names_ = dim_names; dim_names_ = dim_names;
...@@ -131,4 +131,4 @@ bool operator==(const ProcessMesh &lhs, const ProcessMesh &rhs) { ...@@ -131,4 +131,4 @@ bool operator==(const ProcessMesh &lhs, const ProcessMesh &rhs) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -20,12 +20,12 @@ limitations under the License. */ ...@@ -20,12 +20,12 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -58,7 +58,7 @@ class ProcessMesh { ...@@ -58,7 +58,7 @@ class ProcessMesh {
return shape_[i]; return shape_[i];
} }
} }
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(errors::InvalidArgument(
"Cannot find the dimension of %s in this process mesh.", dim_name)); "Cannot find the dimension of %s in this process mesh.", dim_name));
} }
...@@ -90,4 +90,4 @@ inline bool operator!=(const ProcessMesh& lhs, const ProcessMesh& rhs) { ...@@ -90,4 +90,4 @@ inline bool operator!=(const ProcessMesh& lhs, const ProcessMesh& rhs) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -19,9 +19,9 @@ limitations under the License. */ ...@@ -19,9 +19,9 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
...@@ -50,7 +50,7 @@ inline int64_t canonical_dim(int dim, int ndim) { ...@@ -50,7 +50,7 @@ inline int64_t canonical_dim(int dim, int ndim) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim >= -ndim && dim < ndim, dim >= -ndim && dim < ndim,
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"Dimension %d is outside of [-%d, %d).", dim, ndim, ndim)); "Dimension %d is outside of [-%d, %d).", dim, ndim, ndim));
if (dim < 0) { if (dim < 0) {
return dim + ndim; return dim + ndim;
...@@ -111,4 +111,4 @@ std::string to_string_with_precision(const T a_value, const int n = 2) { ...@@ -111,4 +111,4 @@ std::string to_string_with_precision(const T a_value, const int n = 2) {
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册