未验证 提交 b03b4a3c 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the c++ dist attr (#47358)

* [Auto Parallel] Improve the c++ dist attr

* [Auto Parallel] Modify test_program.py

* [Auto Parallel] Add the missiong import
上级 3b219e5e
...@@ -29,36 +29,57 @@ namespace auto_parallel { ...@@ -29,36 +29,57 @@ namespace auto_parallel {
std::vector<std::string> TensorDistAttr::fields_{ std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"}; "process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
TensorDistAttr::TensorDistAttr(const VarDesc& tensor) TensorDistAttr::TensorDistAttr(const VarDesc& tensor) : tensor_(&tensor) {
: tensor_(&tensor), batch_dim_(0) { VLOG(4) << "[TensorDistAttr constructor] tensor name: " << tensor_->Name();
if (tensor_->GetType() == framework::proto::VarType::READER) return;
if (tensor_->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY) return;
if (tensor_->GetType() == framework::proto::VarType::STEP_SCOPES) return;
tensor_shape_ = tensor_->GetShape();
VLOG(4) << "[TensorDistAttr constructor] tensor shape: "
<< str_join(tensor_shape_);
set_default_dims_mapping(); set_default_dims_mapping();
std::vector<int64_t> tensor_shape = tensor_->GetShape(); for (std::size_t i = 0; i < tensor_shape_.size(); ++i) {
for (std::size_t i = 0; i < tensor_shape.size(); ++i) {
dynamic_dims_.push_back(false); dynamic_dims_.push_back(false);
} }
} }
TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) { TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
if (tensor_ == nullptr) { if (tensor_ == nullptr) {
tensor_ = dist_attr.tensor(); tensor_ = dist_attr.tensor_;
tensor_shape_ = dist_attr.tensor_shape_;
} }
set_process_mesh(dist_attr.process_mesh()); if (tensor_ != nullptr) {
set_dims_mapping(dist_attr.dims_mapping()); VLOG(4) << "[TensorDistAttr copy constructor] tensor name: "
set_batch_dim(dist_attr.batch_dim()); << tensor_->Name() << ", tensro shape: " << str_join(tensor_shape_);
set_dynamic_dims(dist_attr.dynamic_dims()); } else {
set_annotated(dist_attr.annotated()); VLOG(4) << "[TensorDistAttr copy constructor] tensor name: None"
<< ", tensro shape: " << str_join(tensor_shape_);
}
copy_from(dist_attr);
} }
TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) { TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (tensor_ == nullptr) { if (tensor_ == nullptr) {
tensor_ = dist_attr.tensor(); tensor_ = dist_attr.tensor_;
tensor_shape_ = dist_attr.tensor_shape_;
} }
if (tensor_ != nullptr) {
VLOG(4) << "[TensorDistAttr assign constructor] tensor name: "
<< tensor_->Name() << ", tensro shape: " << str_join(tensor_shape_);
} else {
VLOG(4) << "[TensorDistAttr assign constructor] tensor name: None"
<< ", tensro shape: " << str_join(tensor_shape_);
}
copy_from(dist_attr);
return *this;
}
void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh()); set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping()); set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim()); set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims()); set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated()); set_annotated(dist_attr.annotated());
return *this;
} }
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) { void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
...@@ -84,9 +105,9 @@ void TensorDistAttr::set_batch_dim(int64_t batch_dim) { ...@@ -84,9 +105,9 @@ void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong batch_dim %d in this distributed attribute.", batch_dim)); "Wrong batch_dim %d in this distributed attribute.", batch_dim));
if (tensor_ != nullptr) { if (tensor_ != nullptr && tensor_shape_.size() > 0) {
std::vector<int64_t> tensor_shape = tensor_->GetShape(); int64_t canonical_batch_dim =
int64_t canonical_batch_dim = canonical_dim(batch_dim, tensor_shape.size()); canonical_dim(batch_dim, tensor_shape_.size());
batch_dim_ = canonical_batch_dim; batch_dim_ = canonical_batch_dim;
} else { } else {
batch_dim_ = batch_dim; batch_dim_ = batch_dim;
...@@ -113,8 +134,7 @@ void TensorDistAttr::set_annotated( ...@@ -113,8 +134,7 @@ void TensorDistAttr::set_annotated(
void TensorDistAttr::set_default_dims_mapping() { void TensorDistAttr::set_default_dims_mapping() {
if (tensor_ != nullptr) { if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape(); dims_mapping_ = std::vector<int64_t>(tensor_shape_.size(), -1);
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
} }
} }
...@@ -127,6 +147,8 @@ void TensorDistAttr::annotate(const std::string& name) { ...@@ -127,6 +147,8 @@ void TensorDistAttr::annotate(const std::string& name) {
bool TensorDistAttr::verify_process_mesh( bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const { const ProcessMesh& process_mesh) const {
VLOG(4) << "[TensorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (!process_mesh_.empty()) { if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) { for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping < -1 || dim_mapping >= process_mesh_.ndim()) { if (dim_mapping < -1 || dim_mapping >= process_mesh_.ndim()) {
...@@ -139,11 +161,9 @@ bool TensorDistAttr::verify_process_mesh( ...@@ -139,11 +161,9 @@ bool TensorDistAttr::verify_process_mesh(
bool TensorDistAttr::verify_dims_mapping( bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping) const { const std::vector<int64_t>& dims_mapping) const {
if (tensor_ != nullptr) { VLOG(4) << "[TensorDistAttr verify_dims_mapping] " << str_join(dims_mapping);
std::vector<int64_t> tensor_shape = tensor_->GetShape(); if (dims_mapping.size() != tensor_shape_.size()) {
if (dims_mapping.size() != tensor_shape.size()) { return false;
return false;
}
} }
std::unordered_map<int64_t, int64_t> map; std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) { if (!process_mesh_.empty()) {
...@@ -168,9 +188,9 @@ bool TensorDistAttr::verify_dims_mapping( ...@@ -168,9 +188,9 @@ bool TensorDistAttr::verify_dims_mapping(
} }
bool TensorDistAttr::verify_batch_dim(int64_t dim) const { bool TensorDistAttr::verify_batch_dim(int64_t dim) const {
if (tensor_ != nullptr) { VLOG(4) << "[TensorDistAttr verify_batch_dim] " << dim;
std::vector<int64_t> tensor_shape = tensor_->GetShape(); int64_t ndim = tensor_shape_.size();
int64_t ndim = tensor_shape.size(); if (tensor_ != nullptr && ndim > 0) {
if (dim < 0) { if (dim < 0) {
dim = dim + ndim; dim = dim + ndim;
} }
...@@ -183,17 +203,16 @@ bool TensorDistAttr::verify_batch_dim(int64_t dim) const { ...@@ -183,17 +203,16 @@ bool TensorDistAttr::verify_batch_dim(int64_t dim) const {
bool TensorDistAttr::verify_dynamic_dims( bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims) const { const std::vector<bool>& dynamic_dims) const {
if (tensor_ != nullptr) { VLOG(4) << "[TensorDistAttr verify_dynamic_dims] " << str_join(dynamic_dims);
std::vector<int64_t> tensor_shape = tensor_->GetShape(); if (dynamic_dims.size() != tensor_shape_.size()) {
if (dynamic_dims.size() != tensor_shape.size()) { return false;
return false;
}
} }
return true; return true;
} }
bool TensorDistAttr::verify_annotated( bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const { const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[TensorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) { for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first); auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) { if (result == std::end(fields_)) {
...@@ -204,9 +223,6 @@ bool TensorDistAttr::verify_annotated( ...@@ -204,9 +223,6 @@ bool TensorDistAttr::verify_annotated(
} }
bool TensorDistAttr::verify() const { bool TensorDistAttr::verify() const {
if (tensor_ == nullptr) {
return false;
}
if (!verify_process_mesh(process_mesh_)) { if (!verify_process_mesh(process_mesh_)) {
return false; return false;
} }
...@@ -240,19 +256,17 @@ std::string TensorDistAttr::to_string() const { ...@@ -240,19 +256,17 @@ std::string TensorDistAttr::to_string() const {
return dist_str; return dist_str;
} }
TensorDistAttr TensorDistAttr::from_proto(const TensorDistAttrProto& proto) { void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
TensorDistAttr dist_attr; process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dist_attr.process_mesh_ = ProcessMesh::from_proto(proto.process_mesh()); dims_mapping_.resize(proto.dims_mapping_size());
dist_attr.dims_mapping_.resize(proto.dims_mapping_size());
for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) { for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) {
dist_attr.dims_mapping_[i] = proto.dims_mapping(i); dims_mapping_[i] = proto.dims_mapping(i);
} }
dist_attr.batch_dim_ = proto.batch_dim(); batch_dim_ = proto.batch_dim();
dist_attr.dynamic_dims_.resize(proto.dynamic_dims_size()); dynamic_dims_.resize(proto.dynamic_dims_size());
for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) { for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) {
dist_attr.dynamic_dims_[i] = proto.dynamic_dims(i); dynamic_dims_[i] = proto.dynamic_dims(i);
} }
return dist_attr;
} }
TensorDistAttrProto TensorDistAttr::to_proto() const { TensorDistAttrProto TensorDistAttr::to_proto() const {
...@@ -268,6 +282,26 @@ TensorDistAttrProto TensorDistAttr::to_proto() const { ...@@ -268,6 +282,26 @@ TensorDistAttrProto TensorDistAttr::to_proto() const {
return proto; return proto;
} }
std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
platform::errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
return data;
}
void TensorDistAttr::parse_from_string(const std::string& data) {
TensorDistAttrProto proto;
PADDLE_ENFORCE_EQ(proto.ParseFromString(data),
true,
platform::errors::InvalidArgument(
"Failed to parse tensor dist attr from string."));
from_proto(proto);
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) { if (lhs.process_mesh() != rhs.process_mesh()) {
return false; return false;
...@@ -288,52 +322,103 @@ std::vector<std::string> OperatorDistAttr::fields_{ ...@@ -288,52 +322,103 @@ std::vector<std::string> OperatorDistAttr::fields_{
"process_mesh", "impl_type", "impl_idx"}; "process_mesh", "impl_type", "impl_idx"};
OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) { OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) {
VLOG(4) << "[OperatorDistAttr constructor] op type: " << op_->Type();
initialize();
}
OperatorDistAttr::OperatorDistAttr(const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
if (op_ != nullptr) {
VLOG(4) << "[OperatorDistAttr copy constructor] op type: " << op_->Type();
} else {
VLOG(4) << "[OperatorDistAttr copy constructor] op type: None";
}
initialize();
copy_from(dist_attr);
}
OperatorDistAttr& OperatorDistAttr::operator=(
const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
if (op_ != nullptr) {
VLOG(4) << "[OperatorDistAttr assign constructor] op type: " << op_->Type();
} else {
VLOG(4) << "[OperatorDistAttr assign constructor] op type: None";
}
initialize();
copy_from(dist_attr);
return *this;
}
void OperatorDistAttr::initialize() {
if (op_ == nullptr) return;
for (std::string name : op_->InputArgumentNames()) { for (std::string name : op_->InputArgumentNames()) {
VarDesc* input = op_->Block()->FindVarRecursive(name); VarDesc* input = op_->Block()->FindVarRecursive(name);
VLOG(4) << "[OperatorDistAttr create input dist attr] " << name;
inputs_[name] = input; inputs_[name] = input;
input_dist_attrs_[name] = TensorDistAttr(*input); if (input == nullptr || op_->Type() == "create_py_reader") {
input_dist_attrs_[name] = TensorDistAttr();
} else {
input_dist_attrs_[name] = TensorDistAttr(*input);
}
} }
for (std::string name : op_->OutputArgumentNames()) { for (std::string name : op_->OutputArgumentNames()) {
VarDesc* output = op_->Block()->FindVarRecursive(name); VarDesc* output = op_->Block()->FindVarRecursive(name);
VLOG(4) << "[OperatorDistAttr create output dist attr] " << name;
outputs_[name] = output; outputs_[name] = output;
output_dist_attrs_[name] = TensorDistAttr(*output); if (output == nullptr) {
output_dist_attrs_[name] = TensorDistAttr();
} else {
output_dist_attrs_[name] = TensorDistAttr(*output);
}
} }
impl_type_ = "default"; impl_type_ = "default";
impl_idx_ = 0; impl_idx_ = 0;
} }
OperatorDistAttr::OperatorDistAttr(const OperatorDistAttr& dist_attr) { void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) { set_input_dist_attrs(dist_attr.input_dist_attrs());
op_ = dist_attr.op(); set_output_dist_attrs(dist_attr.output_dist_attrs());
}
for (const auto& item : dist_attr.input_dist_attrs()) {
set_input_dist_attr(item.first, item.second);
}
for (const auto& item : dist_attr.output_dist_attrs()) {
set_output_dist_attr(item.first, item.second);
}
set_process_mesh(dist_attr.process_mesh()); set_process_mesh(dist_attr.process_mesh());
set_impl_type(dist_attr.impl_type()); set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx()); set_impl_idx(dist_attr.impl_idx());
set_annotated(dist_attr.annotated()); set_annotated(dist_attr.annotated());
impl_type_ = dist_attr.impl_type();
impl_idx_ = dist_attr.impl_idx();
} }
OperatorDistAttr& OperatorDistAttr::operator=( void OperatorDistAttr::set_input_dist_attrs(
const OperatorDistAttr& dist_attr) { const std::map<std::string, TensorDistAttr>& dist_attrs) {
if (op_ == nullptr) { if (op_ == nullptr) {
op_ = dist_attr.op(); for (const auto& item : dist_attrs) {
} set_input_dist_attr(item.first, item.second);
for (const auto& item : dist_attr.input_dist_attrs()) { }
set_input_dist_attr(item.first, item.second); } else {
for (const auto& item : input_dist_attrs_) {
if (dist_attrs.count(item.first) == 1) {
set_input_dist_attr(item.first, dist_attrs.at(item.first));
}
}
} }
for (const auto& item : dist_attr.output_dist_attrs()) { }
set_output_dist_attr(item.first, item.second);
void OperatorDistAttr::set_output_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs) {
if (op_ == nullptr) {
for (const auto& item : dist_attrs) {
set_output_dist_attr(item.first, item.second);
}
} else {
for (const auto& item : output_dist_attrs_) {
if (dist_attrs.count(item.first) == 1) {
set_output_dist_attr(item.first, dist_attrs.at(item.first));
}
}
} }
set_process_mesh(dist_attr.process_mesh());
set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx());
set_annotated(dist_attr.annotated());
return *this;
} }
void OperatorDistAttr::set_input_dist_attr(const std::string& name, void OperatorDistAttr::set_input_dist_attr(const std::string& name,
...@@ -341,8 +426,10 @@ void OperatorDistAttr::set_input_dist_attr(const std::string& name, ...@@ -341,8 +426,10 @@ void OperatorDistAttr::set_input_dist_attr(const std::string& name,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
verify_input_dist_attr(name, dist_attr), verify_input_dist_attr(name, dist_attr),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument("Wrong dist_attr %s for %s. %s",
"Wrong dist_attr %s for %s.", dist_attr.to_string(), name)); dist_attr.to_string(),
name,
to_string()));
input_dist_attrs_[name] = dist_attr; input_dist_attrs_[name] = dist_attr;
// Make sure the process mesh of input be same as that of the op // Make sure the process mesh of input be same as that of the op
input_dist_attrs_[name].set_process_mesh(process_mesh_); input_dist_attrs_[name].set_process_mesh(process_mesh_);
...@@ -394,8 +481,30 @@ void OperatorDistAttr::set_annotated( ...@@ -394,8 +481,30 @@ void OperatorDistAttr::set_annotated(
annotated_ = annotated; annotated_ = annotated;
} }
const std::vector<int64_t>& OperatorDistAttr::input_dims_mapping(
const std::string& name) const {
return input_dist_attr(name).dims_mapping();
}
void OperatorDistAttr::set_input_dims_mapping(
const std::string& name, const std::vector<int64_t>& dims_mapping) {
input_dist_attr(name).set_dims_mapping(dims_mapping);
}
const std::vector<int64_t>& OperatorDistAttr::output_dims_mapping(
const std::string& name) {
return output_dist_attr(name).dims_mapping();
}
void OperatorDistAttr::set_output_dims_mapping(
const std::string& name, const std::vector<int64_t>& dims_mapping) {
output_dist_attr(name).set_dims_mapping(dims_mapping);
}
bool OperatorDistAttr::verify_input_dist_attr( bool OperatorDistAttr::verify_input_dist_attr(
const std::string& name, const TensorDistAttr& dist_attr) const { const std::string& name, const TensorDistAttr& dist_attr) const {
VLOG(4) << "[OperatorDistAttr verify_input_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify()) { if (!dist_attr.verify()) {
return false; return false;
} }
...@@ -414,6 +523,8 @@ bool OperatorDistAttr::verify_input_dist_attr( ...@@ -414,6 +523,8 @@ bool OperatorDistAttr::verify_input_dist_attr(
bool OperatorDistAttr::verify_output_dist_attr( bool OperatorDistAttr::verify_output_dist_attr(
const std::string& name, const TensorDistAttr& dist_attr) const { const std::string& name, const TensorDistAttr& dist_attr) const {
VLOG(4) << "[OperatorDistAttr verify_output_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify()) { if (!dist_attr.verify()) {
return false; return false;
} }
...@@ -432,6 +543,8 @@ bool OperatorDistAttr::verify_output_dist_attr( ...@@ -432,6 +543,8 @@ bool OperatorDistAttr::verify_output_dist_attr(
bool OperatorDistAttr::verify_process_mesh( bool OperatorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const { const ProcessMesh& process_mesh) const {
VLOG(4) << "[OperatorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (process_mesh != process_mesh_) { if (process_mesh != process_mesh_) {
return false; return false;
} }
...@@ -450,6 +563,7 @@ bool OperatorDistAttr::verify_process_mesh( ...@@ -450,6 +563,7 @@ bool OperatorDistAttr::verify_process_mesh(
bool OperatorDistAttr::verify_annotated( bool OperatorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const { const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[OperatorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) { for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first); auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) { if (result == std::end(fields_)) {
...@@ -457,11 +571,15 @@ bool OperatorDistAttr::verify_annotated( ...@@ -457,11 +571,15 @@ bool OperatorDistAttr::verify_annotated(
} }
} }
for (auto& item : input_dist_attrs_) { for (auto& item : input_dist_attrs_) {
VLOG(4) << "[OperatorDistAttr verify_annotated input] "
<< str_join(item.second.annotated());
if (!item.second.verify_annotated(item.second.annotated())) { if (!item.second.verify_annotated(item.second.annotated())) {
return false; return false;
} }
} }
for (auto& item : output_dist_attrs_) { for (auto& item : output_dist_attrs_) {
VLOG(4) << "[OperatorDistAttr verify_annotated output] "
<< str_join(item.second.annotated());
if (!item.second.verify_annotated(item.second.annotated())) { if (!item.second.verify_annotated(item.second.annotated())) {
return false; return false;
} }
...@@ -501,6 +619,44 @@ bool OperatorDistAttr::verify() const { ...@@ -501,6 +619,44 @@ bool OperatorDistAttr::verify() const {
return true; return true;
} }
void OperatorDistAttr::rename_input(const std::string& old_name,
const std::string& new_name) {
for (auto& item : input_dist_attrs_) {
if (item.first == old_name) {
VarDesc* new_input = op_->Block()->FindVarRecursive(new_name);
inputs_[new_name] = new_input;
if (new_input == nullptr) {
input_dist_attrs_[new_name] = TensorDistAttr();
} else {
input_dist_attrs_[new_name] = TensorDistAttr(*new_input);
input_dist_attrs_[new_name].copy_from(input_dist_attrs_[old_name]);
}
inputs_.erase(old_name);
input_dist_attrs_.erase(old_name);
break;
}
}
}
void OperatorDistAttr::rename_output(const std::string& old_name,
const std::string& new_name) {
for (auto& item : output_dist_attrs_) {
if (item.first == old_name) {
VarDesc* new_output = op_->Block()->FindVarRecursive(new_name);
outputs_[new_name] = new_output;
if (new_output == nullptr) {
output_dist_attrs_[new_name] = TensorDistAttr();
} else {
output_dist_attrs_[new_name] = TensorDistAttr(*new_output);
output_dist_attrs_[new_name].copy_from(output_dist_attrs_[old_name]);
}
outputs_.erase(old_name);
output_dist_attrs_.erase(old_name);
break;
}
}
}
std::string OperatorDistAttr::to_string() const { std::string OperatorDistAttr::to_string() const {
std::string str; std::string str;
if (op_ != nullptr) { if (op_ != nullptr) {
...@@ -525,23 +681,22 @@ std::string OperatorDistAttr::to_string() const { ...@@ -525,23 +681,22 @@ std::string OperatorDistAttr::to_string() const {
return str; return str;
} }
OperatorDistAttr OperatorDistAttr::from_proto( void OperatorDistAttr::from_proto(const OperatorDistAttrProto& proto) {
const OperatorDistAttrProto& proto) {
OperatorDistAttr dist_attr;
for (int64_t i = 0; i < proto.input_dist_attrs_size(); ++i) { for (int64_t i = 0; i < proto.input_dist_attrs_size(); ++i) {
dist_attr.input_dist_attrs_[proto.input_dist_attrs(i).name()] = TensorDistAttr dist_attr;
TensorDistAttr::from_proto( std::string name = proto.input_dist_attrs(i).name();
proto.input_dist_attrs(i).tensor_dist_attr()); dist_attr.from_proto(proto.input_dist_attrs(i).tensor_dist_attr());
input_dist_attrs_[name] = dist_attr;
} }
for (int64_t i = 0; i < proto.output_dist_attrs_size(); ++i) { for (int64_t i = 0; i < proto.output_dist_attrs_size(); ++i) {
dist_attr.output_dist_attrs_[proto.output_dist_attrs(i).name()] = TensorDistAttr dist_attr;
TensorDistAttr::from_proto( std::string name = proto.output_dist_attrs(i).name();
proto.output_dist_attrs(i).tensor_dist_attr()); dist_attr.from_proto(proto.output_dist_attrs(i).tensor_dist_attr());
output_dist_attrs_[name] = dist_attr;
} }
dist_attr.process_mesh_ = ProcessMesh::from_proto(proto.process_mesh()); process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dist_attr.impl_type_ = proto.impl_type(); impl_type_ = proto.impl_type();
dist_attr.impl_idx_ = proto.impl_idx(); impl_idx_ = proto.impl_idx();
return dist_attr;
} }
OperatorDistAttrProto OperatorDistAttr::to_proto() const { OperatorDistAttrProto OperatorDistAttr::to_proto() const {
...@@ -562,6 +717,26 @@ OperatorDistAttrProto OperatorDistAttr::to_proto() const { ...@@ -562,6 +717,26 @@ OperatorDistAttrProto OperatorDistAttr::to_proto() const {
return proto; return proto;
} }
std::string OperatorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
platform::errors::InvalidArgument(
"Failed to serialize op dist attr to string."));
return data;
}
void OperatorDistAttr::parse_from_string(const std::string& data) {
OperatorDistAttrProto proto;
PADDLE_ENFORCE_EQ(proto.ParseFromString(data),
true,
platform::errors::InvalidArgument(
"Failed to parse op dist attr from string."));
from_proto(proto);
}
bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) { bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) { if (lhs.process_mesh() != rhs.process_mesh()) {
return false; return false;
......
...@@ -56,6 +56,8 @@ class TensorDistAttr { ...@@ -56,6 +56,8 @@ class TensorDistAttr {
TensorDistAttr& operator=(const TensorDistAttr& dist_attr); TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
void copy_from(const TensorDistAttr& dist_attr);
const VarDesc* tensor() const { return tensor_; } const VarDesc* tensor() const { return tensor_; }
const ProcessMesh& process_mesh() const { return process_mesh_; } const ProcessMesh& process_mesh() const { return process_mesh_; }
...@@ -101,16 +103,21 @@ class TensorDistAttr { ...@@ -101,16 +103,21 @@ class TensorDistAttr {
// TensorDistAttr from_string(const std::string& dist_str); // TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const; std::string to_string() const;
static TensorDistAttr from_proto(const TensorDistAttrProto& proto); void from_proto(const TensorDistAttrProto& proto);
TensorDistAttrProto to_proto() const; TensorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private: private:
static std::vector<std::string> fields_; static std::vector<std::string> fields_;
const VarDesc* tensor_{nullptr}; const VarDesc* tensor_{nullptr};
std::vector<int64_t> tensor_shape_;
ProcessMesh process_mesh_; ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_; std::vector<int64_t> dims_mapping_;
int64_t batch_dim_; int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_; std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_; std::map<std::string, bool> annotated_;
}; };
...@@ -136,6 +143,10 @@ class OperatorDistAttr { ...@@ -136,6 +143,10 @@ class OperatorDistAttr {
OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr); OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr);
void initialize();
void copy_from(const OperatorDistAttr& dist_attr);
const OpDesc* op() const { return op_; } const OpDesc* op() const { return op_; }
const VarDesc& input(const std::string& name) const { const VarDesc& input(const std::string& name) const {
...@@ -150,10 +161,16 @@ class OperatorDistAttr { ...@@ -150,10 +161,16 @@ class OperatorDistAttr {
return input_dist_attrs_; return input_dist_attrs_;
} }
void set_input_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs);
const std::map<std::string, TensorDistAttr>& output_dist_attrs() const { const std::map<std::string, TensorDistAttr>& output_dist_attrs() const {
return output_dist_attrs_; return output_dist_attrs_;
} }
void set_output_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs);
const TensorDistAttr& input_dist_attr(const std::string& name) const { const TensorDistAttr& input_dist_attr(const std::string& name) const {
return input_dist_attrs_.at(name); return input_dist_attrs_.at(name);
} }
...@@ -198,6 +215,16 @@ class OperatorDistAttr { ...@@ -198,6 +215,16 @@ class OperatorDistAttr {
void annotate(const std::string& name); void annotate(const std::string& name);
const std::vector<int64_t>& input_dims_mapping(const std::string& name) const;
void set_input_dims_mapping(const std::string& name,
const std::vector<int64_t>& dims_mapping);
const std::vector<int64_t>& output_dims_mapping(const std::string& name);
void set_output_dims_mapping(const std::string& name,
const std::vector<int64_t>& dims_mapping);
bool verify_input_dist_attr(const std::string& name, bool verify_input_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) const; const TensorDistAttr& dist_attr) const;
...@@ -210,13 +237,21 @@ class OperatorDistAttr { ...@@ -210,13 +237,21 @@ class OperatorDistAttr {
bool verify() const; bool verify() const;
void rename_input(const std::string& old_name, const std::string& new_name);
void rename_output(const std::string& old_name, const std::string& new_name);
// OperatorDistAttr from_string(const std::string& dist_str); // OperatorDistAttr from_string(const std::string& dist_str);
std::string to_string() const; std::string to_string() const;
static OperatorDistAttr from_proto(const OperatorDistAttrProto& proto); void from_proto(const OperatorDistAttrProto& proto);
OperatorDistAttrProto to_proto() const; OperatorDistAttrProto to_proto() const;
std::string serialize_to_string();
void parse_from_string(const std::string& data);
private: private:
static std::vector<std::string> fields_; static std::vector<std::string> fields_;
const OpDesc* op_{nullptr}; const OpDesc* op_{nullptr};
......
...@@ -81,10 +81,9 @@ TEST(DistAttr, ctor) { ...@@ -81,10 +81,9 @@ TEST(DistAttr, ctor) {
x_sstream << x_dist_attr; x_sstream << x_dist_attr;
EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string()); EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string());
auto x_proto = x_dist_attr.to_proto(); auto x_proto = x_dist_attr.to_proto();
TensorDistAttr new_x_dist_attr = TensorDistAttr::from_proto(x_proto); TensorDistAttr new_x_dist_attr(*x);
new_x_dist_attr.from_proto(x_proto);
EXPECT_EQ(x_dist_attr, new_x_dist_attr); EXPECT_EQ(x_dist_attr, new_x_dist_attr);
// new_x_dist_attr is not valid since it does not bind to an var_desc
EXPECT_EQ(new_x_dist_attr.verify(), false);
y_dist_attr.set_process_mesh(process_mesh); y_dist_attr.set_process_mesh(process_mesh);
y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, 0})); y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, 0}));
...@@ -139,10 +138,9 @@ TEST(DistAttr, ctor) { ...@@ -139,10 +138,9 @@ TEST(DistAttr, ctor) {
mul_sstream << mul_dist_attr; mul_sstream << mul_dist_attr;
EXPECT_EQ(mul_sstream.str(), mul_dist_attr.to_string()); EXPECT_EQ(mul_sstream.str(), mul_dist_attr.to_string());
auto mul_proto = mul_dist_attr.to_proto(); auto mul_proto = mul_dist_attr.to_proto();
OperatorDistAttr new_mul_dist_attr = OperatorDistAttr::from_proto(mul_proto); OperatorDistAttr new_mul_dist_attr(*op);
new_mul_dist_attr.from_proto(mul_proto);
EXPECT_EQ(mul_dist_attr, new_mul_dist_attr); EXPECT_EQ(mul_dist_attr, new_mul_dist_attr);
// new_mul_dist_attr is not valid since it does not bind to an op_desc
EXPECT_EQ(new_mul_dist_attr.verify(), false);
} }
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -82,7 +82,7 @@ inline std::string str_join(std::map<std::string, bool> const& elements, ...@@ -82,7 +82,7 @@ inline std::string str_join(std::map<std::string, bool> const& elements,
for (const auto& item : elements) { for (const auto& item : elements) {
str += item.first + ": " + std::to_string(item.second) + ","; str += item.first + ": " + std::to_string(item.second) + ",";
} }
return str.substr(0, str.size() - 2); return str.substr(0, str.size() - 1);
} }
// Refer to https://stackoverflow.com/a/46931770 // Refer to https://stackoverflow.com/a/46931770
......
...@@ -119,8 +119,30 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { ...@@ -119,8 +119,30 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
} }
default: default:
PADDLE_THROW(platform::errors::Unavailable("Unsupport attribute type %d.", PADDLE_THROW(platform::errors::Unavailable(
attr_desc.type())); "Unsupported attribute type %d.", attr_desc.type()));
}
return paddle::blank();
}
Attribute GetAttrValue(const proto::VarDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case proto::AttrType::INT: {
return attr_desc.i();
}
case proto::AttrType::STRING: {
return attr_desc.s();
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
return val;
}
default:
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported attribute type %d.", attr_desc.type()));
} }
return paddle::blank(); return paddle::blank();
} }
......
...@@ -37,6 +37,8 @@ paddle::any GetAttrValue(const Attribute& attr); ...@@ -37,6 +37,8 @@ paddle::any GetAttrValue(const Attribute& attr);
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
Attribute GetAttrValue(const proto::VarDesc::Attr& attr_desc);
template <typename T> template <typename T>
struct ExtractAttribute { struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name) explicit ExtractAttribute(const std::string& attr_name)
......
...@@ -885,6 +885,10 @@ void OpDesc::RenameOutput(const std::string &old_name, ...@@ -885,6 +885,10 @@ void OpDesc::RenameOutput(const std::string &old_name,
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name); std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
} }
if (dist_attr_) {
dist_attr_->rename_output(old_name, new_name);
}
need_update_ = true; need_update_ = true;
} }
...@@ -900,6 +904,10 @@ void OpDesc::RenameInput(const std::string &old_name, ...@@ -900,6 +904,10 @@ void OpDesc::RenameInput(const std::string &old_name,
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name); std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
} }
if (dist_attr_) {
dist_attr_->rename_input(old_name, new_name);
}
need_update_ = true; need_update_ = true;
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -28,6 +29,16 @@ VarDesc::VarDesc(const VarDesc &other) ...@@ -28,6 +29,16 @@ VarDesc::VarDesc(const VarDesc &other)
if (other.dist_attr_) { if (other.dist_attr_) {
dist_attr_.reset(new TensorDistAttr(*other.dist_attr_)); dist_attr_.reset(new TensorDistAttr(*other.dist_attr_));
} }
need_updated_ = true;
}
VarDesc::VarDesc(const proto::VarDesc &desc) : desc_(desc) {
// Restore attrs_ for auto parallel
for (const proto::VarDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr);
}
need_updated_ = true;
} }
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); } proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
...@@ -348,14 +359,15 @@ void VarDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -348,14 +359,15 @@ void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
bool valid = attr_type == proto::AttrType::INT || bool valid = attr_type == proto::AttrType::INT ||
attr_type == proto::AttrType::STRING || attr_type == proto::AttrType::STRING ||
attr_type == proto::AttrType::INTS; attr_type == proto::AttrType::INTS;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(valid,
valid, true,
true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("The value for attr (%s) must be " "The value for attr (%s) must be "
"one of list or int or string.", "one of int, string, list of int for now.",
name)); name));
this->attrs_[name] = v; this->attrs_[name] = v;
need_updated_ = true;
} }
Attribute VarDesc::GetAttr(const std::string &name) const { Attribute VarDesc::GetAttr(const std::string &name) const {
...@@ -367,6 +379,63 @@ Attribute VarDesc::GetAttr(const std::string &name) const { ...@@ -367,6 +379,63 @@ Attribute VarDesc::GetAttr(const std::string &name) const {
return it->second; return it->second;
} }
struct SetVarAttrDescVisitor {
explicit SetVarAttrDescVisitor(proto::VarDesc::Attr *attr) : attr_(attr) {}
mutable proto::VarDesc::Attr *attr_;
template <typename T>
void operator()(T &&v) {
using U = std::decay_t<decltype(v)>;
if (std::is_same<U, int>::value) {
set_attr_value(v);
} else if (std::is_same<U, std::string>::value) {
set_attr_value(v);
} else if (std::is_same<U, std::vector<int>>::value) {
set_attr_value(v);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method of SetAttrDescVisitor object."));
}
}
// This template is used to pass the compilation
template <typename U>
void set_attr_value(U v);
void set_attr_value(int v) { attr_->set_i(v); }
void set_attr_value(const std::string &v) { attr_->set_s(v); }
void set_attr_value(const std::vector<int> &v) {
VectorToRepeated(v, attr_->mutable_ints());
}
};
// Only need to flush the attrs for auto parallel for now
void VarDesc::Flush() {
VLOG(4) << "Flush "
<< " " << Name() << " " << need_updated_;
if (need_updated_) {
this->desc_.mutable_attrs()->Clear();
std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
attrs_.end()};
std::sort(
sorted_attrs.begin(),
sorted_attrs.end(),
[](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<proto::AttrType>(attr.second.index() - 1));
SetVarAttrDescVisitor visitor(attr_desc);
paddle::visit(visitor, attr.second);
}
need_updated_ = false;
}
}
TensorDistAttr *VarDesc::MutableDistAttr() { TensorDistAttr *VarDesc::MutableDistAttr() {
// If dist_attr_ is nullptr, construct a new one and return. // If dist_attr_ is nullptr, construct a new one and return.
if (dist_attr_) { if (dist_attr_) {
...@@ -375,12 +444,14 @@ TensorDistAttr *VarDesc::MutableDistAttr() { ...@@ -375,12 +444,14 @@ TensorDistAttr *VarDesc::MutableDistAttr() {
dist_attr_.reset(new TensorDistAttr(*this)); dist_attr_.reset(new TensorDistAttr(*this));
return dist_attr_.get(); return dist_attr_.get();
} }
need_updated_ = true;
} }
void VarDesc::SetDistAttr(const TensorDistAttr &dist_attr) { void VarDesc::SetDistAttr(const TensorDistAttr &dist_attr) {
// Make sure this dist attr be created // Make sure this dist attr be created
MutableDistAttr(); MutableDistAttr();
*dist_attr_ = dist_attr; *dist_attr_ = dist_attr;
need_updated_ = true;
} }
bool operator==(const VarDesc &left, const VarDesc &right) { bool operator==(const VarDesc &left, const VarDesc &right) {
......
...@@ -71,9 +71,7 @@ class VarDesc { ...@@ -71,9 +71,7 @@ class VarDesc {
need_updated_ = true; need_updated_ = true;
} }
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) { explicit VarDesc(const proto::VarDesc &desc);
// need_updated_ = true;
}
// Explicitly implement the copy constructor for auto parallel // Explicitly implement the copy constructor for auto parallel
VarDesc(const VarDesc &other); VarDesc(const VarDesc &other);
...@@ -90,7 +88,7 @@ class VarDesc { ...@@ -90,7 +88,7 @@ class VarDesc {
} }
proto::VarDesc *Proto() { proto::VarDesc *Proto() {
need_updated_ = true; Flush(); // Only flush attrs for auto parallel
return &desc_; return &desc_;
} }
...@@ -194,6 +192,8 @@ class VarDesc { ...@@ -194,6 +192,8 @@ class VarDesc {
bool NeedUpdate() const { return need_updated_; } bool NeedUpdate() const { return need_updated_; }
void SetNeedUpdate(bool need) { need_updated_ = need; } void SetNeedUpdate(bool need) { need_updated_ = need; }
void Flush();
// The following methods are only used for auto parallel. // The following methods are only used for auto parallel.
uint64_t Id() const { return id_; } uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; } uint64_t OriginalId() const { return original_id_; }
......
...@@ -165,6 +165,12 @@ void BindAutoParallel(py::module *m) { ...@@ -165,6 +165,12 @@ void BindAutoParallel(py::module *m) {
&DeviceMesh::dim_size)) &DeviceMesh::dim_size))
.def(py::self == py::self) .def(py::self == py::self)
.def(py::self != py::self) .def(py::self != py::self)
.def(
"__deepcopy__",
[](const TensorDistAttr &self, py::dict) {
return TensorDistAttr(self);
},
py::arg("memo"))
.def("__str__", &DeviceMesh::to_string); .def("__str__", &DeviceMesh::to_string);
py::class_<TensorDistAttr>(*m, "TensorDistAttr") py::class_<TensorDistAttr>(*m, "TensorDistAttr")
...@@ -182,9 +188,17 @@ void BindAutoParallel(py::module *m) { ...@@ -182,9 +188,17 @@ void BindAutoParallel(py::module *m) {
.def_property("dynamic_dims", .def_property("dynamic_dims",
&TensorDistAttr::dynamic_dims, &TensorDistAttr::dynamic_dims,
&TensorDistAttr::set_dynamic_dims) &TensorDistAttr::set_dynamic_dims)
.def_property("annotated",
&TensorDistAttr::annotated,
&TensorDistAttr::set_annotated)
.def("is_annotated", &TensorDistAttr::is_annotated) .def("is_annotated", &TensorDistAttr::is_annotated)
.def("annotate", &TensorDistAttr::annotate) .def("annotate", &TensorDistAttr::annotate)
.def("verify", &TensorDistAttr::verify) .def("verify", &TensorDistAttr::verify)
.def("serialize_to_string",
[](TensorDistAttr &self) {
return py::bytes(self.serialize_to_string());
})
.def("parse_from_string", &TensorDistAttr::parse_from_string)
.def(py::self == py::self) .def(py::self == py::self)
.def(py::self != py::self) .def(py::self != py::self)
.def("__str__", &TensorDistAttr::to_string); .def("__str__", &TensorDistAttr::to_string);
...@@ -201,20 +215,23 @@ void BindAutoParallel(py::module *m) { ...@@ -201,20 +215,23 @@ void BindAutoParallel(py::module *m) {
.def_property("impl_idx", .def_property("impl_idx",
&OperatorDistAttr::impl_idx, &OperatorDistAttr::impl_idx,
&OperatorDistAttr::set_impl_idx) &OperatorDistAttr::set_impl_idx)
.def_property("annotated",
&OperatorDistAttr::annotated,
&OperatorDistAttr::set_annotated)
.def_property("inputs_dist_attrs",
&OperatorDistAttr::input_dist_attrs,
&OperatorDistAttr::set_input_dist_attrs)
.def_property("outputs_dist_attrs",
&OperatorDistAttr::output_dist_attrs,
&OperatorDistAttr::set_output_dist_attrs)
.def("input", &OperatorDistAttr::input) .def("input", &OperatorDistAttr::input)
.def("output", &OperatorDistAttr::output) .def("output", &OperatorDistAttr::output)
.def("input_dist_attrs", .def("get_input_dist_attr",
&OperatorDistAttr::input_dist_attrs,
py::return_value_policy::reference)
.def("output_dist_attrs",
&OperatorDistAttr::output_dist_attrs,
py::return_value_policy::reference)
.def("input_dist_attr",
static_cast<TensorDistAttr &( static_cast<TensorDistAttr &(
OperatorDistAttr::*)(const std::string &)>( OperatorDistAttr::*)(const std::string &)>(
&OperatorDistAttr::input_dist_attr), &OperatorDistAttr::input_dist_attr),
py::return_value_policy::reference) py::return_value_policy::reference)
.def("output_dist_attr", .def("get_output_dist_attr",
static_cast<TensorDistAttr &( static_cast<TensorDistAttr &(
OperatorDistAttr::*)(const std::string &)>( OperatorDistAttr::*)(const std::string &)>(
&OperatorDistAttr::output_dist_attr), &OperatorDistAttr::output_dist_attr),
...@@ -223,9 +240,25 @@ void BindAutoParallel(py::module *m) { ...@@ -223,9 +240,25 @@ void BindAutoParallel(py::module *m) {
.def("set_output_dist_attr", &OperatorDistAttr::set_output_dist_attr) .def("set_output_dist_attr", &OperatorDistAttr::set_output_dist_attr)
.def("is_annotated", &OperatorDistAttr::is_annotated) .def("is_annotated", &OperatorDistAttr::is_annotated)
.def("annotate", &OperatorDistAttr::annotate) .def("annotate", &OperatorDistAttr::annotate)
.def("get_input_dims_mapping", &OperatorDistAttr::input_dims_mapping)
.def("set_input_dims_mapping", &OperatorDistAttr::set_input_dims_mapping)
.def("get_output_dims_mapping", &OperatorDistAttr::output_dims_mapping)
.def("set_output_dims_mapping",
&OperatorDistAttr::set_output_dims_mapping)
.def("verify", &OperatorDistAttr::verify) .def("verify", &OperatorDistAttr::verify)
.def("serialize_to_string",
[](OperatorDistAttr &self) {
return py::bytes(self.serialize_to_string());
})
.def("parse_from_string", &OperatorDistAttr::parse_from_string)
.def(py::self == py::self) .def(py::self == py::self)
.def(py::self != py::self) .def(py::self != py::self)
.def(
"__deepcopy__",
[](const OperatorDistAttr &self, py::dict) {
return OperatorDistAttr(self);
},
py::arg("memo"))
.def("__str__", &OperatorDistAttr::to_string); .def("__str__", &OperatorDistAttr::to_string);
} }
......
...@@ -21,8 +21,10 @@ from paddle.distributed.passes import PassContext ...@@ -21,8 +21,10 @@ from paddle.distributed.passes import PassContext
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .utils import _copy_dist_attr_to_cpp
from .utils import is_loss_grad_op from .utils import is_loss_grad_op
# There always exists a default context for user. And user can set it to another one. # There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None _g_default_distributed_context = None
...@@ -76,6 +78,7 @@ class DistributedContext: ...@@ -76,6 +78,7 @@ class DistributedContext:
self._serial_optimizer = None self._serial_optimizer = None
self._serial_feed_vars = {} self._serial_feed_vars = {}
self._serial_fetch_vars = {} self._serial_fetch_vars = {}
self._lr_optimizer = None # record the optimzier holding lr_scheduler
# Data members related to the program # Data members related to the program
self._dist_tensors_for_program = {} self._dist_tensors_for_program = {}
...@@ -392,7 +395,7 @@ class DistributedContext: ...@@ -392,7 +395,7 @@ class DistributedContext:
if dist: if dist:
self._restore_dist_info(dist_mode) self._restore_dist_info(dist_mode)
def initialize(self, with_graph=True): def initialize(self, with_graph=True, with_cpp=False):
if not self._is_initialized: if not self._is_initialized:
if not self._serial_main_program: if not self._serial_main_program:
if self._original_serial_main_program: if self._original_serial_main_program:
...@@ -425,6 +428,10 @@ class DistributedContext: ...@@ -425,6 +428,10 @@ class DistributedContext:
self._ops_ids = list(self._dist_ops_for_program.keys()) self._ops_ids = list(self._dist_ops_for_program.keys())
self._is_initialized = True self._is_initialized = True
# TODO: This will be removed in the future
if with_cpp:
_copy_dist_attr_to_cpp(self)
if with_graph: if with_graph:
set_flags({"FLAGS_convert_all_blocks": True}) set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph( self._serial_graph = framework.IrGraph(
...@@ -597,7 +604,11 @@ class DistributedContext: ...@@ -597,7 +604,11 @@ class DistributedContext:
tensor tensor
) )
if default_dist_tensor and default_ctx is not self: if default_dist_tensor and default_ctx is not self:
self.add_dist_tensor_for_program(default_dist_tensor) dist_tensor = DistributedTensor(tensor)
dist_tensor.dist_attr = copy.deepcopy(
default_dist_tensor.dist_attr
)
self.add_dist_tensor_for_program(dist_tensor)
current_dist_tensor = self.get_dist_tensor_for_program(tensor) current_dist_tensor = self.get_dist_tensor_for_program(tensor)
if current_dist_tensor is None: if current_dist_tensor is None:
dist_tensor = DistributedTensor(tensor) dist_tensor = DistributedTensor(tensor)
...@@ -606,7 +617,9 @@ class DistributedContext: ...@@ -606,7 +617,9 @@ class DistributedContext:
# Copy the distributed operators in the default context # Copy the distributed operators in the default context
default_dist_op = default_ctx.get_dist_op_for_program(op) default_dist_op = default_ctx.get_dist_op_for_program(op)
if default_dist_op and default_ctx is not self: if default_dist_op and default_ctx is not self:
self.add_dist_op_for_program(default_dist_op) dist_op = DistributedOperator(op)
dist_op.dist_attr = copy.deepcopy(default_dist_op.dist_attr)
self.add_dist_op_for_program(dist_op)
current_dist_op = self.get_dist_op_for_program(op) current_dist_op = self.get_dist_op_for_program(op)
if current_dist_op is None: if current_dist_op is None:
dist_op = DistributedOperator(op) dist_op = DistributedOperator(op)
......
...@@ -1907,3 +1907,120 @@ def validate_opt(optimizer): ...@@ -1907,3 +1907,120 @@ def validate_opt(optimizer):
optimizer._parameter_list = None optimizer._parameter_list = None
optimizer._param_groups = None optimizer._param_groups = None
return optimizer return optimizer
def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
py_process_mesh = py_dist_attr.process_mesh
if py_process_mesh is not None:
cpp_dist_attr.process_mesh = core.ProcessMesh(
py_process_mesh.shape,
py_process_mesh.process_ids,
["d" + str(i) for i in range(len(py_process_mesh.shape))],
)
cpp_dist_attr.dims_mapping = py_dist_attr.dims_mapping
cpp_dist_attr.annotated = py_dist_attr._is_annotated
def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh
if not cpp_process_mesh.empty():
py_dist_attr.process_mesh = ProcessMesh(
shape=cpp_process_mesh.shape,
process_ids=cpp_process_mesh.process_ids,
)
py_dist_attr.dims_mapping = cpp_dist_attr.dims_mapping
py_dist_attr._is_annotated = cpp_dist_attr.annotated
def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
py_process_mesh = py_dist_attr.process_mesh
if py_process_mesh is not None:
cpp_dist_attr.process_mesh = core.ProcessMesh(
py_process_mesh.shape,
py_process_mesh.process_ids,
["d" + str(i) for i in range(len(py_process_mesh.shape))],
)
cpp_dist_attr.impl_type = py_dist_attr.impl_type
cpp_dist_attr.impl_idx = py_dist_attr.impl_idx
cpp_dist_attr.annotated = py_dist_attr._is_annotated
for name, py_tensor_dist_attr in py_dist_attr.inputs_dist_attrs.items():
cpp_tensor_dist_attr = cpp_dist_attr.get_input_dist_attr(name)
_copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr)
for name, py_tensor_dist_attr in py_dist_attr.outputs_dist_attrs.items():
cpp_tensor_dist_attr = cpp_dist_attr.get_output_dist_attr(name)
_copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr)
def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh
if not cpp_process_mesh.empty():
py_dist_attr.process_mesh = ProcessMesh(
shape=cpp_process_mesh.shape,
process_ids=cpp_process_mesh.process_ids,
)
py_dist_attr.impl_type = cpp_dist_attr.impl_type
py_dist_attr.impl_idx = cpp_dist_attr.impl_idx
py_dist_attr._is_annotated = cpp_dist_attr.annotated
py_dist_attr.op_type = cpp_dist_attr.op.type()
for name, cpp_tensor_dist_attr in cpp_dist_attr.inputs_dist_attrs.items():
py_tensor_dist_attr = py_dist_attr.get_input_dist_attr(name)
_copy_tensor_dist_attr_from_cpp(
cpp_tensor_dist_attr, py_tensor_dist_attr
)
for name, cpp_tensor_dist_attr in cpp_dist_attr.outputs_dist_attrs.items():
py_tensor_dist_attr = py_dist_attr.get_output_dist_attr(name)
_copy_tensor_dist_attr_from_cpp(
cpp_tensor_dist_attr, py_tensor_dist_attr
)
def _copy_dist_attr_to_cpp(dist_context):
for dist_tensor in dist_context._dist_tensors_for_program.values():
_copy_tensor_dist_attr_to_cpp(
dist_tensor.serial_tensor.dist_attr, dist_tensor.dist_attr
)
for dist_op in dist_context._dist_ops_for_program.values():
_copy_op_dist_attr_to_cpp(
dist_op.serial_op.dist_attr, dist_op.dist_attr
)
def _copy_dist_attr_from_cpp(dist_context):
for dist_tensor in dist_context._dist_tensors_for_program.values():
_copy_tensor_dist_attr_from_cpp(
dist_tensor.serial_tensor.dist_attr, dist_tensor.dist_attr
)
for dist_op in dist_context._dist_ops_for_program.values():
_copy_op_dist_attr_from_cpp(
dist_op.serial_op.dist_attr, dist_op.dist_attr
)
def _copy_dist_attr_to_cpp_for_graph(dist_context):
for node in dist_context.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
py_dist_attr = dist_context.get_tensor_dist_attr_for_graph(node)
cpp_dist_attr = node.var().dist_attr
_copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr)
if node.is_op() and node.op() is not None:
py_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
cpp_dist_attr = node.op().dist_attr
_copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr)
def _copy_dist_attr_from_cpp_for_graph(dist_context):
for node in dist_context.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
py_dist_attr = dist_context.get_tensor_dist_attr_for_graph(node)
cpp_dist_attr = node.var().dist_attr
_copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr)
if node.is_op() and node.op() is not None:
py_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
cpp_dist_attr = node.op().dist_attr
_copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr)
...@@ -2612,7 +2612,7 @@ class Variable(metaclass=VariableMetaClass): ...@@ -2612,7 +2612,7 @@ class Variable(metaclass=VariableMetaClass):
"""Get the names of all attributes defined.""" """Get the names of all attributes defined."""
return self.desc.attr_names() return self.desc.attr_names()
def _get_attr(self, name): def attr(self, name):
""" """
Get the attribute by name. Get the attribute by name.
......
...@@ -103,6 +103,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -103,6 +103,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster_v2 MODULES test_cluster_v2) py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_serialization MODULES test_serialization)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_dist_matmul MODULES test_dist_matmul) py_test_modules(test_dist_matmul MODULES test_dist_matmul)
py_test_modules(test_process_mesh MODULES test_process_mesh) py_test_modules(test_process_mesh MODULES test_process_mesh)
......
...@@ -13,15 +13,188 @@ ...@@ -13,15 +13,188 @@
# limitations under the License # limitations under the License
import unittest import unittest
import copy
import paddle import paddle
import numpy as np
import paddle.nn as nn
import paddle.static as static import paddle.static as static
import paddle.nn.functional as F
from paddle.distributed import fleet
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
set_default_distributed_context,
)
from paddle.distributed.auto_parallel.utils import (
_copy_dist_attr_to_cpp,
_copy_dist_attr_from_cpp,
_copy_dist_attr_to_cpp_for_graph,
_copy_dist_attr_from_cpp_for_graph,
)
from paddle.fluid.core import TensorDistAttr from paddle.fluid.core import TensorDistAttr
from paddle.fluid.core import OperatorDistAttr from paddle.fluid.core import OperatorDistAttr
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh
paddle.enable_static() paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]], dim_names=['x', 'y'])
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range
)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
process_mesh=_g_process_mesh[0],
shard_spec=[None, 'y'],
)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
process_mesh=_g_process_mesh[1],
shard_spec=['y', None],
)
out = self.linear1(out)
return out
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1],
)
yield batch_input, batch_label
return __reader__
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32'
)
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False
)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places()
)
# data dist_attr
auto.shard_tensor(
input, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
auto.shard_tensor(
label, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_start(input)
mlp_mid = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_mid(pred)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_end(pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
feed_vars = {"inputs": [input], "labels": [label]}
fetch_vars = {"loss": [loss]}
return (
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
)
class TestDistAttr(unittest.TestCase): class TestDistAttr(unittest.TestCase):
def test_tensor_dist_attr_ctor(self): def test_tensor_dist_attr_ctor(self):
...@@ -102,23 +275,25 @@ class TestDistAttr(unittest.TestCase): ...@@ -102,23 +275,25 @@ class TestDistAttr(unittest.TestCase):
op_dist_attr.set_output_dist_attr(output.name, output_dist_attr) op_dist_attr.set_output_dist_attr(output.name, output_dist_attr)
self.assertEqual(op_dist_attr.process_mesh, process_mesh) self.assertEqual(op_dist_attr.process_mesh, process_mesh)
self.assertEqual( self.assertEqual(
op_dist_attr.input_dist_attr(input.name).process_mesh, process_mesh op_dist_attr.get_input_dist_attr(input.name).process_mesh,
process_mesh,
) )
self.assertEqual( self.assertEqual(
op_dist_attr.input_dist_attr(input1.name).process_mesh, process_mesh op_dist_attr.get_input_dist_attr(input1.name).process_mesh,
process_mesh,
) )
self.assertEqual( self.assertEqual(
op_dist_attr.output_dist_attr(output.name).process_mesh, op_dist_attr.get_output_dist_attr(output.name).process_mesh,
process_mesh, process_mesh,
) )
self.assertEqual( self.assertEqual(
op_dist_attr.input_dist_attr(input.name).dims_mapping, [0, -1] op_dist_attr.get_input_dist_attr(input.name).dims_mapping, [0, -1]
) )
self.assertEqual( self.assertEqual(
op_dist_attr.input_dist_attr(input1.name).dims_mapping, [-1, 1] op_dist_attr.get_input_dist_attr(input1.name).dims_mapping, [-1, 1]
) )
self.assertEqual( self.assertEqual(
op_dist_attr.output_dist_attr(output.name).dims_mapping, [0, 1] op_dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1]
) )
self.assertTrue(op_dist_attr.verify()) self.assertTrue(op_dist_attr.verify())
self.assertTrue(str(op_dist_attr), str(op_dist_attr)) self.assertTrue(str(op_dist_attr), str(op_dist_attr))
...@@ -126,13 +301,13 @@ class TestDistAttr(unittest.TestCase): ...@@ -126,13 +301,13 @@ class TestDistAttr(unittest.TestCase):
op_dist_attr = OperatorDistAttr(op.desc) op_dist_attr = OperatorDistAttr(op.desc)
op_dist_attr.process_mesh = process_mesh op_dist_attr.process_mesh = process_mesh
# Set the distributed attribute of input directly # Set the distributed attribute of input directly
input_dist_attr = op_dist_attr.input_dist_attr(input.name) input_dist_attr = op_dist_attr.get_input_dist_attr(input.name)
input_dist_attr.dims_mapping = [-1, 0] input_dist_attr.dims_mapping = [-1, 0]
# Set the distributed attribute of input1 directly # Set the distributed attribute of input1 directly
input1_dist_attr = op_dist_attr.input_dist_attr(input1.name) input1_dist_attr = op_dist_attr.get_input_dist_attr(input1.name)
input1_dist_attr.dims_mapping = [0, -1] input1_dist_attr.dims_mapping = [0, -1]
# Set the distributed attribute of output directly # Set the distributed attribute of output directly
output_dist_attr = op_dist_attr.output_dist_attr(output.name) output_dist_attr = op_dist_attr.get_output_dist_attr(output.name)
output_dist_attr.dims_mapping = [-1, -1] output_dist_attr.dims_mapping = [-1, -1]
self.assertEqual(op_dist_attr.process_mesh, process_mesh) self.assertEqual(op_dist_attr.process_mesh, process_mesh)
self.assertEqual(input_dist_attr.process_mesh, process_mesh) self.assertEqual(input_dist_attr.process_mesh, process_mesh)
...@@ -171,22 +346,25 @@ class TestDistAttr(unittest.TestCase): ...@@ -171,22 +346,25 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual(op.desc.dist_attr.process_mesh, process_mesh) self.assertEqual(op.desc.dist_attr.process_mesh, process_mesh)
self.assertEqual( self.assertEqual(
op.dist_attr.input_dist_attr(input.name).process_mesh, process_mesh op.dist_attr.get_input_dist_attr(input.name).process_mesh,
process_mesh,
) )
self.assertEqual( self.assertEqual(
op.dist_attr.input_dist_attr(input1.name).process_mesh, process_mesh op.dist_attr.get_input_dist_attr(input1.name).process_mesh,
process_mesh,
) )
self.assertEqual( self.assertEqual(
op.dist_attr.input_dist_attr(input.name).dims_mapping, [0, -1] op.dist_attr.get_input_dist_attr(input.name).dims_mapping, [0, -1]
) )
self.assertEqual( self.assertEqual(
op.dist_attr.input_dist_attr(input.name).dims_mapping, [0, -1] op.dist_attr.get_input_dist_attr(input.name).dims_mapping, [0, -1]
) )
self.assertEqual( self.assertEqual(
op.desc.dist_attr.input_dist_attr(input1.name).dims_mapping, [-1, 1] op.desc.dist_attr.get_input_dist_attr(input1.name).dims_mapping,
[-1, 1],
) )
self.assertEqual( self.assertEqual(
op.dist_attr.output_dist_attr(output.name).dims_mapping, [0, 1] op.dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1]
) )
self.assertTrue(op.desc.dist_attr.verify()) self.assertTrue(op.desc.dist_attr.verify())
self.assertTrue(str(op_dist_attr), str(op_dist_attr)) self.assertTrue(str(op_dist_attr), str(op_dist_attr))
...@@ -195,5 +373,80 @@ class TestDistAttr(unittest.TestCase): ...@@ -195,5 +373,80 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual(op.desc.dist_attr, OperatorDistAttr(op.desc)) self.assertEqual(op.desc.dist_attr, OperatorDistAttr(op.desc))
class TestDistAttrConversion(unittest.TestCase):
def test_dist_attr_conversion_for_program(self):
set_default_distributed_context(DistributedContext())
(
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
) = get_program()
dist_context = DistributedContext(
train_program, start_program, optimizer, loss, feed_vars, fetch_vars
)
dist_context.initialize()
original_dist_tensors = copy.deepcopy(
dist_context._dist_tensors_for_program
)
original_dist_ops = copy.deepcopy(dist_context._dist_ops_for_program)
_copy_dist_attr_to_cpp(dist_context)
_copy_dist_attr_from_cpp(dist_context)
for dist_tensor in dist_context._dist_tensors_for_program.values():
original_dist_tensor = original_dist_tensors[
dist_tensor.serial_tensor.desc.original_id()
]
self.assertEqual(
dist_tensor.dist_attr, original_dist_tensor.dist_attr
)
for dist_op in dist_context._dist_ops_for_program.values():
original_dist_op = original_dist_ops[
dist_op.serial_op.desc.original_id()
]
self.assertEqual(dist_op.dist_attr, original_dist_op.dist_attr)
def test_dist_attr_conversion_for_graph(self):
set_default_distributed_context(DistributedContext())
(
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
) = get_program()
dist_context = DistributedContext(
train_program, start_program, optimizer, loss, feed_vars, fetch_vars
)
dist_context.initialize()
original_dist_tensors = copy.deepcopy(
dist_context._dist_tensors_for_graph
)
original_dist_ops = copy.deepcopy(dist_context._dist_ops_for_graph)
_copy_dist_attr_to_cpp_for_graph(dist_context)
_copy_dist_attr_from_cpp_for_graph(dist_context)
for (
node_id,
dist_tensor,
) in dist_context._dist_tensors_for_graph.items():
original_dist_tensor = original_dist_tensors[node_id]
self.assertEqual(
dist_tensor.dist_attr, original_dist_tensor.dist_attr
)
for node_id, dist_op in dist_context._dist_ops_for_graph.items():
original_dist_op = original_dist_ops[node_id]
self.assertEqual(dist_op.dist_attr, original_dist_op.dist_attr)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import unittest
import paddle
import numpy as np
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
from paddle.fluid.framework import Program
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
set_default_distributed_context,
)
from paddle.fluid.core import TensorDistAttr
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]], dim_names=['x', 'y'])
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range
)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None,
)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(
self.linear0.weight,
process_mesh=_g_process_mesh[0],
shard_spec=[None, 'y'],
)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(
self.linear1.weight,
process_mesh=_g_process_mesh[1],
shard_spec=['y', None],
)
out = auto.shard_op(self.linear1, process_mesh=_g_process_mesh)(out)
return out
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1],
)
yield batch_input, batch_label
return __reader__
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32'
)
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False
)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places()
)
# data dist_attr
auto.shard_tensor(
input, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
auto.shard_tensor(
label, process_mesh=_g_process_mesh[0], shard_spec=['y', None, None]
)
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_start(input)
mlp_mid = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_mid(pred)
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
pred = mlp_end(pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
loss = paddle.mean(error_cost)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
feed_vars = {"inputs": [input], "labels": [label]}
fetch_vars = {"loss": [loss]}
return (
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
)
class TestDistAttrSerialization(unittest.TestCase):
def test_serialization_tensor(self):
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
input = static.data(name="input", shape=[2, 3], dtype='float32')
dist_attr = input.dist_attr
dist_attr.process_mesh = ProcessMesh([[0, 1, 2], [3, 4, 5]])
dist_attr.dims_mapping = [0, -1]
dist_attr.batch_dim = 1
dist_attr.dynamic_dims = [1, 1]
dist_attr_data = dist_attr.serialize_to_string()
def test_serialization_opearator(self):
train_program = static.Program()
start_program = static.Program()
with static.program_guard(train_program, start_program):
input = static.data(name="input", shape=[2, 3], dtype='float32')
input1 = static.data(name="input1", shape=[3, 4], dtype='float32')
output = paddle.matmul(input, input1)
op = train_program.current_block().ops[0]
process_mesh = ProcessMesh([[0, 1, 2], [3, 4, 5]])
op_dist_attr = op.dist_attr
op_dist_attr.process_mesh = process_mesh
# Set the distributed attribute of input
input_dist_attr = TensorDistAttr(input.desc)
input_dist_attr.dims_mapping = [0, -1]
op_dist_attr.set_input_dist_attr(input.name, input_dist_attr)
# Set the distributed attribute of input1
input1_dist_attr = TensorDistAttr(input1.desc)
input1_dist_attr.dims_mapping = [-1, 1]
op_dist_attr.set_input_dist_attr(input1.name, input1_dist_attr)
# Set the distributed attribute of output
output_dist_attr = TensorDistAttr(output.desc)
output_dist_attr.dims_mapping = [0, 1]
op_dist_attr.set_output_dist_attr(output.name, output_dist_attr)
def test_serialization_program(self):
set_default_distributed_context(DistributedContext())
(
train_program,
start_program,
dataloader,
loss,
optimizer,
feed_vars,
fetch_vars,
) = get_program()
dist_context = DistributedContext(
train_program, start_program, optimizer, loss, feed_vars, fetch_vars
)
dist_context.initialize(with_cpp=True)
# Distribute context will clone the original train program to serial_main_program
original_program = dist_context.serial_main_program
for block in original_program.blocks:
for tensor in block.vars.values():
dist_attr_data = tensor.dist_attr.serialize_to_string()
tensor._set_attr("dist_attr", dist_attr_data)
for op in block.ops:
dist_attr_data = op.dist_attr.serialize_to_string()
op._set_attr("dist_attr", dist_attr_data)
program_data = original_program.desc.serialize_to_string()
program = Program.parse_from_string(program_data)
for block in program.blocks:
for tensor in block.vars.values():
dist_attr_data = tensor.attr("dist_attr")
tensor._remove_attr("dist_attr")
tensor.dist_attr.parse_from_string(dist_attr_data)
for op in block.ops:
dist_attr_data = op.attr("dist_attr")
op._remove_attr("dist_attr")
op.dist_attr.parse_from_string(dist_attr_data)
self.assertEqual(len(original_program.blocks), len(program.blocks))
for original_block, block in zip(
original_program.blocks, program.blocks
):
self.assertEqual(
len(original_block.vars.values()), len(block.vars.values())
)
for original_tensor in original_block.vars.values():
self.assertEqual(
original_tensor.dist_attr,
block.vars[original_tensor.name].dist_attr,
)
self.assertEqual(len(original_block.ops), len(block.ops))
for original_op, op in zip(original_block.ops, block.ops):
self.assertEqual(original_op.dist_attr, op.dist_attr)
if __name__ == "__main__":
unittest.main()
...@@ -19,6 +19,8 @@ import paddle ...@@ -19,6 +19,8 @@ import paddle
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid as fluid import paddle.fluid as fluid
paddle.enable_static()
main_program = default_main_program() main_program = default_main_program()
...@@ -228,15 +230,13 @@ class TestProgramProto(unittest.TestCase): ...@@ -228,15 +230,13 @@ class TestProgramProto(unittest.TestCase):
b = program.desc.serialize_to_string() b = program.desc.serialize_to_string()
self.assertFalse(a == b) self.assertFalse(a == b)
# it seems the attrs of framework::VarDesc is not write to proto,
# except for persistable/need_check_feed/is_parameter/stop_gradient
def test_update_var_attr(self): def test_update_var_attr(self):
program = build_program() program = build_program()
a = program.desc.serialize_to_string() a = program.desc.serialize_to_string()
program.current_block().var("x").desc._set_attr("a", 1) program.current_block().var("x").desc._set_attr("a", 1)
self.assertFalse(program.desc.need_update()) self.assertTrue(program.desc.need_update())
b = program.desc.serialize_to_string() b = program.desc.serialize_to_string()
self.assertTrue(a == b) # not affected self.assertFalse(a == b)
class TestProgramHash(unittest.TestCase): class TestProgramHash(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册