未验证 提交 2c77b575 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add the c++ dist attrs (#44989)

* [Auto Parallel] Add the c++ dist attrs

* [Auto Parallel] Remove some codes to be less than 1000 lines
上级 2832ab22
......@@ -16,6 +16,15 @@ cc_test(
SRCS process_mesh_test.cc
DEPS process_mesh)
cc_library(
dist_attr
SRCS dist_attr.cc
DEPS process_mesh auto_parallel_proto proto_desc)
cc_test(
dist_attr_test
SRCS dist_attr_test.cc
DEPS dist_attr)
cc_library(
dist_mapper
SRCS dist_mapper.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <iostream>
#include <iterator>
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
TensorDistAttr::TensorDistAttr(const VarDesc& tensor)
: tensor_(&tensor), batch_dim_(0) {
set_default_dims_mapping();
std::vector<int64_t> tensor_shape = tensor_->GetShape();
for (std::size_t i = 0; i < tensor_shape.size(); ++i) {
dynamic_dims_.push_back(false);
}
}
TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
if (tensor_ == nullptr) {
tensor_ = dist_attr.tensor();
}
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
}
TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (tensor_ == nullptr) {
tensor_ = dist_attr.tensor();
}
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
return *this;
}
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
PADDLE_ENFORCE_EQ(verify_process_mesh(process_mesh),
true,
platform::errors::InvalidArgument(
"Wrong process mesh %s.", process_mesh.to_string()));
process_mesh_ = process_mesh;
}
void TensorDistAttr::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
PADDLE_ENFORCE_EQ(verify_dims_mapping(dims_mapping),
true,
platform::errors::InvalidArgument("Wrong dims_mapping %s.",
str_join(dims_mapping)));
dims_mapping_ = dims_mapping;
}
void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
PADDLE_ENFORCE_EQ(
verify_batch_dim(batch_dim),
true,
platform::errors::InvalidArgument(
"Wrong batch_dim %d in this distributed attribute.", batch_dim));
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
int64_t canonical_batch_dim = canonical_dim(batch_dim, tensor_shape.size());
batch_dim_ = canonical_batch_dim;
} else {
batch_dim_ = batch_dim;
}
}
void TensorDistAttr::set_dynamic_dims(const std::vector<bool>& dynamic_dims) {
PADDLE_ENFORCE_EQ(
verify_dynamic_dims(dynamic_dims),
true,
platform::errors::InvalidArgument("The dynamic_dims [%s] is wrong.",
str_join(dynamic_dims)));
dynamic_dims_ = dynamic_dims;
}
void TensorDistAttr::set_annotated(
const std::map<std::string, bool>& annotated) {
PADDLE_ENFORCE_EQ(verify_annotated(annotated),
true,
platform::errors::InvalidArgument(
"The annotated [%s] is wrong.", str_join(annotated)));
annotated_ = annotated;
}
void TensorDistAttr::set_default_dims_mapping() {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
}
}
void TensorDistAttr::annotate(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
annotated_[name] = true;
}
}
bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping < -1 || dim_mapping >= process_mesh_.ndim()) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping) const {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
if (dims_mapping.size() != tensor_shape.size()) {
return false;
}
}
std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) {
for (int64_t i : dims_mapping) {
if (i < -1 || i >= process_mesh_.ndim()) {
return false;
}
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
} else {
for (int64_t i : dims_mapping) {
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
}
return true;
}
bool TensorDistAttr::verify_batch_dim(int64_t dim) const {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
int64_t ndim = tensor_shape.size();
if (dim < 0) {
dim = dim + ndim;
}
if (dim < 0 || dim >= ndim) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims) const {
if (tensor_ != nullptr) {
std::vector<int64_t> tensor_shape = tensor_->GetShape();
if (dynamic_dims.size() != tensor_shape.size()) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify() const {
if (tensor_ == nullptr) {
return false;
}
if (!verify_process_mesh(process_mesh_)) {
return false;
}
if (!verify_dims_mapping(dims_mapping_)) {
return false;
}
if (!verify_batch_dim(batch_dim_)) {
return false;
}
if (!verify_dynamic_dims(dynamic_dims_)) {
return false;
}
if (!verify_annotated(annotated_)) {
return false;
}
return true;
}
std::string TensorDistAttr::to_string() const {
std::string dist_str;
if (tensor_ != nullptr) {
dist_str = "{tensor_name: " + tensor_->Name() + ", ";
} else {
dist_str = "{tensor_name: None, ";
}
dist_str += "process_mesh: " + process_mesh_.to_string() + ", ";
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
return dist_str;
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
}
if (lhs.dims_mapping() != rhs.dims_mapping()) {
return false;
}
if (lhs.batch_dim() != rhs.batch_dim()) {
return false;
}
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
return true;
}
std::vector<std::string> OperatorDistAttr::fields_{
"process_mesh", "impl_type", "impl_idx"};
OperatorDistAttr::OperatorDistAttr(const OpDesc& op) : op_(&op) {
for (std::string name : op_->InputArgumentNames()) {
VarDesc* input = op_->Block()->FindVarRecursive(name);
inputs_[name] = input;
input_dist_attrs_[name] = TensorDistAttr(*input);
}
for (std::string name : op_->OutputArgumentNames()) {
VarDesc* output = op_->Block()->FindVarRecursive(name);
outputs_[name] = output;
output_dist_attrs_[name] = TensorDistAttr(*output);
}
impl_type_ = "default";
impl_idx_ = 0;
}
OperatorDistAttr::OperatorDistAttr(const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
for (const auto& item : dist_attr.input_dist_attrs()) {
set_input_dist_attr(item.first, item.second);
}
for (const auto& item : dist_attr.output_dist_attrs()) {
set_output_dist_attr(item.first, item.second);
}
set_process_mesh(dist_attr.process_mesh());
set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx());
set_annotated(dist_attr.annotated());
}
OperatorDistAttr& OperatorDistAttr::operator=(
const OperatorDistAttr& dist_attr) {
if (op_ == nullptr) {
op_ = dist_attr.op();
}
for (const auto& item : dist_attr.input_dist_attrs()) {
set_input_dist_attr(item.first, item.second);
}
for (const auto& item : dist_attr.output_dist_attrs()) {
set_output_dist_attr(item.first, item.second);
}
set_process_mesh(dist_attr.process_mesh());
set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx());
set_annotated(dist_attr.annotated());
return *this;
}
void OperatorDistAttr::set_input_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
verify_input_dist_attr(name, dist_attr),
true,
platform::errors::InvalidArgument(
"Wrong dist_attr %s for %s.", dist_attr.to_string(), name));
input_dist_attrs_[name] = dist_attr;
// Make sure the process mesh of input be same as that of the op
input_dist_attrs_[name].set_process_mesh(process_mesh_);
}
void OperatorDistAttr::set_output_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(
verify_output_dist_attr(name, dist_attr),
true,
platform::errors::InvalidArgument(
"Wrong dist_attr %s for %s.", dist_attr.to_string(), name));
output_dist_attrs_[name] = dist_attr;
// Make sure the process mesh of output be same as that of the op
output_dist_attrs_[name].set_process_mesh(process_mesh_);
}
void OperatorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
for (auto& item : input_dist_attrs_) {
item.second.set_process_mesh(process_mesh);
}
for (auto& item : output_dist_attrs_) {
item.second.set_process_mesh(process_mesh);
}
process_mesh_ = process_mesh;
}
void OperatorDistAttr::annotate(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
annotated_[name] = true;
}
if (name == "process_mesh") {
for (auto& item : input_dist_attrs_) {
item.second.annotate(name);
}
for (auto& item : output_dist_attrs_) {
item.second.annotate(name);
}
}
}
void OperatorDistAttr::set_annotated(
const std::map<std::string, bool>& annotated) {
PADDLE_ENFORCE_EQ(verify_annotated(annotated),
true,
platform::errors::InvalidArgument(
"The annotated [%s] is wrong.", str_join(annotated)));
annotated_ = annotated;
}
bool OperatorDistAttr::verify_input_dist_attr(
const std::string& name, const TensorDistAttr& dist_attr) const {
if (!dist_attr.verify()) {
return false;
}
if (op_ != nullptr) {
if (dist_attr.tensor() != nullptr) {
if (name != dist_attr.tensor()->Name()) {
return false;
}
}
if (input_dist_attrs_.count(name) == 0) {
return false;
}
}
return true;
}
bool OperatorDistAttr::verify_output_dist_attr(
const std::string& name, const TensorDistAttr& dist_attr) const {
if (!dist_attr.verify()) {
return false;
}
if (op_ != nullptr) {
if (dist_attr.tensor() != nullptr) {
if (name != dist_attr.tensor()->Name()) {
return false;
}
}
if (output_dist_attrs_.count(name) == 0) {
return false;
}
}
return true;
}
bool OperatorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
if (process_mesh != process_mesh_) {
return false;
}
for (auto& item : input_dist_attrs_) {
if (item.second.process_mesh() != process_mesh) {
return false;
}
}
for (auto& item : output_dist_attrs_) {
if (item.second.process_mesh() != process_mesh) {
return false;
}
}
return true;
}
bool OperatorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
return false;
}
}
for (auto& item : input_dist_attrs_) {
if (!item.second.verify_annotated(item.second.annotated())) {
return false;
}
}
for (auto& item : output_dist_attrs_) {
if (!item.second.verify_annotated(item.second.annotated())) {
return false;
}
}
return true;
}
bool OperatorDistAttr::verify() const {
if (op_ == nullptr) {
return false;
}
if (!verify_process_mesh(process_mesh_)) {
return false;
}
for (auto const& item : input_dist_attrs_) {
auto input_names = op_->InputArgumentNames();
auto found =
std::find(std::begin(input_names), std::end(input_names), item.first);
if (found == std::end(input_names)) {
return false;
}
if (!verify_input_dist_attr(item.first, item.second)) {
return false;
}
}
for (auto const& item : output_dist_attrs_) {
auto output_names = op_->OutputArgumentNames();
auto found =
std::find(std::begin(output_names), std::end(output_names), item.first);
if (found == std::end(output_names)) {
return false;
}
if (!verify_output_dist_attr(item.first, item.second)) {
return false;
}
}
return true;
}
std::string OperatorDistAttr::to_string() const {
std::string str;
if (op_ != nullptr) {
str += "{op_type: " + op_->Type() + ", ";
} else {
str += "{op_type: None, ";
}
str += "impl_type: " + impl_type_ + ", ";
str += "impl_idx: " + std::to_string(impl_idx_) + ", ";
str += "annotated: [" + str_join(annotated_) + "], ";
str += "\nprocess_mesh: " + process_mesh_.to_string() + ", ";
str += "\ninput_dist_attrs: [\n";
for (auto const& item : input_dist_attrs_) {
str += " " + item.second.to_string() + ",\n";
}
str.replace(str.size() - 2, 2, "]");
str += "\noutput_dist_attrs: [\n";
for (auto const& item : output_dist_attrs_) {
str += " " + item.second.to_string() + ",\n";
}
str.replace(str.size() - 2, 2, "]}");
return str;
}
bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
}
if (lhs.impl_type() != rhs.impl_type()) {
return false;
}
if (lhs.impl_idx() != rhs.impl_idx()) {
return false;
}
for (auto const& item : lhs.input_dist_attrs()) {
if (rhs.input_dist_attrs().count(item.first) != 1) {
return false;
}
if (rhs.input_dist_attrs().at(item.first) !=
lhs.input_dist_attrs().at(item.first)) {
return false;
}
}
for (auto const& item : lhs.output_dist_attrs()) {
if (rhs.output_dist_attrs().count(item.first) != 1) {
return false;
}
if (rhs.output_dist_attrs().at(item.first) !=
lhs.output_dist_attrs().at(item.first)) {
return false;
}
}
return true;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
// Forward Declaration
namespace framework {
class BlockDesc;
class OpDesc;
class ProgramDesc;
class VarDesc;
} // namespace framework
namespace distributed {
namespace auto_parallel {
using framework::BlockDesc;
using framework::OpDesc;
using framework::ProgramDesc;
using framework::VarDesc;
class TensorDistAttr {
public:
TensorDistAttr() = default;
explicit TensorDistAttr(const VarDesc& tensor);
TensorDistAttr(const TensorDistAttr& tensor);
TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
const VarDesc* tensor() const { return tensor_; }
const ProcessMesh& process_mesh() const { return process_mesh_; }
void set_process_mesh(const ProcessMesh& process_mesh);
const std::vector<int64_t>& dims_mapping() const { return dims_mapping_; }
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
int64_t batch_dim() const { return batch_dim_; }
void set_batch_dim(int64_t batch_dim);
const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }
void set_dynamic_dims(const std::vector<bool>& dynamic_dims);
const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated);
void set_default_dims_mapping();
bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1;
}
void annotate(const std::string& name);
bool verify_process_mesh(const ProcessMesh& process_mesh) const;
bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping) const;
bool verify_batch_dim(int64_t dim) const;
bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify() const;
// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
private:
static std::vector<std::string> fields_;
const VarDesc* tensor_{nullptr};
ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_;
int64_t batch_dim_;
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
};
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
os << obj.to_string();
return os;
}
bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs);
inline bool operator!=(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
return !operator==(lhs, rhs);
}
class OperatorDistAttr {
public:
OperatorDistAttr() = default;
explicit OperatorDistAttr(const OpDesc& op);
OperatorDistAttr(const OperatorDistAttr& dist_attr);
OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr);
const OpDesc* op() const { return op_; }
const VarDesc& input(const std::string& name) const {
return *inputs_.at(name);
}
const VarDesc& output(const std::string& name) const {
return *outputs_.at(name);
}
const std::map<std::string, TensorDistAttr>& input_dist_attrs() const {
return input_dist_attrs_;
}
const std::map<std::string, TensorDistAttr>& output_dist_attrs() const {
return output_dist_attrs_;
}
const TensorDistAttr& input_dist_attr(const std::string& name) const {
return input_dist_attrs_.at(name);
}
TensorDistAttr& input_dist_attr(const std::string& name) {
return input_dist_attrs_.at(name);
}
void set_input_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr);
const TensorDistAttr& output_dist_attr(const std::string& name) const {
return output_dist_attrs_.at(name);
}
TensorDistAttr& output_dist_attr(const std::string& name) {
return output_dist_attrs_.at(name);
}
void set_output_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr);
const ProcessMesh& process_mesh() const { return process_mesh_; }
void set_process_mesh(const ProcessMesh& process_mesh);
const std::string& impl_type() const { return impl_type_; }
void set_impl_type(const std::string& impl_type) { impl_type_ = impl_type; }
int64_t impl_idx() const { return impl_idx_; }
void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; }
const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated);
bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1;
}
void annotate(const std::string& name);
bool verify_input_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) const;
bool verify_output_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) const;
bool verify_process_mesh(const ProcessMesh& process_mesh) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify() const;
// OperatorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
private:
static std::vector<std::string> fields_;
const OpDesc* op_{nullptr};
std::map<std::string, VarDesc*> inputs_;
std::map<std::string, VarDesc*> outputs_;
std::map<std::string, TensorDistAttr> input_dist_attrs_;
std::map<std::string, TensorDistAttr> output_dist_attrs_;
ProcessMesh process_mesh_;
std::string impl_type_;
int64_t impl_idx_ = -1;
std::map<std::string, bool> annotated_;
};
inline std::ostream& operator<<(std::ostream& os, const OperatorDistAttr& obj) {
os << obj.to_string();
return os;
}
bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs);
inline bool operator!=(const OperatorDistAttr& lhs,
const OperatorDistAttr& rhs) {
return !operator==(lhs, rhs);
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <sstream>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
TEST(DistAttr, ctor) {
ProgramDesc program;
auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(framework::proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(framework::proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(framework::proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(framework::proto::VarType::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
op->SetType("mul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(framework::proto::VarType::LOD_TENSOR);
out->SetShape({1000, 100});
op->SetOutput("Out", {out->Name()});
std::vector<int64_t> shape = {2, 4};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh process_mesh(shape, process_ids, dim_names);
std::vector<int64_t> shape2 = {2, 2};
std::vector<int64_t> process_ids2 = {0, 1, 2, 3};
std::vector<std::string> dim_names2 = {"a", "b"};
ProcessMesh process_mesh2(shape2, process_ids2, dim_names2);
TensorDistAttr x_dist_attr(*x), y_dist_attr(*y), out_dist_attr(*out);
x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1}));
x_dist_attr.set_batch_dim(0);
x_dist_attr.set_dynamic_dims(std::vector<bool>({true, false}));
x_dist_attr.annotate("process_mesh");
x_dist_attr.annotate("dims_mapping");
EXPECT_EQ(x_dist_attr.process_mesh(), process_mesh);
EXPECT_EQ(x_dist_attr.dims_mapping(), std::vector<int64_t>({0, -1}));
EXPECT_EQ(x_dist_attr.batch_dim(), 0);
EXPECT_EQ(x_dist_attr.dynamic_dims(), std::vector<bool>({true, false}));
EXPECT_EQ(x_dist_attr.is_annotated("process_mesh"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dims_mapping"), true);
EXPECT_EQ(x_dist_attr.verify(), true);
std::stringstream x_sstream;
x_sstream << x_dist_attr;
EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string());
EXPECT_EQ(x_dist_attr, x_dist_attr);
y_dist_attr.set_process_mesh(process_mesh);
y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, 0}));
y_dist_attr.set_batch_dim(-1);
y_dist_attr.set_dynamic_dims(std::vector<bool>({false, true}));
x_dist_attr.annotate("batch_dim");
x_dist_attr.annotate("dynamic_dims");
EXPECT_EQ(y_dist_attr.process_mesh(), process_mesh);
EXPECT_EQ(y_dist_attr.dims_mapping(), std::vector<int64_t>({-1, 0}));
EXPECT_EQ(y_dist_attr.batch_dim(), 1);
EXPECT_EQ(y_dist_attr.dynamic_dims(), std::vector<bool>({false, true}));
EXPECT_EQ(x_dist_attr.is_annotated("batch_dim"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dynamic_dims"), true);
EXPECT_EQ(x_dist_attr.verify(), true);
out_dist_attr.set_process_mesh(process_mesh);
out_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1}));
out_dist_attr.set_batch_dim(1);
out_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
EXPECT_EQ(out_dist_attr.process_mesh(), process_mesh);
EXPECT_EQ(out_dist_attr.dims_mapping(), std::vector<int64_t>({0, 1}));
EXPECT_EQ(out_dist_attr.batch_dim(), 1);
EXPECT_EQ(out_dist_attr.dynamic_dims(), std::vector<bool>({false, false}));
EXPECT_EQ(out_dist_attr.verify(), true);
OperatorDistAttr mul_dist_attr(*op);
mul_dist_attr.set_input_dist_attr(x->Name(), x_dist_attr);
mul_dist_attr.set_input_dist_attr(y->Name(), y_dist_attr);
mul_dist_attr.set_output_dist_attr(out->Name(), out_dist_attr);
mul_dist_attr.set_process_mesh(process_mesh2);
mul_dist_attr.set_impl_type("dist_mul");
mul_dist_attr.set_impl_idx(0);
mul_dist_attr.annotate("process_mesh");
mul_dist_attr.annotate("impl_type");
mul_dist_attr.annotate("impl_idx");
EXPECT_NE(mul_dist_attr.input_dist_attr(x->Name()), x_dist_attr);
EXPECT_NE(mul_dist_attr.input_dist_attr(y->Name()), y_dist_attr);
EXPECT_NE(mul_dist_attr.output_dist_attr(out->Name()), out_dist_attr);
EXPECT_EQ(mul_dist_attr.process_mesh(), process_mesh2);
EXPECT_EQ(mul_dist_attr.input_dist_attr(x->Name()).process_mesh(),
process_mesh2);
EXPECT_EQ(mul_dist_attr.input_dist_attr(y->Name()).process_mesh(),
process_mesh2);
EXPECT_EQ(mul_dist_attr.impl_type(), "dist_mul");
EXPECT_EQ(mul_dist_attr.impl_idx(), 0);
EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), true);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), true);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_idx"), true);
EXPECT_EQ(mul_dist_attr.verify(), true);
std::stringstream mul_sstream;
mul_sstream << mul_dist_attr;
EXPECT_EQ(mul_sstream.str(), mul_dist_attr.to_string());
EXPECT_EQ(mul_dist_attr, mul_dist_attr);
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -76,6 +76,15 @@ std::string str_join(Range const& elements,
return os.str();
}
inline std::string str_join(std::map<std::string, bool> const& elements,
const std::string& delimiter = ",") {
std::string str;
for (const auto& item : elements) {
str += item.first + ": " + std::to_string(item.second) + ",";
}
return str.substr(0, str.size() - 2);
}
// Refer to https://stackoverflow.com/a/46931770
inline std::vector<std::string> str_split(std::string const& input,
const std::string& delimiter = ",") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册