未验证 提交 d2fedeac 编写于 作者: G Ghost Screaming 提交者: GitHub

[Auto Parallel]: Support std::vector<phi::Tensor> input and output for DistTensor. (#56602)

* [WIP] Support std::vector<phi::Tensor> input and output for DistTensor.
Concat forward and backward are verified.

* Polish code for new dist tensor implementation.

* Fix bug of DistTensor upgrade. Add support functions for std::vector<Tensor> -> std::vector<Tensor>.

* Add support for DistTensor type of std::vector<phi::Tensor> as input or output of operators.
Following testcases are passed.
1. concat: std::vector<phi::Tensor> -> phi::Tensor
2. unbind: phi::Tensor -> std::vector<phi::Tensor>
3. broadcast_tensors: std::vector<phi::Tensor> -> std::vector<phi::Tensor>

* Polish code. Remove useless comments.

* Add update_loss_scaling in skip_op_lists.

* Polish code.
上级 c2f0e9c4
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
#ifndef PADDLE_NO_PYTHON #ifndef PADDLE_NO_PYTHON
#include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/hooks.h"
#endif #endif
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
namespace egr { namespace egr {
class TensorWrapper { class TensorWrapper {
...@@ -66,6 +68,16 @@ class TensorWrapper { ...@@ -66,6 +68,16 @@ class TensorWrapper {
intermidiate_tensor_.set_impl(std::make_shared<phi::DenseTensor>( intermidiate_tensor_.set_impl(std::make_shared<phi::DenseTensor>(
std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()), std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()),
dense_tensor->meta())); dense_tensor->meta()));
} else if (phi::distributed::DistTensor::classof(tensor.impl().get())) {
// Only Copy Meta
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor.impl().get());
intermidiate_tensor_.set_impl(
std::make_shared<phi::distributed::DistTensor>(
phi::DenseTensor(std::make_shared<phi::Allocation>(
nullptr, 0, tensor.place()),
dist_tensor->value().meta()),
dist_tensor->dist_attr()));
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature")); "Unrecognized tensor type for no_need_buffer feature"));
......
...@@ -192,8 +192,6 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -192,8 +192,6 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors; return meta_tensors;
} }
/* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Tensor* out) { phi::DenseTensor* SetKernelOutput(Tensor* out) {
if (out) { if (out) {
if (out->impl() == nullptr) { if (out->impl() == nullptr) {
...@@ -546,5 +544,45 @@ phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) { ...@@ -546,5 +544,45 @@ phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) {
return nullptr; return nullptr;
} }
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out) {
std::vector<phi::distributed::DistTensor*> result;
for (auto tmp : out) {
if (tmp) {
// TODO(GhostScreaming): now all dist case are nullptr
if (tmp->impl() == nullptr) {
phi::DenseTensor dense_t;
// TODO(GhostScreaming): polish code, dist_attr is null now
phi::distributed::TensorDistAttr dist_attr;
auto dist_t =
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr);
tmp->set_impl(dist_t);
}
result.emplace_back(
static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
} else {
result.emplace_back(nullptr);
}
}
return result;
}
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out) {
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
phi::DenseTensor dense_t;
// TODO(GhostScreaming): polish code, dist_attr is null now
phi::distributed::TensorDistAttr dist_attr;
auto dist_t =
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr);
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
}
return results;
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -140,6 +140,10 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, ...@@ -140,6 +140,10 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx,
/* ------------------ for auto parallel ----------------------- */ /* ------------------ for auto parallel ----------------------- */
phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out); phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out);
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out);
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -632,5 +632,47 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor( ...@@ -632,5 +632,47 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
return nullptr; return nullptr;
} }
std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(const std::vector<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> out;
for (auto x : input) {
const auto& tensor_in = x.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
const phi::DenseTensor& dense_tensor = dist_tensor->value();
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace(
dense_tensor.place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType(
dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout(dense_tensor.layout(),
target_args_def.layout,
dense_tensor.place(),
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
out.push_back(
std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in));
continue;
}
phi::DenseTensor trans_in_tensor = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(GhostScreaming): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
trans_in_tensor, dist_tensor->dist_attr()));
} else {
out.push_back(nullptr);
}
}
return out;
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -180,5 +180,11 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor( ...@@ -180,5 +180,11 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const TransformFlag& transform_flag, const TransformFlag& transform_flag,
bool is_stride_kernel); bool is_stride_kernel);
std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(const std::vector<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -191,6 +191,16 @@ struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> { ...@@ -191,6 +191,16 @@ struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
} }
} }
void operator()(const paddle::optional<std::vector<Tensor>>& x) {
if (x) {
if (!(x.get_ptr()->empty())) {
for (auto& t : *(x.get_ptr())) {
result &= t.is_dist_tensor();
}
}
}
}
// skip other type args, these args don't used in kernel selection // skip other type args, these args don't used in kernel selection
template <typename T> template <typename T>
void operator()(const T& x) { void operator()(const T& x) {
......
...@@ -75,14 +75,21 @@ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ ...@@ -75,14 +75,21 @@ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = SetKernelDistOutput({}); auto dist_out_{} = SetKernelDistOutput({});
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value()); auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value());
""" """
# TODO(chenweihang): support vector and tuple output later
VECTOR_OUT_CREATION_TEMPLATE = """ VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({}, &api_output);
std::vector<phi::DenseTensor*> dense_out(dist_out.size());
for (size_t i = 0; i < dist_out.size(); i++) {{
dense_out[i] = const_cast<phi::DenseTensor*>(&dist_out[i]->value());
}}
""" """
MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out_{} = {}({}, {}); auto dist_out_{out_name} = SetKernelDistOutput({size}, {in_name});
auto dense_out_{} = const_cast<phi::DenseTensor*>(&dist_out_{}->value()); std::vector<phi::DenseTensor*> dense_out_{out_name}(dist_out_{out_name}.size());
for (size_t i = 0; i < dist_out_{out_name}.size(); i++) {{
dense_out_{out_name}[i] = const_cast<phi::DenseTensor*>(&dist_out_{out_name}[i]->value());
}}
""" """
# TODO(GhostScreaming): support tuple output later
TUPLE_OUT_CREATION_TEMPLATE = """ TUPLE_OUT_CREATION_TEMPLATE = """
""" """
...@@ -90,13 +97,32 @@ TUPLE_OUT_CREATION_TEMPLATE = """ ...@@ -90,13 +97,32 @@ TUPLE_OUT_CREATION_TEMPLATE = """
# Call InferMeta now, replace by InferSPMD function later # Call InferMeta now, replace by InferSPMD function later
# TODO(chenweihang): InferSPMD function design # TODO(chenweihang): InferSPMD function design
SINGLE_DIST_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """ SINGLE_DIST_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """
# TODO(chenweihang): support vector and optional args later VECTOR_DIST_META_IN_TEMPLATE = """{}_meta_ptr_vec, """
VECTOR_DIST_META_IN_TEMPLATE = """ VECTOR_DIST_META_IN_DECL_TEMPLATE = """
std::vector<phi::MetaTensor> {name}_meta_vec;
for (auto tmp : {name}) {{
{name}_meta_vec.emplace_back(MakeMetaTensor(*tmp.impl()));
}}
std::vector<const phi::MetaTensor*> {name}_meta_ptr_vec({name}_meta_vec.size());
for (size_t i=0; i<{name}_meta_ptr_vec.size(); i++) {{
{name}_meta_ptr_vec[i] = &{name}_meta_vec[i];
}}
""" """
# TODO(GhostScreaming): support optional args later
OPTIONAL_DIST_VECTOR_META_IN_TEMPLATE = """ OPTIONAL_DIST_VECTOR_META_IN_TEMPLATE = """
""" """
SINGLE_DIST_META_OUT_DECL_TEMPLATE = """ SINGLE_DIST_META_OUT_DECL_TEMPLATE = """
phi::MetaTensor meta_{}({});""" phi::MetaTensor meta_{}({});"""
VECTOR_DIST_META_OUT_DECL_TEMPLATE = """
std::vector<phi::MetaTensor> {name}_meta_vec;
for (auto tmp : {name}) {{
{name}_meta_vec.emplace_back(phi::MetaTensor(tmp));
}}
std::vector<phi::MetaTensor*> {name}_meta_ptr_vec({name}.size());
for (size_t i=0; i<{name}_meta_vec.size(); i++) {{
{name}_meta_ptr_vec[i] = &{name}_meta_vec[i];
}}
"""
INFER_SPMD_TEMPLATE = """ INFER_SPMD_TEMPLATE = """
phi::{}({}{}); phi::{}({}{});
""" """
...@@ -120,6 +146,18 @@ SINGLE_PREPARE_DATA_TEMPLATE = """ ...@@ -120,6 +146,18 @@ SINGLE_PREPARE_DATA_TEMPLATE = """
auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel); auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel);
auto input_{} = &dist_input_{}->value(); auto input_{} = &dist_input_{}->value();
""" """
VECTOR_PREPARE_DATA_TEMPLATE = """
auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel);
std::vector<const phi::DenseTensor*> dense_input_{name}_vec;
for (auto tmp : dist_input_{name}_vec) {{
dense_input_{name}_vec.emplace_back(&tmp->value());
}}
std::vector<phi::MetaTensor> dense_input_{name}_meta_vec = MakeMetaTensor(dense_input_{name}_vec);
std::vector<const phi::MetaTensor*> dense_input_{name}_meta_ptr_vec(dense_input_{name}_meta_vec.size());
for (size_t i=0; i<dense_input_{name}_meta_vec.size(); i++) {{
dense_input_{name}_meta_ptr_vec[i] = &dense_input_{name}_meta_vec[i];
}}
"""
INFER_META_SINGLE_INPUT_TEMPLATE = """ INFER_META_SINGLE_INPUT_TEMPLATE = """
auto dist_input_{} = {}.impl(); auto dist_input_{} = {}.impl();
auto input_{} = &(static_cast<phi::distributed::DistTensor*>(dist_input_{}.get())->value()); auto input_{} = &(static_cast<phi::distributed::DistTensor*>(dist_input_{}.get())->value());
...@@ -134,13 +172,19 @@ INFER_META_VECTOR_INPUT_TEMPLATE = """ ...@@ -134,13 +172,19 @@ INFER_META_VECTOR_INPUT_TEMPLATE = """
# 6. Infer Local DenseTensor Meta # 6. Infer Local DenseTensor Meta
SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """ SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """
# TODO(chenweihang): support vector and optional args later # TODO(GhostScreaming): support optional args later
VECTOR_META_IN_TEMPLATE = """ VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """
"""
OPTIONAL_VECTOR_META_IN_TEMPLATE = """ OPTIONAL_VECTOR_META_IN_TEMPLATE = """
""" """
SINGLE_META_OUT_DECL_TEMPLATE = """ SINGLE_META_OUT_DECL_TEMPLATE = """
phi::MetaTensor meta_{}({});""" phi::MetaTensor meta_{}({});"""
VECTOR_META_OUT_DECL_TEMPLATE = """
std::vector<phi::MetaTensor> {name}_meta_vec = MakeMetaTensor({name});
std::vector<phi::MetaTensor*> {name}_meta_ptr_vec({name}_meta_vec.size());
for (size_t i=0; i<{name}_meta_vec.size(); i++) {{
{name}_meta_ptr_vec[i] = &{name}_meta_vec[i];
}}
"""
INFER_META_TEMPLATE = """ INFER_META_TEMPLATE = """
phi::{}({}{}); phi::{}({}{});
""" """
...@@ -158,6 +202,8 @@ KERNEL_CALL_TEMPLATE = """ ...@@ -158,6 +202,8 @@ KERNEL_CALL_TEMPLATE = """
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({}, {}); (*kernel_fn)({}, {});
""" """
PREFIX_VECTOR_TENSOR_NAME = "dense_input_"
SUFFIX_VECTOR_TENSOR_NAME = "_vec"
# 8. Reshard Output # 8. Reshard Output
OUTPUT_RESHARD_TEMPLATE = """ OUTPUT_RESHARD_TEMPLATE = """
...@@ -175,6 +221,15 @@ OUTPUT_RESHARD_TEMPLATE = """ ...@@ -175,6 +221,15 @@ OUTPUT_RESHARD_TEMPLATE = """
# types : [], list of output types # types : [], list of output types
# out_size_expr : [], expression for getting size of vector<Tensor> # out_size_expr : [], expression for getting size of vector<Tensor>
# TODO(GhostScreaming): Support std::tuple<...> type of input and output later.
skip_op_lists = [
"check_finite_and_unscale", # std::vector<Tensor>&, const Tensor& -> std::tuple<std::vector<Tensor>&, Tensor>
"coalesce_tensor", # const std::vector<Tensor>&, DataType, bool, bool, bool, float, bool, int, int, const std::vector<int64_t>&, const std::vector<int64_t>& -> std::tuple<std::vector<Tensor>, Tensor>
"update_loss_scaling", # std::vector<Tensor>, const Tensor, ... -> std::tuple<std::vector<Tensor>, Tensor, Tensor, Tensor>
"einsum",
"einsum_grad", # const std::vector<Tensor>&, const std::string& -> std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>>
]
class DistForwardAPI(ForwardAPI): class DistForwardAPI(ForwardAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
...@@ -184,10 +239,13 @@ class DistForwardAPI(ForwardAPI): ...@@ -184,10 +239,13 @@ class DistForwardAPI(ForwardAPI):
def init_dist_api_members(self): def init_dist_api_members(self):
self.gene_dist_input_func = { self.gene_dist_input_func = {
"const Tensor&": { "const Tensor&": {
"dense": self.generate_dense_input, "dense": self.generate_single_dense_input,
}, },
"const paddle::optional<Tensor>&": { "const paddle::optional<Tensor>&": {
"dense": self.generate_dense_input, "dense": self.generate_single_dense_input,
},
"const std::vector<Tensor>&": {
"dense": self.generate_vector_dense_input,
}, },
} }
...@@ -254,6 +312,10 @@ class DistForwardAPI(ForwardAPI): ...@@ -254,6 +312,10 @@ class DistForwardAPI(ForwardAPI):
self.dense_output_args.append('dense_out') self.dense_output_args.append('dense_out')
if self.outputs['types'][0] == 'Tensor': if self.outputs['types'][0] == 'Tensor':
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE output_creation_code += SINGLE_OUT_CREATION_TEMPLATE
elif self.outputs['types'][0] == 'std::vector<Tensor>':
output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format(
self.outputs['out_size_expr'][0]
)
else: else:
self.vector_output_size_assertion_check() self.vector_output_size_assertion_check()
elif output_num > 1: elif output_num > 1:
...@@ -285,10 +347,10 @@ class DistForwardAPI(ForwardAPI): ...@@ -285,10 +347,10 @@ class DistForwardAPI(ForwardAPI):
get_out_code = f"std::get<{i}>(api_output).get_ptr()" get_out_code = f"std::get<{i}>(api_output).get_ptr()"
if out_type == 'std::vector<Tensor>': if out_type == 'std::vector<Tensor>':
self.vector_output_size_assertion_check(i) self.vector_output_size_assertion_check()
# Special case for inplace vector and inplace optional<vector> # Special case for inplace vector and inplace optional<vector>
# TODO(chenweihang): support this branch later # TODO(chenweihang): support this branch later
if self.is_inplace_output(): if self.is_inplace_output(i):
set_out_func = "SetInplaceVectorKernelOutput" set_out_func = "SetInplaceVectorKernelOutput"
if self.is_inplace_and_optional_output(i): if self.is_inplace_and_optional_output(i):
set_out_func = ( set_out_func = (
...@@ -297,12 +359,9 @@ class DistForwardAPI(ForwardAPI): ...@@ -297,12 +359,9 @@ class DistForwardAPI(ForwardAPI):
get_out_code = f"std::get<{i}>(api_output)" get_out_code = f"std::get<{i}>(api_output)"
output_creation_code += ( output_creation_code += (
MULTI_VECTOR_OUT_CREATION_TEMPLATE.format( MULTI_VECTOR_OUT_CREATION_TEMPLATE.format(
i, out_name=i,
set_out_func, size=self.outputs['out_size_expr'][i],
self.outputs['out_size_expr'][i], in_name=get_out_code,
get_out_code,
i,
i,
) )
) )
else: else:
...@@ -335,6 +394,7 @@ class DistForwardAPI(ForwardAPI): ...@@ -335,6 +394,7 @@ class DistForwardAPI(ForwardAPI):
if infer_meta['param'] is not None if infer_meta['param'] is not None
else input_names + attr_names else input_names + attr_names
) )
input_meta_code = ""
input_args_code = "" input_args_code = ""
for param in infer_meta_params: for param in infer_meta_params:
if param in input_names: if param in input_names:
...@@ -342,6 +402,16 @@ class DistForwardAPI(ForwardAPI): ...@@ -342,6 +402,16 @@ class DistForwardAPI(ForwardAPI):
input_args_code += SINGLE_DIST_META_IN_TEMPLATE.format( input_args_code += SINGLE_DIST_META_IN_TEMPLATE.format(
param param
) )
elif (
self.inputs['input_info'][param]
== "const std::vector<Tensor>&"
):
input_args_code += VECTOR_DIST_META_IN_TEMPLATE.format(
param
)
input_meta_code += VECTOR_DIST_META_IN_DECL_TEMPLATE.format(
name=param
)
else: else:
raise ValueError( raise ValueError(
f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported."
...@@ -360,8 +430,15 @@ class DistForwardAPI(ForwardAPI): ...@@ -360,8 +430,15 @@ class DistForwardAPI(ForwardAPI):
output_args_code = "" output_args_code = ""
for i, out_name in enumerate(self.dist_output_args): for i, out_name in enumerate(self.dist_output_args):
if self.outputs['types'][i] == 'std::vector<Tensor>': if self.outputs['types'][i] == 'std::vector<Tensor>':
# TODO(chenweihang): support vector output later output_decl_code += VECTOR_DIST_META_OUT_DECL_TEMPLATE.format(
pass name=out_name
)
if len(self.dense_output_args) == 1:
output_args_code += f"{out_name}_meta_ptr_vec, "
else:
output_args_code += (
f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, "
)
else: else:
output_decl_code += SINGLE_DIST_META_OUT_DECL_TEMPLATE.format( output_decl_code += SINGLE_DIST_META_OUT_DECL_TEMPLATE.format(
out_name, out_name out_name, out_name
...@@ -374,8 +451,12 @@ class DistForwardAPI(ForwardAPI): ...@@ -374,8 +451,12 @@ class DistForwardAPI(ForwardAPI):
) )
output_args_code = output_args_code[:-2] output_args_code = output_args_code[:-2]
return output_decl_code + INFER_SPMD_TEMPLATE.format( return (
infer_meta_func_code, input_args_code, output_args_code output_decl_code
+ input_meta_code
+ INFER_SPMD_TEMPLATE.format(
infer_meta_func_code, input_args_code, output_args_code
)
) )
def generate_kernel_selection_code(self) -> str: def generate_kernel_selection_code(self) -> str:
...@@ -386,8 +467,7 @@ class DistForwardAPI(ForwardAPI): ...@@ -386,8 +467,7 @@ class DistForwardAPI(ForwardAPI):
def generate_reshard_input_code(self) -> str: def generate_reshard_input_code(self) -> str:
return INPUT_RESHARD_TEMPLATE.format() return INPUT_RESHARD_TEMPLATE.format()
# override BaseAPI's method def generate_single_dense_input(
def generate_dense_input(
self, self,
input_name, input_name,
): ):
...@@ -410,6 +490,26 @@ class DistForwardAPI(ForwardAPI): ...@@ -410,6 +490,26 @@ class DistForwardAPI(ForwardAPI):
return input_tensor_code return input_tensor_code
def generate_vector_dense_input(
self,
input_name,
):
input_tensor_code = ""
trans_flag = self.gene_trans_flag(input_name)
input_names = self.inputs['names']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_tensor_code += VECTOR_PREPARE_DATA_TEMPLATE.format(
name=input_name,
index=kernel_param.index(input_name),
trans_flag=trans_flag,
)
return input_tensor_code
def generate_prepare_data_code(self) -> str: def generate_prepare_data_code(self) -> str:
input_names = self.inputs['names'] input_names = self.inputs['names']
attr_names = self.attrs['names'] attr_names = self.attrs['names']
...@@ -420,7 +520,7 @@ class DistForwardAPI(ForwardAPI): ...@@ -420,7 +520,7 @@ class DistForwardAPI(ForwardAPI):
for i, input_name in enumerate(input_names): for i, input_name in enumerate(input_names):
# set input code # set input code
if input_name in kernel_param: if input_name in kernel_param:
# onlu support dense tensor # only support dense tensor
api_tensor_type = self.inputs['input_info'][input_name] api_tensor_type = self.inputs['input_info'][input_name]
phi_tensor_type = 'dense' phi_tensor_type = 'dense'
if api_tensor_type in self.gene_dist_input_func.keys(): if api_tensor_type in self.gene_dist_input_func.keys():
...@@ -480,6 +580,11 @@ class DistForwardAPI(ForwardAPI): ...@@ -480,6 +580,11 @@ class DistForwardAPI(ForwardAPI):
if param in input_names: if param in input_names:
if self.inputs['input_info'][param] == "const Tensor&": if self.inputs['input_info'][param] == "const Tensor&":
input_args_code += SINGLE_META_IN_TEMPLATE.format(param) input_args_code += SINGLE_META_IN_TEMPLATE.format(param)
elif (
self.inputs['input_info'][param]
== "const std::vector<Tensor>&"
):
input_args_code += VECTOR_META_IN_TEMPLATE.format(param)
else: else:
raise ValueError( raise ValueError(
f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported." f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
...@@ -498,8 +603,15 @@ class DistForwardAPI(ForwardAPI): ...@@ -498,8 +603,15 @@ class DistForwardAPI(ForwardAPI):
output_args_code = "" output_args_code = ""
for i, out_name in enumerate(self.dense_output_args): for i, out_name in enumerate(self.dense_output_args):
if self.outputs['types'][i] == 'std::vector<Tensor>': if self.outputs['types'][i] == 'std::vector<Tensor>':
# TODO(chenweihang): support vector output later output_decl_code += VECTOR_META_OUT_DECL_TEMPLATE.format(
pass name=out_name
)
if len(self.dense_output_args) == 1:
output_args_code += f"{out_name}_meta_ptr_vec, "
else:
output_args_code += (
f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, "
)
else: else:
output_decl_code += SINGLE_META_OUT_DECL_TEMPLATE.format( output_decl_code += SINGLE_META_OUT_DECL_TEMPLATE.format(
out_name, out_name out_name, out_name
...@@ -548,7 +660,11 @@ class DistForwardAPI(ForwardAPI): ...@@ -548,7 +660,11 @@ class DistForwardAPI(ForwardAPI):
if input_infos[arg] == "const Tensor&": if input_infos[arg] == "const Tensor&":
input_args.append("*" + PREFIX_TENSOR_NAME + arg) input_args.append("*" + PREFIX_TENSOR_NAME + arg)
elif input_infos[arg] == "const std::vector<Tensor>&": elif input_infos[arg] == "const std::vector<Tensor>&":
input_args.append(PREFIX_TENSOR_NAME + arg) input_args.append(
PREFIX_VECTOR_TENSOR_NAME
+ arg
+ SUFFIX_VECTOR_TENSOR_NAME
)
else: else:
# do nothing # do nothing
pass pass
...@@ -614,12 +730,19 @@ class DistForwardAPI(ForwardAPI): ...@@ -614,12 +730,19 @@ class DistForwardAPI(ForwardAPI):
) )
def check_argument_whether_support_auto_parallel(self): def check_argument_whether_support_auto_parallel(self):
global skip_op_lists
for name in self.inputs['names']: for name in self.inputs['names']:
if self.inputs['input_info'][name] != "const Tensor&": if self.inputs['input_info'][name] not in [
"const Tensor&",
"const std::vector<Tensor>&",
]:
return False return False
for out_type in self.outputs['types']: for out_type in self.outputs['types']:
if out_type != "Tensor": if out_type not in ["Tensor", "std::vector<Tensor>"]:
return False return False
if self.kernel['func'][0] in skip_op_lists:
return False
return True return True
# override BaseAPI's method # override BaseAPI's method
...@@ -661,7 +784,6 @@ class DistForwardAPI(ForwardAPI): ...@@ -661,7 +784,6 @@ class DistForwardAPI(ForwardAPI):
and self.check_argument_whether_support_auto_parallel() and self.check_argument_whether_support_auto_parallel()
): ):
dist_branch_code = self.generate_auto_paralel_branch() dist_branch_code = self.generate_auto_paralel_branch()
return API_IMPL_TEMPLATE.format( return API_IMPL_TEMPLATE.format(
self.get_return_type(inplace_flag), self.get_return_type(inplace_flag),
api_func_name, api_func_name,
......
...@@ -27,6 +27,13 @@ SINGLE_OUT_CREATION_TEMPLATE = """ ...@@ -27,6 +27,13 @@ SINGLE_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({}); auto dist_out = SetKernelDistOutput({});
auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value()); auto dense_out = const_cast<phi::DenseTensor*>(&dist_out->value());
""" """
VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out = SetKernelDistOutput({name});
std::vector<phi::DenseTensor*> dense_out(dist_out.size());
for (size_t i=0; i<dist_out.size(); i++) {{
dense_out[i] = const_cast<phi::DenseTensor*>(&dist_out[i]->value());
}}
"""
INPLACE_OUT_CREATION_TEMPLATE = """ INPLACE_OUT_CREATION_TEMPLATE = """
*{} = {}; *{} = {};
""" """
...@@ -53,6 +60,10 @@ class DistBackwardAPI(DistForwardAPI, BackwardAPI): ...@@ -53,6 +60,10 @@ class DistBackwardAPI(DistForwardAPI, BackwardAPI):
output_creation_code += SINGLE_OUT_CREATION_TEMPLATE.format( output_creation_code += SINGLE_OUT_CREATION_TEMPLATE.format(
self.outputs['names'][0] self.outputs['names'][0]
) )
elif self.outputs['types'][0] == 'std::vector<Tensor>':
output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format(
name=self.outputs['names'][0]
)
else: else:
self.vector_output_size_assertion_check() self.vector_output_size_assertion_check()
elif output_num > 1: elif output_num > 1:
......
...@@ -83,6 +83,61 @@ class TestDistTensorForDygraphAPI(unittest.TestCase): ...@@ -83,6 +83,61 @@ class TestDistTensorForDygraphAPI(unittest.TestCase):
dist_out.backward() dist_out.backward()
self.check_tensor_eq(local_in.grad, dist_in.grad) self.check_tensor_eq(local_in.grad, dist_in.grad)
# input: std::vector<phi::Tensor>, output: phi::Tensor
def test_concat_for_dist_tensor(self):
x1 = np.random.random(size=[4, 4]).astype("float32")
x2 = np.random.random(size=[4, 4]).astype("float32")
x3 = np.random.random(size=[4, 4]).astype("float32")
local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1)
local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2)
local_in3, dist_in3 = self.create_local_and_dist_tensor_pair(x3)
local_out = paddle.concat([local_in1, local_in2, local_in3])
dist_out = paddle.concat([dist_in1, dist_in2, dist_in3])
self.check_tensor_eq(local_out, dist_out)
local_out.backward()
dist_out.backward()
self.check_tensor_eq(local_in1.grad, dist_in1.grad)
self.check_tensor_eq(local_in2.grad, dist_in2.grad)
self.check_tensor_eq(local_in3.grad, dist_in3.grad)
# input: std::vector<phi::Tensor>, output: std::vector<phi::Tensor>
def test_broadcast_tensors_for_dist_tensor(self):
x1 = np.random.random(size=[4, 4]).astype("float32")
x2 = np.random.random(size=[4, 4]).astype("float32")
local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1)
local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2)
local_out1, local_out2 = paddle.broadcast_tensors(
[local_in1, local_in2]
)
dist_out1, dist_out2 = paddle.broadcast_tensors([dist_in1, dist_in2])
self.check_tensor_eq(local_out1, dist_out1)
self.check_tensor_eq(local_out2, dist_out2)
local_out = local_out1 + local_out2
dist_out = dist_out1 + dist_out2
local_out.backward()
dist_out.backward()
self.check_tensor_eq(local_in1.grad, dist_in1.grad)
self.check_tensor_eq(local_in2.grad, dist_in2.grad)
# input: phi::Tensor, output: std::vector<phi::Tensor>
def test_unbind_api_for_dist_tensor(self):
x = np.random.random(size=[2, 8]).astype("float32")
local_in, dist_in = self.create_local_and_dist_tensor_pair(x)
local_out1, local_out2 = paddle.unbind(local_in, axis=0)
dist_out1, dist_out2 = paddle.unbind(dist_in, axis=0)
self.check_tensor_eq(local_out1, dist_out1)
self.check_tensor_eq(local_out2, dist_out2)
local_out = local_out1 + local_out2
dist_out = dist_out1 + dist_out2
local_out.backward()
dist_out.backward()
self.check_tensor_eq(local_in.grad, dist_in.grad)
def test_matmul_api_for_dist_tensor(self): def test_matmul_api_for_dist_tensor(self):
x = np.random.random(size=[4, 4]).astype("float32") x = np.random.random(size=[4, 4]).astype("float32")
y = np.random.random(size=[4, 4]).astype("float32") y = np.random.random(size=[4, 4]).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册