未验证 提交 be1152a4 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

[dist attr 迁移到 phi]Dist attr (#53848)

* merge code from forsish

* polish

* paddle/fluid/pybind/auto_parallel_py.cc

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
上级 4af0f140
proto_library(auto_parallel_proto SRCS auto_parallel.proto)
cc_library(
device_mesh
SRCS device_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library(
process_mesh
SRCS process_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library(
dist_attr
op_dist_attr
SRCS dist_attr.cc
DEPS process_mesh auto_parallel_proto proto_desc phi_enforce)
cc_library(
dist_mapper
SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce)
DEPS dist_attr process_mesh dist_mapper auto_parallel_proto proto_desc
phi_enforce)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper)
add_subdirectory(test)
......@@ -26,10 +26,9 @@ namespace paddle {
namespace distributed {
namespace auto_parallel {
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
using phi::distributed::auto_parallel::str_join;
static inline std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
if (tensor == nullptr) return std::vector<int64_t>();
switch (tensor->GetType()) {
case framework::proto::VarType::READER:
......@@ -43,251 +42,6 @@ static inline std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
}
}
TensorDistAttr::TensorDistAttr(const VarDesc& tensor) {
VLOG(4) << "[TensorDistAttr constructor] tensor name: " << tensor.Name();
std::vector<int64_t> tensor_shape = get_tensor_shape(&tensor);
set_default_dims_mapping(tensor_shape);
set_default_dynamic_dims(tensor_shape);
}
TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
copy_from(dist_attr);
}
TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (this == &dist_attr) return *this;
TensorDistAttr tmp(dist_attr);
std::swap(this->process_mesh_, tmp.process_mesh_);
std::swap(this->dims_mapping_, tmp.dims_mapping_);
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
return *this;
}
void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
}
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
process_mesh_ = process_mesh;
}
void TensorDistAttr::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
dims_mapping_ = dims_mapping;
}
void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
batch_dim_ = batch_dim;
}
void TensorDistAttr::set_dynamic_dims(const std::vector<bool>& dynamic_dims) {
dynamic_dims_ = dynamic_dims;
}
void TensorDistAttr::set_annotated(
const std::map<std::string, bool>& annotated) {
annotated_ = annotated;
}
void TensorDistAttr::set_default_dims_mapping(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
}
}
void TensorDistAttr::set_default_dynamic_dims(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}
}
void TensorDistAttr::mark_annotated(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
annotated_[name] = true;
}
}
bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
VLOG(4) << "[TensorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping >= process_mesh_.ndim()) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dims_mapping] " << str_join(dims_mapping);
if (dims_mapping.size() != tensor_shape.size()) {
return false;
}
std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) {
for (int64_t i : dims_mapping) {
if (i < -1 || i >= process_mesh_.ndim()) {
return false;
}
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
} else {
for (int64_t i : dims_mapping) {
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_batch_dim(
int64_t dim, const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_batch_dim] " << dim;
int64_t ndim = tensor_shape.size();
if (ndim > 0) {
if (dim < 0) {
dim = dim + ndim;
}
if (dim < 0 || dim >= ndim) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dynamic_dims] " << str_join(dynamic_dims);
if (dynamic_dims.size() > 0 && dynamic_dims.size() != tensor_shape.size()) {
return false;
}
return true;
}
bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[TensorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify(const VarDesc* tensor) const {
auto tensor_shape = get_tensor_shape(tensor);
if (!verify_process_mesh(process_mesh_)) {
return false;
}
if (!verify_dims_mapping(dims_mapping_, tensor_shape)) {
return false;
}
if (!verify_batch_dim(batch_dim_, tensor_shape)) {
return false;
}
if (!verify_dynamic_dims(dynamic_dims_, tensor_shape)) {
return false;
}
if (!verify_annotated(annotated_)) {
return false;
}
return true;
}
std::string TensorDistAttr::to_string() const {
std::string dist_str;
dist_str += "{process_mesh: " + process_mesh_.to_string() + ", ";
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
return dist_str;
}
void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dims_mapping_.resize(proto.dims_mapping_size());
for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) {
dims_mapping_[i] = proto.dims_mapping(i);
}
batch_dim_ = proto.batch_dim();
dynamic_dims_.resize(proto.dynamic_dims_size());
for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) {
dynamic_dims_[i] = proto.dynamic_dims(i);
}
}
TensorDistAttrProto TensorDistAttr::to_proto() const {
TensorDistAttrProto proto;
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
for (const auto& i : dims_mapping_) {
proto.add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
}
return proto;
}
std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
platform::errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
return data;
}
void TensorDistAttr::parse_from_string(const std::string& data) {
TensorDistAttrProto proto;
PADDLE_ENFORCE_EQ(proto.ParseFromString(data),
true,
platform::errors::InvalidArgument(
"Failed to parse tensor dist attr from string."));
from_proto(proto);
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
}
if (lhs.dims_mapping() != rhs.dims_mapping()) {
return false;
}
if (lhs.batch_dim() != rhs.batch_dim()) {
return false;
}
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
return true;
}
std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"impl_type",
"impl_idx",
......@@ -335,7 +89,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
if (input == nullptr || op->Type() == "create_py_reader") {
input_dist_attrs_[name] = TensorDistAttr();
} else {
input_dist_attrs_[name] = TensorDistAttr(*input);
input_dist_attrs_[name] = TensorDistAttr(get_tensor_shape(input));
}
}
for (std::string name : op->OutputArgumentNames()) {
......@@ -344,7 +98,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
if (output == nullptr) {
output_dist_attrs_[name] = TensorDistAttr();
} else {
output_dist_attrs_[name] = TensorDistAttr(*output);
output_dist_attrs_[name] = TensorDistAttr(get_tensor_shape(output));
}
}
op_type_ = op->Type();
......@@ -465,7 +219,8 @@ bool OperatorDistAttr::verify_input_dist_attr(const std::string& name,
const VarDesc* tensor) const {
VLOG(4) << "[OperatorDistAttr verify_input_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify(tensor)) {
auto tensor_shape = get_tensor_shape(tensor);
if (!dist_attr.verify(tensor_shape)) {
return false;
}
if (tensor != nullptr) {
......@@ -484,7 +239,8 @@ bool OperatorDistAttr::verify_output_dist_attr(const std::string& name,
const VarDesc* tensor) const {
VLOG(4) << "[OperatorDistAttr verify_output_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify(tensor)) {
auto tensor_shape = get_tensor_shape(tensor);
if (!dist_attr.verify(tensor_shape)) {
return false;
}
if (tensor != nullptr) {
......
......@@ -21,10 +21,11 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
......@@ -46,97 +47,13 @@ using framework::OpDesc;
using framework::ProgramDesc;
using framework::VarDesc;
constexpr const char* kDefault = "default";
class TensorDistAttr {
public:
TensorDistAttr() = default;
explicit TensorDistAttr(const VarDesc& tensor);
TensorDistAttr(const TensorDistAttr& tensor);
TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
void copy_from(const TensorDistAttr& dist_attr);
const ProcessMesh& process_mesh() const { return process_mesh_; }
void set_process_mesh(const ProcessMesh& process_mesh);
const std::vector<int64_t>& dims_mapping() const { return dims_mapping_; }
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);
int64_t batch_dim() const { return batch_dim_; }
void set_batch_dim(int64_t batch_dim);
const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }
void set_dynamic_dims(const std::vector<bool>& dynamic_dims);
void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);
const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated);
bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1 && annotated_.at(name) == true;
}
void mark_annotated(const std::string& name);
void clear_annotated() { annotated_.clear(); }
using phi::distributed::auto_parallel::OperatorDistAttrProto;
using phi::distributed::auto_parallel::ProcessMesh;
using phi::distributed::auto_parallel::TensorDistAttr;
bool verify_process_mesh(const ProcessMesh& process_mesh) const;
bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const;
bool verify_batch_dim(int64_t dim,
const std::vector<int64_t>& tensor_shape) const;
bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify(const VarDesc* tensor = nullptr) const;
// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
void from_proto(const TensorDistAttrProto& proto);
TensorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private:
static std::vector<std::string> fields_;
ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_;
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
};
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
os << obj.to_string();
return os;
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs);
constexpr const char* kDefault = "default";
inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
return !operator==(lhs, rhs);
}
std::vector<int64_t> get_tensor_shape(const VarDesc* tensor);
class OperatorDistAttr {
public:
......
......@@ -11,7 +11,7 @@ cc_test(
cc_test(
dist_attr_test
SRCS dist_attr_test.cc
DEPS dist_attr)
DEPS dist_attr proto_desc)
cc_test(
dist_mapper_test
......
......@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include <iostream>
#include <sstream>
#include "gtest/gtest.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -90,4 +90,4 @@ TEST(DeviceMesh, Ctor) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -17,29 +17,37 @@ limitations under the License. */
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
using paddle::framework::BlockDesc;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc;
using paddle::distributed::auto_parallel::get_tensor_shape;
using paddle::distributed::auto_parallel::OperatorDistAttr;
TEST(DistAttr, ctor) {
ProgramDesc program;
auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(framework::proto::VarType::LOD_TENSOR);
x->SetType(paddle::framework::proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(framework::proto::VarType::FP32);
x->SetDataType(paddle::framework::proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(framework::proto::VarType::LOD_TENSOR);
y->SetType(paddle::framework::proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(framework::proto::VarType::FP32);
y->SetDataType(paddle::framework::proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
......@@ -48,10 +56,15 @@ TEST(DistAttr, ctor) {
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(framework::proto::VarType::LOD_TENSOR);
out->SetType(paddle::framework::proto::VarType::LOD_TENSOR);
out->SetShape({1000, 100});
op->SetOutput("Out", {out->Name()});
auto get_dist_attr = [](const VarDesc* var_desc) {
auto shape = get_tensor_shape(var_desc);
return TensorDistAttr(shape);
};
std::vector<int64_t> shape = {2, 4};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7};
std::vector<std::string> dim_names = {"x", "y"};
......@@ -62,7 +75,9 @@ TEST(DistAttr, ctor) {
std::vector<std::string> dim_names2 = {"a", "b"};
ProcessMesh process_mesh2(shape2, process_ids2, dim_names2);
TensorDistAttr x_dist_attr(*x), y_dist_attr(*y), out_dist_attr(*out);
auto x_dist_attr = get_dist_attr(x);
auto y_dist_attr = get_dist_attr(y);
auto out_dist_attr = get_dist_attr(out);
x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1}));
x_dist_attr.set_batch_dim(0);
......@@ -75,7 +90,7 @@ TEST(DistAttr, ctor) {
EXPECT_EQ(x_dist_attr.dynamic_dims(), std::vector<bool>({true, false}));
EXPECT_EQ(x_dist_attr.is_annotated("process_mesh"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dims_mapping"), true);
EXPECT_EQ(x_dist_attr.verify(x), true);
EXPECT_EQ(x_dist_attr.verify(get_tensor_shape(x)), true);
x_dist_attr.clear_annotated();
EXPECT_EQ(x_dist_attr.annotated().empty(), true);
......@@ -83,7 +98,7 @@ TEST(DistAttr, ctor) {
x_sstream << x_dist_attr;
EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string());
auto x_proto = x_dist_attr.to_proto();
TensorDistAttr new_x_dist_attr(*x);
TensorDistAttr new_x_dist_attr = get_dist_attr(x);
new_x_dist_attr.from_proto(x_proto);
EXPECT_EQ(x_dist_attr, new_x_dist_attr);
......@@ -95,11 +110,11 @@ TEST(DistAttr, ctor) {
x_dist_attr.mark_annotated("dynamic_dims");
EXPECT_EQ(y_dist_attr.process_mesh(), process_mesh);
EXPECT_EQ(y_dist_attr.dims_mapping(), std::vector<int64_t>({-1, 0}));
EXPECT_EQ(y_dist_attr.batch_dim(), 1);
EXPECT_EQ(y_dist_attr.batch_dim(), -1);
EXPECT_EQ(y_dist_attr.dynamic_dims(), std::vector<bool>({false, true}));
EXPECT_EQ(x_dist_attr.is_annotated("batch_dim"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dynamic_dims"), true);
EXPECT_EQ(x_dist_attr.verify(y), true);
EXPECT_EQ(x_dist_attr.verify(get_tensor_shape(y)), true);
out_dist_attr.set_process_mesh(process_mesh);
out_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1}));
......@@ -109,11 +124,11 @@ TEST(DistAttr, ctor) {
EXPECT_EQ(out_dist_attr.dims_mapping(), std::vector<int64_t>({0, 1}));
EXPECT_EQ(out_dist_attr.batch_dim(), 1);
EXPECT_EQ(out_dist_attr.dynamic_dims(), std::vector<bool>({false, false}));
EXPECT_EQ(out_dist_attr.verify(out), true);
EXPECT_EQ(out_dist_attr.verify(get_tensor_shape(out)), true);
OperatorDistAttr mul_dist_attr(*op);
EXPECT_EQ(mul_dist_attr.impl_type(), kDefault);
EXPECT_EQ(mul_dist_attr.impl_idx(), -1);
EXPECT_EQ(mul_dist_attr.impl_idx(), 0);
EXPECT_EQ(mul_dist_attr.is_recompute(), false);
EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), false);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), false);
......@@ -157,4 +172,4 @@ TEST(DistAttr, ctor) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include <map>
#include <sstream>
#include "gtest/gtest.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -69,4 +69,4 @@ TEST(DistributedMapper, Ctor) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include <iostream>
#include <sstream>
#include "gtest/gtest.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -50,4 +50,4 @@ TEST(ProcessMesh, Ctor) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -333,7 +333,7 @@ endif()
cc_library(
data_layout_transform
SRCS data_layout_transform.cc
DEPS tensor math_function)
DEPS tensor math_function phi_data_layout_transform)
cc_test(
data_layout_transform_test
SRCS data_layout_transform_test.cc
......@@ -348,7 +348,8 @@ cc_library(
selected_rows_utils
data_device_transform
data_type_transform
data_layout_transform)
data_layout_transform
phi_data_transform)
cc_library(
attribute
......@@ -541,7 +542,7 @@ cc_library(
glog
version
xxhash
dist_attr
op_dist_attr
scalar
op_version_proto
op_version_registry)
......
......@@ -441,7 +441,8 @@ TensorDistAttr *VarDesc::MutableDistAttr() {
if (dist_attr_) {
return dist_attr_.get();
} else {
dist_attr_.reset(new TensorDistAttr(*this));
auto shape = paddle::distributed::auto_parallel::get_tensor_shape(this);
dist_attr_.reset(new TensorDistAttr(shape));
return dist_attr_.get();
}
need_updated_ = true;
......
......@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
using paddle::distributed::auto_parallel::TensorDistAttr;
using phi::distributed::auto_parallel::TensorDistAttr;
// convert between std::vector and protobuf repeated.
template <typename T>
......
......@@ -501,6 +501,11 @@ if(WITH_PYTHON)
SRCS ${PYBIND_SRCS}
DEPS ${PYBIND_DEPS} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
# cc_test do not respect deps, whole archive to link symbols that may need by test
if(WITH_TESTING)
#set_target_properties(${SHARD_LIB_NAME} PROPERTIES LINK_FLAGS "-Wl,--whole-archive")
endif()
# TODO(zhiqiu): some symbols not exported even setting the following
# property. Need to find a better way.
......
......@@ -15,13 +15,13 @@
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/utils/optional.h"
namespace py = pybind11;
......@@ -29,19 +29,19 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
using paddle::distributed::auto_parallel::Device;
using paddle::distributed::auto_parallel::DeviceCapability;
using paddle::distributed::auto_parallel::DeviceMesh;
using paddle::distributed::auto_parallel::DistributedMapper;
using paddle::distributed::auto_parallel::kDefault;
using paddle::distributed::auto_parallel::Link;
using paddle::distributed::auto_parallel::LinkCapability;
using paddle::distributed::auto_parallel::Machine;
using paddle::distributed::auto_parallel::OperatorDistAttr;
using paddle::distributed::auto_parallel::ProcessMesh;
using paddle::distributed::auto_parallel::TensorDistAttr;
using paddle::framework::OpDesc;
using paddle::framework::VarDesc;
using phi::distributed::auto_parallel::Device;
using phi::distributed::auto_parallel::DeviceCapability;
using phi::distributed::auto_parallel::DeviceMesh;
using phi::distributed::auto_parallel::DistributedMapper;
using phi::distributed::auto_parallel::kDefault;
using phi::distributed::auto_parallel::Link;
using phi::distributed::auto_parallel::LinkCapability;
using phi::distributed::auto_parallel::Machine;
using phi::distributed::auto_parallel::ProcessMesh;
using phi::distributed::auto_parallel::TensorDistAttr;
static inline const ProcessMesh *get_tensor_process_mesh(
const TensorDistAttr &self) {
......@@ -227,7 +227,11 @@ void BindAutoParallel(py::module *m) {
py::class_<TensorDistAttr>(*m, "TensorDistAttr")
.def(py::init<>())
.def(py::init<const VarDesc &>())
.def(py::init([](const VarDesc &var_desc) {
auto shape =
paddle::distributed::auto_parallel::get_tensor_shape(&var_desc);
return std::make_unique<TensorDistAttr>(shape);
}))
.def(py::init<const TensorDistAttr &>())
.def_property(
"process_mesh", &get_tensor_process_mesh, &set_tensor_process_mesh)
......@@ -246,9 +250,14 @@ void BindAutoParallel(py::module *m) {
.def("is_annotated", &TensorDistAttr::is_annotated)
.def("mark_annotated", &TensorDistAttr::mark_annotated)
.def("clear_annotated", &TensorDistAttr::clear_annotated)
.def("verify",
&TensorDistAttr::verify,
py::arg("tensor") = static_cast<VarDesc *>(nullptr))
.def(
"verify",
[](TensorDistAttr &self, const VarDesc *tensor) {
auto shape =
paddle::distributed::auto_parallel::get_tensor_shape(tensor);
return self.verify(shape);
},
py::arg("tensor") = static_cast<VarDesc *>(nullptr))
.def("reset", &reset_tensor_dist_attr)
.def("serialize_to_string",
[](TensorDistAttr &self) {
......@@ -369,6 +378,14 @@ void BindAutoParallel(py::module *m) {
},
py::arg("memo"))
.def("__str__", &OperatorDistAttr::to_string);
// TODO(liuzhenhai): DistributedMapper is not used for now, but
// dist_mapper_test need the symbols forch DistributedMapper to be linked,
// remove it latter
m->def("touch_dist_mapper", []() {
DistributedMapper mapper;
return mapper.to_string();
});
}
} // namespace pybind
......
......@@ -39,7 +39,9 @@ set(PHI_DEPS
string_tensor
api_scalar
api_int_array
extended_tensor)
extended_tensor
dist_attr
dist_mapper)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})
......
add_subdirectory(check)
add_subdirectory(store)
add_subdirectory(auto_parallel)
set(COMM_CONTEXT_MANAGER_DEPS tcp_store)
......
proto_library(auto_parallel_proto SRCS auto_parallel.proto)
cc_library(
device_mesh
SRCS device_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library(
process_mesh
SRCS process_mesh.cc
DEPS auto_parallel_proto phi_enforce)
cc_library(
dist_attr
SRCS dist_attr.cc
DEPS process_mesh auto_parallel_proto proto_desc phi_enforce)
cc_library(
dist_mapper
SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless optional by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
......@@ -14,7 +14,7 @@ limitations under the License. */
syntax = "proto2";
package paddle.distributed.auto_parallel;
package phi.distributed.auto_parallel;
// ProcessMesh is used to organize processes and like n-dimension array.
message ProcessMeshProto {
......
......@@ -15,10 +15,10 @@ limitations under the License. */
#include <algorithm>
#include <iterator>
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -169,7 +169,7 @@ void Machine::add_device(const Device &device) {
} else {
PADDLE_ENFORCE_EQ(device.machine_id(),
id(),
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The machine id [%d] of the device should be equal "
"to this machine id [%d].",
device.machine_id(),
......@@ -181,7 +181,7 @@ void Machine::add_device(const Device &device) {
void Machine::add_link(const Link &link) {
PADDLE_ENFORCE_EQ(contains(link.source_id()),
true,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The source device id of the added link [%s] "
"cannot be found in the device_ids. Please add the "
"source device before adding this link",
......@@ -217,31 +217,31 @@ DeviceMesh::DeviceMesh(const std::string &name,
shape_ = shape;
int64_t size = this->size();
PADDLE_ENFORCE_EQ(size,
device_ids.size(),
platform::errors::InvalidArgument(
"The size %d of this device mesh must be "
"equal to the size %d of its device ids.",
size,
device_ids.size()));
PADDLE_ENFORCE_EQ(
size,
device_ids.size(),
errors::InvalidArgument("The size %d of this device mesh must be "
"equal to the size %d of its device ids.",
size,
device_ids.size()));
PADDLE_ENFORCE_EQ(
has_duplicates(device_ids),
false,
platform::errors::InvalidArgument("The device ids [%s] must be unique.",
str_join(device_ids)));
errors::InvalidArgument("The device ids [%s] must be unique.",
str_join(device_ids)));
device_ids_ = device_ids;
PADDLE_ENFORCE_EQ(
shape_.size(),
dim_names.size(),
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The size %d of mesh shape must be equal to the size %d "
"of the dimension names.",
shape_.size(),
dim_names.size()));
PADDLE_ENFORCE_EQ(has_duplicates(dim_names),
false,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The names [%s] of each dimension must be unique.",
str_join(dim_names)));
dim_names_ = dim_names;
......@@ -268,7 +268,7 @@ void DeviceMesh::add_device(const Device &device) {
PADDLE_ENFORCE_EQ(
contains(device.global_id()),
true,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The added device id [%s] cannot be found in the device_ids.",
std::to_string(device.global_id())));
// Operator [] will create a new object if it cannot find one.
......@@ -282,15 +282,15 @@ void DeviceMesh::add_link(const Link &link) {
PADDLE_ENFORCE_EQ(
contains(link.source_id()),
true,
platform::errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.",
std::to_string(link.source_id())));
errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.",
std::to_string(link.source_id())));
PADDLE_ENFORCE_EQ(
contains(link.target_id()),
true,
platform::errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.",
std::to_string(link.target_id())));
errors::InvalidArgument("The source id of the added link [%s] "
"cannot be found in the device_ids.",
std::to_string(link.target_id())));
// Operator [] will create a new object if it cannot find one.
// So we add the default constructor for Device and Machine
// to make sure the new object can be created.
......@@ -395,4 +395,4 @@ bool operator==(const DeviceMesh &lhs, const DeviceMesh &rhs) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -23,11 +23,11 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
struct DeviceCapability {
......@@ -259,7 +259,7 @@ class DeviceMesh {
return shape_[i];
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"Cannot find the dimension of %s in this device mesh.", dim_name));
}
......@@ -298,4 +298,4 @@ inline bool operator!=(const DeviceMesh& lhs, const DeviceMesh& rhs) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include "glog/logging.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
TensorDistAttr::TensorDistAttr(const std::vector<int64_t>& tensor_shape) {
set_default_dims_mapping(tensor_shape);
set_default_dynamic_dims(tensor_shape);
}
TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
copy_from(dist_attr);
}
TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (this == &dist_attr) return *this;
TensorDistAttr tmp(dist_attr);
std::swap(this->process_mesh_, tmp.process_mesh_);
std::swap(this->dims_mapping_, tmp.dims_mapping_);
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
return *this;
}
void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
}
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
process_mesh_ = process_mesh;
}
void TensorDistAttr::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
dims_mapping_ = dims_mapping;
}
void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
batch_dim_ = batch_dim;
}
void TensorDistAttr::set_dynamic_dims(const std::vector<bool>& dynamic_dims) {
dynamic_dims_ = dynamic_dims;
}
void TensorDistAttr::set_annotated(
const std::map<std::string, bool>& annotated) {
annotated_ = annotated;
}
void TensorDistAttr::set_default_dims_mapping(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
}
}
void TensorDistAttr::set_default_dynamic_dims(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}
}
void TensorDistAttr::mark_annotated(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
annotated_[name] = true;
}
}
bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
VLOG(4) << "[TensorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping >= process_mesh_.ndim()) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dims_mapping] " << str_join(dims_mapping);
if (dims_mapping.size() != tensor_shape.size()) {
return false;
}
std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) {
for (int64_t i : dims_mapping) {
if (i < -1 || i >= process_mesh_.ndim()) {
return false;
}
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
} else {
for (int64_t i : dims_mapping) {
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_batch_dim(
int64_t dim, const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_batch_dim] " << dim;
int64_t ndim = tensor_shape.size();
if (ndim > 0) {
if (dim < 0) {
dim = dim + ndim;
}
if (dim < 0 || dim >= ndim) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dynamic_dims] " << str_join(dynamic_dims);
if (dynamic_dims.size() > 0 && dynamic_dims.size() != tensor_shape.size()) {
return false;
}
return true;
}
bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[TensorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
if (!verify_process_mesh(process_mesh_)) {
return false;
}
if (!verify_dims_mapping(dims_mapping_, tensor_shape)) {
return false;
}
if (!verify_batch_dim(batch_dim_, tensor_shape)) {
return false;
}
if (!verify_dynamic_dims(dynamic_dims_, tensor_shape)) {
return false;
}
if (!verify_annotated(annotated_)) {
return false;
}
return true;
}
std::string TensorDistAttr::to_string() const {
std::string dist_str;
dist_str += "{process_mesh: " + process_mesh_.to_string() + ", ";
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
return dist_str;
}
void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dims_mapping_.resize(proto.dims_mapping_size());
for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) {
dims_mapping_[i] = proto.dims_mapping(i);
}
batch_dim_ = proto.batch_dim();
dynamic_dims_.resize(proto.dynamic_dims_size());
for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) {
dynamic_dims_[i] = proto.dynamic_dims(i);
}
}
TensorDistAttrProto TensorDistAttr::to_proto() const {
TensorDistAttrProto proto;
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
for (const auto& i : dims_mapping_) {
proto.add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
}
return proto;
}
std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
return data;
}
void TensorDistAttr::parse_from_string(const std::string& data) {
TensorDistAttrProto proto;
PADDLE_ENFORCE_EQ(
proto.ParseFromString(data),
true,
errors::InvalidArgument(
"Failed to parse tensor dist attr from string: %s.", data));
from_proto(proto);
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
}
if (lhs.dims_mapping() != rhs.dims_mapping()) {
return false;
}
if (lhs.batch_dim() != rhs.batch_dim()) {
return false;
}
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
return true;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace distributed {
namespace auto_parallel {
constexpr const char* kDefault = "default";
class TensorDistAttr {
public:
TensorDistAttr() = default;
explicit TensorDistAttr(const std::vector<int64_t>& tensor_shape);
TensorDistAttr(const TensorDistAttr& tensor);
TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
void copy_from(const TensorDistAttr& dist_attr);
const ProcessMesh& process_mesh() const { return process_mesh_; }
void set_process_mesh(const ProcessMesh& process_mesh);
const std::vector<int64_t>& dims_mapping() const { return dims_mapping_; }
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);
int64_t batch_dim() const { return batch_dim_; }
void set_batch_dim(int64_t batch_dim);
const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }
void set_dynamic_dims(const std::vector<bool>& dynamic_dims);
void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);
const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated);
bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1 && annotated_.at(name) == true;
}
void mark_annotated(const std::string& name);
void clear_annotated() { annotated_.clear(); }
bool verify_process_mesh(const ProcessMesh& process_mesh) const;
bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const;
bool verify_batch_dim(int64_t dim,
const std::vector<int64_t>& tensor_shape) const;
bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify(const std::vector<int64_t>& tensor_shape) const;
// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
void from_proto(const TensorDistAttrProto& proto);
TensorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private:
static std::vector<std::string> fields_;
ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_;
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
};
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
os << obj.to_string();
return os;
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs);
inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
return !operator==(lhs, rhs);
}
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
......@@ -14,10 +14,10 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -32,7 +32,7 @@ void DistributedMapper::set_process_id_to_device_ids(
PADDLE_ENFORCE_GE(
item.first,
0,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The process id %d must be greater than or equal to 0.",
item.first));
std::string device_mesh_name = item.second.first;
......@@ -40,14 +40,14 @@ void DistributedMapper::set_process_id_to_device_ids(
PADDLE_ENFORCE_EQ(
device_meshes_.count(device_mesh_name),
1,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"Cannot find the device mesh %d in device_mesh ids [%s].",
device_mesh_name,
str_join(device_mesh_names)));
PADDLE_ENFORCE_EQ(
has_duplicates(device_ids),
false,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The mapped device ids [%s] of process_mesh %d must be unique.",
str_join(device_ids),
item.first));
......@@ -60,7 +60,7 @@ void DistributedMapper::set_process_id_to_device_ids(
PADDLE_ENFORCE_EQ(
found,
true,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The device id %d cannot be find in the device mesh [%s].",
device_id,
str_join(cur_device_ids)));
......@@ -143,4 +143,4 @@ bool operator==(const DistributedMapper& lhs, const DistributedMapper& rhs) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -15,11 +15,11 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -70,4 +70,4 @@ inline std::ostream& operator<<(std::ostream& os,
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -15,10 +15,10 @@ limitations under the License. */
#include <algorithm>
#include <iterator>
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -30,27 +30,27 @@ ProcessMesh::ProcessMesh(const std::vector<int64_t> &shape,
PADDLE_ENFORCE_EQ(
size,
process_ids.size(),
platform::errors::InvalidArgument("The size of this process mesh must be "
"equal to the size of its process ids.",
size,
process_ids.size()));
errors::InvalidArgument("The size of this process mesh must be "
"equal to the size of its process ids.",
size,
process_ids.size()));
PADDLE_ENFORCE_EQ(
has_duplicates(process_ids),
false,
platform::errors::InvalidArgument("The process ids [%s] must be unique.",
str_join(process_ids_)));
errors::InvalidArgument("The process ids [%s] must be unique.",
str_join(process_ids_)));
process_ids_ = process_ids;
PADDLE_ENFORCE_EQ(shape_.size(),
dim_names.size(),
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The size of mesh shape must be equal to the size "
"of the dimension names.",
shape_.size(),
dim_names_.size()));
PADDLE_ENFORCE_EQ(has_duplicates(dim_names),
false,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The names [%s] of each dimension must be unique.",
str_join(dim_names)));
dim_names_ = dim_names;
......@@ -131,4 +131,4 @@ bool operator==(const ProcessMesh &lhs, const ProcessMesh &rhs) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -20,12 +20,12 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -58,7 +58,7 @@ class ProcessMesh {
return shape_[i];
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(errors::InvalidArgument(
"Cannot find the dimension of %s in this process mesh.", dim_name));
}
......@@ -90,4 +90,4 @@ inline bool operator!=(const ProcessMesh& lhs, const ProcessMesh& rhs) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -19,9 +19,9 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {
......@@ -50,7 +50,7 @@ inline int64_t canonical_dim(int dim, int ndim) {
PADDLE_ENFORCE_EQ(
dim >= -ndim && dim < ndim,
true,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"Dimension %d is outside of [-%d, %d).", dim, ndim, ndim));
if (dim < 0) {
return dim + ndim;
......@@ -111,4 +111,4 @@ std::string to_string_with_precision(const T a_value, const int n = 2) {
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册