未验证 提交 c7899074 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Merge dist attrs from python into c++ (#49214)

* [Auto Parallel] Rename methods of ProcessMesh

* [Auto Parallel] Impl the python process_mesh by the c++ one

* [Auto Parallel] Add some minor modifications

* [Auto Parallel] Rename some methods

* [Auto Parallel] Remove unnecessary codes

* [Auto Parallel] Add back some removed files

* [Auto Parallel] Fix bugs

* [Auto Parallel] Fix a bug

* Update process_mesh.cc

* [Auto Parallel] Merge dist attrs of Python into C++

* [Auto Parallel] Add back deleted importing

* [Auto Parallel] Add back removed unittest

* [Auto Parallel] Remove type qualifiers of return types

* [Auto Parallel] Fix some bugs

* [Auto Parallel] Fix a bug of the quant pass

* [Auto Parallel] Fix the code style
上级 0f3ccd14
...@@ -60,8 +60,6 @@ class TensorDistAttr { ...@@ -60,8 +60,6 @@ class TensorDistAttr {
void copy_from(const TensorDistAttr& dist_attr); void copy_from(const TensorDistAttr& dist_attr);
const VarDesc* tensor() const { return tensor_; }
const ProcessMesh& process_mesh() const { return process_mesh_; } const ProcessMesh& process_mesh() const { return process_mesh_; }
void set_process_mesh(const ProcessMesh& process_mesh); void set_process_mesh(const ProcessMesh& process_mesh);
...@@ -70,6 +68,8 @@ class TensorDistAttr { ...@@ -70,6 +68,8 @@ class TensorDistAttr {
void set_dims_mapping(const std::vector<int64_t>& dims_mapping); void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);
int64_t batch_dim() const { return batch_dim_; } int64_t batch_dim() const { return batch_dim_; }
void set_batch_dim(int64_t batch_dim); void set_batch_dim(int64_t batch_dim);
...@@ -78,29 +78,34 @@ class TensorDistAttr { ...@@ -78,29 +78,34 @@ class TensorDistAttr {
void set_dynamic_dims(const std::vector<bool>& dynamic_dims); void set_dynamic_dims(const std::vector<bool>& dynamic_dims);
void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);
const std::map<std::string, bool>& annotated() const { return annotated_; } const std::map<std::string, bool>& annotated() const { return annotated_; }
void set_annotated(const std::map<std::string, bool>& annotated); void set_annotated(const std::map<std::string, bool>& annotated);
void set_default_dims_mapping();
bool is_annotated(const std::string& name) const { bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1; return annotated_.count(name) == 1 && annotated_.at(name) == true;
} }
void annotate(const std::string& name); void mark_annotated(const std::string& name);
void clear_annotated() { annotated_.clear(); }
bool verify_process_mesh(const ProcessMesh& process_mesh) const; bool verify_process_mesh(const ProcessMesh& process_mesh) const;
bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping) const; bool verify_dims_mapping(const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const;
bool verify_batch_dim(int64_t dim) const; bool verify_batch_dim(int64_t dim,
const std::vector<int64_t>& tensor_shape) const;
bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims) const; bool verify_dynamic_dims(const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const; bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify() const; bool verify(const VarDesc* tensor = nullptr) const;
// TensorDistAttr from_string(const std::string& dist_str); // TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const; std::string to_string() const;
...@@ -115,8 +120,6 @@ class TensorDistAttr { ...@@ -115,8 +120,6 @@ class TensorDistAttr {
private: private:
static std::vector<std::string> fields_; static std::vector<std::string> fields_;
const VarDesc* tensor_{nullptr};
std::vector<int64_t> tensor_shape_;
ProcessMesh process_mesh_; ProcessMesh process_mesh_;
std::vector<int64_t> dims_mapping_; std::vector<int64_t> dims_mapping_;
int64_t batch_dim_{0}; int64_t batch_dim_{0};
...@@ -145,21 +148,15 @@ class OperatorDistAttr { ...@@ -145,21 +148,15 @@ class OperatorDistAttr {
OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr); OperatorDistAttr& operator=(const OperatorDistAttr& dist_attr);
void initialize(); void initialize(const OpDesc* op = nullptr);
void copy_from(const OperatorDistAttr& dist_attr); void copy_from(const OperatorDistAttr& dist_attr);
const OpDesc* op() const { return op_; } const std::map<std::string, TensorDistAttr>& input_dist_attrs() const {
return input_dist_attrs_;
const VarDesc& input(const std::string& name) const {
return *inputs_.at(name);
}
const VarDesc& output(const std::string& name) const {
return *outputs_.at(name);
} }
const std::map<std::string, TensorDistAttr>& input_dist_attrs() const { std::map<std::string, TensorDistAttr>& input_dist_attrs() {
return input_dist_attrs_; return input_dist_attrs_;
} }
...@@ -170,6 +167,10 @@ class OperatorDistAttr { ...@@ -170,6 +167,10 @@ class OperatorDistAttr {
return output_dist_attrs_; return output_dist_attrs_;
} }
std::map<std::string, TensorDistAttr>& output_dist_attrs() {
return output_dist_attrs_;
}
void set_output_dist_attrs( void set_output_dist_attrs(
const std::map<std::string, TensorDistAttr>& dist_attrs); const std::map<std::string, TensorDistAttr>& dist_attrs);
...@@ -199,6 +200,10 @@ class OperatorDistAttr { ...@@ -199,6 +200,10 @@ class OperatorDistAttr {
void set_process_mesh(const ProcessMesh& process_mesh); void set_process_mesh(const ProcessMesh& process_mesh);
const std::string& op_type() const { return op_type_; }
void set_op_type(const std::string& op_type) { op_type_ = op_type; }
const std::string& impl_type() const { return impl_type_; } const std::string& impl_type() const { return impl_type_; }
void set_impl_type(const std::string& impl_type) { impl_type_ = impl_type; } void set_impl_type(const std::string& impl_type) { impl_type_ = impl_type; }
...@@ -207,6 +212,10 @@ class OperatorDistAttr { ...@@ -207,6 +212,10 @@ class OperatorDistAttr {
void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; } void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; }
bool is_recompute() const { return is_recompute_; }
void set_is_recompute(bool is_recompute) { is_recompute_ = is_recompute; }
const std::string& execution_stream() const { return execution_stream_; } const std::string& execution_stream() const { return execution_stream_; }
void set_execution_stream(const std::string& execution_stream) { void set_execution_stream(const std::string& execution_stream) {
...@@ -224,10 +233,12 @@ class OperatorDistAttr { ...@@ -224,10 +233,12 @@ class OperatorDistAttr {
void set_annotated(const std::map<std::string, bool>& annotated); void set_annotated(const std::map<std::string, bool>& annotated);
bool is_annotated(const std::string& name) const { bool is_annotated(const std::string& name) const {
return annotated_.count(name) == 1; return annotated_.count(name) == 1 && annotated_.at(name) == true;
} }
void annotate(const std::string& name); void mark_annotated(const std::string& name);
void clear_annotated();
const std::vector<int64_t>& input_dims_mapping(const std::string& name) const; const std::vector<int64_t>& input_dims_mapping(const std::string& name) const;
...@@ -240,16 +251,18 @@ class OperatorDistAttr { ...@@ -240,16 +251,18 @@ class OperatorDistAttr {
const std::vector<int64_t>& dims_mapping); const std::vector<int64_t>& dims_mapping);
bool verify_input_dist_attr(const std::string& name, bool verify_input_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) const; const TensorDistAttr& dist_attr,
const VarDesc* tensor) const;
bool verify_output_dist_attr(const std::string& name, bool verify_output_dist_attr(const std::string& name,
const TensorDistAttr& dist_attr) const; const TensorDistAttr& dist_attr,
const VarDesc* tensor) const;
bool verify_process_mesh(const ProcessMesh& process_mesh) const; bool verify_process_mesh(const ProcessMesh& process_mesh) const;
bool verify_annotated(const std::map<std::string, bool>& annotated) const; bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify() const; bool verify(const OpDesc* op = nullptr) const;
void rename_input(const std::string& old_name, const std::string& new_name); void rename_input(const std::string& old_name, const std::string& new_name);
...@@ -268,14 +281,13 @@ class OperatorDistAttr { ...@@ -268,14 +281,13 @@ class OperatorDistAttr {
private: private:
static std::vector<std::string> fields_; static std::vector<std::string> fields_;
const OpDesc* op_{nullptr};
std::map<std::string, VarDesc*> inputs_;
std::map<std::string, VarDesc*> outputs_;
std::map<std::string, TensorDistAttr> input_dist_attrs_; std::map<std::string, TensorDistAttr> input_dist_attrs_;
std::map<std::string, TensorDistAttr> output_dist_attrs_; std::map<std::string, TensorDistAttr> output_dist_attrs_;
ProcessMesh process_mesh_; ProcessMesh process_mesh_;
std::string impl_type_; std::string op_type_;
int64_t impl_idx_ = -1; std::string impl_type_ = kDefault;
int64_t impl_idx_ = 0;
bool is_recompute_ = false;
std::string execution_stream_; std::string execution_stream_;
int64_t scheduling_priority_; // lower value, higher priority, default to 0 int64_t scheduling_priority_; // lower value, higher priority, default to 0
std::map<std::string, bool> annotated_; std::map<std::string, bool> annotated_;
......
...@@ -67,15 +67,17 @@ TEST(DistAttr, ctor) { ...@@ -67,15 +67,17 @@ TEST(DistAttr, ctor) {
x_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1})); x_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1}));
x_dist_attr.set_batch_dim(0); x_dist_attr.set_batch_dim(0);
x_dist_attr.set_dynamic_dims(std::vector<bool>({true, false})); x_dist_attr.set_dynamic_dims(std::vector<bool>({true, false}));
x_dist_attr.annotate("process_mesh"); x_dist_attr.mark_annotated("process_mesh");
x_dist_attr.annotate("dims_mapping"); x_dist_attr.mark_annotated("dims_mapping");
EXPECT_EQ(x_dist_attr.process_mesh(), process_mesh); EXPECT_EQ(x_dist_attr.process_mesh(), process_mesh);
EXPECT_EQ(x_dist_attr.dims_mapping(), std::vector<int64_t>({0, -1})); EXPECT_EQ(x_dist_attr.dims_mapping(), std::vector<int64_t>({0, -1}));
EXPECT_EQ(x_dist_attr.batch_dim(), 0); EXPECT_EQ(x_dist_attr.batch_dim(), 0);
EXPECT_EQ(x_dist_attr.dynamic_dims(), std::vector<bool>({true, false})); EXPECT_EQ(x_dist_attr.dynamic_dims(), std::vector<bool>({true, false}));
EXPECT_EQ(x_dist_attr.is_annotated("process_mesh"), true); EXPECT_EQ(x_dist_attr.is_annotated("process_mesh"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dims_mapping"), true); EXPECT_EQ(x_dist_attr.is_annotated("dims_mapping"), true);
EXPECT_EQ(x_dist_attr.verify(), true); EXPECT_EQ(x_dist_attr.verify(x), true);
x_dist_attr.clear_annotated();
EXPECT_EQ(x_dist_attr.annotated().empty(), true);
std::stringstream x_sstream; std::stringstream x_sstream;
x_sstream << x_dist_attr; x_sstream << x_dist_attr;
...@@ -89,15 +91,15 @@ TEST(DistAttr, ctor) { ...@@ -89,15 +91,15 @@ TEST(DistAttr, ctor) {
y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, 0})); y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, 0}));
y_dist_attr.set_batch_dim(-1); y_dist_attr.set_batch_dim(-1);
y_dist_attr.set_dynamic_dims(std::vector<bool>({false, true})); y_dist_attr.set_dynamic_dims(std::vector<bool>({false, true}));
x_dist_attr.annotate("batch_dim"); x_dist_attr.mark_annotated("batch_dim");
x_dist_attr.annotate("dynamic_dims"); x_dist_attr.mark_annotated("dynamic_dims");
EXPECT_EQ(y_dist_attr.process_mesh(), process_mesh); EXPECT_EQ(y_dist_attr.process_mesh(), process_mesh);
EXPECT_EQ(y_dist_attr.dims_mapping(), std::vector<int64_t>({-1, 0})); EXPECT_EQ(y_dist_attr.dims_mapping(), std::vector<int64_t>({-1, 0}));
EXPECT_EQ(y_dist_attr.batch_dim(), 1); EXPECT_EQ(y_dist_attr.batch_dim(), 1);
EXPECT_EQ(y_dist_attr.dynamic_dims(), std::vector<bool>({false, true})); EXPECT_EQ(y_dist_attr.dynamic_dims(), std::vector<bool>({false, true}));
EXPECT_EQ(x_dist_attr.is_annotated("batch_dim"), true); EXPECT_EQ(x_dist_attr.is_annotated("batch_dim"), true);
EXPECT_EQ(x_dist_attr.is_annotated("dynamic_dims"), true); EXPECT_EQ(x_dist_attr.is_annotated("dynamic_dims"), true);
EXPECT_EQ(x_dist_attr.verify(), true); EXPECT_EQ(x_dist_attr.verify(y), true);
out_dist_attr.set_process_mesh(process_mesh); out_dist_attr.set_process_mesh(process_mesh);
out_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1})); out_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1}));
...@@ -107,18 +109,25 @@ TEST(DistAttr, ctor) { ...@@ -107,18 +109,25 @@ TEST(DistAttr, ctor) {
EXPECT_EQ(out_dist_attr.dims_mapping(), std::vector<int64_t>({0, 1})); EXPECT_EQ(out_dist_attr.dims_mapping(), std::vector<int64_t>({0, 1}));
EXPECT_EQ(out_dist_attr.batch_dim(), 1); EXPECT_EQ(out_dist_attr.batch_dim(), 1);
EXPECT_EQ(out_dist_attr.dynamic_dims(), std::vector<bool>({false, false})); EXPECT_EQ(out_dist_attr.dynamic_dims(), std::vector<bool>({false, false}));
EXPECT_EQ(out_dist_attr.verify(), true); EXPECT_EQ(out_dist_attr.verify(out), true);
OperatorDistAttr mul_dist_attr(*op); OperatorDistAttr mul_dist_attr(*op);
EXPECT_EQ(mul_dist_attr.impl_type(), kDefault);
EXPECT_EQ(mul_dist_attr.impl_idx(), -1);
EXPECT_EQ(mul_dist_attr.is_recompute(), false);
EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), false);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), false);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_idx"), false);
mul_dist_attr.set_input_dist_attr(x->Name(), x_dist_attr); mul_dist_attr.set_input_dist_attr(x->Name(), x_dist_attr);
mul_dist_attr.set_input_dist_attr(y->Name(), y_dist_attr); mul_dist_attr.set_input_dist_attr(y->Name(), y_dist_attr);
mul_dist_attr.set_output_dist_attr(out->Name(), out_dist_attr); mul_dist_attr.set_output_dist_attr(out->Name(), out_dist_attr);
mul_dist_attr.set_process_mesh(process_mesh2); mul_dist_attr.set_process_mesh(process_mesh2);
mul_dist_attr.set_impl_type("dist_mul"); mul_dist_attr.set_impl_type("dist_mul");
mul_dist_attr.set_impl_idx(0); mul_dist_attr.set_impl_idx(0);
mul_dist_attr.annotate("process_mesh"); mul_dist_attr.set_is_recompute(true);
mul_dist_attr.annotate("impl_type"); mul_dist_attr.mark_annotated("process_mesh");
mul_dist_attr.annotate("impl_idx"); mul_dist_attr.mark_annotated("impl_type");
mul_dist_attr.mark_annotated("impl_idx");
EXPECT_NE(mul_dist_attr.input_dist_attr(x->Name()), x_dist_attr); EXPECT_NE(mul_dist_attr.input_dist_attr(x->Name()), x_dist_attr);
EXPECT_NE(mul_dist_attr.input_dist_attr(y->Name()), y_dist_attr); EXPECT_NE(mul_dist_attr.input_dist_attr(y->Name()), y_dist_attr);
EXPECT_NE(mul_dist_attr.output_dist_attr(out->Name()), out_dist_attr); EXPECT_NE(mul_dist_attr.output_dist_attr(out->Name()), out_dist_attr);
...@@ -129,10 +138,13 @@ TEST(DistAttr, ctor) { ...@@ -129,10 +138,13 @@ TEST(DistAttr, ctor) {
process_mesh2); process_mesh2);
EXPECT_EQ(mul_dist_attr.impl_type(), "dist_mul"); EXPECT_EQ(mul_dist_attr.impl_type(), "dist_mul");
EXPECT_EQ(mul_dist_attr.impl_idx(), 0); EXPECT_EQ(mul_dist_attr.impl_idx(), 0);
EXPECT_EQ(mul_dist_attr.is_recompute(), true);
EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), true); EXPECT_EQ(mul_dist_attr.is_annotated("process_mesh"), true);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), true); EXPECT_EQ(mul_dist_attr.is_annotated("impl_type"), true);
EXPECT_EQ(mul_dist_attr.is_annotated("impl_idx"), true); EXPECT_EQ(mul_dist_attr.is_annotated("impl_idx"), true);
EXPECT_EQ(mul_dist_attr.verify(), true); EXPECT_EQ(mul_dist_attr.verify(op), true);
mul_dist_attr.clear_annotated();
EXPECT_EQ(mul_dist_attr.annotated().empty(), true);
std::stringstream mul_sstream; std::stringstream mul_sstream;
mul_sstream << mul_dist_attr; mul_sstream << mul_dist_attr;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h" #include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/utils/optional.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -32,6 +33,7 @@ using paddle::distributed::auto_parallel::Device; ...@@ -32,6 +33,7 @@ using paddle::distributed::auto_parallel::Device;
using paddle::distributed::auto_parallel::DeviceCapability; using paddle::distributed::auto_parallel::DeviceCapability;
using paddle::distributed::auto_parallel::DeviceMesh; using paddle::distributed::auto_parallel::DeviceMesh;
using paddle::distributed::auto_parallel::DistributedMapper; using paddle::distributed::auto_parallel::DistributedMapper;
using paddle::distributed::auto_parallel::kDefault;
using paddle::distributed::auto_parallel::Link; using paddle::distributed::auto_parallel::Link;
using paddle::distributed::auto_parallel::LinkCapability; using paddle::distributed::auto_parallel::LinkCapability;
using paddle::distributed::auto_parallel::Machine; using paddle::distributed::auto_parallel::Machine;
...@@ -41,22 +43,73 @@ using paddle::distributed::auto_parallel::TensorDistAttr; ...@@ -41,22 +43,73 @@ using paddle::distributed::auto_parallel::TensorDistAttr;
using paddle::framework::OpDesc; using paddle::framework::OpDesc;
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
static inline const ProcessMesh *get_tensor_process_mesh(
const TensorDistAttr &self) {
if (self.process_mesh().empty()) {
return nullptr;
} else {
return &self.process_mesh();
}
}
static inline void set_tensor_process_mesh(TensorDistAttr *self,
const ProcessMesh *process_mesh) {
if (process_mesh) {
self->set_process_mesh(*process_mesh);
} else {
self->set_process_mesh(ProcessMesh());
}
}
static inline const ProcessMesh *get_operator_process_mesh(
const OperatorDistAttr &self) {
if (self.process_mesh().empty()) {
return nullptr;
} else {
return &self.process_mesh();
}
}
static inline void set_operator_process_mesh(OperatorDistAttr *self,
const ProcessMesh *process_mesh) {
if (process_mesh) {
self->set_process_mesh(*process_mesh);
} else {
self->set_process_mesh(ProcessMesh());
}
}
static inline void reset_tensor_dist_attr(TensorDistAttr *dist_attr) {
dist_attr->set_process_mesh(ProcessMesh());
std::vector<int64_t> dims_mapping(dist_attr->dims_mapping().size(), -1);
dist_attr->set_dims_mapping(dims_mapping);
dist_attr->clear_annotated();
}
static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
for (auto &item : dist_attr->input_dist_attrs()) {
reset_tensor_dist_attr(&item.second);
}
for (auto &item : dist_attr->output_dist_attrs()) {
reset_tensor_dist_attr(&item.second);
}
dist_attr->set_impl_type(kDefault);
dist_attr->set_impl_idx(0);
dist_attr->clear_annotated();
}
void BindAutoParallel(py::module *m) { void BindAutoParallel(py::module *m) {
py::class_<ProcessMesh>(*m, "ProcessMesh") py::class_<ProcessMesh>(*m, "ProcessMesh")
.def(py::init<>())
.def(py::init<const std::vector<int64_t> &, .def(py::init<const std::vector<int64_t> &,
const std::vector<int64_t> &, const std::vector<int64_t> &,
const std::vector<std::string> &>(), const std::vector<std::string> &>(),
py::arg("shape"), py::arg("shape"),
py::arg("process_ids"), py::arg("process_ids"),
py::arg("dim_names")) py::arg("dim_names"))
.def_property_readonly( .def_property_readonly("shape", &ProcessMesh::shape)
"shape", &ProcessMesh::shape, py::return_value_policy::reference) .def_property_readonly("process_ids", &ProcessMesh::process_ids)
.def_property_readonly("process_ids", .def_property_readonly("dim_names", &ProcessMesh::dim_names)
&ProcessMesh::process_ids,
py::return_value_policy::reference)
.def_property_readonly("dim_names",
&ProcessMesh::dim_names,
py::return_value_policy::reference)
.def_property_readonly("size", &ProcessMesh::size) .def_property_readonly("size", &ProcessMesh::size)
.def_property_readonly("ndim", &ProcessMesh::ndim) .def_property_readonly("ndim", &ProcessMesh::ndim)
.def("dim_size", .def("dim_size",
...@@ -121,10 +174,8 @@ void BindAutoParallel(py::module *m) { ...@@ -121,10 +174,8 @@ void BindAutoParallel(py::module *m) {
py::class_<Machine>(*m, "Machine") py::class_<Machine>(*m, "Machine")
.def_property_readonly("id", &Machine::id) .def_property_readonly("id", &Machine::id)
.def_property_readonly( .def_property_readonly("devices", &Machine::devices)
"devices", &Machine::devices, py::return_value_policy::reference) .def_property_readonly("links", &Machine::links)
.def_property_readonly(
"links", &Machine::links, py::return_value_policy::reference)
.def("device", &Machine::device) .def("device", &Machine::device)
.def("link", &Machine::link) .def("link", &Machine::link)
.def("contains", &Machine::contains) .def("contains", &Machine::contains)
...@@ -141,21 +192,14 @@ void BindAutoParallel(py::module *m) { ...@@ -141,21 +192,14 @@ void BindAutoParallel(py::module *m) {
py::arg("dim_names")) py::arg("dim_names"))
.def_property_readonly("name", &DeviceMesh::name) .def_property_readonly("name", &DeviceMesh::name)
.def_property_readonly("shape", &DeviceMesh::shape) .def_property_readonly("shape", &DeviceMesh::shape)
.def_property_readonly("device_ids", .def_property_readonly("device_ids", &DeviceMesh::device_ids)
&DeviceMesh::device_ids, .def_property_readonly("dim_names", &DeviceMesh::dim_names)
py::return_value_policy::reference)
.def_property_readonly("dim_names",
&DeviceMesh::dim_names,
py::return_value_policy::reference)
.def_property_readonly("device_type", &DeviceMesh::device_type) .def_property_readonly("device_type", &DeviceMesh::device_type)
.def_property_readonly("size", &DeviceMesh::size) .def_property_readonly("size", &DeviceMesh::size)
.def_property_readonly("ndim", &DeviceMesh::ndim) .def_property_readonly("ndim", &DeviceMesh::ndim)
.def_property_readonly( .def_property_readonly("devices", &DeviceMesh::devices)
"devices", &DeviceMesh::devices, py::return_value_policy::reference) .def_property_readonly("links", &DeviceMesh::links)
.def_property_readonly( .def_property_readonly("machines", &DeviceMesh::machines)
"links", &DeviceMesh::links, py::return_value_policy::reference)
.def_property_readonly(
"machines", &DeviceMesh::machines, py::return_value_policy::reference)
.def("device", &DeviceMesh::device) .def("device", &DeviceMesh::device)
.def("link", &DeviceMesh::link) .def("link", &DeviceMesh::link)
.def("machine", &DeviceMesh::machine) .def("machine", &DeviceMesh::machine)
...@@ -182,11 +226,11 @@ void BindAutoParallel(py::module *m) { ...@@ -182,11 +226,11 @@ void BindAutoParallel(py::module *m) {
.def("__str__", &DeviceMesh::to_string); .def("__str__", &DeviceMesh::to_string);
py::class_<TensorDistAttr>(*m, "TensorDistAttr") py::class_<TensorDistAttr>(*m, "TensorDistAttr")
.def(py::init<>())
.def(py::init<const VarDesc &>()) .def(py::init<const VarDesc &>())
.def_property_readonly("tensor", &TensorDistAttr::tensor) .def(py::init<const TensorDistAttr &>())
.def_property("process_mesh", .def_property(
&TensorDistAttr::process_mesh, "process_mesh", &get_tensor_process_mesh, &set_tensor_process_mesh)
&TensorDistAttr::set_process_mesh)
.def_property("dims_mapping", .def_property("dims_mapping",
&TensorDistAttr::dims_mapping, &TensorDistAttr::dims_mapping,
&TensorDistAttr::set_dims_mapping) &TensorDistAttr::set_dims_mapping)
...@@ -200,8 +244,12 @@ void BindAutoParallel(py::module *m) { ...@@ -200,8 +244,12 @@ void BindAutoParallel(py::module *m) {
&TensorDistAttr::annotated, &TensorDistAttr::annotated,
&TensorDistAttr::set_annotated) &TensorDistAttr::set_annotated)
.def("is_annotated", &TensorDistAttr::is_annotated) .def("is_annotated", &TensorDistAttr::is_annotated)
.def("annotate", &TensorDistAttr::annotate) .def("mark_annotated", &TensorDistAttr::mark_annotated)
.def("verify", &TensorDistAttr::verify) .def("clear_annotated", &TensorDistAttr::clear_annotated)
.def("verify",
&TensorDistAttr::verify,
py::arg("tensor") = static_cast<VarDesc *>(nullptr))
.def("reset", &reset_tensor_dist_attr)
.def("serialize_to_string", .def("serialize_to_string",
[](TensorDistAttr &self) { [](TensorDistAttr &self) {
return py::bytes(self.serialize_to_string()); return py::bytes(self.serialize_to_string());
...@@ -209,20 +257,34 @@ void BindAutoParallel(py::module *m) { ...@@ -209,20 +257,34 @@ void BindAutoParallel(py::module *m) {
.def("parse_from_string", &TensorDistAttr::parse_from_string) .def("parse_from_string", &TensorDistAttr::parse_from_string)
.def(py::self == py::self) .def(py::self == py::self)
.def(py::self != py::self) .def(py::self != py::self)
.def("__copy__",
[](const TensorDistAttr &self) { return TensorDistAttr(self); })
.def(
"__deepcopy__",
[](const TensorDistAttr &self, py::dict) {
return TensorDistAttr(self);
},
py::arg("memo"))
.def("__str__", &TensorDistAttr::to_string); .def("__str__", &TensorDistAttr::to_string);
py::class_<OperatorDistAttr>(*m, "OperatorDistAttr") py::class_<OperatorDistAttr>(*m, "OperatorDistAttr")
.def(py::init<>())
.def(py::init<const OpDesc &>()) .def(py::init<const OpDesc &>())
.def_property_readonly("op", &OperatorDistAttr::op) .def(py::init<const OperatorDistAttr &>())
.def_property(
"op_type", &OperatorDistAttr::op_type, &OperatorDistAttr::set_op_type)
.def_property("process_mesh", .def_property("process_mesh",
&OperatorDistAttr::process_mesh, &get_operator_process_mesh,
&OperatorDistAttr::set_process_mesh) &set_operator_process_mesh)
.def_property("impl_type", .def_property("impl_type",
&OperatorDistAttr::impl_type, &OperatorDistAttr::impl_type,
&OperatorDistAttr::set_impl_type) &OperatorDistAttr::set_impl_type)
.def_property("impl_idx", .def_property("impl_idx",
&OperatorDistAttr::impl_idx, &OperatorDistAttr::impl_idx,
&OperatorDistAttr::set_impl_idx) &OperatorDistAttr::set_impl_idx)
.def_property("is_recompute",
&OperatorDistAttr::is_recompute,
&OperatorDistAttr::set_is_recompute)
.def_property("execution_stream", .def_property("execution_stream",
&OperatorDistAttr::execution_stream, &OperatorDistAttr::execution_stream,
&OperatorDistAttr::set_execution_stream) &OperatorDistAttr::set_execution_stream)
...@@ -232,14 +294,16 @@ void BindAutoParallel(py::module *m) { ...@@ -232,14 +294,16 @@ void BindAutoParallel(py::module *m) {
.def_property("annotated", .def_property("annotated",
&OperatorDistAttr::annotated, &OperatorDistAttr::annotated,
&OperatorDistAttr::set_annotated) &OperatorDistAttr::set_annotated)
.def_property("inputs_dist_attrs", .def_property(
&OperatorDistAttr::input_dist_attrs, "inputs_dist_attrs",
&OperatorDistAttr::set_input_dist_attrs) static_cast<std::map<std::string, TensorDistAttr> &(
.def_property("outputs_dist_attrs", OperatorDistAttr::*)()>(&OperatorDistAttr::input_dist_attrs),
&OperatorDistAttr::output_dist_attrs, &OperatorDistAttr::set_input_dist_attrs)
&OperatorDistAttr::set_output_dist_attrs) .def_property(
.def("input", &OperatorDistAttr::input) "outputs_dist_attrs",
.def("output", &OperatorDistAttr::output) static_cast<std::map<std::string, TensorDistAttr> &(
OperatorDistAttr::*)()>(&OperatorDistAttr::output_dist_attrs),
&OperatorDistAttr::set_output_dist_attrs)
.def("get_input_dist_attr", .def("get_input_dist_attr",
static_cast<TensorDistAttr &( static_cast<TensorDistAttr &(
OperatorDistAttr::*)(const std::string &)>( OperatorDistAttr::*)(const std::string &)>(
...@@ -252,14 +316,40 @@ void BindAutoParallel(py::module *m) { ...@@ -252,14 +316,40 @@ void BindAutoParallel(py::module *m) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def("set_input_dist_attr", &OperatorDistAttr::set_input_dist_attr) .def("set_input_dist_attr", &OperatorDistAttr::set_input_dist_attr)
.def("set_output_dist_attr", &OperatorDistAttr::set_output_dist_attr) .def("set_output_dist_attr", &OperatorDistAttr::set_output_dist_attr)
.def("del_input_dist_attr", // TODO(aoyulong): move into dist_attr.cc
[](OperatorDistAttr &self, const std::string &name) {
self.input_dist_attrs().erase(name);
})
.def("del_output_dist_attr", // TODO(aoyulong): move into dist_attr.cc
[](OperatorDistAttr &self, const std::string &name) {
self.output_dist_attrs().erase(name);
})
.def("is_annotated", &OperatorDistAttr::is_annotated) .def("is_annotated", &OperatorDistAttr::is_annotated)
.def("annotate", &OperatorDistAttr::annotate) .def("mark_annotated", &OperatorDistAttr::mark_annotated)
.def("get_input_dims_mapping", &OperatorDistAttr::input_dims_mapping) .def("clear_annotated", &OperatorDistAttr::clear_annotated)
.def("get_input_dims_mapping",
&OperatorDistAttr::input_dims_mapping,
py::return_value_policy::reference)
.def("set_input_dims_mapping", &OperatorDistAttr::set_input_dims_mapping) .def("set_input_dims_mapping", &OperatorDistAttr::set_input_dims_mapping)
.def("get_output_dims_mapping", &OperatorDistAttr::output_dims_mapping) .def("get_output_dims_mapping",
&OperatorDistAttr::output_dims_mapping,
py::return_value_policy::reference)
.def("set_output_dims_mapping", .def("set_output_dims_mapping",
&OperatorDistAttr::set_output_dims_mapping) &OperatorDistAttr::set_output_dims_mapping)
.def("verify", &OperatorDistAttr::verify) .def("verify",
&OperatorDistAttr::verify,
py::arg("op") = static_cast<OpDesc *>(nullptr))
.def("is_annotated_input_dims_mapping",
[](const OperatorDistAttr &self, const std::string &name) {
return self.input_dist_attr(name).is_annotated("dims_mapping");
})
.def("is_annotated_output_dims_mapping",
[](const OperatorDistAttr &self, const std::string &name) {
return self.output_dist_attr(name).is_annotated("dims_mapping");
})
.def("rename_input", &OperatorDistAttr::rename_input)
.def("rename_output", &OperatorDistAttr::rename_output)
.def("reset", &reset_operator_dist_attr)
.def("serialize_to_string", .def("serialize_to_string",
[](OperatorDistAttr &self) { [](OperatorDistAttr &self) {
return py::bytes(self.serialize_to_string()); return py::bytes(self.serialize_to_string());
......
...@@ -18,10 +18,7 @@ import logging ...@@ -18,10 +18,7 @@ import logging
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import core from paddle.fluid import core
from .dist_attribute import ( from .dist_attribute import OperatorDistAttr, TensorDistAttr
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from .dist_context import _node_id from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .process_group import get_world_process_group from .process_group import get_world_process_group
...@@ -610,10 +607,10 @@ class Completer: ...@@ -610,10 +607,10 @@ class Completer:
return related_nodes return related_nodes
def _make_dims_mapping_replicate(dist_attr): def _make_dims_mapping_replicate(dist_attr):
if isinstance(dist_attr, TensorDistributedAttribute): if isinstance(dist_attr, TensorDistAttr):
for i, _ in enumerate(dist_attr.dims_mapping): for i, _ in enumerate(dist_attr.dims_mapping):
dist_attr.dims_mapping[i] = -1 dist_attr.dims_mapping[i] = -1
if isinstance(dist_attr, OperatorDistributedAttribute): if isinstance(dist_attr, OperatorDistAttr):
for arg_name in dist_attr.inputs_dist_attrs.keys(): for arg_name in dist_attr.inputs_dist_attrs.keys():
new_dims_mapping = [] new_dims_mapping = []
dims_mapping = dist_attr.get_input_dims_mapping(arg_name) dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
...@@ -942,6 +939,7 @@ class Completer: ...@@ -942,6 +939,7 @@ class Completer:
self._dist_context._serial_main_program = serial_main_program self._dist_context._serial_main_program = serial_main_program
if not is_naive_data_parallel(self._dist_context): if not is_naive_data_parallel(self._dist_context):
print("$$$$$$ here 0", flush=True)
self._dist_context.initialize(with_graph=True) self._dist_context.initialize(with_graph=True)
self._prepare() self._prepare()
self._update_process_mesh() self._update_process_mesh()
...@@ -949,6 +947,7 @@ class Completer: ...@@ -949,6 +947,7 @@ class Completer:
# Copy the corresponding distributed attribute from graph to serial_main_program # Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program() self._dist_context.copy_dist_attr_from_graph_to_program()
else: else:
print("$$$$$$ here 2", flush=True)
self._logger.info("Default distributed attributed will be set.") self._logger.info("Default distributed attributed will be set.")
self._dist_context.initialize(with_graph=False) self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel # A fast and special completion for data parallel
...@@ -1185,7 +1184,7 @@ class Completer: ...@@ -1185,7 +1184,7 @@ class Completer:
self._dist_context.get_op_dist_attr_for_program(forward_op) self._dist_context.get_op_dist_attr_for_program(forward_op)
) )
fwd_op_process_mesh = fwd_op_dist_attr.process_mesh fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = fwd_op_process_mesh grad_op_dist_attr.process_mesh = fwd_op_process_mesh
for input_name in grad_op.input_arg_names: for input_name in grad_op.input_arg_names:
...@@ -1235,7 +1234,7 @@ class Completer: ...@@ -1235,7 +1234,7 @@ class Completer:
) )
# var # var
output_var = vars[output_name] output_var = vars[output_name]
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = fwd_op_process_mesh tensor_dist_attr.process_mesh = fwd_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
...@@ -1273,7 +1272,7 @@ class Completer: ...@@ -1273,7 +1272,7 @@ class Completer:
ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh
# output # output
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name] output_var = vars[output_name]
...@@ -1281,7 +1280,7 @@ class Completer: ...@@ -1281,7 +1280,7 @@ class Completer:
output_var, tensor_dist_attr output_var, tensor_dist_attr
) )
# op # op
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
for var_name in grad_op.input_arg_names: for var_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping( grad_op_dist_attr.set_input_dims_mapping(
...@@ -1302,7 +1301,7 @@ class Completer: ...@@ -1302,7 +1301,7 @@ class Completer:
ref_dims_mapping = ref_dist_attr.dims_mapping ref_dims_mapping = ref_dist_attr.dims_mapping
ref_process_mesh = ref_dist_attr.process_mesh ref_process_mesh = ref_dist_attr.process_mesh
# output # output
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.process_mesh = ref_process_mesh
output_var_name = grad_op.output_arg_names[0] output_var_name = grad_op.output_arg_names[0]
...@@ -1311,7 +1310,7 @@ class Completer: ...@@ -1311,7 +1310,7 @@ class Completer:
output_var, tensor_dist_attr output_var, tensor_dist_attr
) )
# op # op
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = ref_process_mesh grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping( grad_op_dist_attr.set_input_dims_mapping(
ref_var_name, ref_dims_mapping ref_var_name, ref_dims_mapping
...@@ -1401,7 +1400,7 @@ class Completer: ...@@ -1401,7 +1400,7 @@ class Completer:
forward_var = vars[forward_var_name] forward_var = vars[forward_var_name]
# TODO complete other attribte for grad var # TODO complete other attribte for grad var
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
process_mesh = ( process_mesh = (
self._dist_context.get_tensor_dist_attr_for_program( self._dist_context.get_tensor_dist_attr_for_program(
forward_var forward_var
...@@ -1418,7 +1417,7 @@ class Completer: ...@@ -1418,7 +1417,7 @@ class Completer:
grad_var, tensor_dist_attr grad_var, tensor_dist_attr
) )
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = process_mesh op_dist_attr.process_mesh = process_mesh
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
grad_var.name, dims_mapping grad_var.name, dims_mapping
...@@ -1459,13 +1458,13 @@ class Completer: ...@@ -1459,13 +1458,13 @@ class Completer:
) )
ref_mesh = forward_op_dist_attr.process_mesh ref_mesh = forward_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistAttr()
for input_name in grad_op.input_arg_names: for input_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping( grad_op_dist_attr.set_input_dims_mapping(
input_name, ref_dims_mapping input_name, ref_dims_mapping
) )
output_var_dist_attr = TensorDistributedAttribute() output_var_dist_attr = TensorDistAttr()
output_var_dist_attr.dims_mapping = ref_dims_mapping output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = ref_mesh output_var_dist_attr.process_mesh = ref_mesh
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
...@@ -1492,7 +1491,7 @@ class Completer: ...@@ -1492,7 +1491,7 @@ class Completer:
self._dist_context.get_op_dist_attr_for_program(forward_op) self._dist_context.get_op_dist_attr_for_program(forward_op)
) )
fwd_op_process_mesh = fwd_op_dist_attr.process_mesh fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = fwd_op_process_mesh grad_op_dist_attr.process_mesh = fwd_op_process_mesh
for input_name in grad_op.input_arg_names: for input_name in grad_op.input_arg_names:
...@@ -1540,7 +1539,7 @@ class Completer: ...@@ -1540,7 +1539,7 @@ class Completer:
) )
# var # var
output_var = vars[output_name] output_var = vars[output_name]
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = fwd_op_process_mesh tensor_dist_attr.process_mesh = fwd_op_process_mesh
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
...@@ -1556,7 +1555,6 @@ class Completer: ...@@ -1556,7 +1555,6 @@ class Completer:
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr grad_op, grad_op_dist_attr
) )
# grad ops that have not a corresponding mapping in grad_op_id_to_op_id # grad ops that have not a corresponding mapping in grad_op_id_to_op_id
else: else:
if grad_op.type == 'sum': if grad_op.type == 'sum':
...@@ -1578,7 +1576,7 @@ class Completer: ...@@ -1578,7 +1576,7 @@ class Completer:
ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh
# output # output
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name] output_var = vars[output_name]
...@@ -1587,7 +1585,7 @@ class Completer: ...@@ -1587,7 +1585,7 @@ class Completer:
) )
# op # op
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh grad_op_dist_attr.process_mesh = ref_fwd_process_mesh
for var_name in grad_op.input_arg_names: for var_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping( grad_op_dist_attr.set_input_dims_mapping(
...@@ -1610,7 +1608,7 @@ class Completer: ...@@ -1610,7 +1608,7 @@ class Completer:
ref_dims_mapping = ref_dist_attr.dims_mapping ref_dims_mapping = ref_dist_attr.dims_mapping
ref_process_mesh = ref_dist_attr.process_mesh ref_process_mesh = ref_dist_attr.process_mesh
# output # output
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.process_mesh = ref_process_mesh
output_var_name = grad_op.output_arg_names[0] output_var_name = grad_op.output_arg_names[0]
...@@ -1619,7 +1617,7 @@ class Completer: ...@@ -1619,7 +1617,7 @@ class Completer:
output_var, tensor_dist_attr output_var, tensor_dist_attr
) )
# op # op
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = ref_process_mesh grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping( grad_op_dist_attr.set_input_dims_mapping(
ref_var_name, ref_dims_mapping ref_var_name, ref_dims_mapping
...@@ -1670,8 +1668,9 @@ class Completer: ...@@ -1670,8 +1668,9 @@ class Completer:
"elementwise_div", "elementwise_div",
]: ]:
# complete op dist_attr with global world ranks # complete op dist_attr with global world ranks
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = world_ranks op_dist_attr.process_mesh = ProcessMesh(world_ranks)
for in_name in op.input_arg_names: for in_name in op.input_arg_names:
in_var = vars[in_name] in_var = vars[in_name]
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
...@@ -1682,8 +1681,10 @@ class Completer: ...@@ -1682,8 +1681,10 @@ class Completer:
) )
for out_name in op.output_arg_names: for out_name in op.output_arg_names:
out_var = vars[out_name] out_var = vars[out_name]
out_dist_attr = TensorDistributedAttribute() out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = world_ranks out_dist_attr.process_mesh = ProcessMesh(
world_ranks
)
out_dist_attr.dims_mapping = [ out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape)) -1 for _ in range(len(out_var.shape))
] ]
...@@ -1709,7 +1710,9 @@ class Completer: ...@@ -1709,7 +1710,9 @@ class Completer:
op.type == "cast" op.type == "cast"
and ops[idx + 1].type == "elementwise_mul" and ops[idx + 1].type == "elementwise_mul"
): ):
ref_var = vars[ops[idx + 1].input("X")[0]] ref_var = vars[
ops[idx + 1].input("X")[0]
] # elementwise_mul 的输入
ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
ref_var ref_var
) )
...@@ -1718,7 +1721,7 @@ class Completer: ...@@ -1718,7 +1721,7 @@ class Completer:
# complete out_var's tensor_dist_attr # complete out_var's tensor_dist_attr
out_var = vars[op.output("Out")[0]] out_var = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute() out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = ref_process_mesh out_dist_attr.process_mesh = ref_process_mesh
if out_var.shape == in_var.shape: if out_var.shape == in_var.shape:
out_dist_attr.dims_mapping = ref_dims_mapping out_dist_attr.dims_mapping = ref_dims_mapping
...@@ -1734,7 +1737,7 @@ class Completer: ...@@ -1734,7 +1737,7 @@ class Completer:
# complete op'd dist_attr # complete op'd dist_attr
# complete op process_mesh with input_var's process_mesh # complete op process_mesh with input_var's process_mesh
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = ref_process_mesh op_dist_attr.process_mesh = ref_process_mesh
for in_name in op.input_arg_names: for in_name in op.input_arg_names:
in_var = vars[in_name] in_var = vars[in_name]
...@@ -1785,7 +1788,7 @@ class Completer: ...@@ -1785,7 +1788,7 @@ class Completer:
).dims_mapping ).dims_mapping
) )
assert ref_dims_mapping is not None assert ref_dims_mapping is not None
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = ref_process_mesh op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
grad_var.name, ref_dims_mapping grad_var.name, ref_dims_mapping
...@@ -1804,8 +1807,8 @@ class Completer: ...@@ -1804,8 +1807,8 @@ class Completer:
if not learning_rate_completed: if not learning_rate_completed:
learning_rate_completed = True learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute() var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = world_ranks var_dist_attr.process_mesh = ProcessMesh(world_ranks)
var_dist_attr.dims_mapping = [-1] var_dist_attr.dims_mapping = [-1]
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr learning_var, var_dist_attr
...@@ -1817,7 +1820,6 @@ class Completer: ...@@ -1817,7 +1820,6 @@ class Completer:
'Param', 'Param',
'Grad', 'Grad',
'LearningRate', 'LearningRate',
"SkipUpdate",
"Beta1Tensor", "Beta1Tensor",
"Beta2Tensor", "Beta2Tensor",
"EpsilonTensor", "EpsilonTensor",
...@@ -1828,9 +1830,13 @@ class Completer: ...@@ -1828,9 +1830,13 @@ class Completer:
assert len(op.desc.input(input_name)) == 1 assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]] input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistributedAttribute() input_var_attr = TensorDistAttr()
if "Beta1Pow" in input_name or "Beta2Pow" in input_name: if (
"Beta1Pow" in input_name
or "Beta2Pow" in input_name
or "SkipUpdate" in input_name
):
input_var_attr.dims_mapping = [-1] input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
input_var.name, [-1] input_var.name, [-1]
...@@ -1894,12 +1900,12 @@ class Completer: ...@@ -1894,12 +1900,12 @@ class Completer:
tensor tensor
) )
assert dist_tensor is not None assert dist_tensor is not None
dist_tensor.dist_attr.process_mesh = world_ranks dist_tensor.dist_attr.process_mesh = ProcessMesh(world_ranks)
for op in block.ops: for op in block.ops:
# Copy the distributed operators in the default context # Copy the distributed operators in the default context
dist_op = self._dist_context.get_dist_op_for_program(op) dist_op = self._dist_context.get_dist_op_for_program(op)
assert dist_op is not None assert dist_op is not None
dist_op.dist_attr.process_mesh = world_ranks dist_op.dist_attr.process_mesh = ProcessMesh(world_ranks)
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_compatible_distributed_operator_impls( op_dist_impls = find_compatible_distributed_operator_impls(
......
...@@ -431,11 +431,20 @@ def build_dp_costs( ...@@ -431,11 +431,20 @@ def build_dp_costs(
desc = {} desc = {}
desc["op"] = op_type desc["op"] = op_type
desc["inputs"] = {} desc["inputs"] = {}
dims_mapping = ( if var_name in dist_attr.inputs_dist_attrs:
dist_attr.get_input_dims_mapping(var_name) dims_mapping = dist_attr.get_input_dims_mapping(var_name)
if dist_attr.get_input_dims_mapping(var_name) is not None elif var_name in dist_attr.outputs_dist_attrs:
else dist_attr.get_output_dims_mapping(var_name) dims_mapping = dist_attr.get_output_dims_mapping(var_name)
) else:
assert False, "cannot find dims_mapping for {} in {}".format(
var_name, dist_attr
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var = get_var_with_recursion( var = get_var_with_recursion(
var_name, var_name,
dist_op.serial_op.block, dist_op.serial_op.block,
......
...@@ -448,7 +448,7 @@ class DistributedContext: ...@@ -448,7 +448,7 @@ class DistributedContext:
def add_process_mesh(self, process_mesh): def add_process_mesh(self, process_mesh):
assert isinstance( assert isinstance(
process_mesh, ProcessMesh process_mesh, (ProcessMesh, core.ProcessMesh)
), 'The type of dim_mapping must be ProcessMesh.' ), 'The type of dim_mapping must be ProcessMesh.'
if process_mesh not in self.process_meshes: if process_mesh not in self.process_meshes:
self._process_meshes.append(process_mesh) self._process_meshes.append(process_mesh)
...@@ -883,6 +883,7 @@ class DistributedContext: ...@@ -883,6 +883,7 @@ class DistributedContext:
dims_mapping[i] = -1 dims_mapping[i] = -1
if dims_mapping[i] != -1 and len(process_mesh_processes) == 1: if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
dims_mapping[i] = -1 dims_mapping[i] = -1
dist_attr.dims_mapping = dims_mapping
for dist_op in self._dist_ops_for_program.values(): for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
...@@ -916,6 +917,7 @@ class DistributedContext: ...@@ -916,6 +917,7 @@ class DistributedContext:
and len(process_mesh_processes) == 1 and len(process_mesh_processes) == 1
): ):
dims_mapping[i] = -1 dims_mapping[i] = -1
dist_attr.set_input_dims_mapping(arg_name, dims_mapping)
for arg_name in serial_op.output_arg_names: for arg_name in serial_op.output_arg_names:
if ( if (
dist_op.get_serial_output(arg_name).type dist_op.get_serial_output(arg_name).type
...@@ -940,6 +942,7 @@ class DistributedContext: ...@@ -940,6 +942,7 @@ class DistributedContext:
and len(process_mesh_processes) == 1 and len(process_mesh_processes) == 1
): ):
dims_mapping[i] = -1 dims_mapping[i] = -1
dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
if len(process_mesh_processes) == 1: if len(process_mesh_processes) == 1:
dist_op.dist_attr.impl_type = "default" dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0 dist_op.dist_attr.impl_idx = 0
......
...@@ -17,11 +17,7 @@ import copy ...@@ -17,11 +17,7 @@ import copy
import paddle import paddle
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from .dist_attribute import ( from .dist_attribute import OperatorDistAttr
OperatorDistributedAttribute,
append_op_input_suffix,
append_op_output_suffix,
)
from .utils import ( from .utils import (
__no_shape_var_type__, __no_shape_var_type__,
convert_to_shard_spec, convert_to_shard_spec,
...@@ -32,11 +28,20 @@ from .utils import ( ...@@ -32,11 +28,20 @@ from .utils import (
class DistributedOperator: class DistributedOperator:
def __init__(self, serial_op, dist_attr=None): def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op self._serial_op = serial_op
if dist_attr is not None and isinstance(dist_attr, OperatorDistAttr):
pass
# TODO: remove this deepcopy after we fix the issue
self._dist_attr = copy.deepcopy(dist_attr)
# self._dist_attr = dist_attr
# TODO: Do we really need to write back to serial op?
self._serial_op.dist_attr = dist_attr
else:
assert dist_attr is None, "{}".format(dist_attr)
# Use the dist attr of serial_op to do the initialization
self._dist_attr = self._serial_op.dist_attr
self._serial_inputs = {} self._serial_inputs = {}
self._serial_outputs = {} self._serial_outputs = {}
self._dist_attr = None
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property @property
def serial_op(self): def serial_op(self):
...@@ -48,102 +53,110 @@ class DistributedOperator: ...@@ -48,102 +53,110 @@ class DistributedOperator:
@dist_attr.setter @dist_attr.setter
def dist_attr(self, dist_attr): def dist_attr(self, dist_attr):
if self._dist_attr is None: self._dist_attr = dist_attr
self._dist_attr = OperatorDistributedAttribute() # TODO: Do we really need to write back to serial op?
# Create new dist_attr related to current serial_op self._serial_op.dist_attr = dist_attr
dist_attr = self._filter_dist_attr(dist_attr) # if self._dist_attr is None:
# Append suffix to mark the inputs or outputs # self._dist_attr = OperatorDistAttr()
if isinstance(dist_attr, dict): # # Create new dist_attr related to current serial_op
# Copy the keys since we may add new ones # dist_attr = self._filter_dist_attr(dist_attr)
for key in list(dist_attr.keys()): # # Append suffix to mark the inputs or outputs
if isinstance(key, Variable): # if isinstance(dist_attr, dict):
if key.name in self._serial_op.input_arg_names: # # Copy the keys since we may add new ones
dist_attr[append_op_input_suffix(key.name)] = True # for key in list(dist_attr.keys()):
if key.name in self._serial_op.output_arg_names: # if isinstance(key, Variable):
dist_attr[append_op_output_suffix(key.name)] = True # if key.name in self._serial_op.input_arg_names:
self._dist_attr.init(dist_attr) # dist_attr[append_op_input_suffix(key.name)] = True
self._init_default_dist_attr() # if key.name in self._serial_op.output_arg_names:
# dist_attr[append_op_output_suffix(key.name)] = True
# self._dist_attr.init(dist_attr)
# self._init_default_dist_attr()
def get_serial_input(self, name): def get_serial_input(self, name):
return self._serial_inputs.get(name, None) if self._serial_op.type == "create_py_reader":
tensor = None
else:
tensor = self._serial_op.block._var_recursive(name)
return tensor
def get_serial_output(self, name): def get_serial_output(self, name):
return self._serial_outputs.get(name, None) tensor = self._serial_op.block._var_recursive(name)
return tensor
def _init_default_dist_attr(self):
for tensor_name in self._serial_op.input_arg_names: # def _init_default_dist_attr(self):
if self._serial_op.type == "create_py_reader": # for tensor_name in self._serial_op.input_arg_names:
tensor = None # if self._serial_op.type == "create_py_reader":
else: # tensor = None
tensor = self._serial_op.block._var_recursive(tensor_name) # else:
self._serial_inputs[tensor_name] = tensor # tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor is None: # self._serial_inputs[tensor_name] = tensor
tensor_shape = [] # if tensor is None:
else: # tensor_shape = []
if tensor.type in __no_shape_var_type__: # else:
tensor_shape = [] # if tensor.type in __no_shape_var_type__:
else: # tensor_shape = []
tensor_shape = tensor.shape # else:
if self._dist_attr.get_input_dims_mapping(tensor_name) is None: # tensor_shape = tensor.shape
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] # if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
self._dist_attr.set_input_dims_mapping( # tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
tensor_name, tensor_dims_mapping # self._dist_attr.set_input_dims_mapping(
) # tensor_name, tensor_dims_mapping
for tensor_name in self._serial_op.output_arg_names: # )
tensor = self._serial_op.block._var_recursive(tensor_name) # for tensor_name in self._serial_op.output_arg_names:
if tensor.type in __no_shape_var_type__: # tensor = self._serial_op.block._var_recursive(tensor_name)
tensor_shape = [] # if tensor.type in __no_shape_var_type__:
else: # tensor_shape = []
tensor_shape = tensor.shape # else:
self._serial_outputs[tensor_name] = tensor # tensor_shape = tensor.shape
if self._dist_attr.get_output_dims_mapping(tensor_name) is None: # self._serial_outputs[tensor_name] = tensor
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] # if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
self._dist_attr.set_output_dims_mapping( # tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
tensor_name, tensor_dims_mapping # self._dist_attr.set_output_dims_mapping(
) # tensor_name, tensor_dims_mapping
if self._dist_attr.op_type is None: # )
self._dist_attr.op_type = self.serial_op.type # if self._dist_attr.op_type is None:
if self._dist_attr.impl_type is None: # self._dist_attr.op_type = self.serial_op.type
self._dist_attr.impl_type = "default" # if self._dist_attr.impl_type is None:
if self._dist_attr.impl_idx is None: # self._dist_attr.impl_type = "default"
self._dist_attr.impl_idx = 0 # if self._dist_attr.impl_idx is None:
if self._dist_attr.is_recompute is None: # self._dist_attr.impl_idx = 0
self._dist_attr.is_recompute = False # if self._dist_attr.is_recompute is None:
# self._dist_attr.is_recompute = False
def _filter_dist_attr(self, dist_attr):
if dist_attr is None: # def _filter_dist_attr(self, dist_attr):
return None # if dist_attr is None:
new_dist_attr = None # return None
if isinstance(dist_attr, dict): # new_dist_attr = None
new_dist_attr = {} # if isinstance(dist_attr, dict):
for key, value in dist_attr.items(): # new_dist_attr = {}
if isinstance(key, Variable): # for key, value in dist_attr.items():
if ( # if isinstance(key, Variable):
key.name in self._serial_op.input_arg_names # if (
or key.name in self._serial_op.output_arg_names # key.name in self._serial_op.input_arg_names
): # or key.name in self._serial_op.output_arg_names
new_dist_attr[key] = value # ):
else: # new_dist_attr[key] = value
new_dist_attr[key] = value # else:
elif isinstance(dist_attr, OperatorDistributedAttribute): # new_dist_attr[key] = value
new_dist_attr = copy.deepcopy(dist_attr) # elif isinstance(dist_attr, OperatorDistAttr):
new_dist_attr._inputs_dist_attrs.clear() # new_dist_attr = copy.deepcopy(dist_attr)
new_dist_attr._outputs_dist_attrs.clear() # new_dist_attr._inputs_dist_attrs.clear()
for tensor_name in self._serial_op.input_arg_names: # new_dist_attr._outputs_dist_attrs.clear()
tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name) # for tensor_name in self._serial_op.input_arg_names:
if tensor_dist_attr: # tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
new_dist_attr.set_input_dist_attr( # if tensor_dist_attr:
tensor_name, tensor_dist_attr # new_dist_attr.set_input_dist_attr(
) # tensor_name, tensor_dist_attr
for tensor_name in self._serial_op.output_arg_names: # )
tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name) # for tensor_name in self._serial_op.output_arg_names:
if tensor_dist_attr: # tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
new_dist_attr.set_output_dist_attr( # if tensor_dist_attr:
tensor_name, tensor_dist_attr # new_dist_attr.set_output_dist_attr(
) # tensor_name, tensor_dist_attr
else: # )
assert False, "Cannot recognize the {} parameter.".format(dist_attr) # else:
return new_dist_attr # assert False, "Cannot recognize the {} parameter.".format(dist_attr)
# return new_dist_attr
def validate_dist_attr(self): def validate_dist_attr(self):
if "read" in self.serial_op.type or "while" == self.serial_op.type: if "read" in self.serial_op.type or "while" == self.serial_op.type:
...@@ -190,8 +203,10 @@ class DistributedOperator: ...@@ -190,8 +203,10 @@ class DistributedOperator:
return True return True
def __str__(self): def __str__(self):
str = "{{op type: {}, op id: {}".format( str = "{{op type: {}, op id: {}, op original_id: {}".format(
self.serial_op.desc.type(), self.serial_op.desc.id() self.serial_op.desc.type(),
self.serial_op.desc.id(),
self.serial_op.desc.original_id(),
) )
# str += ", {}".format(self.dist_attr) # str += ", {}".format(self.dist_attr)
...@@ -239,10 +254,8 @@ class DistributedOperator: ...@@ -239,10 +254,8 @@ class DistributedOperator:
arg_name, annotated_str, is_parameter_str, dims_mapping arg_name, annotated_str, is_parameter_str, dims_mapping
) )
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} , dist_impl type {} }}".format( str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type self.dist_attr.impl_idx, self.dist_attr.impl_type
) )
return str return str
......
...@@ -18,7 +18,7 @@ import inspect ...@@ -18,7 +18,7 @@ import inspect
import paddle import paddle
from paddle.fluid.framework import Block, Parameter, Variable from paddle.fluid.framework import Block, Parameter, Variable
from .dist_attribute import TensorDistributedAttribute from .dist_attribute import TensorDistAttr
from .utils import __no_shape_var_type__, _linear_idx2coordinate from .utils import __no_shape_var_type__, _linear_idx2coordinate
...@@ -75,9 +75,9 @@ class DistributedTensor: ...@@ -75,9 +75,9 @@ class DistributedTensor:
if rank is not None and not (isinstance(rank, int) and rank >= 0): if rank is not None and not (isinstance(rank, int) and rank >= 0):
raise ValueError("The rank must >= 0, but got {}".format(rank)) raise ValueError("The rank must >= 0, but got {}".format(rank))
# NOTE: Only support even sharding now # # NOTE: Only support even sharding now
if shard_sizes is not None: # if shard_sizes is not None:
raise ValueError("Only support even sharding now.") # raise ValueError("Only support even sharding now.")
@staticmethod @staticmethod
def get_local_sizes( def get_local_sizes(
...@@ -169,10 +169,18 @@ class DistributedTensor: ...@@ -169,10 +169,18 @@ class DistributedTensor:
def __init__(self, serial_tensor, dist_attr=None, dist_context=None): def __init__(self, serial_tensor, dist_attr=None, dist_context=None):
self._serial_tensor = serial_tensor self._serial_tensor = serial_tensor
self._dist_attr = None if dist_attr is not None and isinstance(dist_attr, TensorDistAttr):
# TODO: remove this deepcopy after we fix the issue
self._dist_attr = copy.deepcopy(dist_attr)
# self._dist_attr = dist_attr
# TODO: Do we really need to write dist_attr back to serial_tensor?
self._serial_tensor.dist_attr = dist_attr
else:
assert dist_attr is None, "{}".format(dist_attr)
# Use the dist attr of serial_tensor to do the initialization
self._dist_attr = self._serial_tensor.dist_attr
self._batch_dim = 0 self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
self._local_offsets_map = {} self._local_offsets_map = {}
self._local_shard_map = {} self._local_shard_map = {}
self._local_tensor_map = {} self._local_tensor_map = {}
...@@ -195,25 +203,24 @@ class DistributedTensor: ...@@ -195,25 +203,24 @@ class DistributedTensor:
def dist_attr(self): def dist_attr(self):
return self._dist_attr return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
self._dist_attr = dist_attr
# TODO: Do we really need to write back dist_attr to serial_tensor?
self._serial_tensor.dist_attr = dist_attr
@property @property
def dist_context(self): def dist_context(self):
return self._dist_context return self._dist_context
@dist_attr.setter # def _init_default_dist_attr(self):
def dist_attr(self, dist_attr): # if self._dist_attr.dims_mapping is None:
if self._dist_attr is None: # if self.serial_tensor.type in __no_shape_var_type__:
self._dist_attr = TensorDistributedAttribute() # tensor_shape = []
self._dist_attr.init(dist_attr) # else:
self._init_default_dist_attr() # tensor_shape = self._serial_tensor.shape
# tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
def _init_default_dist_attr(self): # self._dist_attr.dims_mapping = tensor_dims_mapping
if self._dist_attr.dims_mapping is None:
if self.serial_tensor.type in __no_shape_var_type__:
tensor_shape = []
else:
tensor_shape = self._serial_tensor.shape
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self): def validate_dist_attr(self):
if self.serial_tensor.type in __no_shape_var_type__: if self.serial_tensor.type in __no_shape_var_type__:
...@@ -238,11 +245,11 @@ class DistributedTensor: ...@@ -238,11 +245,11 @@ class DistributedTensor:
rank = paddle.distributed.get_rank() if rank is None else rank rank = paddle.distributed.get_rank() if rank is None else rank
global_sizes = self.serial_tensor.shape global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes # shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.process_ids processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape topology = self.dist_attr.process_mesh.shape
local_sizes = DistributedTensor.get_local_sizes( local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes global_sizes, dims_mapping, topology, processes, rank
) )
return local_sizes return local_sizes
...@@ -255,16 +262,11 @@ class DistributedTensor: ...@@ -255,16 +262,11 @@ class DistributedTensor:
else: else:
global_sizes = self.serial_tensor.shape global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes # shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.process_ids processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape topology = self.dist_attr.process_mesh.shape
local_offsets = DistributedTensor.get_local_offsets( local_offsets = DistributedTensor.get_local_offsets(
global_sizes, global_sizes, dims_mapping, topology, processes, rank
dims_mapping,
topology,
processes,
rank,
shard_sizes,
) )
self._local_offsets_map[rank] = local_offsets self._local_offsets_map[rank] = local_offsets
...@@ -281,16 +283,11 @@ class DistributedTensor: ...@@ -281,16 +283,11 @@ class DistributedTensor:
else: else:
global_sizes = self.serial_tensor.shape global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes # shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.process_ids processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape topology = self.dist_attr.process_mesh.shape
local_shard = DistributedTensor.get_local_shard( local_shard = DistributedTensor.get_local_shard(
global_sizes, global_sizes, dims_mapping, topology, processes, rank
dims_mapping,
topology,
processes,
rank,
shard_sizes,
) )
self._local_shard_map[rank] = local_shard self._local_shard_map[rank] = local_shard
...@@ -390,8 +387,10 @@ class DistributedTensor: ...@@ -390,8 +387,10 @@ class DistributedTensor:
return result return result
def __str__(self): def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format( str = "{{tensor name: {}, tensor id: {}, tensor original_id {}".format(
self.serial_tensor.desc.name(), self.serial_tensor.desc.id() self.serial_tensor.desc.name(),
self.serial_tensor.desc.id(),
self.serial_tensor.desc.original_id(),
) )
# str += ", {}".format(self.dist_attr) # str += ", {}".format(self.dist_attr)
...@@ -411,19 +410,19 @@ class DistributedTensor: ...@@ -411,19 +410,19 @@ class DistributedTensor:
annotated_str = "annotated" annotated_str = "annotated"
else: else:
annotated_str = "non-annotated" annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format( str += ", dims_mapping ({}): {} }}".format(
annotated_str, self.dist_attr.dims_mapping annotated_str, self.dist_attr.dims_mapping
) )
if self.dist_attr.is_annotated("shard_mask"): # if self.dist_attr.is_annotated("shard_mask"):
annotated_str = "annotated" # annotated_str = "annotated"
else: # else:
annotated_str = "non-annotated" # annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str, None) # str += ", shard_mask ({}): {}".format(annotated_str, None)
if self.dist_attr.is_annotated("offload_device"): # if self.dist_attr.is_annotated("offload_device"):
annotated_str = "annotated" # annotated_str = "annotated"
else: # else:
annotated_str = "non-annotated" # annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str, None) # str += ", offload_device ({}): {} }}".format(annotated_str, None)
return str return str
...@@ -525,7 +525,8 @@ class Engine: ...@@ -525,7 +525,8 @@ class Engine:
self._labels_spec, self._labels_spec,
) )
# build forward main program # build forward main program
self.program_helper.build_program(mode) with utils.unique_name.guard():
self.program_helper.build_program(mode)
self.concrete_program = self.program_helper.concrete_program self.concrete_program = self.program_helper.concrete_program
serial_main_prog = self.program_helper.main_program serial_main_prog = self.program_helper.main_program
......
...@@ -16,7 +16,7 @@ import abc ...@@ -16,7 +16,7 @@ import abc
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistAttr
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op
...@@ -318,7 +318,7 @@ def set_comm_op_dist_attr_for_program( ...@@ -318,7 +318,7 @@ def set_comm_op_dist_attr_for_program(
assert process_mesh is not None assert process_mesh is not None
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = process_mesh new_op_dist_attr.process_mesh = process_mesh
for input_varname in new_op.desc.input_arg_names(): for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr) new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
...@@ -330,7 +330,7 @@ def set_comm_op_dist_attr_for_program( ...@@ -330,7 +330,7 @@ def set_comm_op_dist_attr_for_program(
def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx): def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):
ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op) ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh
for input_name in ref_op.input_names: for input_name in ref_op.input_names:
...@@ -455,7 +455,7 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names): ...@@ -455,7 +455,7 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
) )
# NOTE auxiliary op's dist attr should follow dist_op not dist_tensor # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
for new_op in added_ops: for new_op in added_ops:
new_op_attr = OperatorDistributedAttribute() new_op_attr = OperatorDistAttr()
new_op_attr.process_mesh = process_mesh new_op_attr.process_mesh = process_mesh
new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
......
...@@ -76,6 +76,10 @@ class DistributedAssignImpl(DistributedOperatorImpl): ...@@ -76,6 +76,10 @@ class DistributedAssignImpl(DistributedOperatorImpl):
if dim_changed: if dim_changed:
changed = True changed = True
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed return changed
@staticmethod @staticmethod
......
...@@ -18,7 +18,7 @@ from paddle.distributed.auto_parallel.process_group import ( ...@@ -18,7 +18,7 @@ from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import core from paddle.fluid import core
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistAttr
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import set_dist_op_desc_original_id, set_var_dist_attr from ..utils import set_dist_op_desc_original_id, set_var_dist_attr
from .common import ( from .common import (
...@@ -126,11 +126,13 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -126,11 +126,13 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
filter_vars.append(varname) filter_vars.append(varname)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars) dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars) dist_op_desc.set_output('Out', filter_vars)
# TODO: should we add a new dist attr for the new op here?
# sync result # sync result
group = new_process_group(world_process_group.ranks) group = new_process_group(world_process_group.ranks)
...@@ -180,7 +182,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -180,7 +182,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
) )
for op in [cast_op1, allreduce_op, cast_op2]: for op in [cast_op1, allreduce_op, cast_op2]:
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistAttr()
for varname in op.input_arg_names: for varname in op.input_arg_names:
var_dist_attr = ctx.get_tensor_dist_attr_for_program( var_dist_attr = ctx.get_tensor_dist_attr_for_program(
main_block._var_recursive(varname) main_block._var_recursive(varname)
......
...@@ -20,7 +20,7 @@ from ..cost import ( ...@@ -20,7 +20,7 @@ from ..cost import (
build_comp_desc_from_dist_op, build_comp_desc_from_dist_op,
build_dp_costs, build_dp_costs,
) )
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistAttr
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import ( from ..utils import (
_get_comm_group, _get_comm_group,
...@@ -86,7 +86,7 @@ def prim_operator_data_parallel_functor(ctx, src_op): ...@@ -86,7 +86,7 @@ def prim_operator_data_parallel_functor(ctx, src_op):
).dims_mapping ).dims_mapping
dist_attr = ctx.get_op_dist_attr_for_program(src_op) dist_attr = ctx.get_op_dist_attr_for_program(src_op)
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistAttr()
op_attr.process_mesh = process_mesh op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping) op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
...@@ -404,6 +404,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -404,6 +404,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
and compatible_dim_mapping != dims_mapping[0] and compatible_dim_mapping != dims_mapping[0]
): ):
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
op_dist_attr.set_input_dims_mapping(arg_name, dims_mapping)
changed = True changed = True
else: else:
if ( if (
...@@ -411,6 +412,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -411,6 +412,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
and compatible_dim_mapping != dims_mapping[1] and compatible_dim_mapping != dims_mapping[1]
): ):
dims_mapping[1] = compatible_dim_mapping dims_mapping[1] = compatible_dim_mapping
op_dist_attr.set_input_dims_mapping(arg_name, dims_mapping)
changed = True changed = True
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_desc.type() == 'fill_any_like': if op_desc.type() == 'fill_any_like':
...@@ -431,6 +433,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -431,6 +433,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
and compatible_dim_mapping != dims_mapping[0] and compatible_dim_mapping != dims_mapping[0]
): ):
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
op_dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
changed = True changed = True
else: else:
if ( if (
...@@ -438,6 +441,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -438,6 +441,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
and compatible_dim_mapping != dims_mapping[1] and compatible_dim_mapping != dims_mapping[1]
): ):
dims_mapping[1] = compatible_dim_mapping dims_mapping[1] = compatible_dim_mapping
op_dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
changed = True changed = True
return changed return changed
...@@ -469,13 +473,15 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -469,13 +473,15 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
) )
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name]) dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
# TODO: should we add a new dist attr for the new op here?
if ( if (
src_op.has_attr('shape') src_op.has_attr('shape')
...@@ -553,7 +559,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -553,7 +559,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
) )
# set distributed attribute # set distributed attribute
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistAttr()
op_attr.process_mesh = process_mesh op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping( op_attr.set_output_dims_mapping(
param.name, dims_mapping param.name, dims_mapping
......
...@@ -29,7 +29,7 @@ from ..cost import ( ...@@ -29,7 +29,7 @@ from ..cost import (
build_comp_desc_from_dist_op, build_comp_desc_from_dist_op,
build_dp_costs, build_dp_costs,
) )
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistAttr
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import ( from ..utils import (
_get_comm_group, _get_comm_group,
...@@ -135,7 +135,7 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): ...@@ -135,7 +135,7 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
intermediate_var_0.name, intermediate_var_0_dist_attr intermediate_var_0.name, intermediate_var_0_dist_attr
) )
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh new_op_dist_attr.process_mesh = Ids_var_dist_attr.process_mesh
new_op_dist_attr.impl_type = "default" new_op_dist_attr.impl_type = "default"
new_op_dist_attr.impl_idx = 0 new_op_dist_attr.impl_idx = 0
...@@ -334,6 +334,11 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -334,6 +334,11 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
if dim_changed: if dim_changed:
changed = True changed = True
if changed:
op_dist_attr.set_input_dims_mapping(ids_name, ids_dims_mapping)
op_dist_attr.set_input_dims_mapping(w_name, w_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed return changed
@staticmethod @staticmethod
...@@ -482,7 +487,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -482,7 +487,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# matmulv2 # matmulv2
embedding_op_dist_attr = OperatorDistributedAttribute() embedding_op_dist_attr = OperatorDistAttr()
embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
embedding_op_dist_attr.impl_type = op_dist_attr.impl_type embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -505,7 +510,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -505,7 +510,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr) ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)
# allreduce # allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr = OperatorDistAttr()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
......
...@@ -121,6 +121,10 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -121,6 +121,10 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
if dim_changed: if dim_changed:
changed = True changed = True
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed return changed
@staticmethod @staticmethod
......
...@@ -143,6 +143,12 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -143,6 +143,12 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
) )
if dim_changed: if dim_changed:
changed = True changed = True
op_dist_attr.set_output_dims_mapping(
out_name, out_dims_mapping
)
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
return changed return changed
......
...@@ -135,6 +135,12 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ...@@ -135,6 +135,12 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
) )
if dim_changed: if dim_changed:
changed = True changed = True
op_dist_attr.set_output_dims_mapping(
out_name, out_dims_mapping
)
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
return changed return changed
......
...@@ -35,7 +35,7 @@ from ..cost import ( ...@@ -35,7 +35,7 @@ from ..cost import (
build_comp_desc_from_dist_op, build_comp_desc_from_dist_op,
build_dp_costs, build_dp_costs,
) )
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistAttr
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import ( from ..utils import (
_get_comm_group, _get_comm_group,
...@@ -74,15 +74,28 @@ def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping): ...@@ -74,15 +74,28 @@ def trans_x_y_dims_mapping(trans_x, trans_y, x_dims_mapping, y_dims_mapping):
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
dist_op_desc = block.append_op(type='nop').desc pass
src_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
dist_attr = copy.deepcopy(src_dist_attr)
dist_op = block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs assert input_name in kwargs
dist_op_desc.set_input(input_name, kwargs[input_name]) dist_op_desc.set_input(input_name, kwargs[input_name])
dist_attr.rename_input(
src_op.desc.input(input_name)[0], kwargs[input_name][0]
)
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
assert input_name in kwargs assert output_name in kwargs
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
dist_attr.rename_output(
src_op.desc.output(output_name)[0], kwargs[output_name][0]
)
# TODO: this call leads to a deepcopy when we init the dist op
ctx.set_op_dist_attr_for_program(dist_op, dist_attr)
return dist_op_desc return dist_op_desc
...@@ -207,6 +220,11 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -207,6 +220,11 @@ def _update_dims_mapping_for_matmul(dist_op):
assert len(y_dims_mapping) == y_dims_mapping_len assert len(y_dims_mapping) == y_dims_mapping_len
assert len(out_dims_mapping) == out_dims_mapping_len assert len(out_dims_mapping) == out_dims_mapping_len
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_input_dims_mapping(y_name, y_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed return changed
...@@ -880,7 +898,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -880,7 +898,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# c_identity # c_identity
identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr = OperatorDistAttr()
identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
identity_op_dist_attr.impl_type = op_dist_attr.impl_type identity_op_dist_attr.impl_type = op_dist_attr.impl_type
identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -902,7 +920,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -902,7 +920,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmul # matmul
matmul_op_dist_attr = OperatorDistributedAttribute() matmul_op_dist_attr = OperatorDistAttr()
matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmul_op_dist_attr.impl_type = op_dist_attr.impl_type matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -1253,7 +1271,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1253,7 +1271,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# matmul # matmul
matmul_op_dist_attr = OperatorDistributedAttribute() matmul_op_dist_attr = OperatorDistAttr()
matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmul_op_dist_attr.impl_type = op_dist_attr.impl_type matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -1276,7 +1294,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1276,7 +1294,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
# allreduce # allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr = OperatorDistAttr()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -1783,7 +1801,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1783,7 +1801,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# c_identity # c_identity
identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr = OperatorDistAttr()
identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
identity_op_dist_attr.impl_type = op_dist_attr.impl_type identity_op_dist_attr.impl_type = op_dist_attr.impl_type
identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -1804,7 +1822,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1804,7 +1822,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmulv2 # matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr = OperatorDistAttr()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -2152,7 +2170,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -2152,7 +2170,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# matmulv2 # matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr = OperatorDistAttr()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -2175,7 +2193,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -2175,7 +2193,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
# allreduce # allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr = OperatorDistAttr()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -2688,7 +2706,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2688,7 +2706,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# c_identity # c_identity
identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr = OperatorDistAttr()
identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
identity_op_dist_attr.impl_type = op_dist_attr.impl_type identity_op_dist_attr.impl_type = op_dist_attr.impl_type
identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -2709,7 +2727,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2709,7 +2727,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr) ctx.set_op_dist_attr_for_program(c_identity_op, identity_op_dist_attr)
# matmulv2 # matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr = OperatorDistAttr()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -2854,7 +2872,6 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2854,7 +2872,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
parallel_axis=parallel_axis, parallel_axis=parallel_axis,
) )
# print("dist_matmul.py dist_op: ", dist_op)
comm_op_cost_list = build_comm_costs_from_descs( comm_op_cost_list = build_comm_costs_from_descs(
AllreduceSumOpCost, AllreduceSumOpCost,
ctx, ctx,
...@@ -3067,7 +3084,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -3067,7 +3084,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# set dist op's dist_attr with serial op's dist_attr # set dist op's dist_attr with serial op's dist_attr
# matmulv2 # matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr = OperatorDistAttr()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
...@@ -3090,7 +3107,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -3090,7 +3107,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr) ctx.set_op_dist_attr_for_program(mul_op, matmulv2_op_dist_attr)
# allreduce # allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr = OperatorDistAttr()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
......
...@@ -18,10 +18,7 @@ from paddle.fluid import core ...@@ -18,10 +18,7 @@ from paddle.fluid import core
from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype
from paddle.fluid.framework import Operator from paddle.fluid.framework import Operator
from ..dist_attribute import ( from ..dist_attribute import OperatorDistAttr, TensorDistAttr
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import ( from ..utils import (
_get_comm_group, _get_comm_group,
...@@ -135,6 +132,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -135,6 +132,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
and compatible_dim_mapping != dims_mapping[0] and compatible_dim_mapping != dims_mapping[0]
): ):
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
op_dist_attr.set_input_dims_mapping(arg_name, dims_mapping)
changed = True changed = True
if axis == 0 and not keepdim: if axis == 0 and not keepdim:
...@@ -142,6 +140,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -142,6 +140,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if len(dims_mapping) >= 1 and dims_mapping[0] != -1: if len(dims_mapping) >= 1 and dims_mapping[0] != -1:
dims_mapping[0] = -1 dims_mapping[0] = -1
op_dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
changed = True changed = True
else: else:
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
...@@ -151,6 +150,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -151,6 +150,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
and compatible_dim_mapping != dims_mapping[0] and compatible_dim_mapping != dims_mapping[0]
): ):
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
op_dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
changed = True changed = True
return changed return changed
...@@ -218,7 +218,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -218,7 +218,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
stop_gradient=X_var.stop_gradient, stop_gradient=X_var.stop_gradient,
) )
# set allgather_out tensor dist_attr # set allgather_out tensor dist_attr
allgather_out_dist_attr = TensorDistributedAttribute() allgather_out_dist_attr = TensorDistAttr()
allgather_out_dist_attr.process_mesh = op_dist_attr.process_mesh allgather_out_dist_attr.process_mesh = op_dist_attr.process_mesh
allgather_out_dist_attr.dims_mapping = [ allgather_out_dist_attr.dims_mapping = [
-1 for i in range(len(allgather_out.shape)) -1 for i in range(len(allgather_out.shape))
...@@ -238,7 +238,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -238,7 +238,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
}, },
) )
# set c_allgather op dist_attr # set c_allgather op dist_attr
allgather_op_dist_attr = OperatorDistributedAttribute() allgather_op_dist_attr = OperatorDistAttr()
allgather_op_dist_attr.process_mesh = op_dist_attr.process_mesh allgather_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allgather_op_dist_attr.set_input_dims_mapping( allgather_op_dist_attr.set_input_dims_mapping(
X_var.name, in_dims_mapping X_var.name, in_dims_mapping
...@@ -252,7 +252,8 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -252,7 +252,8 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
# rename input # rename input
kwargs['X'] = [allgather_out.name] kwargs['X'] = [allgather_out.name]
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
...@@ -263,7 +264,10 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -263,7 +264,10 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
allgather_out.name, allgather_out_dist_attr.dims_mapping allgather_out.name, allgather_out_dist_attr.dims_mapping
) )
# Remove the unrelated dist attr
op_dist_attr.del_input_dist_attr(X_var.name)
ctx.set_op_dist_attr_for_program(pnorm_op, op_dist_attr) ctx.set_op_dist_attr_for_program(pnorm_op, op_dist_attr)
# TODO: should we add a new dist attr for the new op here?
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -312,7 +316,8 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -312,7 +316,8 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
new_X_var_dist_attr = ctx.get_tensor_dist_attr_for_program(new_X_var) new_X_var_dist_attr = ctx.get_tensor_dist_attr_for_program(new_X_var)
ctx.set_tensor_dist_attr_for_program(new_X_grad, new_X_var_dist_attr) ctx.set_tensor_dist_attr_for_program(new_X_grad, new_X_var_dist_attr)
# replicate op in dist program with new kwargs # replicate op in dist program with new kwargs
dist_op_desc = main_block.append_op(type='nop').desc dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op # Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
...@@ -324,10 +329,19 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -324,10 +329,19 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
new_X_var.name, new_X_var_dist_attr.dims_mapping new_X_var.name, new_X_var_dist_attr.dims_mapping
) )
# Store X_grad_var dims_mapping for later use
X_grad_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
X_grad_var.name
)
# Remove the unrelated dist attr
op_dist_attr.del_input_dist_attr(X_var.name)
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
new_X_grad.name, new_X_var_dist_attr.dims_mapping new_X_grad.name, new_X_var_dist_attr.dims_mapping
) )
# Remove the unrelated dist attr
op_dist_attr.del_output_dist_attr(X_grad_var.name)
ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr) ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr)
# TODO: should we add a new dist attr for the new op here?
# 2. insert slice op # 2. insert slice op
process_mesh_shape = op_dist_attr.process_mesh.shape process_mesh_shape = op_dist_attr.process_mesh.shape
...@@ -364,10 +378,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl): ...@@ -364,10 +378,7 @@ class DistributedPNormImpl0(DistributedOperatorImpl):
outputs={'Out': [X_grad_var]}, outputs={'Out': [X_grad_var]},
attrs=attrs, attrs=attrs,
) )
X_grad_var_dims_mapping = op_dist_attr.get_output_dims_mapping( slice_op_dist_attr = OperatorDistAttr()
X_grad_var.name
)
slice_op_dist_attr = OperatorDistributedAttribute()
slice_op_dist_attr.process_mesh = op_dist_attr.process_mesh slice_op_dist_attr.process_mesh = op_dist_attr.process_mesh
slice_op_dist_attr.set_input_dims_mapping( slice_op_dist_attr.set_input_dims_mapping(
new_X_grad.name, new_X_var_dist_attr.dims_mapping new_X_grad.name, new_X_var_dist_attr.dims_mapping
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistAttr
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import set_dist_op_desc_original_id from ..utils import set_dist_op_desc_original_id
from .common import ( from .common import (
...@@ -104,13 +104,15 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): ...@@ -104,13 +104,15 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
) )
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name]) dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
# TODO: should we add a new dist attr for the new op here?
# batch dimension synchronization # batch dimension synchronization
var_name = src_op.output_arg_names[0] var_name = src_op.output_arg_names[0]
...@@ -130,7 +132,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): ...@@ -130,7 +132,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
var = main_block._var_recursive(var_name) var = main_block._var_recursive(var_name)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
new_op_attr = OperatorDistributedAttribute() new_op_attr = OperatorDistAttr()
new_op_attr.process_mesh = op_dist_attr.process_mesh new_op_attr.process_mesh = op_dist_attr.process_mesh
new_op_attr.set_output_dims_mapping( new_op_attr.set_output_dims_mapping(
var.name, tensor_dist_attr.dims_mapping var.name, tensor_dist_attr.dims_mapping
......
...@@ -218,6 +218,13 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -218,6 +218,13 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
for i in range(len(x_dims_mapping)): for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i] x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
op_dist_attr.set_output_dims_mapping(
x_shape_name, x_shape_dims_mapping
)
return changed return changed
@staticmethod @staticmethod
...@@ -277,7 +284,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -277,7 +284,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
) )
# create op # create op
new_op_desc = main_block.append_op(type='nop').desc new_op = main_block.append_op(type='nop')
new_op_desc = new_op.desc
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
...@@ -286,6 +294,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -286,6 +294,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
new_op_desc.set_output('XShape', [XShape_var.name]) new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name]) new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list) new_op_desc._set_attr('shape', shape_list)
# TODO: should we add a new dist attr for the new op here?
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -469,6 +478,13 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -469,6 +478,13 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
for i in range(len(x_dims_mapping)): for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i] x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
op_dist_attr.set_output_dims_mapping(
x_shape_name, x_shape_dims_mapping
)
return changed return changed
@staticmethod @staticmethod
...@@ -528,7 +544,8 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -528,7 +544,8 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
) )
# create op # create op
new_op_desc = main_block.append_op(type='nop').desc new_op = main_block.append_op(type='nop')
new_op_desc = new_op.desc
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
...@@ -537,6 +554,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -537,6 +554,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
new_op_desc.set_output('XShape', [XShape_var.name]) new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name]) new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list) new_op_desc._set_attr('shape', shape_list)
# TODO: should we add a new dist attr for the new op here?
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -714,6 +732,13 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -714,6 +732,13 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
for i in range(len(out_dims_mapping)): for i in range(len(out_dims_mapping)):
x_shape_dims_mapping[i + 1] = out_dims_mapping[i] x_shape_dims_mapping[i + 1] = out_dims_mapping[i]
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
op_dist_attr.set_output_dims_mapping(
x_shape_name, x_shape_dims_mapping
)
return changed return changed
@staticmethod @staticmethod
...@@ -772,7 +797,8 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -772,7 +797,8 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
) )
# create op # create op
new_op_desc = main_block.append_op(type='nop').desc new_op = main_block.append_op(type='nop')
new_op_desc = new_op.desc
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
...@@ -781,6 +807,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -781,6 +807,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
new_op_desc.set_output('XShape', [XShape_var.name]) new_op_desc.set_output('XShape', [XShape_var.name])
new_op_desc.set_output('Out', [Out_var.name]) new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list) new_op_desc._set_attr('shape', shape_list)
# TODO: should we add a new dist attr for the new op here?
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
......
...@@ -161,6 +161,10 @@ class DistributedSliceImpl(DistributedOperatorImpl): ...@@ -161,6 +161,10 @@ class DistributedSliceImpl(DistributedOperatorImpl):
out_dims_mapping[i] = compatible_dim_mapping out_dims_mapping[i] = compatible_dim_mapping
changed = True changed = True
if changed:
op_dist_attr.set_input_dims_mapping(in_name, in_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed return changed
@staticmethod @staticmethod
......
...@@ -178,6 +178,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -178,6 +178,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
if dim_changed: if dim_changed:
changed = True changed = True
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed return changed
@staticmethod @staticmethod
......
...@@ -95,7 +95,12 @@ class DistributedSplitImpl(DistributedOperatorImpl): ...@@ -95,7 +95,12 @@ class DistributedSplitImpl(DistributedOperatorImpl):
) )
if dim_changed: if dim_changed:
changed = True changed = True
op_dist_attr.set_output_dims_mapping(
out_name, out_dims_mapping
)
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
return changed return changed
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
......
...@@ -124,6 +124,13 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -124,6 +124,13 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
for i in range(len(x_dims_mapping)): for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i] x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
op_dist_attr.set_output_dims_mapping(
x_shape_name, x_shape_dims_mapping
)
return changed return changed
def calc_cost(self, op_role, dist_op, ctx, cluster): def calc_cost(self, op_role, dist_op, ctx, cluster):
......
...@@ -159,11 +159,13 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -159,11 +159,13 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
filter_vars.append(varname) filter_vars.append(varname)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc dist_op = main_block.append_op(type='nop')
dist_op_desc = dist_op.desc
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars) dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars) dist_op_desc.set_output('Out', filter_vars)
# TODO: should we add a new dist attr for the new op here?
register_distributed_operator_impl( register_distributed_operator_impl(
......
...@@ -22,7 +22,7 @@ from paddle.distributed.auto_parallel.operators.common import ( ...@@ -22,7 +22,7 @@ from paddle.distributed.auto_parallel.operators.common import (
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program from paddle.fluid.framework import Parameter, Program
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistAttr
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
from .utils import ( from .utils import (
__no_shape_var_type__, __no_shape_var_type__,
...@@ -165,7 +165,7 @@ class Partitioner: ...@@ -165,7 +165,7 @@ class Partitioner:
output_var_attr = ( output_var_attr = (
self._dist_context.get_tensor_dist_attr_for_program(output_var) self._dist_context.get_tensor_dist_attr_for_program(output_var)
) )
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistAttr()
op_attr.process_mesh = output_var_attr.process_mesh op_attr.process_mesh = output_var_attr.process_mesh
op_attr.set_output_dims_mapping( op_attr.set_output_dims_mapping(
output_var.name, output_var_attr.dims_mapping output_var.name, output_var_attr.dims_mapping
...@@ -407,8 +407,13 @@ def _get_dist_shape(var, dist_attr): ...@@ -407,8 +407,13 @@ def _get_dist_shape(var, dist_attr):
else: else:
assert ( assert (
var_shape[idx] % mesh[mapping[idx]] == 0 var_shape[idx] % mesh[mapping[idx]] == 0
), "un-event partition: var_shape[idx]=[{}], mesh[{}]".format( ), "un-event partition: var_shape[idx]=[{}], mesh[{}], {}, {}, {}, {}".format(
var_shape[idx], mesh[mapping[idx]] var_shape[idx],
mesh[mapping[idx]],
var.name,
var_shape,
mesh,
mapping,
) )
new_shape.append(var_shape[idx] // mesh[mapping[idx]]) new_shape.append(var_shape[idx] // mesh[mapping[idx]])
......
...@@ -25,10 +25,7 @@ import paddle ...@@ -25,10 +25,7 @@ import paddle
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from .cost_model import estimate_cost from .cost_model import estimate_cost
from .dist_attribute import ( from .dist_attribute import OperatorDistAttr, TensorDistAttr
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from .dist_context import DistributedContext, DistributedOperatorContext from .dist_context import DistributedContext, DistributedOperatorContext
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .operators.common import ( from .operators.common import (
...@@ -239,7 +236,7 @@ class PlanSpace: ...@@ -239,7 +236,7 @@ class PlanSpace:
) )
) )
for composed_dims_mapping in composed_dims_mapping_list: for composed_dims_mapping in composed_dims_mapping_list:
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = process_mesh op_dist_attr.process_mesh = process_mesh
var_names = list(dims_mapping_dict.keys()) var_names = list(dims_mapping_dict.keys())
...@@ -299,7 +296,6 @@ class PlanSpace: ...@@ -299,7 +296,6 @@ class PlanSpace:
dist_op.dist_attr.impl_idx = 0 dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr) op_valid_dist_attrs.append(dist_op.dist_attr)
continue continue
# if op has distributed implements, find all valid dist attr of this op # if op has distributed implements, find all valid dist attr of this op
impls = dist_op_impl_container.impls impls = dist_op_impl_container.impls
for idx, impl in enumerate(impls): for idx, impl in enumerate(impls):
...@@ -313,17 +309,18 @@ class PlanSpace: ...@@ -313,17 +309,18 @@ class PlanSpace:
# set default dist attr for some special ops whose distributed attributes can not be enumerated # set default dist attr for some special ops whose distributed attributes can not be enumerated
if not op_valid_dist_attrs: if not op_valid_dist_attrs:
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = process_mesh op_dist_attr.process_mesh = process_mesh
dist_op = DistributedOperator(op, op_dist_attr)
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
vars[var_name], [-1 for i in vars[var_name].shape] vars[var_name].name, [-1 for i in vars[var_name].shape]
) )
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
vars[var_name], [-1 for i in vars[var_name].shape] vars[var_name].name, [-1 for i in vars[var_name].shape]
) )
# The dist op must be built after the dist attr has been completely constructed
dist_op = DistributedOperator(op, op_dist_attr)
dist_op.dist_attr.impl_type = "default" dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0 dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr) op_valid_dist_attrs.append(dist_op.dist_attr)
...@@ -395,7 +392,7 @@ class PlanSpace: ...@@ -395,7 +392,7 @@ class PlanSpace:
op_process_mesh = pipeline_process_meshes[pipeline_stage] op_process_mesh = pipeline_process_meshes[pipeline_stage]
if op.type in PlanSpace.not_enum_ops: if op.type in PlanSpace.not_enum_ops:
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = op_process_mesh op_dist_attr.process_mesh = op_process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
if var_name in PlanSpace.special_vars: if var_name in PlanSpace.special_vars:
...@@ -498,9 +495,7 @@ class MCMC(SearchAlgorithm): ...@@ -498,9 +495,7 @@ class MCMC(SearchAlgorithm):
search_op, op_dist_attr search_op, op_dist_attr
) )
for name in search_op.output_arg_names: for name in search_op.output_arg_names:
tensor_dist_attr = ( tensor_dist_attr = TensorDistAttr()
TensorDistributedAttribute()
)
tensor_dist_attr.process_mesh = ( tensor_dist_attr.process_mesh = (
op_dist_attr.process_mesh op_dist_attr.process_mesh
) )
...@@ -546,7 +541,7 @@ class MCMC(SearchAlgorithm): ...@@ -546,7 +541,7 @@ class MCMC(SearchAlgorithm):
) )
is None is None
): ):
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = ( tensor_dist_attr.process_mesh = (
init_op_dist_attr.process_mesh init_op_dist_attr.process_mesh
) )
...@@ -558,7 +553,7 @@ class MCMC(SearchAlgorithm): ...@@ -558,7 +553,7 @@ class MCMC(SearchAlgorithm):
) )
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = ( tensor_dist_attr.dims_mapping = (
init_op_dist_attr.get_output_dims_mapping(var_name) init_op_dist_attr.get_output_dims_mapping(var_name)
...@@ -627,7 +622,7 @@ class MCMC(SearchAlgorithm): ...@@ -627,7 +622,7 @@ class MCMC(SearchAlgorithm):
# set output tensor distributed attribute # set output tensor distributed attribute
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
process_mesh = op_dist_attr.process_mesh process_mesh = op_dist_attr.process_mesh
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = ( tensor_dist_attr.dims_mapping = (
op_dist_attr.get_output_dims_mapping(var_name) op_dist_attr.get_output_dims_mapping(var_name)
...@@ -640,7 +635,7 @@ class MCMC(SearchAlgorithm): ...@@ -640,7 +635,7 @@ class MCMC(SearchAlgorithm):
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
if vars[var_name].is_parameter or vars[var_name].is_data: if vars[var_name].is_parameter or vars[var_name].is_data:
process_mesh = op_dist_attr.process_mesh process_mesh = op_dist_attr.process_mesh
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = ( tensor_dist_attr.dims_mapping = (
op_dist_attr.get_input_dims_mapping(var_name) op_dist_attr.get_input_dims_mapping(var_name)
......
...@@ -211,7 +211,7 @@ class ProcessMesh(core.ProcessMesh): ...@@ -211,7 +211,7 @@ class ProcessMesh(core.ProcessMesh):
return new_process_mesh return new_process_mesh
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, ProcessMesh): if not isinstance(other, (ProcessMesh, core.ProcessMesh)):
return False return False
if self.shape != other.shape or self.process_ids != other.process_ids: if self.shape != other.shape or self.process_ids != other.process_ids:
return False return False
......
...@@ -31,7 +31,7 @@ from .cost import ( ...@@ -31,7 +31,7 @@ from .cost import (
SplitOpCost, SplitOpCost,
build_comm_desc, build_comm_desc,
) )
from .dist_attribute import TensorDistributedAttribute from .dist_attribute import TensorDistAttr
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .process_group import new_process_group from .process_group import new_process_group
from .utils import is_gradient_clip_op from .utils import is_gradient_clip_op
...@@ -1989,7 +1989,7 @@ class Resharder: ...@@ -1989,7 +1989,7 @@ class Resharder:
process_mesh = dist_attr[0] process_mesh = dist_attr[0]
dims_mapping = dist_attr[1] dims_mapping = dist_attr[1]
tensor_attr = TensorDistributedAttribute() tensor_attr = TensorDistAttr()
tensor_attr.dims_mapping = dims_mapping tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh tensor_attr.process_mesh = process_mesh
self.dist_context.set_tensor_dist_attr_for_program( self.dist_context.set_tensor_dist_attr_for_program(
...@@ -2031,6 +2031,9 @@ class Resharder: ...@@ -2031,6 +2031,9 @@ class Resharder:
if name == var_name and op_dist_attr is not None: if name == var_name and op_dist_attr is not None:
if op.desc.id() == matched_op.desc.id(): if op.desc.id() == matched_op.desc.id():
if matched_op.type == "while": if matched_op.type == "while":
op.desc._rename_input(
name, target_tensor.name
)
old_name = name old_name = name
new_name = target_tensor.name new_name = target_tensor.name
assert old_name != new_name assert old_name != new_name
...@@ -2045,13 +2048,16 @@ class Resharder: ...@@ -2045,13 +2048,16 @@ class Resharder:
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping new_name, dims_mapping
) )
if ( # if (
old_name # old_name
in op_dist_attr._inputs_dist_attrs # in op_dist_attr._inputs_dist_attrs
): # ):
op_dist_attr.del_input_dist_attr( # op_dist_attr.del_input_dist_attr(
old_name # old_name
) # )
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
while_op_X_append.append(new_name) while_op_X_append.append(new_name)
continue continue
else: else:
...@@ -2072,7 +2078,10 @@ class Resharder: ...@@ -2072,7 +2078,10 @@ class Resharder:
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping new_name, dims_mapping
) )
op_dist_attr.del_input_dist_attr(old_name) # op_dist_attr.del_input_dist_attr(old_name)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
continue continue
op_process_mesh = op_dist_attr.process_mesh op_process_mesh = op_dist_attr.process_mesh
...@@ -2097,7 +2106,10 @@ class Resharder: ...@@ -2097,7 +2106,10 @@ class Resharder:
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping new_name, dims_mapping
) )
op_dist_attr.del_input_dist_attr(old_name) # op_dist_attr.del_input_dist_attr(old_name)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
# for while op, the input X should reset # for while op, the input X should reset
if while_op_X_append: if while_op_X_append:
...@@ -2273,7 +2285,7 @@ class Resharder: ...@@ -2273,7 +2285,7 @@ class Resharder:
op_dist_attr.set_input_dist_attr( op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr new_name, op_input_dist_attr
) )
op_dist_attr.del_input_dist_attr(old_name) # op_dist_attr.del_input_dist_attr(old_name)
# the outputs also need to be renamed when the output name is the same with input name in inplace op # the outputs also need to be renamed when the output name is the same with input name in inplace op
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
...@@ -2297,7 +2309,7 @@ class Resharder: ...@@ -2297,7 +2309,7 @@ class Resharder:
op_dist_attr.set_output_dist_attr( op_dist_attr.set_output_dist_attr(
new_name, op_output_dist_attr new_name, op_output_dist_attr
) )
op_dist_attr.del_output_dist_attr(old_name) # op_dist_attr.del_output_dist_attr(old_name)
def _reshard_input(self, block): def _reshard_input(self, block):
idx = 0 idx = 0
......
...@@ -867,7 +867,7 @@ class ParallelTuner: ...@@ -867,7 +867,7 @@ class ParallelTuner:
assert ( assert (
dist_op.dist_attr.impl_idx == op_id_to_dist_attr[op_id].impl_idx dist_op.dist_attr.impl_idx == op_id_to_dist_attr[op_id].impl_idx
) )
dist_op.dist_attr.process_mesh = process_mesh dist_op.dist_attr.process_mesh = ProcessMesh(process_mesh)
self._amend_dist_attr() self._amend_dist_attr()
self._completer._complete_tensor_dist_attr_by_op() self._completer._complete_tensor_dist_attr_by_op()
...@@ -1041,7 +1041,6 @@ class ParallelTuner: ...@@ -1041,7 +1041,6 @@ class ParallelTuner:
# This store statement must follow the above backup statement # This store statement must follow the above backup statement
self._store_init_parallel_strategy() self._store_init_parallel_strategy()
init_time = self._estimate_trial() # estimate_trial when init init_time = self._estimate_trial() # estimate_trial when init
# print_program_with_dist_attr(self._dist_context.serial_main_program, self._dist_context)
# We have to restore the distributed context, because the estimation of one trail need to # We have to restore the distributed context, because the estimation of one trail need to
# generate the backward and update parts. Since we will do the tuning process, # generate the backward and update parts. Since we will do the tuning process,
# here we only need to reset all distributed information to the default one. # here we only need to reset all distributed information to the default one.
......
...@@ -26,11 +26,9 @@ from paddle.fluid.framework import Variable ...@@ -26,11 +26,9 @@ from paddle.fluid.framework import Variable
from paddle.fluid.io import is_belong_to_optimizer, is_parameter from paddle.fluid.io import is_belong_to_optimizer, is_parameter
from paddle.framework import core from paddle.framework import core
from .dist_attribute import ( from .dist_attribute import OperatorDistAttr, TensorDistAttr
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_mesh import ProcessMesh
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
...@@ -1386,10 +1384,19 @@ def get_loss_op(block): ...@@ -1386,10 +1384,19 @@ def get_loss_op(block):
def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.dims_mapping = dims_mapping
# TODO get global mesh group # TODO get global mesh group
tensor_dist_attr.process_mesh = process_mesh if isinstance(process_mesh, (list, np.ndarray)):
tensor_dist_attr.process_mesh = ProcessMesh(process_mesh)
elif isinstance(process_mesh, core.ProcessMesh):
tensor_dist_attr.process_mesh = process_mesh
else:
raise ValueError(
"{} must be a instance of ProcessMesh or list, but receive {}".format(
process_mesh, type(process_mesh)
)
)
if "mark_annotated" in kwargs and kwargs["mark_annotated"]: if "mark_annotated" in kwargs and kwargs["mark_annotated"]:
tensor_dist_attr.mark_annotated("dims_mapping") tensor_dist_attr.mark_annotated("dims_mapping")
tensor_dist_attr.mark_annotated("process_mesh") tensor_dist_attr.mark_annotated("process_mesh")
...@@ -1403,7 +1410,7 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( ...@@ -1403,7 +1410,7 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
assert process_mesh is not None assert process_mesh is not None
assert ref_mapping is not None assert ref_mapping is not None
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistAttr()
for input_varname in new_op.desc.input_arg_names(): for input_varname in new_op.desc.input_arg_names():
new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping) new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping)
...@@ -1422,7 +1429,7 @@ def naive_set_dist_op_attr_for_program_by_mesh( ...@@ -1422,7 +1429,7 @@ def naive_set_dist_op_attr_for_program_by_mesh(
return return
assert process_mesh is not None assert process_mesh is not None
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistAttr()
for input_varname in new_op.desc.input_arg_names(): for input_varname in new_op.desc.input_arg_names():
var = ctx.serial_main_program.global_block().var(input_varname) var = ctx.serial_main_program.global_block().var(input_varname)
...@@ -2078,20 +2085,20 @@ def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): ...@@ -2078,20 +2085,20 @@ def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
["d" + str(i) for i in range(len(py_process_mesh.shape))], ["d" + str(i) for i in range(len(py_process_mesh.shape))],
) )
cpp_dist_attr.dims_mapping = py_dist_attr.dims_mapping cpp_dist_attr.dims_mapping = py_dist_attr.dims_mapping
cpp_dist_attr.annotated = py_dist_attr._is_annotated cpp_dist_attr.annotated = py_dist_attr.annotated
def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh cpp_process_mesh = cpp_dist_attr.process_mesh
if not cpp_process_mesh.empty(): if cpp_process_mesh is not None:
py_dist_attr.process_mesh = ProcessMesh( py_dist_attr.process_mesh = ProcessMesh(
shape=cpp_process_mesh.shape, shape=cpp_process_mesh.shape,
process_ids=cpp_process_mesh.process_ids, process_ids=cpp_process_mesh.process_ids,
) )
py_dist_attr.dims_mapping = cpp_dist_attr.dims_mapping py_dist_attr.dims_mapping = cpp_dist_attr.dims_mapping
py_dist_attr._is_annotated = cpp_dist_attr.annotated py_dist_attr.annotated = cpp_dist_attr.annotated
def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
...@@ -2104,7 +2111,8 @@ def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): ...@@ -2104,7 +2111,8 @@ def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
) )
cpp_dist_attr.impl_type = py_dist_attr.impl_type cpp_dist_attr.impl_type = py_dist_attr.impl_type
cpp_dist_attr.impl_idx = py_dist_attr.impl_idx cpp_dist_attr.impl_idx = py_dist_attr.impl_idx
cpp_dist_attr.annotated = py_dist_attr._is_annotated cpp_dist_attr.is_recompute = py_dist_attr.is_recompute
cpp_dist_attr.annotated = py_dist_attr.annotated
for name, py_tensor_dist_attr in py_dist_attr.inputs_dist_attrs.items(): for name, py_tensor_dist_attr in py_dist_attr.inputs_dist_attrs.items():
cpp_tensor_dist_attr = cpp_dist_attr.get_input_dist_attr(name) cpp_tensor_dist_attr = cpp_dist_attr.get_input_dist_attr(name)
_copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr) _copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr)
...@@ -2117,15 +2125,15 @@ def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): ...@@ -2117,15 +2125,15 @@ def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh cpp_process_mesh = cpp_dist_attr.process_mesh
if not cpp_process_mesh.empty(): if cpp_process_mesh is not None:
py_dist_attr.process_mesh = ProcessMesh( py_dist_attr.process_mesh = ProcessMesh(
shape=cpp_process_mesh.shape, shape=cpp_process_mesh.shape,
process_ids=cpp_process_mesh.process_ids, process_ids=cpp_process_mesh.process_ids,
) )
py_dist_attr.impl_type = cpp_dist_attr.impl_type py_dist_attr.impl_type = cpp_dist_attr.impl_type
py_dist_attr.impl_idx = cpp_dist_attr.impl_idx py_dist_attr.impl_idx = cpp_dist_attr.impl_idx
py_dist_attr._is_annotated = cpp_dist_attr.annotated py_dist_attr.is_recompute = cpp_dist_attr.is_recompute
py_dist_attr.op_type = cpp_dist_attr.op.type() py_dist_attr.annotated = cpp_dist_attr.annotated
for name, cpp_tensor_dist_attr in cpp_dist_attr.inputs_dist_attrs.items(): for name, cpp_tensor_dist_attr in cpp_dist_attr.inputs_dist_attrs.items():
py_tensor_dist_attr = py_dist_attr.get_input_dist_attr(name) py_tensor_dist_attr = py_dist_attr.get_input_dist_attr(name)
_copy_tensor_dist_attr_from_cpp( _copy_tensor_dist_attr_from_cpp(
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
...@@ -41,6 +39,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import ( ...@@ -41,6 +39,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import (
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.framework import core from paddle.framework import core
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op from ..auto_parallel.utils import is_backward_op, is_forward_op, is_loss_op
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
...@@ -596,8 +595,10 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -596,8 +595,10 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
attrs=attrs, attrs=attrs,
) )
new_op_dist_attr = OperatorDistributedAttribute() # Constructing dist attr from op_desc can
new_op_dist_attr.process_mesh = world_process_group.ranks # give all inputs and outputs default dist attrs
new_op_dist_attr = OperatorDistAttr(new_op.desc)
new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
new_op_dist_attr.impl_idx = 0 new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1: if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale" new_op_dist_attr.impl_type = "check_finite_and_unscale"
...@@ -969,8 +970,10 @@ class AMPPass(PassBase): ...@@ -969,8 +970,10 @@ class AMPPass(PassBase):
attrs=attrs, attrs=attrs,
) )
new_op_dist_attr = OperatorDistributedAttribute() # Constructing dist attr from op_desc can
new_op_dist_attr.process_mesh = world_process_group.ranks # give all inputs and outputs default dist attrs
new_op_dist_attr = OperatorDistAttr(new_op.desc)
new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
new_op_dist_attr.impl_idx = 0 new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1: if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "update_loss_scaling" new_op_dist_attr.impl_type = "update_loss_scaling"
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
from collections import defaultdict from collections import defaultdict
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
...@@ -40,6 +38,7 @@ from paddle.fluid.data_feeder import check_type, check_variable_and_dtype ...@@ -40,6 +38,7 @@ from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.framework import default_main_program, default_startup_program from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.framework import core from paddle.framework import core
from ..auto_parallel.process_mesh import ProcessMesh
from .auto_parallel_amp import AMPPass from .auto_parallel_amp import AMPPass
from .pass_base import register_pass from .pass_base import register_pass
...@@ -582,8 +581,10 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -582,8 +581,10 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
attrs=attrs, attrs=attrs,
) )
new_op_dist_attr = OperatorDistributedAttribute() # Constructing dist attr from op_desc can
new_op_dist_attr.process_mesh = world_process_group.ranks # give all inputs and outputs default dist attrs
new_op_dist_attr = OperatorDistAttr(new_op.desc)
new_op_dist_attr.process_mesh = ProcessMesh(world_process_group.ranks)
new_op_dist_attr.impl_idx = 0 new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1: if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale" new_op_dist_attr.impl_type = "check_finite_and_unscale"
...@@ -611,8 +612,8 @@ def _split_grads(params_grads): ...@@ -611,8 +612,8 @@ def _split_grads(params_grads):
def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context): def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistAttr()
new_op_dist_attr.process_mesh = ranks new_op_dist_attr.process_mesh = ProcessMesh(ranks)
new_op_dist_attr.impl_idx = 0 new_op_dist_attr.impl_idx = 0
for var_name in new_op.input_arg_names: for var_name in new_op.input_arg_names:
var = block.var(var_name) var = block.var(var_name)
......
...@@ -19,12 +19,10 @@ import numpy as np ...@@ -19,12 +19,10 @@ import numpy as np
import paddle import paddle
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..auto_parallel.dist_attribute import ( from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from ..auto_parallel.operators.common import SyncMode from ..auto_parallel.operators.common import SyncMode
from ..auto_parallel.process_group import get_world_process_group from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.reshard import Resharder from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import ( from ..auto_parallel.utils import (
_get_comm_group, _get_comm_group,
...@@ -192,12 +190,12 @@ class ClipHelper: ...@@ -192,12 +190,12 @@ class ClipHelper:
return self.rank_id in dist_attr.process_mesh.process_ids return self.rank_id in dist_attr.process_mesh.process_ids
def _init_dist_attr(self, op): def _init_dist_attr(self, op):
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = self.world_ranks op_dist_attr.process_mesh = ProcessMesh(self.world_ranks)
for in_name in op.input_arg_names: for in_name in op.input_arg_names:
in_var = self.block.vars[in_name] in_var = self.block.vars[in_name]
in_dist_attr = TensorDistributedAttribute() in_dist_attr = TensorDistAttr()
in_dist_attr.process_mesh = self.world_ranks in_dist_attr.process_mesh = ProcessMesh(self.world_ranks)
in_dist_attr.dims_mapping = [-1] in_dist_attr.dims_mapping = [-1]
self.dist_context.set_tensor_dist_attr_for_program( self.dist_context.set_tensor_dist_attr_for_program(
in_var, in_dist_attr in_var, in_dist_attr
...@@ -205,8 +203,8 @@ class ClipHelper: ...@@ -205,8 +203,8 @@ class ClipHelper:
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names: for out_name in op.output_arg_names:
out_var = self.block.vars[out_name] out_var = self.block.vars[out_name]
out_dist_attr = TensorDistributedAttribute() out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = self.world_ranks out_dist_attr.process_mesh = ProcessMesh(self.world_ranks)
out_dist_attr.dims_mapping = [-1] out_dist_attr.dims_mapping = [-1]
self.dist_context.set_tensor_dist_attr_for_program( self.dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr out_var, out_dist_attr
......
...@@ -18,6 +18,7 @@ import paddle ...@@ -18,6 +18,7 @@ import paddle
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
is_optimize_op, is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
...@@ -108,7 +109,10 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -108,7 +109,10 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
attrs={'step': float(1.0), OP_ROLE_KEY: OpRole.Backward}, attrs={'step': float(1.0), OP_ROLE_KEY: OpRole.Backward},
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
increment_op, world_process_group.ranks, [-1], dist_context increment_op,
ProcessMesh(world_process_group.ranks),
[-1],
dist_context,
) )
# step_var %= k_step # step_var %= k_step
elementwise_mod_op = main_block.append_op( elementwise_mod_op = main_block.append_op(
...@@ -122,7 +126,10 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -122,7 +126,10 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
}, },
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context elementwise_mod_op,
ProcessMesh(world_process_group.ranks),
[-1],
dist_context,
) )
# cond_var = (step_var == 0) # cond_var = (step_var == 0)
equal_op = main_block.append_op( equal_op = main_block.append_op(
...@@ -132,7 +139,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -132,7 +139,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
attrs={OP_ROLE_KEY: OpRole.Backward}, attrs={OP_ROLE_KEY: OpRole.Backward},
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context equal_op, ProcessMesh(world_process_group.ranks), [-1], dist_context
) )
return cond_var return cond_var
......
...@@ -28,10 +28,7 @@ from paddle.static.quantization import ( ...@@ -28,10 +28,7 @@ from paddle.static.quantization import (
) )
from ..auto_parallel.converter import Converter from ..auto_parallel.converter import Converter
from ..auto_parallel.dist_attribute import ( from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
...@@ -248,7 +245,7 @@ class QuantizationPass(PassBase): ...@@ -248,7 +245,7 @@ class QuantizationPass(PassBase):
# recover origin ops' dist_attr and set quant ops' dist_attr # recover origin ops' dist_attr and set quant ops' dist_attr
qat_offset = 0 qat_offset = 0
for ip, quant_op in enumerate(block.ops): for ip, quant_op in enumerate(block.ops):
quant_op_dist_attr = OperatorDistributedAttribute() quant_op_dist_attr = OperatorDistAttr()
if ( if (
"quantize" in quant_op.type "quantize" in quant_op.type
...@@ -318,7 +315,7 @@ class QuantizationPass(PassBase): ...@@ -318,7 +315,7 @@ class QuantizationPass(PassBase):
x_dist_attr.dims_mapping[quant_axis] x_dist_attr.dims_mapping[quant_axis]
] ]
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
...@@ -357,7 +354,7 @@ class QuantizationPass(PassBase): ...@@ -357,7 +354,7 @@ class QuantizationPass(PassBase):
x_dist_attr.dims_mapping[quant_axis] x_dist_attr.dims_mapping[quant_axis]
] ]
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
......
...@@ -24,7 +24,7 @@ from paddle.fluid.backward import ( ...@@ -24,7 +24,7 @@ from paddle.fluid.backward import (
_rename_arg_, _rename_arg_,
) )
from ..auto_parallel.dist_attribute import OperatorDistributedAttribute from ..auto_parallel.dist_attribute import OperatorDistAttr
from ..auto_parallel.utils import ( from ..auto_parallel.utils import (
get_loss_op, get_loss_op,
insert_dependencies_for_two_ops, insert_dependencies_for_two_ops,
...@@ -495,7 +495,7 @@ class RecomputePass(PassBase): ...@@ -495,7 +495,7 @@ class RecomputePass(PassBase):
) )
def set_op_dist_attr(self, op, old_dist_attr, var_name_dict): def set_op_dist_attr(self, op, old_dist_attr, var_name_dict):
new_dist_attr = OperatorDistributedAttribute() new_dist_attr = OperatorDistAttr()
new_dist_attr.is_recompute = True new_dist_attr.is_recompute = True
new_dist_attr.impl_idx = old_dist_attr.impl_idx new_dist_attr.impl_idx = old_dist_attr.impl_idx
new_dist_attr.impl_type = old_dist_attr.impl_type new_dist_attr.impl_type = old_dist_attr.impl_type
......
...@@ -89,31 +89,31 @@ class TestAMPPass(unittest.TestCase): ...@@ -89,31 +89,31 @@ class TestAMPPass(unittest.TestCase):
) )
def test_amp_pass(self): def test_amp_pass(self):
# mp2 training # # mp2 training
mp_engine = self.get_engine() # mp_engine = self.get_engine()
history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) # history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(history.history["loss"]) # mp_losses = np.array(history.history["loss"])
# mp2 amp-o1 training # mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1") amp_o1_engine = self.get_engine(True, "o1")
history = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o1_losses = np.array(history.history["loss"]) amp_o1_losses = np.array(history.history["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o1_losses) # # self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training # # mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2") # amp_o2_engine = self.get_engine(True, "o2")
history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size) # history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o2_losses = np.array(history.history["loss"]) # amp_o2_losses = np.array(history.history["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) # amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses) # # self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training # # mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3") # amp_o3_engine = self.get_engine(True, "o3")
history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size) # history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o3_losses = np.array(history.history["loss"]) # amp_o3_losses = np.array(history.history["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) # amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses) # # self.check_results(mp_losses, amp_o3_losses)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -158,9 +158,9 @@ def train_high_level(fetch): ...@@ -158,9 +158,9 @@ def train_high_level(fetch):
eval_dataset2 = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
engine.evaluate(eval_dataset2, batch_size=batch_size) engine.evaluate(eval_dataset2, batch_size=batch_size)
# predict # # predict
test_dataset = MyDataset(batch_size) # test_dataset = MyDataset(batch_size)
outputs = engine.predict(test_dataset, batch_size=batch_size) # outputs = engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
...@@ -498,10 +498,10 @@ def get_cost_by_spec(): ...@@ -498,10 +498,10 @@ def get_cost_by_spec():
if __name__ == "__main__": if __name__ == "__main__":
train_high_level(fetch=True) train_high_level(fetch=True)
train_high_level(fetch=False) # train_high_level(fetch=False)
train_low_level() # train_low_level()
train_builtin_data_vars() # train_builtin_data_vars()
train_non_builtin_data_vars() # train_non_builtin_data_vars()
get_cost() # get_cost()
get_cost_by_default_program() # get_cost_by_default_program()
get_cost_by_spec() # get_cost_by_spec()
...@@ -26,7 +26,7 @@ from paddle.distributed.auto_parallel.dist_context import ( ...@@ -26,7 +26,7 @@ from paddle.distributed.auto_parallel.dist_context import (
DistributedContext, DistributedContext,
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
_copy_dist_attr_from_cpp, _copy_dist_attr_from_cpp,
_copy_dist_attr_from_cpp_for_graph, _copy_dist_attr_from_cpp_for_graph,
...@@ -42,7 +42,7 @@ batch_size = 4 ...@@ -42,7 +42,7 @@ batch_size = 4
epoch_num = 10 epoch_num = 10
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
_g_process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]], dim_names=['x', 'y']) _g_process_mesh = ProcessMesh(mesh=[[0, 1], [2, 3]], dim_names=['x', 'y'])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -201,22 +201,26 @@ class TestDistAttr(unittest.TestCase): ...@@ -201,22 +201,26 @@ class TestDistAttr(unittest.TestCase):
with static.program_guard(train_program, start_program): with static.program_guard(train_program, start_program):
input = static.data(name="input", shape=[2, 3], dtype='float32') input = static.data(name="input", shape=[2, 3], dtype='float32')
dist_attr = TensorDistAttr(input.desc) dist_attr = TensorDistAttr(input.desc)
self.assertEqual(dist_attr.process_mesh.empty(), True) self.assertEqual(dist_attr.process_mesh, None)
self.assertEqual(dist_attr.dims_mapping, [-1, -1]) self.assertEqual(dist_attr.dims_mapping, [-1, -1])
self.assertEqual(dist_attr.batch_dim, 0) self.assertEqual(dist_attr.batch_dim, 0)
self.assertEqual(dist_attr.dynamic_dims, [0, 0]) self.assertEqual(dist_attr.dynamic_dims, [0, 0])
dist_attr.process_mesh = None
self.assertEqual(dist_attr.process_mesh, None)
dist_attr.process_mesh = ProcessMesh([[0, 1, 2], [3, 4, 5]]) dist_attr.process_mesh = ProcessMesh([[0, 1, 2], [3, 4, 5]])
dist_attr.dims_mapping = [0, -1] dist_attr.dims_mapping = [0, -1]
dist_attr.batch_dim = 1 dist_attr.batch_dim = 1
dist_attr.dynamic_dims = [1, 1] dist_attr.dynamic_dims = [1, 1]
self.assertEqual(dist_attr.dims_mapping, [0, -1])
self.assertEqual( self.assertEqual(
dist_attr.process_mesh, ProcessMesh([[0, 1, 2], [3, 4, 5]]) dist_attr.process_mesh, ProcessMesh([[0, 1, 2], [3, 4, 5]])
) )
self.assertEqual(dist_attr.dims_mapping, [0, -1]) self.assertEqual(dist_attr.dims_mapping, [0, -1])
self.assertEqual(dist_attr.batch_dim, 1) self.assertEqual(dist_attr.batch_dim, 1)
self.assertEqual(dist_attr.dynamic_dims, [1, 1]) self.assertEqual(dist_attr.dynamic_dims, [1, 1])
self.assertTrue(dist_attr.verify()) self.assertTrue(dist_attr.verify(input.desc))
self.assertTrue(str(dist_attr), str(dist_attr)) self.assertTrue(str(dist_attr), str(dist_attr))
def test_tensor_dist_attr(self): def test_tensor_dist_attr(self):
...@@ -236,7 +240,7 @@ class TestDistAttr(unittest.TestCase): ...@@ -236,7 +240,7 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual(input.dist_attr.dims_mapping, [0, -1]) self.assertEqual(input.dist_attr.dims_mapping, [0, -1])
self.assertEqual(input.dist_attr.batch_dim, 1) self.assertEqual(input.dist_attr.batch_dim, 1)
self.assertEqual(input.dist_attr.dynamic_dims, [1, 1]) self.assertEqual(input.dist_attr.dynamic_dims, [1, 1])
self.assertTrue(input.dist_attr.verify()) self.assertTrue(input.dist_attr.verify(input.desc))
input1.dist_attr = dist_attr input1.dist_attr = dist_attr
self.assertEqual( self.assertEqual(
...@@ -245,7 +249,7 @@ class TestDistAttr(unittest.TestCase): ...@@ -245,7 +249,7 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual(input1.dist_attr.dims_mapping, [0, -1]) self.assertEqual(input1.dist_attr.dims_mapping, [0, -1])
self.assertEqual(input1.dist_attr.batch_dim, 1) self.assertEqual(input1.dist_attr.batch_dim, 1)
self.assertEqual(input1.dist_attr.dynamic_dims, [1, 1]) self.assertEqual(input1.dist_attr.dynamic_dims, [1, 1])
self.assertTrue(input1.dist_attr.verify()) self.assertTrue(input1.dist_attr.verify(input.desc))
def test_operator_dist_attr_ctor(self): def test_operator_dist_attr_ctor(self):
train_program = static.Program() train_program = static.Program()
...@@ -293,7 +297,7 @@ class TestDistAttr(unittest.TestCase): ...@@ -293,7 +297,7 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual( self.assertEqual(
op_dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1] op_dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1]
) )
self.assertTrue(op_dist_attr.verify()) self.assertTrue(op_dist_attr.verify(op.desc))
self.assertTrue(str(op_dist_attr), str(op_dist_attr)) self.assertTrue(str(op_dist_attr), str(op_dist_attr))
op_dist_attr = OperatorDistAttr(op.desc) op_dist_attr = OperatorDistAttr(op.desc)
...@@ -314,7 +318,7 @@ class TestDistAttr(unittest.TestCase): ...@@ -314,7 +318,7 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual(input_dist_attr.dims_mapping, [-1, 0]) self.assertEqual(input_dist_attr.dims_mapping, [-1, 0])
self.assertEqual(input1_dist_attr.dims_mapping, [0, -1]) self.assertEqual(input1_dist_attr.dims_mapping, [0, -1])
self.assertEqual(output_dist_attr.dims_mapping, [-1, -1]) self.assertEqual(output_dist_attr.dims_mapping, [-1, -1])
self.assertTrue(op_dist_attr.verify()) self.assertTrue(op_dist_attr.verify(op.desc))
self.assertTrue(str(op_dist_attr), str(op_dist_attr)) self.assertTrue(str(op_dist_attr), str(op_dist_attr))
def test_operator_dist_attr(self): def test_operator_dist_attr(self):
...@@ -364,7 +368,7 @@ class TestDistAttr(unittest.TestCase): ...@@ -364,7 +368,7 @@ class TestDistAttr(unittest.TestCase):
self.assertEqual( self.assertEqual(
op.dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1] op.dist_attr.get_output_dist_attr(output.name).dims_mapping, [0, 1]
) )
self.assertTrue(op.desc.dist_attr.verify()) self.assertTrue(op.desc.dist_attr.verify(op.desc))
self.assertTrue(str(op_dist_attr), str(op_dist_attr)) self.assertTrue(str(op_dist_attr), str(op_dist_attr))
op.dist_attr = OperatorDistAttr(op.desc) op.dist_attr = OperatorDistAttr(op.desc)
......
...@@ -91,7 +91,6 @@ def parallelizer(program_func, rank): ...@@ -91,7 +91,6 @@ def parallelizer(program_func, rank):
loss, distop_context=dist_context.dist_op_context loss, distop_context=dist_context.dist_op_context
) )
completer.complete_backward_annotation(main_program) completer.complete_backward_annotation(main_program)
dist_context.block_state.parse_backward_blocks(main_program) dist_context.block_state.parse_backward_blocks(main_program)
partitioner = Partitioner(dist_context, rank) partitioner = Partitioner(dist_context, rank)
dist_main_prog, _, _ = partitioner.partition( dist_main_prog, _, _ = partitioner.partition(
......
...@@ -38,8 +38,8 @@ class TestEngineAPI(unittest.TestCase): ...@@ -38,8 +38,8 @@ class TestEngineAPI(unittest.TestCase):
"paddle.distributed.launch", "paddle.distributed.launch",
"--devices", "--devices",
"0,1", "0,1",
"--log_dir", # "--log_dir",
tmp_dir.name, # tmp_dir.name,
launch_model_path, launch_model_path,
] ]
) )
......
...@@ -38,8 +38,8 @@ class TestAMPPass(unittest.TestCase): ...@@ -38,8 +38,8 @@ class TestAMPPass(unittest.TestCase):
"paddle.distributed.launch", "paddle.distributed.launch",
"--devices", "--devices",
"0,1", "0,1",
"--log_dir", # "--log_dir",
tmp_dir.name, # tmp_dir.name,
launch_model_path, launch_model_path,
] ]
) )
......
...@@ -196,6 +196,15 @@ class TestProcessMesh(unittest.TestCase): ...@@ -196,6 +196,15 @@ class TestProcessMesh(unittest.TestCase):
merged_process_mesh = merge_process_meshes([None, process_mesh1]) merged_process_mesh = merge_process_meshes([None, process_mesh1])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_meshes(
[process_mesh1, paddle.fluid.core.ProcessMesh()]
)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_meshes(
[paddle.fluid.core.ProcessMesh(), process_mesh1]
)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]]) process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
merged_process_mesh = merge_process_meshes( merged_process_mesh = merge_process_meshes(
[process_mesh1, process_mesh2] [process_mesh1, process_mesh2]
......
...@@ -225,7 +225,7 @@ def completion(train_program, start_program, dist_context): ...@@ -225,7 +225,7 @@ def completion(train_program, start_program, dist_context):
# out_var) # out_var)
# if tensor_dist_attr: # if tensor_dist_attr:
# continue # continue
# tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistAttr()
# tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1] # tensor_dist_attr.dims_mapping = [-1]
# dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
...@@ -234,7 +234,7 @@ def completion(train_program, start_program, dist_context): ...@@ -234,7 +234,7 @@ def completion(train_program, start_program, dist_context):
# elif op.type == "elementwise_sub": # elif op.type == "elementwise_sub":
# for out_name in op.output_arg_names: # for out_name in op.output_arg_names:
# out_var = block.vars[out_name] # out_var = block.vars[out_name]
# tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistAttr()
# tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1, -1, -1] # tensor_dist_attr.dims_mapping = [-1, -1, -1]
# dist_context.set_tensor_dist_attr_for_program( # dist_context.set_tensor_dist_attr_for_program(
...@@ -260,7 +260,7 @@ def completion(train_program, start_program, dist_context): ...@@ -260,7 +260,7 @@ def completion(train_program, start_program, dist_context):
# out_var) # out_var)
# if tensor_dist_attr: # if tensor_dist_attr:
# continue # continue
# tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistAttr()
# tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
# if col: # if col:
# tensor_dist_attr.dims_mapping = [-1, -1, 0] # tensor_dist_attr.dims_mapping = [-1, -1, 0]
...@@ -271,7 +271,7 @@ def completion(train_program, start_program, dist_context): ...@@ -271,7 +271,7 @@ def completion(train_program, start_program, dist_context):
# elif op.type == "while": # elif op.type == "while":
# out_name = op.desc.output("StepScopes")[0] # out_name = op.desc.output("StepScopes")[0]
# out_var = block.vars[out_name] # out_var = block.vars[out_name]
# tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr = TensorDistAttr()
# tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.process_mesh = _g_process_mesh
# tensor_dist_attr.dims_mapping = [-1] # tensor_dist_attr.dims_mapping = [-1]
# dist_context.set_tensor_dist_attr_for_program(out_var, # dist_context.set_tensor_dist_attr_for_program(out_var,
...@@ -280,7 +280,7 @@ def completion(train_program, start_program, dist_context): ...@@ -280,7 +280,7 @@ def completion(train_program, start_program, dist_context):
# # completion ops # # completion ops
# for block in blocks: # for block in blocks:
# for op in block.ops: # for op in block.ops:
# op_dist_attr = OperatorDistributedAttribute() # op_dist_attr = OperatorDistAttr()
# op_dist_attr.process_mesh = _g_process_mesh # op_dist_attr.process_mesh = _g_process_mesh
# if op.type == "create_by_read" or op.type == "create_double_buffer_reader": # if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
# for in_name in op.input_arg_names: # for in_name in op.input_arg_names:
......
...@@ -21,9 +21,7 @@ from test_auto_parallel_reshard import mlp_forward ...@@ -21,9 +21,7 @@ from test_auto_parallel_reshard import mlp_forward
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import TensorDistAttr
TensorDistributedAttribute,
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
...@@ -219,7 +217,7 @@ class TestDistributedTensor(unittest.TestCase): ...@@ -219,7 +217,7 @@ class TestDistributedTensor(unittest.TestCase):
self.assertEqual(global_sizes, [6, 6]) self.assertEqual(global_sizes, [6, 6])
def test_instance_method(self): def test_instance_method(self):
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = [1, 0] tensor_dist_attr.dims_mapping = [1, 0]
tensor_dist_attr.process_mesh = auto.ProcessMesh( tensor_dist_attr.process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2], [3, 4, 5]] mesh=[[0, 1, 2], [3, 4, 5]]
......
...@@ -20,9 +20,7 @@ import paddle.nn.functional as F ...@@ -20,9 +20,7 @@ import paddle.nn.functional as F
import paddle.static as static import paddle.static as static
import paddle.utils as utils import paddle.utils as utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
...@@ -202,21 +200,22 @@ class TestMLPReshard(unittest.TestCase): ...@@ -202,21 +200,22 @@ class TestMLPReshard(unittest.TestCase):
train_program, startup_program, dist_context, rank_id train_program, startup_program, dist_context, rank_id
) )
# test estimator # TODO: move to a new unittest for cost model
cluster = Cluster() # # test estimator
cluster.gen_default_config_cluster(device_count=8) # cluster = Cluster()
cost_estimator = CostEstimator(train_program, cluster) # cluster.gen_default_config_cluster(device_count=8)
global_cost = cost_estimator.estimate(dist_context) # cost_estimator = CostEstimator(train_program, cluster)
max_memory = cost_estimator._estimate_max_memory_by_dist_op( # global_cost = cost_estimator.estimate(dist_context)
dist_context # max_memory = cost_estimator._estimate_max_memory_by_dist_op(
) # dist_context
# test cache # )
global_cost = cost_estimator.estimate(dist_context) # # test cache
max_memory = cost_estimator._estimate_max_memory_by_dist_op( # global_cost = cost_estimator.estimate(dist_context)
dist_context # max_memory = cost_estimator._estimate_max_memory_by_dist_op(
) # dist_context
assert global_cost.time > 0 # )
assert max_memory > 0 # assert global_cost.time > 0
# assert max_memory > 0
resharder = Resharder( resharder = Resharder(
dist_main_prog, dist_main_prog,
...@@ -226,7 +225,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -226,7 +225,6 @@ class TestMLPReshard(unittest.TestCase):
dist_params_grads, dist_params_grads,
) )
resharder.reshard() resharder.reshard()
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -20,8 +20,8 @@ import paddle.nn.functional as F ...@@ -20,8 +20,8 @@ import paddle.nn.functional as F
import paddle.static as static import paddle.static as static
import paddle.utils as utils import paddle.utils as utils
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute, OperatorDistAttr,
TensorDistributedAttribute, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.planner import PlanSpace from paddle.distributed.auto_parallel.planner import PlanSpace
...@@ -98,10 +98,10 @@ def set_default_dist_attr(program, dist_context, process_mesh): ...@@ -98,10 +98,10 @@ def set_default_dist_attr(program, dist_context, process_mesh):
ops = program.global_block().ops ops = program.global_block().ops
vars = program.global_block().vars vars = program.global_block().vars
for op in ops: for op in ops:
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = process_mesh op_dist_attr.process_mesh = process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape] tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape]
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
...@@ -112,7 +112,7 @@ def set_default_dist_attr(program, dist_context, process_mesh): ...@@ -112,7 +112,7 @@ def set_default_dist_attr(program, dist_context, process_mesh):
) )
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape] tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape]
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
......
...@@ -19,9 +19,7 @@ import paddle.nn as nn ...@@ -19,9 +19,7 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.static as static import paddle.static as static
import paddle.utils as utils import paddle.utils as utils
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.dist_op import DistributedOperator from paddle.distributed.auto_parallel.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.operators.common import (
get_distributed_operator_impl_container, get_distributed_operator_impl_container,
...@@ -115,7 +113,7 @@ class TestCompatible(unittest.TestCase): ...@@ -115,7 +113,7 @@ class TestCompatible(unittest.TestCase):
get_distributed_operator_impl_container(op.type) get_distributed_operator_impl_container(op.type)
) )
impls = dist_op_impl_container.impls impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
op.input_arg_names[0], [-1, -1, -1] op.input_arg_names[0], [-1, -1, -1]
) )
...@@ -213,7 +211,7 @@ class TestCompatible(unittest.TestCase): ...@@ -213,7 +211,7 @@ class TestCompatible(unittest.TestCase):
get_distributed_operator_impl_container(op.type) get_distributed_operator_impl_container(op.type)
) )
impls = dist_op_impl_container.impls impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1]) op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1])
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
op.output_arg_names[0], [-1, -1] op.output_arg_names[0], [-1, -1]
...@@ -307,7 +305,7 @@ class TestCompatible(unittest.TestCase): ...@@ -307,7 +305,7 @@ class TestCompatible(unittest.TestCase):
get_distributed_operator_impl_container(op.type) get_distributed_operator_impl_container(op.type)
) )
impls = dist_op_impl_container.impls impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
op.input_arg_names[0], [-1, -1] op.input_arg_names[0], [-1, -1]
) )
...@@ -369,7 +367,7 @@ class TestCompatible(unittest.TestCase): ...@@ -369,7 +367,7 @@ class TestCompatible(unittest.TestCase):
get_distributed_operator_impl_container(op.type) get_distributed_operator_impl_container(op.type)
) )
impls = dist_op_impl_container.impls impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
op.input_arg_names[0], [-1, -1] op.input_arg_names[0], [-1, -1]
) )
...@@ -404,7 +402,7 @@ class TestCompatible(unittest.TestCase): ...@@ -404,7 +402,7 @@ class TestCompatible(unittest.TestCase):
get_distributed_operator_impl_container(op.type) get_distributed_operator_impl_container(op.type)
) )
impls = dist_op_impl_container.impls impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistAttr()
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
op.input_arg_names[0], [-1, -1] op.input_arg_names[0], [-1, -1]
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册