提交 42a7b771 编写于 作者: Y Yichen Zhang

test wrap DistTensorSpec in dygraph mode

上级 c92992d5
...@@ -20,4 +20,7 @@ cc_library( ...@@ -20,4 +20,7 @@ cc_library(
SRCS dist_mapper.cc SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce) DEPS device_mesh auto_parallel_proto phi_enforce)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper) add_subdirectory(spmd_rules)
cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper
dist_tensor_spec)
cc_library(
dist_tensor_spec
SRCS dist_tensor_spec.cc
DEPS dist_attr)
...@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -30,6 +29,25 @@ DistTensorSpec::DistTensorSpec(const std::vector<int64_t>& shape, ...@@ -30,6 +29,25 @@ DistTensorSpec::DistTensorSpec(const std::vector<int64_t>& shape,
DistTensorSpec::~DistTensorSpec() {} DistTensorSpec::~DistTensorSpec() {}
DistTensorSpec::DistTensorSpec(const Tensor& tensor) {
shape_ = tensor.shape();
std::vector<int64_t> pm_shape, pm_ids;
pm_shape = {4};
pm_ids = {0, 1, 2, 3};
std::vector<std::string> dim_name = {"mp"};
ProcessMesh pm(pm_shape, pm_ids, dim_name);
std::vector<int64_t> dims_mapping = {-1, 0};
TensorDistAttr dist_attr;
dist_attr.set_process_mesh(pm);
dist_attr.set_dims_mapping(dims_mapping);
dist_attr_.copy_from(dist_attr);
std::cout << dist_attr_;
}
const std::vector<int64_t>& DistTensorSpec::get_dims_mapping() { const std::vector<int64_t>& DistTensorSpec::get_dims_mapping() {
return dist_attr_.dims_mapping(); return dist_attr_.dims_mapping();
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h" #include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -29,6 +30,8 @@ class DistTensorSpec { ...@@ -29,6 +30,8 @@ class DistTensorSpec {
DistTensorSpec(const std::vector<int64_t>& shape, DistTensorSpec(const std::vector<int64_t>& shape,
const TensorDistAttr& dist_attr); const TensorDistAttr& dist_attr);
explicit DistTensorSpec(const Tensor& tensor);
~DistTensorSpec(); ~DistTensorSpec();
// get dims_mapping from dist_attr_ // get dims_mapping from dist_attr_
......
...@@ -395,7 +395,8 @@ cc_library( ...@@ -395,7 +395,8 @@ cc_library(
phi_data_transform phi_data_transform
api_custom_impl api_custom_impl
api_tensor_utils api_tensor_utils
phi_profiler) phi_profiler
dist_tensor_spec)
cc_library( cc_library(
phi_bw_function_api phi_bw_function_api
SRCS ${bw_api_source_file} ${fused_bw_api_source_file} SRCS ${bw_api_source_file} ${fused_bw_api_source_file}
......
...@@ -1278,6 +1278,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1278,6 +1278,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
}} }}
""" """
def gen_dist_tensor_code(self):
# define the DistTensorSpec vector for input and output tensors
api_code = " \nstd::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n"
# get DistTensorSpec for each input tensor
for tensor_name in self.inputs['names']:
api_code += f" input_specs.emplace_back(paddle::distributed::auto_parallel::DistTensorSpec({tensor_name}));\n"
api_code += "\n"
return api_code
def gene_base_api_code(self, inplace_flag=False): def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name() api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_': if inplace_flag and api_func_name[-1] != '_':
...@@ -1286,6 +1297,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -1286,6 +1297,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()} {self.gene_kernel_select()}
""" """
if api_func_name == 'matmul':
api_code += self.gen_dist_tensor_code()
if len(self.kernel['func']) > 1: if len(self.kernel['func']) > 1:
kernel_dispatch_code = '' kernel_dispatch_code = ''
......
...@@ -379,6 +379,8 @@ def source_include(header_file_path): ...@@ -379,6 +379,8 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h" #include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h" #include "paddle/phi/api/profiler/supplement_tracing.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
DECLARE_bool(conv2d_disable_cudnn); DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list); DECLARE_int32(low_precision_op_list);
""" """
......
...@@ -2355,3 +2355,52 @@ def is_dep_skip_op(op): ...@@ -2355,3 +2355,52 @@ def is_dep_skip_op(op):
return True return True
return False return False
# def wrap_data_for_completion(
# dist_op: DistributedOperator,
# input_names: list,
# output_names: list,
# attr_names: list
# ):
# """
# Get data used in inferring distributed attributes, including:
# 1. DistTensorSpec for each input and output tensor of this dist_op.
# 2. Operator attributes of this dist_op, e.g. transpose_x in matmul op.
#
# Args:
# dist_op: the DistributedOperator
# input_names: list, name of the dist_op's input tensors
# output_names: list, name of the dist_op's output tensors
# attr_names: list, attribute name of the dist_op's corresponding serial op
#
# Returns:
# input_specs: list, DistTensorSpec for each input tensor of the dist_op
# output_specs: list, DistTensorSpec for each output tensor of the dist_op
# attrs: dict, attribute map of the dist op
# """
#
# input_specs = []
# output_specs = []
# attrs = {}
#
# serial_op = dist_op.serial_op
#
# # Construct each input tensor's DistTensorSpec with shape and dist_attr
# for name in input_names:
# tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name)
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
# input_specs.append(dist_spec)
#
# # Construct each output tensor's DistTensorSpec with shape and dist_attr
# for name in output_names:
# tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name)
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
# output_specs.append(dist_spec)
#
# for attr_name in attr_names:
# attrs[attr_name] = serial_op.desc.attr(attr_name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册