提交 c3ea2a6b 编写于 作者: Y Yichen Zhang

define python api and wrap function in static mode for DistTensorSpec

上级 09d82a5d
......@@ -5,4 +5,4 @@ cc_library(
phi_enforce)
add_subdirectory(test)
add_subdirectory(spmd_rules)
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
......@@ -27,28 +27,41 @@ DistTensorSpec::DistTensorSpec(const std::vector<int64_t>& shape,
dist_attr_.copy_from(dist_attr);
}
DistTensorSpec::DistTensorSpec(const DistTensorSpec& spec) {
std::vector<int64_t> spec_shape = spec.get_shape();
shape_.assign(spec_shape.begin(), spec_shape.end());
dist_attr_.copy_from(spec.get_dist_attr());
}
DistTensorSpec::~DistTensorSpec() {}
DistTensorSpec::DistTensorSpec(const Tensor& tensor) {
shape_ = tensor.shape();
std::vector<int64_t> pm_shape, pm_ids;
pm_shape = {4};
pm_ids = {0, 1, 2, 3};
std::vector<std::string> dim_name = {"mp"};
// std::vector<int64_t> pm_shape, pm_ids;
// pm_shape = {4};
// pm_ids = {0, 1, 2, 3};
// std::vector<std::string> dim_name = {"mp"};
ProcessMesh pm(pm_shape, pm_ids, dim_name);
std::vector<int64_t> dims_mapping = {-1, 0};
TensorDistAttr dist_attr;
dist_attr.set_process_mesh(pm);
dist_attr.set_dims_mapping(dims_mapping);
// ProcessMesh pm(pm_shape, pm_ids, dim_name);
// std::vector<int64_t> dims_mapping = {-1, 0};
// TensorDistAttr dist_attr;
// dist_attr.set_process_mesh(pm);
// dist_attr.set_dims_mapping(dims_mapping);
dist_attr_.copy_from(dist_attr);
// dist_attr_.copy_from(dist_attr);
std::cout << dist_attr_;
// std::cout << dist_attr_;
}
const std::vector<int64_t>& DistTensorSpec::get_dims_mapping() {
DistTensorSpec& DistTensorSpec::operator=(const DistTensorSpec& spec) {
std::vector<int64_t> spec_shape = spec.get_shape();
shape_ = spec_shape;
dist_attr_.copy_from(spec.get_dist_attr());
return *this;
}
const std::vector<int64_t>& DistTensorSpec::get_dims_mapping() const {
return dist_attr_.dims_mapping();
}
......@@ -57,7 +70,7 @@ void DistTensorSpec::set_dims_mapping(
dist_attr_.set_dims_mapping(dims_mapping);
}
const ProcessMesh& DistTensorSpec::get_process_mesh() {
const ProcessMesh& DistTensorSpec::get_process_mesh() const {
return dist_attr_.process_mesh();
}
......@@ -65,7 +78,22 @@ void DistTensorSpec::set_process_mesh(const ProcessMesh& process_mesh) {
dist_attr_.set_process_mesh(process_mesh);
}
const std::vector<int64_t>& DistTensorSpec::get_shape() { return shape_; }
const std::vector<int64_t>& DistTensorSpec::get_shape() const { return shape_; }
const TensorDistAttr& DistTensorSpec::get_dist_attr() const {
return dist_attr_;
}
void DistTensorSpec::set_dist_attr(const TensorDistAttr& dist_attr) {
dist_attr_ = dist_attr;
}
std::string DistTensorSpec::to_string() const {
using phi::distributed::auto_parallel::str_join;
std::string spec_str = "{tensor_shape:[" + str_join(shape_) + "], ";
spec_str += "dist_attr:" + dist_attr_.to_string() + "}";
return spec_str;
}
} // namespace auto_parallel
} // namespace distributed
......
......@@ -14,39 +14,55 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::ProcessMesh;
using phi::distributed::auto_parallel::TensorDistAttr;
/**
* A unified data class for inferring distributed attributes
* in both dygraph mode and static mode
*/
class DistTensorSpec {
public:
DistTensorSpec() = default;
DistTensorSpec(const std::vector<int64_t>& shape,
const TensorDistAttr& dist_attr);
DistTensorSpec(const DistTensorSpec& spec);
// temp function, only for test in dygraph mode
explicit DistTensorSpec(const Tensor& tensor);
~DistTensorSpec();
DistTensorSpec& operator=(const DistTensorSpec& spec);
// get dims_mapping from dist_attr_
const std::vector<int64_t>& get_dims_mapping();
const std::vector<int64_t>& get_dims_mapping() const;
// set dims_mapping in dist_attr_
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
// get process_mesh from dist_attr_
const ProcessMesh& get_process_mesh();
const ProcessMesh& get_process_mesh() const;
// set process_mesh in dist_attr_
void set_process_mesh(const ProcessMesh& process_mesh);
const std::vector<int64_t>& get_shape();
const TensorDistAttr& get_dist_attr() const;
void set_dist_attr(const TensorDistAttr& dist_attr);
const std::vector<int64_t>& get_shape() const;
std::string to_string() const;
private:
std::vector<int64_t> shape_;
......
......@@ -15,6 +15,7 @@
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h"
......@@ -29,6 +30,7 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
using paddle::distributed::auto_parallel::DistTensorSpec;
using paddle::distributed::auto_parallel::OperatorDistAttr;
using paddle::framework::OpDesc;
using paddle::framework::VarDesc;
......@@ -276,6 +278,25 @@ void BindAutoParallel(py::module *m) {
py::arg("memo"))
.def("__str__", &TensorDistAttr::to_string);
py::class_<DistTensorSpec>(*m, "DistTensorSpec")
.def(py::init<>())
.def(py::init<const DistTensorSpec &>())
.def(py::init<const std::vector<int64_t> &, const TensorDistAttr &>())
.def("get_dims_mapping", &DistTensorSpec::get_dims_mapping)
.def("set_dims_mapping", &DistTensorSpec::set_dims_mapping)
.def("get_process_mesh", &DistTensorSpec::get_process_mesh)
.def("set_process_mesh", &DistTensorSpec::set_process_mesh)
.def_property_readonly("shape", &DistTensorSpec::get_shape)
.def("__str__", &DistTensorSpec::to_string)
.def("__copy__",
[](const DistTensorSpec &self) { return DistTensorSpec(self); })
.def(
"__deepcopy__",
[](const DistTensorSpec &self, py::dict) {
return DistTensorSpec(self);
},
py::arg("memo"));
py::class_<OperatorDistAttr>(*m, "OperatorDistAttr")
.def(py::init<>())
.def(py::init<const OpDesc &>())
......
......@@ -1280,7 +1280,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
def gen_dist_tensor_code(self):
# define the DistTensorSpec vector for input and output tensors
api_code = " \nstd::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n"
api_code = " \n std::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n"
# get DistTensorSpec for each input tensor
for tensor_name in self.inputs['names']:
......@@ -1297,8 +1297,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()}
"""
if api_func_name == 'matmul':
api_code += self.gen_dist_tensor_code()
# if api_func_name == 'matmul':
# api_code += self.gen_dist_tensor_code()
if len(self.kernel['func']) > 1:
kernel_dispatch_code = ''
......
......@@ -20,4 +20,5 @@ cc_library(
SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper
dist_tensor_spec)
......@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
from paddle.fluid.core import DistTensorSpec # noqa: F401
from paddle.fluid.core import OperatorDistAttr # noqa: F401
from paddle.fluid.core import TensorDistAttr # noqa: F401
......@@ -105,6 +105,18 @@ def _update_dims_mapping_for_matmul(dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
# test DistTensorSpec
# input_name_list = []
# output_name_list = []
# input_name_list.append(op_desc.input('X')[0])
# input_name_list.append(op_desc.input('Y')[0])
# output_name_list.append(op_desc.output('Out')[0])
# attr_name_list = ['trans_x', 'trans_y']
# input_specs, output_specs, attrs = wrap_data_for_completion(
# dist_op, input_name_list, output_name_list, attr_name_list
# )
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
......
......@@ -26,7 +26,7 @@ from paddle.framework import core
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_attribute import DistTensorSpec, OperatorDistAttr, TensorDistAttr
from .process_group import get_all_process_groups
from .process_mesh import ProcessMesh
......@@ -2357,50 +2357,64 @@ def is_dep_skip_op(op):
return False
# def wrap_data_for_completion(
# dist_op: DistributedOperator,
# input_names: list,
# output_names: list,
# attr_names: list
# ):
# """
# Get data used in inferring distributed attributes, including:
# 1. DistTensorSpec for each input and output tensor of this dist_op.
# 2. Operator attributes of this dist_op, e.g. transpose_x in matmul op.
#
# Args:
# dist_op: the DistributedOperator
# input_names: list, name of the dist_op's input tensors
# output_names: list, name of the dist_op's output tensors
# attr_names: list, attribute name of the dist_op's corresponding serial op
#
# Returns:
# input_specs: list, DistTensorSpec for each input tensor of the dist_op
# output_specs: list, DistTensorSpec for each output tensor of the dist_op
# attrs: dict, attribute map of the dist op
# """
#
# input_specs = []
# output_specs = []
# attrs = {}
#
# serial_op = dist_op.serial_op
#
# # Construct each input tensor's DistTensorSpec with shape and dist_attr
# for name in input_names:
# tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name)
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
# input_specs.append(dist_spec)
#
# # Construct each output tensor's DistTensorSpec with shape and dist_attr
# for name in output_names:
# tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name)
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
# output_specs.append(dist_spec)
#
# for attr_name in attr_names:
# attrs[attr_name] = serial_op.desc.attr(attr_name)
def wrap_data_for_completion(
dist_op, input_names: list, output_names: list, attr_names: list
):
"""
Get data used in inferring distributed attributes, including:
1. DistTensorSpec for each input and output tensor of this dist_op.
2. Operator attributes of this dist_op, e.g. transpose_x in matmul op.
Args:
dist_op: the DistributedOperator
input_names: list, name of the dist_op's input tensors
output_names: list, name of the dist_op's output tensors
attr_names: list, attribute name of the dist_op's corresponding serial op
Returns:
input_specs: list, DistTensorSpec for each input tensor of the dist_op
output_specs: list, DistTensorSpec for each output tensor of the dist_op
attrs: dict, attribute map of the dist op
Usage:
op_desc = dist_op.serial_op.desc
input_name_list = []
output_name_list = []
input_name_list.append(op_desc.input('X')[0]) # 'X' is the arg name for op
input_name_list.append(op_desc.input('Y')[0])
output_name_list.append(op_desc.output('Out')[0])
attr_name_list = ['trans_x', 'trans_y']
input_specs, output_specs, attrs = wrap_data_for_completion(
dist_op,
input_name_list,
output_name_list,
attr_name_list)
"""
input_specs = []
output_specs = []
attrs = {}
serial_op = dist_op.serial_op
# Construct each input tensor's DistTensorSpec with shape and dist_attr
for name in input_names:
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name)
var = serial_op.block._var_recursive(name)
tensor_shape = var.shape
dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
input_specs.append(dist_spec)
# Construct each output tensor's DistTensorSpec with shape and dist_attr
for name in output_names:
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name)
var = serial_op.block._var_recursive(name)
tensor_shape = var.shape
dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
output_specs.append(dist_spec)
for attr_name in attr_names:
attrs[attr_name] = serial_op.desc.attr(attr_name)
return input_specs, output_specs, attrs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册