提交 09d82a5d 编写于 作者: Y Yichen Zhang

test wrap DistTensorSpec in dygraph mode

上级 198bc1f5
......@@ -5,3 +5,4 @@ cc_library(
phi_enforce)
add_subdirectory(test)
cc_library(
dist_tensor_spec
SRCS dist_tensor_spec.cc
DEPS dist_attr)
......@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
namespace paddle {
namespace distributed {
......@@ -30,6 +29,25 @@ DistTensorSpec::DistTensorSpec(const std::vector<int64_t>& shape,
DistTensorSpec::~DistTensorSpec() {}
DistTensorSpec::DistTensorSpec(const Tensor& tensor) {
shape_ = tensor.shape();
std::vector<int64_t> pm_shape, pm_ids;
pm_shape = {4};
pm_ids = {0, 1, 2, 3};
std::vector<std::string> dim_name = {"mp"};
ProcessMesh pm(pm_shape, pm_ids, dim_name);
std::vector<int64_t> dims_mapping = {-1, 0};
TensorDistAttr dist_attr;
dist_attr.set_process_mesh(pm);
dist_attr.set_dims_mapping(dims_mapping);
dist_attr_.copy_from(dist_attr);
std::cout << dist_attr_;
}
const std::vector<int64_t>& DistTensorSpec::get_dims_mapping() {
return dist_attr_.dims_mapping();
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace distributed {
......@@ -29,6 +30,8 @@ class DistTensorSpec {
DistTensorSpec(const std::vector<int64_t>& shape,
const TensorDistAttr& dist_attr);
explicit DistTensorSpec(const Tensor& tensor);
~DistTensorSpec();
// get dims_mapping from dist_attr_
......
......@@ -395,7 +395,8 @@ cc_library(
phi_data_transform
api_custom_impl
api_tensor_utils
phi_profiler)
phi_profiler
dist_tensor_spec)
cc_library(
phi_bw_function_api
SRCS ${bw_api_source_file} ${fused_bw_api_source_file}
......
......@@ -1278,6 +1278,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
}}
"""
def gen_dist_tensor_code(self):
# define the DistTensorSpec vector for input and output tensors
api_code = " \nstd::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n"
# get DistTensorSpec for each input tensor
for tensor_name in self.inputs['names']:
api_code += f" input_specs.emplace_back(paddle::distributed::auto_parallel::DistTensorSpec({tensor_name}));\n"
api_code += "\n"
return api_code
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
......@@ -1286,6 +1297,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()}
"""
if api_func_name == 'matmul':
api_code += self.gen_dist_tensor_code()
if len(self.kernel['func']) > 1:
kernel_dispatch_code = ''
......
......@@ -379,6 +379,8 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -2355,3 +2355,52 @@ def is_dep_skip_op(op):
return True
return False
# def wrap_data_for_completion(
# dist_op: DistributedOperator,
# input_names: list,
# output_names: list,
# attr_names: list
# ):
# """
# Get data used in inferring distributed attributes, including:
# 1. DistTensorSpec for each input and output tensor of this dist_op.
# 2. Operator attributes of this dist_op, e.g. transpose_x in matmul op.
#
# Args:
# dist_op: the DistributedOperator
# input_names: list, name of the dist_op's input tensors
# output_names: list, name of the dist_op's output tensors
# attr_names: list, attribute name of the dist_op's corresponding serial op
#
# Returns:
# input_specs: list, DistTensorSpec for each input tensor of the dist_op
# output_specs: list, DistTensorSpec for each output tensor of the dist_op
# attrs: dict, attribute map of the dist op
# """
#
# input_specs = []
# output_specs = []
# attrs = {}
#
# serial_op = dist_op.serial_op
#
# # Construct each input tensor's DistTensorSpec with shape and dist_attr
# for name in input_names:
# tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name)
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
# input_specs.append(dist_spec)
#
# # Construct each output tensor's DistTensorSpec with shape and dist_attr
# for name in output_names:
# tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name)
# var = serial_op.block._var_recursive(name)
# tensor_shape = var.shape
# dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
# output_specs.append(dist_spec)
#
# for attr_name in attr_names:
# attrs[attr_name] = serial_op.desc.attr(attr_name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册