提交 c3ea2a6b 编写于 作者: Y Yichen Zhang

define python api and wrap function in static mode for DistTensorSpec

上级 09d82a5d
...@@ -5,4 +5,4 @@ cc_library( ...@@ -5,4 +5,4 @@ cc_library(
phi_enforce) phi_enforce)
add_subdirectory(test) add_subdirectory(test)
add_subdirectory(spmd_rules)
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -27,28 +27,41 @@ DistTensorSpec::DistTensorSpec(const std::vector<int64_t>& shape, ...@@ -27,28 +27,41 @@ DistTensorSpec::DistTensorSpec(const std::vector<int64_t>& shape,
dist_attr_.copy_from(dist_attr); dist_attr_.copy_from(dist_attr);
} }
DistTensorSpec::DistTensorSpec(const DistTensorSpec& spec) {
std::vector<int64_t> spec_shape = spec.get_shape();
shape_.assign(spec_shape.begin(), spec_shape.end());
dist_attr_.copy_from(spec.get_dist_attr());
}
DistTensorSpec::~DistTensorSpec() {} DistTensorSpec::~DistTensorSpec() {}
DistTensorSpec::DistTensorSpec(const Tensor& tensor) { DistTensorSpec::DistTensorSpec(const Tensor& tensor) {
shape_ = tensor.shape(); shape_ = tensor.shape();
std::vector<int64_t> pm_shape, pm_ids; // std::vector<int64_t> pm_shape, pm_ids;
pm_shape = {4}; // pm_shape = {4};
pm_ids = {0, 1, 2, 3}; // pm_ids = {0, 1, 2, 3};
std::vector<std::string> dim_name = {"mp"}; // std::vector<std::string> dim_name = {"mp"};
ProcessMesh pm(pm_shape, pm_ids, dim_name); // ProcessMesh pm(pm_shape, pm_ids, dim_name);
std::vector<int64_t> dims_mapping = {-1, 0}; // std::vector<int64_t> dims_mapping = {-1, 0};
TensorDistAttr dist_attr; // TensorDistAttr dist_attr;
dist_attr.set_process_mesh(pm); // dist_attr.set_process_mesh(pm);
dist_attr.set_dims_mapping(dims_mapping); // dist_attr.set_dims_mapping(dims_mapping);
dist_attr_.copy_from(dist_attr); // dist_attr_.copy_from(dist_attr);
std::cout << dist_attr_; // std::cout << dist_attr_;
} }
const std::vector<int64_t>& DistTensorSpec::get_dims_mapping() { DistTensorSpec& DistTensorSpec::operator=(const DistTensorSpec& spec) {
std::vector<int64_t> spec_shape = spec.get_shape();
shape_ = spec_shape;
dist_attr_.copy_from(spec.get_dist_attr());
return *this;
}
const std::vector<int64_t>& DistTensorSpec::get_dims_mapping() const {
return dist_attr_.dims_mapping(); return dist_attr_.dims_mapping();
} }
...@@ -57,7 +70,7 @@ void DistTensorSpec::set_dims_mapping( ...@@ -57,7 +70,7 @@ void DistTensorSpec::set_dims_mapping(
dist_attr_.set_dims_mapping(dims_mapping); dist_attr_.set_dims_mapping(dims_mapping);
} }
const ProcessMesh& DistTensorSpec::get_process_mesh() { const ProcessMesh& DistTensorSpec::get_process_mesh() const {
return dist_attr_.process_mesh(); return dist_attr_.process_mesh();
} }
...@@ -65,7 +78,22 @@ void DistTensorSpec::set_process_mesh(const ProcessMesh& process_mesh) { ...@@ -65,7 +78,22 @@ void DistTensorSpec::set_process_mesh(const ProcessMesh& process_mesh) {
dist_attr_.set_process_mesh(process_mesh); dist_attr_.set_process_mesh(process_mesh);
} }
const std::vector<int64_t>& DistTensorSpec::get_shape() { return shape_; } const std::vector<int64_t>& DistTensorSpec::get_shape() const { return shape_; }
const TensorDistAttr& DistTensorSpec::get_dist_attr() const {
return dist_attr_;
}
void DistTensorSpec::set_dist_attr(const TensorDistAttr& dist_attr) {
dist_attr_ = dist_attr;
}
std::string DistTensorSpec::to_string() const {
using phi::distributed::auto_parallel::str_join;
std::string spec_str = "{tensor_shape:[" + str_join(shape_) + "], ";
spec_str += "dist_attr:" + dist_attr_.to_string() + "}";
return spec_str;
}
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
......
...@@ -14,39 +14,55 @@ limitations under the License. */ ...@@ -14,39 +14,55 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
namespace auto_parallel { namespace auto_parallel {
using phi::distributed::auto_parallel::ProcessMesh;
using phi::distributed::auto_parallel::TensorDistAttr;
/** /**
* A unified data class for inferring distributed attributes * A unified data class for inferring distributed attributes
* in both dygraph mode and static mode * in both dygraph mode and static mode
*/ */
class DistTensorSpec { class DistTensorSpec {
public: public:
DistTensorSpec() = default;
DistTensorSpec(const std::vector<int64_t>& shape, DistTensorSpec(const std::vector<int64_t>& shape,
const TensorDistAttr& dist_attr); const TensorDistAttr& dist_attr);
DistTensorSpec(const DistTensorSpec& spec);
// temp function, only for test in dygraph mode
explicit DistTensorSpec(const Tensor& tensor); explicit DistTensorSpec(const Tensor& tensor);
~DistTensorSpec(); ~DistTensorSpec();
DistTensorSpec& operator=(const DistTensorSpec& spec);
// get dims_mapping from dist_attr_ // get dims_mapping from dist_attr_
const std::vector<int64_t>& get_dims_mapping(); const std::vector<int64_t>& get_dims_mapping() const;
// set dims_mapping in dist_attr_ // set dims_mapping in dist_attr_
void set_dims_mapping(const std::vector<int64_t>& dims_mapping); void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
// get process_mesh from dist_attr_ // get process_mesh from dist_attr_
const ProcessMesh& get_process_mesh(); const ProcessMesh& get_process_mesh() const;
// set process_mesh in dist_attr_ // set process_mesh in dist_attr_
void set_process_mesh(const ProcessMesh& process_mesh); void set_process_mesh(const ProcessMesh& process_mesh);
const std::vector<int64_t>& get_shape(); const TensorDistAttr& get_dist_attr() const;
void set_dist_attr(const TensorDistAttr& dist_attr);
const std::vector<int64_t>& get_shape() const;
std::string to_string() const;
private: private:
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h" #include "paddle/fluid/pybind/auto_parallel_py.h"
...@@ -29,6 +30,7 @@ namespace py = pybind11; ...@@ -29,6 +30,7 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using paddle::distributed::auto_parallel::DistTensorSpec;
using paddle::distributed::auto_parallel::OperatorDistAttr; using paddle::distributed::auto_parallel::OperatorDistAttr;
using paddle::framework::OpDesc; using paddle::framework::OpDesc;
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
...@@ -276,6 +278,25 @@ void BindAutoParallel(py::module *m) { ...@@ -276,6 +278,25 @@ void BindAutoParallel(py::module *m) {
py::arg("memo")) py::arg("memo"))
.def("__str__", &TensorDistAttr::to_string); .def("__str__", &TensorDistAttr::to_string);
py::class_<DistTensorSpec>(*m, "DistTensorSpec")
.def(py::init<>())
.def(py::init<const DistTensorSpec &>())
.def(py::init<const std::vector<int64_t> &, const TensorDistAttr &>())
.def("get_dims_mapping", &DistTensorSpec::get_dims_mapping)
.def("set_dims_mapping", &DistTensorSpec::set_dims_mapping)
.def("get_process_mesh", &DistTensorSpec::get_process_mesh)
.def("set_process_mesh", &DistTensorSpec::set_process_mesh)
.def_property_readonly("shape", &DistTensorSpec::get_shape)
.def("__str__", &DistTensorSpec::to_string)
.def("__copy__",
[](const DistTensorSpec &self) { return DistTensorSpec(self); })
.def(
"__deepcopy__",
[](const DistTensorSpec &self, py::dict) {
return DistTensorSpec(self);
},
py::arg("memo"));
py::class_<OperatorDistAttr>(*m, "OperatorDistAttr") py::class_<OperatorDistAttr>(*m, "OperatorDistAttr")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const OpDesc &>()) .def(py::init<const OpDesc &>())
......
...@@ -1280,7 +1280,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1280,7 +1280,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
def gen_dist_tensor_code(self): def gen_dist_tensor_code(self):
# define the DistTensorSpec vector for input and output tensors # define the DistTensorSpec vector for input and output tensors
api_code = " \nstd::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n" api_code = " \n std::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n"
# get DistTensorSpec for each input tensor # get DistTensorSpec for each input tensor
for tensor_name in self.inputs['names']: for tensor_name in self.inputs['names']:
...@@ -1297,8 +1297,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1297,8 +1297,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()} {self.gene_kernel_select()}
""" """
if api_func_name == 'matmul': # if api_func_name == 'matmul':
api_code += self.gen_dist_tensor_code() # api_code += self.gen_dist_tensor_code()
if len(self.kernel['func']) > 1: if len(self.kernel['func']) > 1:
kernel_dispatch_code = '' kernel_dispatch_code = ''
......
...@@ -20,4 +20,5 @@ cc_library( ...@@ -20,4 +20,5 @@ cc_library(
SRCS dist_mapper.cc SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce) DEPS device_mesh auto_parallel_proto phi_enforce)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper) cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper
dist_tensor_spec)
...@@ -12,5 +12,6 @@ ...@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from paddle.fluid.core import DistTensorSpec # noqa: F401
from paddle.fluid.core import OperatorDistAttr # noqa: F401 from paddle.fluid.core import OperatorDistAttr # noqa: F401
from paddle.fluid.core import TensorDistAttr # noqa: F401 from paddle.fluid.core import TensorDistAttr # noqa: F401
...@@ -105,6 +105,18 @@ def _update_dims_mapping_for_matmul(dist_op): ...@@ -105,6 +105,18 @@ def _update_dims_mapping_for_matmul(dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
# test DistTensorSpec
# input_name_list = []
# output_name_list = []
# input_name_list.append(op_desc.input('X')[0])
# input_name_list.append(op_desc.input('Y')[0])
# output_name_list.append(op_desc.output('Out')[0])
# attr_name_list = ['trans_x', 'trans_y']
# input_specs, output_specs, attrs = wrap_data_for_completion(
# dist_op, input_name_list, output_name_list, attr_name_list
# )
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
......
...@@ -26,7 +26,7 @@ from paddle.framework import core ...@@ -26,7 +26,7 @@ from paddle.framework import core
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable from paddle.static import Variable
from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_attribute import DistTensorSpec, OperatorDistAttr, TensorDistAttr
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
...@@ -2357,50 +2357,64 @@ def is_dep_skip_op(op): ...@@ -2357,50 +2357,64 @@ def is_dep_skip_op(op):
return False return False
# def wrap_data_for_completion( def wrap_data_for_completion(
# dist_op: DistributedOperator, dist_op, input_names: list, output_names: list, attr_names: list
# input_names: list, ):
# output_names: list, """
# attr_names: list Get data used in inferring distributed attributes, including:
# ): 1. DistTensorSpec for each input and output tensor of this dist_op.
# """ 2. Operator attributes of this dist_op, e.g. transpose_x in matmul op.
# Get data used in inferring distributed attributes, including:
# 1. DistTensorSpec for each input and output tensor of this dist_op. Args:
# 2. Operator attributes of this dist_op, e.g. transpose_x in matmul op. dist_op: the DistributedOperator
# input_names: list, name of the dist_op's input tensors
# Args: output_names: list, name of the dist_op's output tensors
# dist_op: the DistributedOperator attr_names: list, attribute name of the dist_op's corresponding serial op
# input_names: list, name of the dist_op's input tensors
# output_names: list, name of the dist_op's output tensors Returns:
# attr_names: list, attribute name of the dist_op's corresponding serial op input_specs: list, DistTensorSpec for each input tensor of the dist_op
# output_specs: list, DistTensorSpec for each output tensor of the dist_op
# Returns: attrs: dict, attribute map of the dist op
# input_specs: list, DistTensorSpec for each input tensor of the dist_op
# output_specs: list, DistTensorSpec for each output tensor of the dist_op Usage:
# attrs: dict, attribute map of the dist op op_desc = dist_op.serial_op.desc
# """ input_name_list = []
# output_name_list = []
# input_specs = [] input_name_list.append(op_desc.input('X')[0]) # 'X' is the arg name for op
# output_specs = [] input_name_list.append(op_desc.input('Y')[0])
# attrs = {} output_name_list.append(op_desc.output('Out')[0])
# attr_name_list = ['trans_x', 'trans_y']
# serial_op = dist_op.serial_op input_specs, output_specs, attrs = wrap_data_for_completion(
# dist_op,
# # Construct each input tensor's DistTensorSpec with shape and dist_attr input_name_list,
# for name in input_names: output_name_list,
# tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name) attr_name_list)
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape """
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
# input_specs.append(dist_spec) input_specs = []
# output_specs = []
# # Construct each output tensor's DistTensorSpec with shape and dist_attr attrs = {}
# for name in output_names:
# tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name) serial_op = dist_op.serial_op
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape # Construct each input tensor's DistTensorSpec with shape and dist_attr
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr) for name in input_names:
# output_specs.append(dist_spec) tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name)
# var = serial_op.block._var_recursive(name)
# for attr_name in attr_names: tensor_shape = var.shape
# attrs[attr_name] = serial_op.desc.attr(attr_name) dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
input_specs.append(dist_spec)
# Construct each output tensor's DistTensorSpec with shape and dist_attr
for name in output_names:
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name)
var = serial_op.block._var_recursive(name)
tensor_shape = var.shape
dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
output_specs.append(dist_spec)
for attr_name in attr_names:
attrs[attr_name] = serial_op.desc.attr(attr_name)
return input_specs, output_specs, attrs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册