From 2c77b5753ddfb85513c36fc0d0200b239d4d874e Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 9 Aug 2022 14:12:46 +0800 Subject: [PATCH] [Auto Parallel] Add the c++ dist attrs (#44989) * [Auto Parallel] Add the c++ dist attrs * [Auto Parallel] Remove some codes to be less than 1000 lines --- .../distributed/auto_parallel/CMakeLists.txt | 9 + .../distributed/auto_parallel/dist_attr.cc | 533 ++++++++++++++++++ .../distributed/auto_parallel/dist_attr.h | 239 ++++++++ .../auto_parallel/dist_attr_test.cc | 142 +++++ .../fluid/distributed/auto_parallel/utils.h | 9 + 5 files changed, 932 insertions(+) create mode 100644 paddle/fluid/distributed/auto_parallel/dist_attr.cc create mode 100644 paddle/fluid/distributed/auto_parallel/dist_attr.h create mode 100644 paddle/fluid/distributed/auto_parallel/dist_attr_test.cc diff --git a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt index 49f45476316..976e76f8931 100644 --- a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt @@ -16,6 +16,15 @@ cc_test( SRCS process_mesh_test.cc DEPS process_mesh) +cc_library( + dist_attr + SRCS dist_attr.cc + DEPS process_mesh auto_parallel_proto proto_desc) +cc_test( + dist_attr_test + SRCS dist_attr_test.cc + DEPS dist_attr) + cc_library( dist_mapper SRCS dist_mapper.cc diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.cc b/paddle/fluid/distributed/auto_parallel/dist_attr.cc new file mode 100644 index 00000000000..9f9609962fc --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.cc @@ -0,0 +1,533 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/dist_attr.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +std::vector TensorDistAttr::fields_{ + "process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"}; + +TensorDistAttr::TensorDistAttr(const VarDesc& tensor) + : tensor_(&tensor), batch_dim_(0) { + set_default_dims_mapping(); + std::vector tensor_shape = tensor_->GetShape(); + for (std::size_t i = 0; i < tensor_shape.size(); ++i) { + dynamic_dims_.push_back(false); + } +} + +TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) { + if (tensor_ == nullptr) { + tensor_ = dist_attr.tensor(); + } + set_process_mesh(dist_attr.process_mesh()); + set_dims_mapping(dist_attr.dims_mapping()); + set_batch_dim(dist_attr.batch_dim()); + set_dynamic_dims(dist_attr.dynamic_dims()); + set_annotated(dist_attr.annotated()); +} + +TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) { + if (tensor_ == nullptr) { + tensor_ = dist_attr.tensor(); + } + set_process_mesh(dist_attr.process_mesh()); + set_dims_mapping(dist_attr.dims_mapping()); + set_batch_dim(dist_attr.batch_dim()); + set_dynamic_dims(dist_attr.dynamic_dims()); + set_annotated(dist_attr.annotated()); + return *this; +} + +void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) { + PADDLE_ENFORCE_EQ(verify_process_mesh(process_mesh), + true, + platform::errors::InvalidArgument( + "Wrong process mesh %s.", process_mesh.to_string())); + process_mesh_ = process_mesh; +} + +void TensorDistAttr::set_dims_mapping( + const std::vector& dims_mapping) { + PADDLE_ENFORCE_EQ(verify_dims_mapping(dims_mapping), + true, + platform::errors::InvalidArgument("Wrong dims_mapping %s.", + str_join(dims_mapping))); + dims_mapping_ = dims_mapping; +} + +void TensorDistAttr::set_batch_dim(int64_t batch_dim) { + PADDLE_ENFORCE_EQ( + verify_batch_dim(batch_dim), + true, + platform::errors::InvalidArgument( + "Wrong batch_dim %d in this distributed attribute.", batch_dim)); + if (tensor_ != nullptr) { + std::vector tensor_shape = tensor_->GetShape(); + int64_t canonical_batch_dim = canonical_dim(batch_dim, tensor_shape.size()); + batch_dim_ = canonical_batch_dim; + } else { + batch_dim_ = batch_dim; + } +} + +void TensorDistAttr::set_dynamic_dims(const std::vector& dynamic_dims) { + PADDLE_ENFORCE_EQ( + verify_dynamic_dims(dynamic_dims), + true, + platform::errors::InvalidArgument("The dynamic_dims [%s] is wrong.", + str_join(dynamic_dims))); + dynamic_dims_ = dynamic_dims; +} + +void TensorDistAttr::set_annotated( + const std::map& annotated) { + PADDLE_ENFORCE_EQ(verify_annotated(annotated), + true, + platform::errors::InvalidArgument( + "The annotated [%s] is wrong.", str_join(annotated))); + annotated_ = annotated; +} + +void TensorDistAttr::set_default_dims_mapping() { + if (tensor_ != nullptr) { + std::vector tensor_shape = tensor_->GetShape(); + dims_mapping_ = std::vector(tensor_shape.size(), -1); + } +} + +void TensorDistAttr::annotate(const std::string& name) { + auto result = std::find(std::begin(fields_), std::end(fields_), name); + if (result != std::end(fields_)) { + annotated_[name] = true; + } +} + +bool TensorDistAttr::verify_process_mesh( + const ProcessMesh& process_mesh) const { + if (!process_mesh_.empty()) { + for (int64_t dim_mapping : dims_mapping_) { + if (dim_mapping < -1 || dim_mapping >= process_mesh_.ndim()) { + return false; + } + } + } + return true; +} + +bool TensorDistAttr::verify_dims_mapping( + const std::vector& dims_mapping) const { + if (tensor_ != nullptr) { + std::vector tensor_shape = tensor_->GetShape(); + if (dims_mapping.size() != tensor_shape.size()) { + return false; + } + } + std::unordered_map map; + if (!process_mesh_.empty()) { + for (int64_t i : dims_mapping) { + if (i < -1 || i >= process_mesh_.ndim()) { + return false; + } + ++map[i]; + if (i != -1 && map[i] > 1) { + return false; + } + } + } else { + for (int64_t i : dims_mapping) { + ++map[i]; + if (i != -1 && map[i] > 1) { + return false; + } + } + } + return true; +} + +bool TensorDistAttr::verify_batch_dim(int64_t dim) const { + if (tensor_ != nullptr) { + std::vector tensor_shape = tensor_->GetShape(); + int64_t ndim = tensor_shape.size(); + if (dim < 0) { + dim = dim + ndim; + } + if (dim < 0 || dim >= ndim) { + return false; + } + } + return true; +} + +bool TensorDistAttr::verify_dynamic_dims( + const std::vector& dynamic_dims) const { + if (tensor_ != nullptr) { + std::vector tensor_shape = tensor_->GetShape(); + if (dynamic_dims.size() != tensor_shape.size()) { + return false; + } + } + return true; +} + +bool TensorDistAttr::verify_annotated( + const std::map& annotated) const { + for (const auto& item : annotated) { + auto result = std::find(std::begin(fields_), std::end(fields_), item.first); + if (result == std::end(fields_)) { + return false; + } + } + return true; +} + +bool TensorDistAttr::verify() const { + if (tensor_ == nullptr) { + return false; + } + if (!verify_process_mesh(process_mesh_)) { + return false; + } + if (!verify_dims_mapping(dims_mapping_)) { + return false; + } + if (!verify_batch_dim(batch_dim_)) { + return false; + } + if (!verify_dynamic_dims(dynamic_dims_)) { + return false; + } + if (!verify_annotated(annotated_)) { + return false; + } + return true; +} + +std::string TensorDistAttr::to_string() const { + std::string dist_str; + if (tensor_ != nullptr) { + dist_str = "{tensor_name: " + tensor_->Name() + ", "; + } else { + dist_str = "{tensor_name: None, "; + } + dist_str += "process_mesh: " + process_mesh_.to_string() + ", "; + dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], "; + dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", "; + dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], "; + dist_str += "annotated: [" + str_join(annotated_) + "]}"; + return dist_str; +} + +bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { + if (lhs.process_mesh() != rhs.process_mesh()) { + return false; + } + if (lhs.dims_mapping() != rhs.dims_mapping()) { + return false; + } + if (lhs.batch_dim() != rhs.batch_dim()) { + return false; + } + if (lhs.dynamic_dims() != rhs.dynamic_dims()) { + return false; + } + return true; +} + +std::vector OperatorDistAttr::fields_{ + "process_mesh", "impl_type", "impl_idx"}; + +OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) { + for (std::string name : op_->InputArgumentNames()) { + VarDesc* input = op_->Block()->FindVarRecursive(name); + inputs_[name] = input; + input_dist_attrs_[name] = TensorDistAttr(*input); + } + for (std::string name : op_->OutputArgumentNames()) { + VarDesc* output = op_->Block()->FindVarRecursive(name); + outputs_[name] = output; + output_dist_attrs_[name] = TensorDistAttr(*output); + } + impl_type_ = "default"; + impl_idx_ = 0; +} + +OperatorDistAttr::OperatorDistAttr(const OperatorDistAttr& dist_attr) { + if (op_ == nullptr) { + op_ = dist_attr.op(); + } + for (const auto& item : dist_attr.input_dist_attrs()) { + set_input_dist_attr(item.first, item.second); + } + for (const auto& item : dist_attr.output_dist_attrs()) { + set_output_dist_attr(item.first, item.second); + } + set_process_mesh(dist_attr.process_mesh()); + set_impl_type(dist_attr.impl_type()); + set_impl_idx(dist_attr.impl_idx()); + set_annotated(dist_attr.annotated()); +} + +OperatorDistAttr& OperatorDistAttr::operator=( + const OperatorDistAttr& dist_attr) { + if (op_ == nullptr) { + op_ = dist_attr.op(); + } + for (const auto& item : dist_attr.input_dist_attrs()) { + set_input_dist_attr(item.first, item.second); + } + for (const auto& item : dist_attr.output_dist_attrs()) { + set_output_dist_attr(item.first, item.second); + } + set_process_mesh(dist_attr.process_mesh()); + set_impl_type(dist_attr.impl_type()); + set_impl_idx(dist_attr.impl_idx()); + set_annotated(dist_attr.annotated()); + return *this; +} + +void OperatorDistAttr::set_input_dist_attr(const std::string& name, + const TensorDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + verify_input_dist_attr(name, dist_attr), + true, + platform::errors::InvalidArgument( + "Wrong dist_attr %s for %s.", dist_attr.to_string(), name)); + input_dist_attrs_[name] = dist_attr; + // Make sure the process mesh of input be same as that of the op + input_dist_attrs_[name].set_process_mesh(process_mesh_); +} + +void OperatorDistAttr::set_output_dist_attr(const std::string& name, + const TensorDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + verify_output_dist_attr(name, dist_attr), + true, + platform::errors::InvalidArgument( + "Wrong dist_attr %s for %s.", dist_attr.to_string(), name)); + output_dist_attrs_[name] = dist_attr; + // Make sure the process mesh of output be same as that of the op + output_dist_attrs_[name].set_process_mesh(process_mesh_); +} + +void OperatorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) { + for (auto& item : input_dist_attrs_) { + item.second.set_process_mesh(process_mesh); + } + for (auto& item : output_dist_attrs_) { + item.second.set_process_mesh(process_mesh); + } + process_mesh_ = process_mesh; +} + +void OperatorDistAttr::annotate(const std::string& name) { + auto result = std::find(std::begin(fields_), std::end(fields_), name); + if (result != std::end(fields_)) { + annotated_[name] = true; + } + if (name == "process_mesh") { + for (auto& item : input_dist_attrs_) { + item.second.annotate(name); + } + for (auto& item : output_dist_attrs_) { + item.second.annotate(name); + } + } +} + +void OperatorDistAttr::set_annotated( + const std::map& annotated) { + PADDLE_ENFORCE_EQ(verify_annotated(annotated), + true, + platform::errors::InvalidArgument( + "The annotated [%s] is wrong.", str_join(annotated))); + annotated_ = annotated; +} + +bool OperatorDistAttr::verify_input_dist_attr( + const std::string& name, const TensorDistAttr& dist_attr) const { + if (!dist_attr.verify()) { + return false; + } + if (op_ != nullptr) { + if (dist_attr.tensor() != nullptr) { + if (name != dist_attr.tensor()->Name()) { + return false; + } + } + if (input_dist_attrs_.count(name) == 0) { + return false; + } + } + return true; +} + +bool OperatorDistAttr::verify_output_dist_attr( + const std::string& name, const TensorDistAttr& dist_attr) const { + if (!dist_attr.verify()) { + return false; + } + if (op_ != nullptr) { + if (dist_attr.tensor() != nullptr) { + if (name != dist_attr.tensor()->Name()) { + return false; + } + } + if (output_dist_attrs_.count(name) == 0) { + return false; + } + } + return true; +} + +bool OperatorDistAttr::verify_process_mesh( + const ProcessMesh& process_mesh) const { + if (process_mesh != process_mesh_) { + return false; + } + for (auto& item : input_dist_attrs_) { + if (item.second.process_mesh() != process_mesh) { + return false; + } + } + for (auto& item : output_dist_attrs_) { + if (item.second.process_mesh() != process_mesh) { + return false; + } + } + return true; +} + +bool OperatorDistAttr::verify_annotated( + const std::map& annotated) const { + for (const auto& item : annotated) { + auto result = std::find(std::begin(fields_), std::end(fields_), item.first); + if (result == std::end(fields_)) { + return false; + } + } + for (auto& item : input_dist_attrs_) { + if (!item.second.verify_annotated(item.second.annotated())) { + return false; + } + } + for (auto& item : output_dist_attrs_) { + if (!item.second.verify_annotated(item.second.annotated())) { + return false; + } + } + return true; +} + +bool OperatorDistAttr::verify() const { + if (op_ == nullptr) { + return false; + } + if (!verify_process_mesh(process_mesh_)) { + return false; + } + for (auto const& item : input_dist_attrs_) { + auto input_names = op_->InputArgumentNames(); + auto found = + std::find(std::begin(input_names), std::end(input_names), item.first); + if (found == std::end(input_names)) { + return false; + } + if (!verify_input_dist_attr(item.first, item.second)) { + return false; + } + } + for (auto const& item : output_dist_attrs_) { + auto output_names = op_->OutputArgumentNames(); + auto found = + std::find(std::begin(output_names), std::end(output_names), item.first); + if (found == std::end(output_names)) { + return false; + } + if (!verify_output_dist_attr(item.first, item.second)) { + return false; + } + } + return true; +} + +std::string OperatorDistAttr::to_string() const { + std::string str; + if (op_ != nullptr) { + str += "{op_type: " + op_->Type() + ", "; + } else { + str += "{op_type: None, "; + } + str += "impl_type: " + impl_type_ + ", "; + str += "impl_idx: " + std::to_string(impl_idx_) + ", "; + str += "annotated: [" + str_join(annotated_) + "], "; + str += "\nprocess_mesh: " + process_mesh_.to_string() + ", "; + str += "\ninput_dist_attrs: [\n"; + for (auto const& item : input_dist_attrs_) { + str += " " + item.second.to_string() + ",\n"; + } + str.replace(str.size() - 2, 2, "]"); + str += "\noutput_dist_attrs: [\n"; + for (auto const& item : output_dist_attrs_) { + str += " " + item.second.to_string() + ",\n"; + } + str.replace(str.size() - 2, 2, "]}"); + return str; +} + +bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) { + if (lhs.process_mesh() != rhs.process_mesh()) { + return false; + } + if (lhs.impl_type() != rhs.impl_type()) { + return false; + } + if (lhs.impl_idx() != rhs.impl_idx()) { + return false; + } + for (auto const& item : lhs.input_dist_attrs()) { + if (rhs.input_dist_attrs().count(item.first) != 1) { + return false; + } + if (rhs.input_dist_attrs().at(item.first) != + lhs.input_dist_attrs().at(item.first)) { + return false; + } + } + for (auto const& item : lhs.output_dist_attrs()) { + if (rhs.output_dist_attrs().count(item.first) != 1) { + return false; + } + if (rhs.output_dist_attrs().at(item.first) != + lhs.output_dist_attrs().at(item.first)) { + return false; + } + } + return true; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.h b/paddle/fluid/distributed/auto_parallel/dist_attr.h new file mode 100644 index 00000000000..ae089ef94b9 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.h @@ -0,0 +1,239 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h" +#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" +#include "paddle/fluid/distributed/auto_parallel/utils.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { + +// Forward Declaration +namespace framework { + +class BlockDesc; +class OpDesc; +class ProgramDesc; +class VarDesc; + +} // namespace framework + +namespace distributed { +namespace auto_parallel { + +using framework::BlockDesc; +using framework::OpDesc; +using framework::ProgramDesc; +using framework::VarDesc; + +class TensorDistAttr { + public: + TensorDistAttr() = default; + + explicit TensorDistAttr(const VarDesc& tensor); + + TensorDistAttr(const TensorDistAttr& tensor); + + TensorDistAttr& operator=(const TensorDistAttr& dist_attr); + + const VarDesc* tensor() const { return tensor_; } + + const ProcessMesh& process_mesh() const { return process_mesh_; } + + void set_process_mesh(const ProcessMesh& process_mesh); + + const std::vector& dims_mapping() const { return dims_mapping_; } + + void set_dims_mapping(const std::vector& dims_mapping); + + int64_t batch_dim() const { return batch_dim_; } + + void set_batch_dim(int64_t batch_dim); + + const std::vector& dynamic_dims() const { return dynamic_dims_; } + + void set_dynamic_dims(const std::vector& dynamic_dims); + + const std::map& annotated() const { return annotated_; } + + void set_annotated(const std::map& annotated); + + void set_default_dims_mapping(); + + bool is_annotated(const std::string& name) const { + return annotated_.count(name) == 1; + } + + void annotate(const std::string& name); + + bool verify_process_mesh(const ProcessMesh& process_mesh) const; + + bool verify_dims_mapping(const std::vector& dims_mapping) const; + + bool verify_batch_dim(int64_t dim) const; + + bool verify_dynamic_dims(const std::vector& dynamic_dims) const; + + bool verify_annotated(const std::map& annotated) const; + + bool verify() const; + + // TensorDistAttr from_string(const std::string& dist_str); + std::string to_string() const; + + private: + static std::vector fields_; + const VarDesc* tensor_{nullptr}; + ProcessMesh process_mesh_; + std::vector dims_mapping_; + int64_t batch_dim_; + std::vector dynamic_dims_; + std::map annotated_; +}; + +inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) { + os << obj.to_string(); + return os; +} + +bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs); + +inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) { + return !operator==(lhs, rhs); +} + +class OperatorDistAttr { + public: + OperatorDistAttr() = default; + + explicit OperatorDistAttr(const OpDesc& op); + + OperatorDistAttr(const OperatorDistAttr& dist_attr); + + OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr); + + const OpDesc* op() const { return op_; } + + const VarDesc& input(const std::string& name) const { + return *inputs_.at(name); + } + + const VarDesc& output(const std::string& name) const { + return *outputs_.at(name); + } + + const std::map& input_dist_attrs() const { + return input_dist_attrs_; + } + + const std::map& output_dist_attrs() const { + return output_dist_attrs_; + } + + const TensorDistAttr& input_dist_attr(const std::string& name) const { + return input_dist_attrs_.at(name); + } + + TensorDistAttr& input_dist_attr(const std::string& name) { + return input_dist_attrs_.at(name); + } + + void set_input_dist_attr(const std::string& name, + const TensorDistAttr& dist_attr); + + const TensorDistAttr& output_dist_attr(const std::string& name) const { + return output_dist_attrs_.at(name); + } + + TensorDistAttr& output_dist_attr(const std::string& name) { + return output_dist_attrs_.at(name); + } + + void set_output_dist_attr(const std::string& name, + const TensorDistAttr& dist_attr); + + const ProcessMesh& process_mesh() const { return process_mesh_; } + + void set_process_mesh(const ProcessMesh& process_mesh); + + const std::string& impl_type() const { return impl_type_; } + + void set_impl_type(const std::string& impl_type) { impl_type_ = impl_type; } + + int64_t impl_idx() const { return impl_idx_; } + + void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; } + + const std::map& annotated() const { return annotated_; } + + void set_annotated(const std::map& annotated); + + bool is_annotated(const std::string& name) const { + return annotated_.count(name) == 1; + } + + void annotate(const std::string& name); + + bool verify_input_dist_attr(const std::string& name, + const TensorDistAttr& dist_attr) const; + + bool verify_output_dist_attr(const std::string& name, + const TensorDistAttr& dist_attr) const; + + bool verify_process_mesh(const ProcessMesh& process_mesh) const; + + bool verify_annotated(const std::map& annotated) const; + + bool verify() const; + + // OperatorDistAttr from_string(const std::string& dist_str); + std::string to_string() const; + + private: + static std::vector fields_; + const OpDesc* op_{nullptr}; + std::map inputs_; + std::map outputs_; + std::map input_dist_attrs_; + std::map output_dist_attrs_; + ProcessMesh process_mesh_; + std::string impl_type_; + int64_t impl_idx_ = -1; + std::map annotated_; +}; + +inline std::ostream& operator<<(std::ostream& os, const OperatorDistAttr& obj) { + os << obj.to_string(); + return os; +} + +bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs); + +inline bool operator!=(const OperatorDistAttr& lhs, + const OperatorDistAttr& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr_test.cc b/paddle/fluid/distributed/auto_parallel/dist_attr_test.cc new file mode 100644 index 00000000000..1b9ac4271b4 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/dist_attr_test.cc @@ -0,0 +1,142 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/auto_parallel/dist_attr.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +TEST(DistAttr, ctor) { + ProgramDesc program; + auto* global_block = program.MutableBlock(0); + auto* x = global_block->Var("X"); + x->SetType(framework::proto::VarType::LOD_TENSOR); + x->SetLoDLevel(0); + x->SetDataType(framework::proto::VarType::FP32); + x->SetShape({1000, 784}); + + auto* y = global_block->Var("Y"); + y->SetType(framework::proto::VarType::LOD_TENSOR); + y->SetLoDLevel(0); + y->SetDataType(framework::proto::VarType::FP32); + y->SetShape({784, 100}); + + auto* op = global_block->AppendOp(); + op->SetType("mul"); + op->SetInput("X", {x->Name()}); + op->SetInput("Y", {y->Name()}); + + auto* out = global_block->Var("Out"); + out->SetType(framework::proto::VarType::LOD_TENSOR); + out->SetShape({1000, 100}); + op->SetOutput("Out", {out->Name()}); + + std::vector shape = {2, 4}; + std::vector process_ids = {0, 1, 2, 3, 4, 5, 6, 7}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(shape, process_ids, dim_names); + + std::vector shape2 = {2, 2}; + std::vector process_ids2 = {0, 1, 2, 3}; + std::vector dim_names2 = {"a", "b"}; + ProcessMesh process_mesh2(shape2, process_ids2, dim_names2); + + TensorDistAttr x_dist_attr(*x), y_dist_attr(*y), out_dist_attr(*out); + x_dist_attr.set_process_mesh(process_mesh); + x_dist_attr.set_dims_mapping(std::vector({0, -1})); + x_dist_attr.set_batch_dim(0); + x_dist_attr.set_dynamic_dims(std::vector({true, false})); + x_dist_attr.annotate("process_mesh"); + x_dist_attr.annotate("dims_mapping"); + EXPECT_EQ(x_dist_attr.process_mesh(), process_mesh); + EXPECT_EQ(x_dist_attr.dims_mapping(), std::vector({0, -1})); + EXPECT_EQ(x_dist_attr.batch_dim(), 0); + EXPECT_EQ(x_dist_attr.dynamic_dims(), std::vector({true, false})); + EXPECT_EQ(x_dist_attr.is_annotated("process_mesh"), true); + EXPECT_EQ(x_dist_attr.is_annotated("dims_mapping"), true); + EXPECT_EQ(x_dist_attr.verify(), true); + + std::stringstream x_sstream; + x_sstream << x_dist_attr; + EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string()); + EXPECT_EQ(x_dist_attr, x_dist_attr); + + y_dist_attr.set_process_mesh(process_mesh); + y_dist_attr.set_dims_mapping(std::vector({-1, 0})); + y_dist_attr.set_batch_dim(-1); + y_dist_attr.set_dynamic_dims(std::vector({false, true})); + x_dist_attr.annotate("batch_dim"); + x_dist_attr.annotate("dynamic_dims"); + EXPECT_EQ(y_dist_attr.process_mesh(), process_mesh); + EXPECT_EQ(y_dist_attr.dims_mapping(), std::vector({-1, 0})); + EXPECT_EQ(y_dist_attr.batch_dim(), 1); + EXPECT_EQ(y_dist_attr.dynamic_dims(), std::vector({false, true})); + EXPECT_EQ(x_dist_attr.is_annotated("batch_dim"), true); + EXPECT_EQ(x_dist_attr.is_annotated("dynamic_dims"), true); + EXPECT_EQ(x_dist_attr.verify(), true); + + out_dist_attr.set_process_mesh(process_mesh); + out_dist_attr.set_dims_mapping(std::vector({0, 1})); + out_dist_attr.set_batch_dim(1); + out_dist_attr.set_dynamic_dims(std::vector({false, false})); + EXPECT_EQ(out_dist_attr.process_mesh(), process_mesh); + EXPECT_EQ(out_dist_attr.dims_mapping(), std::vector({0, 1})); + EXPECT_EQ(out_dist_attr.batch_dim(), 1); + EXPECT_EQ(out_dist_attr.dynamic_dims(), std::vector({false, false})); + EXPECT_EQ(out_dist_attr.verify(), true); + + OperatorDistAttr mul_dist_attr(*op); + mul_dist_attr.set_input_dist_attr(x->Name(), x_dist_attr); + mul_dist_attr.set_input_dist_attr(y->Name(), y_dist_attr); + mul_dist_attr.set_output_dist_attr(out->Name(), out_dist_attr); + mul_dist_attr.set_process_mesh(process_mesh2); + mul_dist_attr.set_impl_type("dist_mul"); + mul_dist_attr.set_impl_idx(0); + mul_dist_attr.annotate("process_mesh"); + mul_dist_attr.annotate("impl_type"); + mul_dist_attr.annotate("impl_idx"); + EXPECT_NE(mul_dist_attr.input_dist_attr(x->Name()), x_dist_attr); + EXPECT_NE(mul_dist_attr.input_dist_attr(y->Name()), y_dist_attr); + EXPECT_NE(mul_dist_attr.output_dist_attr(out->Name()), out_dist_attr); + EXPECT_EQ(mul_dist_attr.process_mesh(), process_mesh2); + EXPECT_EQ(mul_dist_attr.input_dist_attr(x->Name()).process_mesh(), + process_mesh2); + EXPECT_EQ(mul_dist_attr.input_dist_attr(y->Name()).process_mesh(), + process_mesh2); + EXPECT_EQ(mul_dist_attr.impl_type(), "dist_mul"); + EXPECT_EQ(mul_dist_attr.impl_idx(), 0); + EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), true); + EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), true); + EXPECT_EQ(mul_dist_attr.is_annotated("impl_idx"), true); + EXPECT_EQ(mul_dist_attr.verify(), true); + + std::stringstream mul_sstream; + mul_sstream << mul_dist_attr; + EXPECT_EQ(mul_sstream.str(), mul_dist_attr.to_string()); + EXPECT_EQ(mul_dist_attr, mul_dist_attr); +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/utils.h b/paddle/fluid/distributed/auto_parallel/utils.h index 106cbea5e2d..de4162730b1 100644 --- a/paddle/fluid/distributed/auto_parallel/utils.h +++ b/paddle/fluid/distributed/auto_parallel/utils.h @@ -76,6 +76,15 @@ std::string str_join(Range const& elements, return os.str(); } +inline std::string str_join(std::map const& elements, + const std::string& delimiter = ",") { + std::string str; + for (const auto& item : elements) { + str += item.first + ": " + std::to_string(item.second) + ","; + } + return str.substr(0, str.size() - 2); +} + // Refer to https://stackoverflow.com/a/46931770 inline std::vector str_split(std::string const& input, const std::string& delimiter = ",") { -- GitLab