未验证 提交 b03b4a3c 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the c++ dist attr (#47358)

* [Auto Parallel] Improve the c++ dist attr

* [Auto Parallel] Modify test_program.py

* [Auto Parallel] Add the missiong import
上级 3b219e5e
......@@ -29,36 +29,57 @@ namespace auto_parallel {
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
TensorDistAttr::TensorDistAttr(const VarDesc& tensor)
: tensor_(&tensor), batch_dim_(0) {
TensorDistAttr::TensorDistAttr(const VarDesc& tensor) : tensor_(&tensor) {
VLOG(4) << "[TensorDistAttr constructor] tensor name: " << tensor_->Name();
if (tensor_->GetType() == framework::proto::VarType::READER) return;
if (tensor_->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY) return;
if (tensor_->GetType() == framework::proto::VarType::STEP_SCOPES) return;
tensor_shape_ = tensor_->GetShape();
VLOG(4) << "[TensorDistAttr constructor] tensor shape: "
<< str_join(tensor_shape_);
set_default_dims_mapping();
std::vector<int64_t> tensor_shape = tensor_->GetShape();
for (std::size_t i = 0; i < tensor_shape.size(); ++i) {
for (std::size_t i = 0; i < tensor_shape_.size(); ++i) {
dynamic_dims_.push_back(false);
}
}
TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
if (tensor_ == nullptr) {
tensor_ = dist_attr.tensor();
tensor_ = dist_attr.tensor_;
tensor_shape_ = dist_attr.tensor_shape_;
}
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
if (tensor_ != nullptr) {
VLOG(4) << "[TensorDistAttr copy constructor] tensor name: "
<< tensor_->Name() << ", tensro shape: " << str_join(tensor_shape_);
} else {
VLOG(4) << "[TensorDistAttr copy constructor] tensor name: None"
<< ", tensro shape: " << str_join(tensor_shape_);
}
copy_from(dist_attr);
}
TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (tensor_ == nullptr) {
tensor_ = dist_attr.tensor();
tensor_ = dist_attr.tensor_;
tensor_shape_ = dist_attr.tensor_shape_;
}
if (tensor_ != nullptr) {
VLOG(4) << "[TensorDistAttr assign constructor] tensor name: "
<< tensor_->Name() << ", tensro shape: " << str_join(tensor_shape_);
} else {
VLOG(4) << "[TensorDistAttr assign constructor] tensor name: None"
<< ", tensro shape: " << str_join(tensor_shape_);
}
copy_from(dist_attr);
return *this;
}
void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
return *this;
}
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
......@@ -84,9 +105,9 @@ void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
true,
platform::errors::InvalidArgument(
"Wrong batch_dim %d in this distributed attribute.", batch_dim));
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
int64_t canonical_batch_dim = canonical_dim(batch_dim, tensor_shape.size());
if (tensor_ != nullptr && tensor_shape_.size() > 0) {
int64_t canonical_batch_dim =
canonical_dim(batch_dim, tensor_shape_.size());
batch_dim_ = canonical_batch_dim;
} else {
batch_dim_ = batch_dim;
......@@ -113,8 +134,7 @@ void TensorDistAttr::set_annotated(
void TensorDistAttr::set_default_dims_mapping() {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
dims_mapping_ = std::vector<int64_t>(tensor_shape_.size(), -1);
}
}
......@@ -127,6 +147,8 @@ void TensorDistAttr::annotate(const std::string& name) {
bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
VLOG(4) << "[TensorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping < -1 || dim_mapping >= process_mesh_.ndim()) {
......@@ -139,11 +161,9 @@ bool TensorDistAttr::verify_process_mesh(
bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping) const {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
if (dims_mapping.size() != tensor_shape.size()) {
return false;
}
VLOG(4) << "[TensorDistAttr verify_dims_mapping] " << str_join(dims_mapping);
if (dims_mapping.size() != tensor_shape_.size()) {
return false;
}
std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) {
......@@ -168,9 +188,9 @@ bool TensorDistAttr::verify_dims_mapping(
}
bool TensorDistAttr::verify_batch_dim(int64_t dim) const {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
int64_t ndim = tensor_shape.size();
VLOG(4) << "[TensorDistAttr verify_batch_dim] " << dim;
int64_t ndim = tensor_shape_.size();
if (tensor_ != nullptr && ndim > 0) {
if (dim < 0) {
dim = dim + ndim;
}
......@@ -183,17 +203,16 @@ bool TensorDistAttr::verify_batch_dim(int64_t dim) const {
bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims) const {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
if (dynamic_dims.size() != tensor_shape.size()) {
return false;
}
VLOG(4) << "[TensorDistAttr verify_dynamic_dims] " << str_join(dynamic_dims);
if (dynamic_dims.size() != tensor_shape_.size()) {
return false;
}
return true;
}
bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[TensorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
......@@ -204,9 +223,6 @@ bool TensorDistAttr::verify_annotated(
}
bool TensorDistAttr::verify() const {
if (tensor_ == nullptr) {
return false;
}
if (!verify_process_mesh(process_mesh_)) {
return false;
}
......@@ -240,19 +256,17 @@ std::string TensorDistAttr::to_string() const {
return dist_str;
}
TensorDistAttr TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
TensorDistAttr dist_attr;
dist_attr.process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dist_attr.dims_mapping_.resize(proto.dims_mapping_size());
void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dims_mapping_.resize(proto.dims_mapping_size());
for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) {
dist_attr.dims_mapping_[i] = proto.dims_mapping(i);
dims_mapping_[i] = proto.dims_mapping(i);
}
dist_attr.batch_dim_ = proto.batch_dim();
dist_attr.dynamic_dims_.resize(proto.dynamic_dims_size());
batch_dim_ = proto.batch_dim();
dynamic_dims_.resize(proto.dynamic_dims_size());
for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) {
dist_attr.dynamic_dims_[i] = proto.dynamic_dims(i);
dynamic_dims_[i] = proto.dynamic_dims(i);
}
return dist_attr;
}
TensorDistAttrProto TensorDistAttr::to_proto() const {
......@@ -268,6 +282,26 @@ TensorDistAttrProto TensorDistAttr::to_proto() const {
return proto;
}
std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
platform::errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
return data;
}
void TensorDistAttr::parse_from_string(const std::string& data) {
TensorDistAttrProto proto;
PADDLE_ENFORCE_EQ(proto.ParseFromString(data),
true,
platform::errors::InvalidArgument(
"Failed to parse tensor dist attr from string."));
from_proto(proto);
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
......@@ -288,52 +322,103 @@ std::vector<std::string> OperatorDistAttr::fields_{
"process_mesh", "impl_type", "impl_idx"};
OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) {
VLOG(4) << "[OperatorDistAttr constructor] op type: " << op_->Type();
initialize();
}
OperatorDistAttr::OperatorDistAttr(const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
if (op_ != nullptr) {
VLOG(4) << "[OperatorDistAttr copy constructor] op type: " << op_->Type();
} else {
VLOG(4) << "[OperatorDistAttr copy constructor] op type: None";
}
initialize();
copy_from(dist_attr);
}
OperatorDistAttr& OperatorDistAttr::operator=(
const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
if (op_ != nullptr) {
VLOG(4) << "[OperatorDistAttr assign constructor] op type: " << op_->Type();
} else {
VLOG(4) << "[OperatorDistAttr assign constructor] op type: None";
}
initialize();
copy_from(dist_attr);
return *this;
}
void OperatorDistAttr::initialize() {
if (op_ == nullptr) return;
for (std::string name : op_->InputArgumentNames()) {
VarDesc* input = op_->Block()->FindVarRecursive(name);
VLOG(4) << "[OperatorDistAttr create input dist attr] " << name;
inputs_[name] = input;
input_dist_attrs_[name] = TensorDistAttr(*input);
if (input == nullptr || op_->Type() == "create_py_reader") {
input_dist_attrs_[name] = TensorDistAttr();
} else {
input_dist_attrs_[name] = TensorDistAttr(*input);
}
}
for (std::string name : op_->OutputArgumentNames()) {
VarDesc* output = op_->Block()->FindVarRecursive(name);
VLOG(4) << "[OperatorDistAttr create output dist attr] " << name;
outputs_[name] = output;
output_dist_attrs_[name] = TensorDistAttr(*output);
if (output == nullptr) {
output_dist_attrs_[name] = TensorDistAttr();
} else {
output_dist_attrs_[name] = TensorDistAttr(*output);
}
}
impl_type_ = "default";
impl_idx_ = 0;
}
OperatorDistAttr::OperatorDistAttr(const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
for (const auto& item : dist_attr.input_dist_attrs()) {
set_input_dist_attr(item.first, item.second);
}
for (const auto& item : dist_attr.output_dist_attrs()) {
set_output_dist_attr(item.first, item.second);
}
void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
set_input_dist_attrs(dist_attr.input_dist_attrs());
set_output_dist_attrs(dist_attr.output_dist_attrs());
set_process_mesh(dist_attr.process_mesh());
set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx());
set_annotated(dist_attr.annotated());
impl_type_ = dist_attr.impl_type();
impl_idx_ = dist_attr.impl_idx();
}
OperatorDistAttr& OperatorDistAttr::operator=(
const OperatorDistAttr& dist_attr) {
void OperatorDistAttr::set_input_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
for (const auto& item : dist_attr.input_dist_attrs()) {
set_input_dist_attr(item.first, item.second);
for (const auto& item : dist_attrs) {
set_input_dist_attr(item.first, item.second);
}
} else {
for (const auto& item : input_dist_attrs_) {
if (dist_attrs.count(item.first) == 1) {
set_input_dist_attr(item.first, dist_attrs.at(item.first));
}
}
}
for (const auto& item : dist_attr.output_dist_attrs()) {
set_output_dist_attr(item.first, item.second);
}
void OperatorDistAttr::set_output_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs) {
if (op_ == nullptr) {
for (const auto& item : dist_attrs) {
set_output_dist_attr(item.first, item.second);
}
} else {
for (const auto& item : output_dist_attrs_) {
if (dist_attrs.count(item.first) == 1) {
set_output_dist_attr(item.first, dist_attrs.at(item.first));
}
}
}
set_process_mesh(dist_attr.process_mesh());
set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx());
set_annotated(dist_attr.annotated());
return *this;
}
void OperatorDistAttr::set_input_dist_attr(const std::string& name,
......@@ -341,8 +426,10 @@ void OperatorDistAttr::set_input_dist_attr(const std::string& name,
PADDLE_ENFORCE_EQ(
verify_input_dist_attr(name, dist_attr),
true,
platform::errors::InvalidArgument(
"Wrong dist_attr %s for %s.", dist_attr.to_string(), name));
platform::errors::InvalidArgument("Wrong dist_attr %s for %s. %s",
dist_attr.to_string(),
name,
to_string()));
input_dist_attrs_[name] = dist_attr;
// Make sure the process mesh of input be same as that of the op
input_dist_attrs_[name].set_process_mesh(process_mesh_);
......@@ -394,8 +481,30 @@ void OperatorDistAttr::set_annotated(
annotated_ = annotated;
}
const std::vector<int64_t>& OperatorDistAttr::input_dims_mapping(
const std::string& name) const {
return input_dist_attr(name).dims_mapping();
}
void OperatorDistAttr::set_input_dims_mapping(
const std::string& name, const std::vector<int64_t>& dims_mapping) {
input_dist_attr(name).set_dims_mapping(dims_mapping);
}
const std::vector<int64_t>& OperatorDistAttr::output_dims_mapping(
const std::string& name) {
return output_dist_attr(name).dims_mapping();
}
void OperatorDistAttr::set_output_dims_mapping(
const std::string& name, const std::vector<int64_t>& dims_mapping) {
output_dist_attr(name).set_dims_mapping(dims_mapping);
}
bool OperatorDistAttr::verify_input_dist_attr(
const std::string& name, const TensorDistAttr& dist_attr) const {
VLOG(4) << "[OperatorDistAttr verify_input_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify()) {
return false;
}
......@@ -414,6 +523,8 @@ bool OperatorDistAttr::verify_input_dist_attr(
bool OperatorDistAttr::verify_output_dist_attr(
const std::string& name, const TensorDistAttr& dist_attr) const {
VLOG(4) << "[OperatorDistAttr verify_output_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify()) {
return false;
}
......@@ -432,6 +543,8 @@ bool OperatorDistAttr::verify_output_dist_attr(
bool OperatorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
VLOG(4) << "[OperatorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (process_mesh != process_mesh_) {
return false;
}
......@@ -450,6 +563,7 @@ bool OperatorDistAttr::verify_process_mesh(
bool OperatorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[OperatorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
......@@ -457,11 +571,15 @@ bool OperatorDistAttr::verify_annotated(
}
}
for (auto& item : input_dist_attrs_) {
VLOG(4) << "[OperatorDistAttr verify_annotated input] "
<< str_join(item.second.annotated());
if (!item.second.verify_annotated(item.second.annotated())) {
return false;
}
}
for (auto& item : output_dist_attrs_) {
VLOG(4) << "[OperatorDistAttr verify_annotated output] "
<< str_join(item.second.annotated());
if (!item.second.verify_annotated(item.second.annotated())) {
return false;
}
......@@ -501,6 +619,44 @@ bool OperatorDistAttr::verify() const {
return true;
}
void OperatorDistAttr::rename_input(const std::string& old_name,
const std::string& new_name) {
for (auto& item : input_dist_attrs_) {
if (item.first == old_name) {
VarDesc* new_input = op_->Block()->FindVarRecursive(new_name);
inputs_[new_name] = new_input;
if (new_input == nullptr) {
input_dist_attrs_[new_name] = TensorDistAttr();
} else {
input_dist_attrs_[new_name] = TensorDistAttr(*new_input);
input_dist_attrs_[new_name].copy_from(input_dist_attrs_[old_name]);
}
inputs_.erase(old_name);
input_dist_attrs_.erase(old_name);
break;
}
}
}
void OperatorDistAttr::rename_output(const std::string& old_name,
const std::string& new_name) {
for (auto& item : output_dist_attrs_) {
if (item.first == old_name) {
VarDesc* new_output = op_->Block()->FindVarRecursive(new_name);
outputs_[new_name] = new_output;
if (new_output == nullptr) {
output_dist_attrs_[new_name] = TensorDistAttr();
} else {
output_dist_attrs_[new_name] = TensorDistAttr(*new_output);
output_dist_attrs_[new_name].copy_from(output_dist_attrs_[old_name]);
}
outputs_.erase(old_name);
output_dist_attrs_.erase(old_name);
break;
}
}
}
std::string OperatorDistAttr::to_string() const {
std::string str;
if (op_ != nullptr) {
......@@ -525,23 +681,22 @@ std::string OperatorDistAttr::to_string() const {
return str;
}
OperatorDistAttr OperatorDistAttr::from_proto(
const OperatorDistAttrProto& proto) {
OperatorDistAttr dist_attr;
void OperatorDistAttr::from_proto(const OperatorDistAttrProto& proto) {
for (int64_t i = 0; i < proto.input_dist_attrs_size(); ++i) {
dist_attr.input_dist_attrs_[proto.input_dist_attrs(i).name()] =
TensorDistAttr::from_proto(
proto.input_dist_attrs(i).tensor_dist_attr());
TensorDistAttr dist_attr;
std::string name = proto.input_dist_attrs(i).name();
dist_attr.from_proto(proto.input_dist_attrs(i).tensor_dist_attr());
input_dist_attrs_[name] = dist_attr;
}
for (int64_t i = 0; i < proto.output_dist_attrs_size(); ++i) {
dist_attr.output_dist_attrs_[proto.output_dist_attrs(i).name()] =
TensorDistAttr::from_proto(
proto.output_dist_attrs(i).tensor_dist_attr());
TensorDistAttr dist_attr;
std::string name = proto.output_dist_attrs(i).name();
dist_attr.from_proto(proto.output_dist_attrs(i).tensor_dist_attr());
output_dist_attrs_[name] = dist_attr;
}
dist_attr.process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dist_attr.impl_type_ = proto.impl_type();
dist_attr.impl_idx_ = proto.impl_idx();
return dist_attr;
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
impl_type_ = proto.impl_type();
impl_idx_ = proto.impl_idx();
}
OperatorDistAttrProto OperatorDistAttr::to_proto() const {
......@@ -562,6 +717,26 @@ OperatorDistAttrProto OperatorDistAttr::to_proto() const {
return proto;
}
std::string OperatorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
platform::errors::InvalidArgument(
"Failed to serialize op dist attr to string."));
return data;
}
void OperatorDistAttr::parse_from_string(const std::string& data) {
OperatorDistAttrProto proto;
PADDLE_ENFORCE_EQ(proto.ParseFromString(data),
true,
platform::errors::InvalidArgument(
"Failed to parse op dist attr from string."));
from_proto(proto);
}
bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
......
......@@ -56,6 +56,8 @@ class TensorDistAttr {
TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
void copy_from(const TensorDistAttr& dist_attr);
const VarDesc* tensor() const { return tensor_; }
const ProcessMesh& process_mesh() const { return process_mesh_; }
......@@ -101,16 +103,21 @@ class TensorDistAttr {
// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
static TensorDistAttr from_proto(const TensorDistAttrProto& proto);
void from_proto(const TensorDistAttrProto& proto);
TensorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private:
static std::vector<std::string> fields_;
const VarDesc* tensor_{nullptr};
std::vector<int64_t> tensor_shape_;
ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_;
int64_t batch_dim_;
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
};
......@@ -136,6 +143,10 @@ class OperatorDistAttr {
OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr);
void initialize();
void copy_from(const OperatorDistAttr& dist_attr);
const OpDesc* op() const { return op_; }
const VarDesc& input(const std::string& name) const {
......@@ -150,10 +161,16 @@ class OperatorDistAttr {
return input_dist_attrs_;
}
void set_input_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs);
const std::map<std::string, TensorDistAttr>& output_dist_attrs() const {
return output_dist_attrs_;
}
void set_output_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs);
const TensorDistAttr& input_dist_attr(const std::string& name) const {
return input_dist_attrs_.at(name);
}
......@@ -198,6 +215,16 @@ class OperatorDistAttr {
void annotate(const std::string& name);
const std::vector<int64_t>& input_dims_mapping(const std::string& name) const;
void set_input_dims_mapping(const std::string& name,
const std::vector<int64_t>& dims_mapping);
const std::vector<int64_t>& output_dims_mapping(const std::string& name);
void set_output_dims_mapping(const std::string& name,
const std::vector<int64_t>& dims_mapping);
bool verify_input_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) const;
......@@ -210,13 +237,21 @@ class OperatorDistAttr {
bool verify() const;
void rename_input(const std::string& old_name, const std::string& new_name);
void rename_output(const std::string& old_name, const std::string& new_name);
// OperatorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
static OperatorDistAttr from_proto(const OperatorDistAttrProto& proto);
void from_proto(const OperatorDistAttrProto& proto);
OperatorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private:
static std::vector<std::string> fields_;
const OpDesc* op_{nullptr};
......
......@@ -81,10 +81,9 @@ TEST(DistAttr, ctor) {
x_sstream << x_dist_attr;
EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string());
auto x_proto = x_dist_attr.to_proto();
TensorDistAttr new_x_dist_attr = TensorDistAttr::from_proto(x_proto);
TensorDistAttr new_x_dist_attr(*x);
new_x_dist_attr.from_proto(x_proto);
EXPECT_EQ(x_dist_attr, new_x_dist_attr);
// new_x_dist_attr is not valid since it does not bind to an var_desc
EXPECT_EQ(new_x_dist_attr.verify(), false);
y_dist_attr.set_process_mesh(process_mesh);
y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, 0}));
......@@ -139,10 +138,9 @@ TEST(DistAttr, ctor) {
mul_sstream << mul_dist_attr;
EXPECT_EQ(mul_sstream.str(), mul_dist_attr.to_string());
auto mul_proto = mul_dist_attr.to_proto();
OperatorDistAttr new_mul_dist_attr = OperatorDistAttr::from_proto(mul_proto);
OperatorDistAttr new_mul_dist_attr(*op);
new_mul_dist_attr.from_proto(mul_proto);
EXPECT_EQ(mul_dist_attr, new_mul_dist_attr);
// new_mul_dist_attr is not valid since it does not bind to an op_desc
EXPECT_EQ(new_mul_dist_attr.verify(), false);
}
} // namespace auto_parallel
......
......@@ -82,7 +82,7 @@ inline std::string str_join(std::map<std::string, bool> const& elements,
for (const auto& item : elements) {
str += item.first + ": " + std::to_string(item.second) + ",";
}
return str.substr(0, str.size() - 2);
return str.substr(0, str.size() - 1);
}
// Refer to https://stackoverflow.com/a/46931770
......
......@@ -119,8 +119,30 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
}
default:
PADDLE_THROW(platform::errors::Unavailable("Unsupport attribute type %d.",
attr_desc.type()));
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported attribute type %d.", attr_desc.type()));
}
return paddle::blank();
}
Attribute GetAttrValue(const proto::VarDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case proto::AttrType::INT: {
return attr_desc.i();
}
case proto::AttrType::STRING: {
return attr_desc.s();
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
return val;
}
default:
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported attribute type %d.", attr_desc.type()));
}
return paddle::blank();
}
......
......@@ -37,6 +37,8 @@ paddle::any GetAttrValue(const Attribute& attr);
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
Attribute GetAttrValue(const proto::VarDesc::Attr& attr_desc);
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
......
......@@ -885,6 +885,10 @@ void OpDesc::RenameOutput(const std::string &old_name,
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
}
if (dist_attr_) {
dist_attr_->rename_output(old_name, new_name);
}
need_update_ = true;
}
......@@ -900,6 +904,10 @@ void OpDesc::RenameInput(const std::string &old_name,
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
}
if (dist_attr_) {
dist_attr_->rename_input(old_name, new_name);
}
need_update_ = true;
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/var_desc.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -28,6 +29,16 @@ VarDesc::VarDesc(const VarDesc &other)
if (other.dist_attr_) {
dist_attr_.reset(new TensorDistAttr(*other.dist_attr_));
}
need_updated_ = true;
}
VarDesc::VarDesc(const proto::VarDesc &desc) : desc_(desc) {
// Restore attrs_ for auto parallel
for (const proto::VarDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr);
}
need_updated_ = true;
}
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
......@@ -348,14 +359,15 @@ void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
bool valid = attr_type == proto::AttrType::INT ||
attr_type == proto::AttrType::STRING ||
attr_type == proto::AttrType::INTS;
PADDLE_ENFORCE_EQ(
valid,
true,
platform::errors::InvalidArgument("The value for attr (%s) must be "
"one of list or int or string.",
name));
PADDLE_ENFORCE_EQ(valid,
true,
platform::errors::InvalidArgument(
"The value for attr (%s) must be "
"one of int, string, list of int for now.",
name));
this->attrs_[name] = v;
need_updated_ = true;
}
Attribute VarDesc::GetAttr(const std::string &name) const {
......@@ -367,6 +379,63 @@ Attribute VarDesc::GetAttr(const std::string &name) const {
return it->second;
}
struct SetVarAttrDescVisitor {
explicit SetVarAttrDescVisitor(proto::VarDesc::Attr *attr) : attr_(attr) {}
mutable proto::VarDesc::Attr *attr_;
template <typename T>
void operator()(T &&v) {
using U = std::decay_t<decltype(v)>;
if (std::is_same<U, int>::value) {
set_attr_value(v);
} else if (std::is_same<U, std::string>::value) {
set_attr_value(v);
} else if (std::is_same<U, std::vector<int>>::value) {
set_attr_value(v);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method of SetAttrDescVisitor object."));
}
}
// This template is used to pass the compilation
template <typename U>
void set_attr_value(U v);
void set_attr_value(int v) { attr_->set_i(v); }
void set_attr_value(const std::string &v) { attr_->set_s(v); }
void set_attr_value(const std::vector<int> &v) {
VectorToRepeated(v, attr_->mutable_ints());
}
};
// Only need to flush the attrs for auto parallel for now
void VarDesc::Flush() {
VLOG(4) << "Flush "
<< " " << Name() << " " << need_updated_;
if (need_updated_) {
this->desc_.mutable_attrs()->Clear();
std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
attrs_.end()};
std::sort(
sorted_attrs.begin(),
sorted_attrs.end(),
[](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<proto::AttrType>(attr.second.index() - 1));
SetVarAttrDescVisitor visitor(attr_desc);
paddle::visit(visitor, attr.second);
}
need_updated_ = false;
}
}
TensorDistAttr *VarDesc::MutableDistAttr() {
// If dist_attr_ is nullptr, construct a new one and return.
if (dist_attr_) {
......@@ -375,12 +444,14 @@ TensorDistAttr *VarDesc::MutableDistAttr() {
dist_attr_.reset(new TensorDistAttr(*this));
return dist_attr_.get();
}
need_updated_ = true;
}
void VarDesc::SetDistAttr(const TensorDistAttr &dist_attr) {
// Make sure this dist attr be created
MutableDistAttr();
*dist_attr_ = dist_attr;
need_updated_ = true;
}
bool operator==(const VarDesc &left, const VarDesc &right) {
......
......@@ -71,9 +71,7 @@ class VarDesc {
need_updated_ = true;
}
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {
// need_updated_ = true;
}
explicit VarDesc(const proto::VarDesc &desc);
// Explicitly implement the copy constructor for auto parallel
VarDesc(const VarDesc &other);
......@@ -90,7 +88,7 @@ class VarDesc {
}
proto::VarDesc *Proto() {
need_updated_ = true;
Flush(); // Only flush attrs for auto parallel
return &desc_;
}
......@@ -194,6 +192,8 @@ class VarDesc {
bool NeedUpdate() const { return need_updated_; }
void SetNeedUpdate(bool need) { need_updated_ = need; }
void Flush();
// The following methods are only used for auto parallel.
uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; }
......
......@@ -165,6 +165,12 @@ void BindAutoParallel(py::module *m) {
&DeviceMesh::dim_size))
.def(py::self == py::self)
.def(py::self != py::self)
.def(
"__deepcopy__",
[](const TensorDistAttr &self, py::dict) {
return TensorDistAttr(self);
},
py::arg("memo"))
.def("__str__", &DeviceMesh::to_string);
py::class_<TensorDistAttr>(*m, "TensorDistAttr")
......@@ -182,9 +188,17 @@ void BindAutoParallel(py::module *m) {
.def_property("dynamic_dims",
&TensorDistAttr::dynamic_dims,
&TensorDistAttr::set_dynamic_dims)
.def_property("annotated",
&TensorDistAttr::annotated,
&TensorDistAttr::set_annotated)
.def("is_annotated", &TensorDistAttr::is_annotated)
.def("annotate", &TensorDistAttr::annotate)
.def("verify", &TensorDistAttr::verify)
.def("serialize_to_string",
[](TensorDistAttr &self) {
return py::bytes(self.serialize_to_string());
})
.def("parse_from_string", &TensorDistAttr::parse_from_string)
.def(py::self == py::self)
.def(py::self != py::self)
.def("__str__", &TensorDistAttr::to_string);
......@@ -201,20 +215,23 @@ void BindAutoParallel(py::module *m) {
.def_property("impl_idx",
&OperatorDistAttr::impl_idx,
&OperatorDistAttr::set_impl_idx)
.def_property("annotated",
&OperatorDistAttr::annotated,
&OperatorDistAttr::set_annotated)
.def_property("inputs_dist_attrs",
&OperatorDistAttr::input_dist_attrs,
&OperatorDistAttr::set_input_dist_attrs)
.def_property("outputs_dist_attrs",
&OperatorDistAttr::output_dist_attrs,
&OperatorDistAttr::set_output_dist_attrs)
.def("input", &OperatorDistAttr::input)
.def("output", &OperatorDistAttr::output)
.def("input_dist_attrs",
&OperatorDistAttr::input_dist_attrs,
py::return_value_policy::reference)
.def("output_dist_attrs",
&OperatorDistAttr::output_dist_attrs,
py::return_value_policy::reference)
.def("input_dist_attr",
.def("get_input_dist_attr",
static_cast<TensorDistAttr &(
OperatorDistAttr::*)(const std::string &)>(
&OperatorDistAttr::input_dist_attr),
py::return_value_policy::reference)
.def("output_dist_attr",
.def("get_output_dist_attr",
static_cast<TensorDistAttr &(
OperatorDistAttr::*)(const std::string &)>(
&OperatorDistAttr::output_dist_attr),
......@@ -223,9 +240,25 @@ void BindAutoParallel(py::module *m) {
.def("set_output_dist_attr", &OperatorDistAttr::set_output_dist_attr)
.def("is_annotated", &OperatorDistAttr::is_annotated)
.def("annotate", &OperatorDistAttr::annotate)
.def("get_input_dims_mapping", &OperatorDistAttr::input_dims_mapping)
.def("set_input_dims_mapping", &OperatorDistAttr::set_input_dims_mapping)
.def("get_output_dims_mapping", &OperatorDistAttr::output_dims_mapping)
.def("set_output_dims_mapping",
&OperatorDistAttr::set_output_dims_mapping)
.def("verify", &OperatorDistAttr::verify)
.def("serialize_to_string",
[](OperatorDistAttr &self) {
return py::bytes(self.serialize_to_string());
})
.def("parse_from_string", &OperatorDistAttr::parse_from_string)
.def(py::self == py::self)
.def(py::self != py::self)
.def(
"__deepcopy__",
[](const OperatorDistAttr &self, py::dict) {
return OperatorDistAttr(self);
},
py::arg("memo"))
.def("__str__", &OperatorDistAttr::to_string);
}
......
......@@ -21,8 +21,10 @@ from paddle.distributed.passes import PassContext
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
from .utils import _copy_dist_attr_to_cpp
from .utils import is_loss_grad_op
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None
......@@ -76,6 +78,7 @@ class DistributedContext:
self._serial_optimizer = None
self._serial_feed_vars = {}
self._serial_fetch_vars = {}
self._lr_optimizer = None # record the optimzier holding lr_scheduler
# Data members related to the program
self._dist_tensors_for_program = {}
......@@ -392,7 +395,7 @@ class DistributedContext:
if dist:
self._restore_dist_info(dist_mode)
def initialize(self, with_graph=True):
def initialize(self, with_graph=True, with_cpp=False):
if not self._is_initialized:
if not self._serial_main_program:
if self._original_serial_main_program:
......@@ -425,6 +428,10 @@ class DistributedContext:
self._ops_ids = list(self._dist_ops_for_program.keys())
self._is_initialized = True
# TODO: This will be removed in the future
if with_cpp:
_copy_dist_attr_to_cpp(self)
if with_graph:
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
......@@ -597,7 +604,11 @@ class DistributedContext:
tensor
)
if default_dist_tensor and default_ctx is not self:
self.add_dist_tensor_for_program(default_dist_tensor)
dist_tensor = DistributedTensor(tensor)
dist_tensor.dist_attr = copy.deepcopy(
default_dist_tensor.dist_attr
)
self.add_dist_tensor_for_program(dist_tensor)
current_dist_tensor = self.get_dist_tensor_for_program(tensor)
if current_dist_tensor is None:
dist_tensor = DistributedTensor(tensor)
......@@ -606,7 +617,9 @@ class DistributedContext:
# Copy the distributed operators in the default context
default_dist_op = default_ctx.get_dist_op_for_program(op)
if default_dist_op and default_ctx is not self:
self.add_dist_op_for_program(default_dist_op)
dist_op = DistributedOperator(op)
dist_op.dist_attr = copy.deepcopy(default_dist_op.dist_attr)
self.add_dist_op_for_program(dist_op)
current_dist_op = self.get_dist_op_for_program(op)
if current_dist_op is None:
dist_op = DistributedOperator(op)
......
......@@ -1907,3 +1907,120 @@ def validate_opt(optimizer):
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
py_process_mesh = py_dist_attr.process_mesh
if py_process_mesh is not None:
cpp_dist_attr.process_mesh = core.ProcessMesh(
py_process_mesh.shape,
py_process_mesh.process_ids,
["d" + str(i) for i in range(len(py_process_mesh.shape))],
)
cpp_dist_attr.dims_mapping = py_dist_attr.dims_mapping
cpp_dist_attr.annotated = py_dist_attr._is_annotated
def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh
if not cpp_process_mesh.empty():
py_dist_attr.process_mesh = ProcessMesh(
shape=cpp_process_mesh.shape,
process_ids=cpp_process_mesh.process_ids,
)
py_dist_attr.dims_mapping = cpp_dist_attr.dims_mapping
py_dist_attr._is_annotated = cpp_dist_attr.annotated
def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
py_process_mesh = py_dist_attr.process_mesh
if py_process_mesh is not None:
cpp_dist_attr.process_mesh = core.ProcessMesh(
py_process_mesh.shape,
py_process_mesh.process_ids,
["d" + str(i) for i in range(len(py_process_mesh.shape))],
)
cpp_dist_attr.impl_type = py_dist_attr.impl_type
cpp_dist_attr.impl_idx = py_dist_attr.impl_idx
cpp_dist_attr.annotated = py_dist_attr._is_annotated
for name, py_tensor_dist_attr in py_dist_attr.inputs_dist_attrs.items():
cpp_tensor_dist_attr = cpp_dist_attr.get_input_dist_attr(name)
_copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr)
for name, py_tensor_dist_attr in py_dist_attr.outputs_dist_attrs.items():
cpp_tensor_dist_attr = cpp_dist_attr.get_output_dist_attr(name)
_copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr)
def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh
if not cpp_process_mesh.empty():
py_dist_attr.process_mesh = ProcessMesh(
shape=cpp_process_mesh.shape,
process_ids=cpp_process_mesh.process_ids,
)
py_dist_attr.impl_type = cpp_dist_attr.impl_type
py_dist_attr.impl_idx = cpp_dist_attr.impl_idx
py_dist_attr._is_annotated = cpp_dist_attr.annotated
py_dist_attr.op_type = cpp_dist_attr.op.type()
for name, cpp_tensor_dist_attr in cpp_dist_attr.inputs_dist_attrs.items():
py_tensor_dist_attr = py_dist_attr.get_input_dist_attr(name)
_copy_tensor_dist_attr_from_cpp(
cpp_tensor_dist_attr, py_tensor_dist_attr
)
for name, cpp_tensor_dist_attr in cpp_dist_attr.outputs_dist_attrs.items():
py_tensor_dist_attr = py_dist_attr.get_output_dist_attr(name)
_copy_tensor_dist_attr_from_cpp(
cpp_tensor_dist_attr, py_tensor_dist_attr
)
def _copy_dist_attr_to_cpp(dist_context):
for dist_tensor in dist_context._dist_tensors_for_program.values():
_copy_tensor_dist_attr_to_cpp(
dist_tensor.serial_tensor.dist_attr, dist_tensor.dist_attr
)
for dist_op in dist_context._dist_ops_for_program.values():
_copy_op_dist_attr_to_cpp(
dist_op.serial_op.dist_attr, dist_op.dist_attr
)
def _copy_dist_attr_from_cpp(dist_context):
for dist_tensor in dist_context._dist_tensors_for_program.values():
_copy_tensor_dist_attr_from_cpp(
dist_tensor.serial_tensor.dist_attr, dist_tensor.dist_attr
)
for dist_op in dist_context._dist_ops_for_program.values():
_copy_op_dist_attr_from_cpp(
dist_op.serial_op.dist_attr, dist_op.dist_attr
)
def _copy_dist_attr_to_cpp_for_graph(dist_context):
for node in dist_context.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
py_dist_attr = dist_context.get_tensor_dist_attr_for_graph(node)
cpp_dist_attr = node.var().dist_attr
_copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr)
if node.is_op() and node.op() is not None:
py_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
cpp_dist_attr = node.op().dist_attr
_copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr)
def _copy_dist_attr_from_cpp_for_graph(dist_context):
for node in dist_context.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
py_dist_attr = dist_context.get_tensor_dist_attr_for_graph(node)
cpp_dist_attr = node.var().dist_attr
_copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr)
if node.is_op() and node.op() is not None:
py_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
cpp_dist_attr = node.op().dist_attr
_copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr)
......@@ -2612,7 +2612,7 @@ class Variable(metaclass=VariableMetaClass):
"""Get the names of all attributes defined."""
return self.desc.attr_names()
def _get_attr(self, name):
def attr(self, name):
"""
Get the attribute by name.
......
......@@ -103,6 +103,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_serialization MODULES test_serialization)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_dist_matmul MODULES test_dist_matmul)
py_test_modules(test_process_mesh MODULES test_process_mesh)
......
......@@ -13,15 +13,188 @@
# limitations under the License
import unittest
import copy
import paddle
import numpy as np
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
from paddle.distributed import fleet
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
set_default_distributed_context,
)
from paddle.distributed.auto_parallel.utils import (
_copy_dist_attr_to_cpp,
_copy_dist_attr_from_cpp,
_copy_dist_attr_to_cpp_for_graph,
_copy_dist_attr_from_cpp_for_graph,
)
from paddle.fluid.core import TensorDistAttr
from paddle.fluid.core import OperatorDistAttr
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]], dim_names=['x', 'y'])
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range
)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
process_mesh=_g_process_mesh[0],
shard_spec=[None, 'y'],
)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
process_mesh=_g_process_mesh[1],
shard_spec=['y', None],
)
out = self.linear1(out)
return out
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1],
)
yield batch_input, batch_label
return __reader__
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32'
)
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False
)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places()
)
# data dist_attr
auto.shard_tensor(
input, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
auto.shard_tensor(
label, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_start(input)
mlp_mid = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_mid(pred)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_end(pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
feed_vars = {"inputs": [input], "labels": [label]}
fetch_vars = {"loss": [loss]}
return (
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
)
class TestDistAttr(unittest.TestCase):
def test_tensor_dist_attr_ctor(self):
......@@ -102,23 +275,25 @@ class TestDistAttr(unittest.TestCase):
op_dist_attr.set_output_dist_attr(output.name, output_dist_attr)
self.assertEqual(op_dist_attr.process_mesh, process_mesh)
self.assertEqual(
op_dist_attr.input_dist_attr(input.name).process_mesh, process_mesh
op_dist_attr.get_input_dist_attr(input.name).process_mesh,
process_mesh,
)
self.assertEqual(
op_dist_attr.input_dist_attr(input1.name).process_mesh, process_mesh
op_dist_attr.get_input_dist_attr(input1.name).process_mesh,
process_mesh,
)
self.assertEqual(
op_dist_attr.output_dist_attr(output.name).process_mesh,
op_dist_attr.get_output_dist_attr(output.name).process_mesh,
process_mesh,
)
self.assertEqual(
op_dist_attr.input_dist_attr(input.name).dims_mapping, [0, -1]
op_dist_attr.get_input_dist_attr(input.name).dims_mapping, [0, -1]
)
self.assertEqual(
op_dist_attr.input_dist_attr(input1.name).dims_mapping, [-1, 1]
op_dist_attr.get_input_dist_attr(input1.name).dims_mapping, [-1, 1]
)
self.assertEqual(
op_dist_attr.output_dist_attr(output.name).dims_mapping, [0, 1]
op_dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1]
)
self.assertTrue(op_dist_attr.verify())
self.assertTrue(str(op_dist_attr), str(op_dist_attr))
......@@ -126,13 +301,13 @@ class TestDistAttr(unittest.TestCase):
op_dist_attr = OperatorDistAttr(op.desc)
op_dist_attr.process_mesh = process_mesh
# Set the distributed attribute of input directly
input_dist_attr = op_dist_attr.input_dist_attr(input.name)
input_dist_attr = op_dist_attr.get_input_dist_attr(input.name)
input_dist_attr.dims_mapping = [-1, 0]
# Set the distributed attribute of input1 directly
input1_dist_attr = op_dist_attr.input_dist_attr(input1.name)
input1_dist_attr = op_dist_attr.get_input_dist_attr(input1.name)
input1_dist_attr.dims_mapping = [0, -1]
# Set the distributed attribute of output directly
output_dist_attr = op_dist_attr.output_dist_attr(output.name)
output_dist_attr = op_dist_attr.get_output_dist_attr(output.name)
output_dist_attr.dims_mapping = [-1, -1]
self.assertEqual(op_dist_attr.process_mesh, process_mesh)
self.assertEqual(input_dist_attr.process_mesh, process_mesh)
......@@ -171,22 +346,25 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual(op.desc.dist_attr.process_mesh, process_mesh)
self.assertEqual(
op.dist_attr.input_dist_attr(input.name).process_mesh, process_mesh
op.dist_attr.get_input_dist_attr(input.name).process_mesh,
process_mesh,
)
self.assertEqual(
op.dist_attr.input_dist_attr(input1.name).process_mesh, process_mesh
op.dist_attr.get_input_dist_attr(input1.name).process_mesh,
process_mesh,
)
self.assertEqual(
op.dist_attr.input_dist_attr(input.name).dims_mapping, [0, -1]
op.dist_attr.get_input_dist_attr(input.name).dims_mapping, [0, -1]
)
self.assertEqual(
op.dist_attr.input_dist_attr(input.name).dims_mapping, [0, -1]
op.dist_attr.get_input_dist_attr(input.name).dims_mapping, [0, -1]
)
self.assertEqual(
op.desc.dist_attr.input_dist_attr(input1.name).dims_mapping, [-1, 1]
op.desc.dist_attr.get_input_dist_attr(input1.name).dims_mapping,
[-1, 1],
)
self.assertEqual(
op.dist_attr.output_dist_attr(output.name).dims_mapping, [0, 1]
op.dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1]
)
self.assertTrue(op.desc.dist_attr.verify())
self.assertTrue(str(op_dist_attr), str(op_dist_attr))
......@@ -195,5 +373,80 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual(op.desc.dist_attr, OperatorDistAttr(op.desc))
class TestDistAttrConversion(unittest.TestCase):
def test_dist_attr_conversion_for_program(self):
set_default_distributed_context(DistributedContext())
(
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
) = get_program()
dist_context = DistributedContext(
train_program, start_program, optimizer, loss, feed_vars, fetch_vars
)
dist_context.initialize()
original_dist_tensors = copy.deepcopy(
dist_context._dist_tensors_for_program
)
original_dist_ops = copy.deepcopy(dist_context._dist_ops_for_program)
_copy_dist_attr_to_cpp(dist_context)
_copy_dist_attr_from_cpp(dist_context)
for dist_tensor in dist_context._dist_tensors_for_program.values():
original_dist_tensor = original_dist_tensors[
dist_tensor.serial_tensor.desc.original_id()
]
self.assertEqual(
dist_tensor.dist_attr, original_dist_tensor.dist_attr
)
for dist_op in dist_context._dist_ops_for_program.values():
original_dist_op = original_dist_ops[
dist_op.serial_op.desc.original_id()
]
self.assertEqual(dist_op.dist_attr, original_dist_op.dist_attr)
def test_dist_attr_conversion_for_graph(self):
set_default_distributed_context(DistributedContext())
(
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
) = get_program()
dist_context = DistributedContext(
train_program, start_program, optimizer, loss, feed_vars, fetch_vars
)
dist_context.initialize()
original_dist_tensors = copy.deepcopy(
dist_context._dist_tensors_for_graph
)
original_dist_ops = copy.deepcopy(dist_context._dist_ops_for_graph)
_copy_dist_attr_to_cpp_for_graph(dist_context)
_copy_dist_attr_from_cpp_for_graph(dist_context)
for (
node_id,
dist_tensor,
) in dist_context._dist_tensors_for_graph.items():
original_dist_tensor = original_dist_tensors[node_id]
self.assertEqual(
dist_tensor.dist_attr, original_dist_tensor.dist_attr
)
for node_id, dist_op in dist_context._dist_ops_for_graph.items():
original_dist_op = original_dist_ops[node_id]
self.assertEqual(dist_op.dist_attr, original_dist_op.dist_attr)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import unittest
import paddle
import numpy as np
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
from paddle.fluid.framework import Program
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
set_default_distributed_context,
)
from paddle.fluid.core import TensorDistAttr
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]], dim_names=['x', 'y'])
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range
)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
process_mesh=_g_process_mesh[0],
shard_spec=[None, 'y'],
)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
process_mesh=_g_process_mesh[1],
shard_spec=['y', None],
)
out = auto.shard_op(self.linear1, process_mesh=_g_process_mesh)(out)
return out
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1],
)
yield batch_input, batch_label
return __reader__
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32'
)
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False
)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places()
)
# data dist_attr
auto.shard_tensor(
input, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
auto.shard_tensor(
label, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_start(input)
mlp_mid = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_mid(pred)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_end(pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
feed_vars = {"inputs": [input], "labels": [label]}
fetch_vars = {"loss": [loss]}
return (
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
)
class TestDistAttrSerialization(unittest.TestCase):
def test_serialization_tensor(self):
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
input = static.data(name="input", shape=[2, 3], dtype='float32')
dist_attr = input.dist_attr
dist_attr.process_mesh = ProcessMesh([[0, 1, 2], [3, 4, 5]])
dist_attr.dims_mapping = [0, -1]
dist_attr.batch_dim = 1
dist_attr.dynamic_dims = [1, 1]
dist_attr_data = dist_attr.serialize_to_string()
def test_serialization_opearator(self):
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
input = static.data(name="input", shape=[2, 3], dtype='float32')
input1 = static.data(name="input1", shape=[3, 4], dtype='float32')
output = paddle.matmul(input, input1)
op = train_program.current_block().ops[0]
process_mesh = ProcessMesh([[0, 1, 2], [3, 4, 5]])
op_dist_attr = op.dist_attr
op_dist_attr.process_mesh = process_mesh
# Set the distributed attribute of input
input_dist_attr = TensorDistAttr(input.desc)
input_dist_attr.dims_mapping = [0, -1]
op_dist_attr.set_input_dist_attr(input.name, input_dist_attr)
# Set the distributed attribute of input1
input1_dist_attr = TensorDistAttr(input1.desc)
input1_dist_attr.dims_mapping = [-1, 1]
op_dist_attr.set_input_dist_attr(input1.name, input1_dist_attr)
# Set the distributed attribute of output
output_dist_attr = TensorDistAttr(output.desc)
output_dist_attr.dims_mapping = [0, 1]
op_dist_attr.set_output_dist_attr(output.name, output_dist_attr)
def test_serialization_program(self):
set_default_distributed_context(DistributedContext())
(
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
) = get_program()
dist_context = DistributedContext(
train_program, start_program, optimizer, loss, feed_vars, fetch_vars
)
dist_context.initialize(with_cpp=True)
# Distribute context will clone the original train program to serial_main_program
original_program = dist_context.serial_main_program
for block in original_program.blocks:
for tensor in block.vars.values():
dist_attr_data = tensor.dist_attr.serialize_to_string()
tensor._set_attr("dist_attr", dist_attr_data)
for op in block.ops:
dist_attr_data = op.dist_attr.serialize_to_string()
op._set_attr("dist_attr", dist_attr_data)
program_data = original_program.desc.serialize_to_string()
program = Program.parse_from_string(program_data)
for block in program.blocks:
for tensor in block.vars.values():
dist_attr_data = tensor.attr("dist_attr")
tensor._remove_attr("dist_attr")
tensor.dist_attr.parse_from_string(dist_attr_data)
for op in block.ops:
dist_attr_data = op.attr("dist_attr")
op._remove_attr("dist_attr")
op.dist_attr.parse_from_string(dist_attr_data)
self.assertEqual(len(original_program.blocks), len(program.blocks))
for original_block, block in zip(
original_program.blocks, program.blocks
):
self.assertEqual(
len(original_block.vars.values()), len(block.vars.values())
)
for original_tensor in original_block.vars.values():
self.assertEqual(
original_tensor.dist_attr,
block.vars[original_tensor.name].dist_attr,
)
self.assertEqual(len(original_block.ops), len(block.ops))
for original_op, op in zip(original_block.ops, block.ops):
self.assertEqual(original_op.dist_attr, op.dist_attr)
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,8 @@ import paddle
import paddle.fluid.layers as layers
import paddle.fluid as fluid
paddle.enable_static()
main_program = default_main_program()
......@@ -228,15 +230,13 @@ class TestProgramProto(unittest.TestCase):
b = program.desc.serialize_to_string()
self.assertFalse(a == b)
# it seems the attrs of framework::VarDesc is not write to proto,
# except for persistable/need_check_feed/is_parameter/stop_gradient
def test_update_var_attr(self):
program = build_program()
a = program.desc.serialize_to_string()
program.current_block().var("x").desc._set_attr("a", 1)
self.assertFalse(program.desc.need_update())
self.assertTrue(program.desc.need_update())
b = program.desc.serialize_to_string()
self.assertTrue(a == b) # not affected
self.assertFalse(a == b)
class TestProgramHash(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册