diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index 47b205e6fc76eceb9251541c92bf392be5f88399..f94ec416d4f3c5324227209238dec2eba747479b 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -32,6 +32,8 @@ #ifndef PADDLE_NO_PYTHON #include "paddle/fluid/eager/hooks.h" #endif +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" namespace egr { class TensorWrapper { @@ -66,6 +68,16 @@ class TensorWrapper { intermidiate_tensor_.set_impl(std::make_shared( std::make_shared(nullptr, 0, tensor.place()), dense_tensor->meta())); + } else if (phi::distributed::DistTensor::classof(tensor.impl().get())) { + // Only Copy Meta + phi::distributed::DistTensor* dist_tensor = + static_cast(tensor.impl().get()); + intermidiate_tensor_.set_impl( + std::make_shared( + phi::DenseTensor(std::make_shared( + nullptr, 0, tensor.place()), + dist_tensor->value().meta()), + dist_tensor->dist_attr())); } else { PADDLE_THROW(paddle::platform::errors::Fatal( "Unrecognized tensor type for no_need_buffer feature")); diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 51875ed9175ec08656732917dc8186b017261787..c6da10d12dea397f584d8b0dbb47b9025a53f39d 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -192,8 +192,6 @@ std::vector MakeMetaTensor( return meta_tensors; } -/* ------------------ for output ----------------------- */ - phi::DenseTensor* SetKernelOutput(Tensor* out) { if (out) { if (out->impl() == nullptr) { @@ -546,5 +544,45 @@ phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) { return nullptr; } +std::vector SetKernelDistOutput( + std::vector out) { + std::vector result; + for (auto tmp : out) { + if (tmp) { + // TODO(GhostScreaming): now all dist case are nullptr + if (tmp->impl() == nullptr) { + phi::DenseTensor dense_t; + // TODO(GhostScreaming): polish code, dist_attr is null now + phi::distributed::TensorDistAttr dist_attr; + auto dist_t = + std::make_shared(dense_t, dist_attr); + tmp->set_impl(dist_t); + } + result.emplace_back( + static_cast(tmp->impl().get())); + } else { + result.emplace_back(nullptr); + } + } + return result; +} + +std::vector SetKernelDistOutput( + size_t out_size, std::vector* out) { + out->reserve(out_size); + std::vector results(out_size); + for (size_t i = 0; i < out_size; ++i) { + phi::DenseTensor dense_t; + // TODO(GhostScreaming): polish code, dist_attr is null now + phi::distributed::TensorDistAttr dist_attr; + auto dist_t = + std::make_shared(dense_t, dist_attr); + results[i] = dist_t.get(); + out->emplace_back(); + out->back().set_impl(dist_t); + } + return results; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 605423b431a7c09ee572ab0d8382b2566fb80413..997bb6f8dc805e0912a8287d31c7f16f9c65cb50 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -140,6 +140,10 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, /* ------------------ for auto parallel ----------------------- */ phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out); +std::vector SetKernelDistOutput( + std::vector out); +std::vector SetKernelDistOutput( + size_t out_size, std::vector* out); } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index b1559908ebc3ff593334dc9476fa47450461a14e..0e86b84e074fe9b1087188656b593f3b96aa843f 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -632,5 +632,47 @@ std::shared_ptr PrepareDataForDistTensor( return nullptr; } +std::vector> +PrepareDataForDistTensor(const std::vector& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { + std::vector> out; + for (auto x : input) { + const auto& tensor_in = x.impl(); + if (tensor_in) { + phi::distributed::DistTensor* dist_tensor = + static_cast(tensor_in.get()); + const phi::DenseTensor& dense_tensor = dist_tensor->value(); + if (!transform_flag.NeedTransform() || !dense_tensor.initialized() || + (!NeedTransformPlace( + dense_tensor.place(), target_args_def.backend, transform_flag) && + !NeedTransformDataType( + dense_tensor.dtype(), target_args_def.dtype, transform_flag) && + !NeedTransformLayout(dense_tensor.layout(), + target_args_def.layout, + dense_tensor.place(), + transform_flag) && + !NeedTransform2Contiguous(is_stride_kernel, + dense_tensor.meta().is_contiguous()))) { + out.push_back( + std::static_pointer_cast(tensor_in)); + continue; + } + phi::DenseTensor trans_in_tensor = TransformData( + dense_tensor, target_args_def, transform_flag, is_stride_kernel); + // TODO(GhostScreaming): The global meta in DistTensor is not changed, + // but the local meta in DenseTensor maybe changed, such as layout + // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. + VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; + out.push_back(std::make_shared( + trans_in_tensor, dist_tensor->dist_attr())); + } else { + out.push_back(nullptr); + } + } + return out; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index bc59ac8cfa73e554dfc3aa697e9d99ee6c15b5b9..4247317857c23450387cf4895890b47ee7b0e92b 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -180,5 +180,11 @@ std::shared_ptr PrepareDataForDistTensor( const TransformFlag& transform_flag, bool is_stride_kernel); +std::vector> +PrepareDataForDistTensor(const std::vector& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/kernel_dispatch.h b/paddle/phi/api/lib/kernel_dispatch.h index 4fd684b0bd6e7b60dfaede6d4e34c8c44f14d22a..7ff9ab3b33f538c4c19fd04108a04cd49146f4f5 100644 --- a/paddle/phi/api/lib/kernel_dispatch.h +++ b/paddle/phi/api/lib/kernel_dispatch.h @@ -191,6 +191,16 @@ struct DistTensorTypeParser : ArgsIterator { } } + void operator()(const paddle::optional>& x) { + if (x) { + if (!(x.get_ptr()->empty())) { + for (auto& t : *(x.get_ptr())) { + result &= t.is_dist_tensor(); + } + } + } + } + // skip other type args, these args don't used in kernel selection template void operator()(const T& x) { diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index a1d07b55772d3d1d72358173f4d804b10676e4be..5c0b642228ba7079b5a5967a5ba58d8c56264321 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -75,14 +75,21 @@ MULTI_SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out_{} = SetKernelDistOutput({}); auto dense_out_{} = const_cast(&dist_out_{}->value()); """ - -# TODO(chenweihang): support vector and tuple output later VECTOR_OUT_CREATION_TEMPLATE = """ + auto dist_out = SetKernelDistOutput({}, &api_output); + std::vector dense_out(dist_out.size()); + for (size_t i = 0; i < dist_out.size(); i++) {{ + dense_out[i] = const_cast(&dist_out[i]->value()); + }} """ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ - auto dist_out_{} = {}({}, {}); - auto dense_out_{} = const_cast(&dist_out_{}->value()); + auto dist_out_{out_name} = SetKernelDistOutput({size}, {in_name}); + std::vector dense_out_{out_name}(dist_out_{out_name}.size()); + for (size_t i = 0; i < dist_out_{out_name}.size(); i++) {{ + dense_out_{out_name}[i] = const_cast(&dist_out_{out_name}[i]->value()); + }} """ +# TODO(GhostScreaming): support tuple output later TUPLE_OUT_CREATION_TEMPLATE = """ """ @@ -90,13 +97,32 @@ TUPLE_OUT_CREATION_TEMPLATE = """ # Call InferMeta now, replace by InferSPMD function later # TODO(chenweihang): InferSPMD function design SINGLE_DIST_META_IN_TEMPLATE = """MakeMetaTensor(*{}.impl()), """ -# TODO(chenweihang): support vector and optional args later -VECTOR_DIST_META_IN_TEMPLATE = """ +VECTOR_DIST_META_IN_TEMPLATE = """{}_meta_ptr_vec, """ +VECTOR_DIST_META_IN_DECL_TEMPLATE = """ + std::vector {name}_meta_vec; + for (auto tmp : {name}) {{ + {name}_meta_vec.emplace_back(MakeMetaTensor(*tmp.impl())); + }} + std::vector {name}_meta_ptr_vec({name}_meta_vec.size()); + for (size_t i=0; i<{name}_meta_ptr_vec.size(); i++) {{ + {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; + }} """ +# TODO(GhostScreaming): support optional args later OPTIONAL_DIST_VECTOR_META_IN_TEMPLATE = """ """ SINGLE_DIST_META_OUT_DECL_TEMPLATE = """ phi::MetaTensor meta_{}({});""" +VECTOR_DIST_META_OUT_DECL_TEMPLATE = """ + std::vector {name}_meta_vec; + for (auto tmp : {name}) {{ + {name}_meta_vec.emplace_back(phi::MetaTensor(tmp)); + }} + std::vector {name}_meta_ptr_vec({name}.size()); + for (size_t i=0; i<{name}_meta_vec.size(); i++) {{ + {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; + }} +""" INFER_SPMD_TEMPLATE = """ phi::{}({}{}); """ @@ -120,6 +146,18 @@ SINGLE_PREPARE_DATA_TEMPLATE = """ auto dist_input_{} = PrepareDataForDistTensor({}, GetKernelInputArgDef(kernel.InputAt({}), kernel_backend), {}, kernel_result.is_stride_kernel); auto input_{} = &dist_input_{}->value(); """ +VECTOR_PREPARE_DATA_TEMPLATE = """ + auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + std::vector dense_input_{name}_vec; + for (auto tmp : dist_input_{name}_vec) {{ + dense_input_{name}_vec.emplace_back(&tmp->value()); + }} + std::vector dense_input_{name}_meta_vec = MakeMetaTensor(dense_input_{name}_vec); + std::vector dense_input_{name}_meta_ptr_vec(dense_input_{name}_meta_vec.size()); + for (size_t i=0; i(dist_input_{}.get())->value()); @@ -134,13 +172,19 @@ INFER_META_VECTOR_INPUT_TEMPLATE = """ # 6. Infer Local DenseTensor Meta SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """ -# TODO(chenweihang): support vector and optional args later -VECTOR_META_IN_TEMPLATE = """ -""" +# TODO(GhostScreaming): support optional args later +VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """ OPTIONAL_VECTOR_META_IN_TEMPLATE = """ """ SINGLE_META_OUT_DECL_TEMPLATE = """ phi::MetaTensor meta_{}({});""" +VECTOR_META_OUT_DECL_TEMPLATE = """ + std::vector {name}_meta_vec = MakeMetaTensor({name}); + std::vector {name}_meta_ptr_vec({name}_meta_vec.size()); + for (size_t i=0; i<{name}_meta_vec.size(); i++) {{ + {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; + }} +""" INFER_META_TEMPLATE = """ phi::{}({}{}); """ @@ -158,6 +202,8 @@ KERNEL_CALL_TEMPLATE = """ auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)({}, {}); """ +PREFIX_VECTOR_TENSOR_NAME = "dense_input_" +SUFFIX_VECTOR_TENSOR_NAME = "_vec" # 8. Reshard Output OUTPUT_RESHARD_TEMPLATE = """ @@ -175,6 +221,15 @@ OUTPUT_RESHARD_TEMPLATE = """ # types : [], list of output types # out_size_expr : [], expression for getting size of vector +# TODO(GhostScreaming): Support std::tuple<...> type of input and output later. +skip_op_lists = [ + "check_finite_and_unscale", # std::vector&, const Tensor& -> std::tuple&, Tensor> + "coalesce_tensor", # const std::vector&, DataType, bool, bool, bool, float, bool, int, int, const std::vector&, const std::vector& -> std::tuple, Tensor> + "update_loss_scaling", # std::vector, const Tensor, ... -> std::tuple, Tensor, Tensor, Tensor> + "einsum", + "einsum_grad", # const std::vector&, const std::string& -> std::tuple, std::vector> +] + class DistForwardAPI(ForwardAPI): def __init__(self, api_item_yaml): @@ -184,10 +239,13 @@ class DistForwardAPI(ForwardAPI): def init_dist_api_members(self): self.gene_dist_input_func = { "const Tensor&": { - "dense": self.generate_dense_input, + "dense": self.generate_single_dense_input, }, "const paddle::optional&": { - "dense": self.generate_dense_input, + "dense": self.generate_single_dense_input, + }, + "const std::vector&": { + "dense": self.generate_vector_dense_input, }, } @@ -254,6 +312,10 @@ class DistForwardAPI(ForwardAPI): self.dense_output_args.append('dense_out') if self.outputs['types'][0] == 'Tensor': output_creation_code += SINGLE_OUT_CREATION_TEMPLATE + elif self.outputs['types'][0] == 'std::vector': + output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format( + self.outputs['out_size_expr'][0] + ) else: self.vector_output_size_assertion_check() elif output_num > 1: @@ -285,10 +347,10 @@ class DistForwardAPI(ForwardAPI): get_out_code = f"std::get<{i}>(api_output).get_ptr()" if out_type == 'std::vector': - self.vector_output_size_assertion_check(i) + self.vector_output_size_assertion_check() # Special case for inplace vector and inplace optional # TODO(chenweihang): support this branch later - if self.is_inplace_output(): + if self.is_inplace_output(i): set_out_func = "SetInplaceVectorKernelOutput" if self.is_inplace_and_optional_output(i): set_out_func = ( @@ -297,12 +359,9 @@ class DistForwardAPI(ForwardAPI): get_out_code = f"std::get<{i}>(api_output)" output_creation_code += ( MULTI_VECTOR_OUT_CREATION_TEMPLATE.format( - i, - set_out_func, - self.outputs['out_size_expr'][i], - get_out_code, - i, - i, + out_name=i, + size=self.outputs['out_size_expr'][i], + in_name=get_out_code, ) ) else: @@ -335,6 +394,7 @@ class DistForwardAPI(ForwardAPI): if infer_meta['param'] is not None else input_names + attr_names ) + input_meta_code = "" input_args_code = "" for param in infer_meta_params: if param in input_names: @@ -342,6 +402,16 @@ class DistForwardAPI(ForwardAPI): input_args_code += SINGLE_DIST_META_IN_TEMPLATE.format( param ) + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + input_args_code += VECTOR_DIST_META_IN_TEMPLATE.format( + param + ) + input_meta_code += VECTOR_DIST_META_IN_DECL_TEMPLATE.format( + name=param + ) else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -360,8 +430,15 @@ class DistForwardAPI(ForwardAPI): output_args_code = "" for i, out_name in enumerate(self.dist_output_args): if self.outputs['types'][i] == 'std::vector': - # TODO(chenweihang): support vector output later - pass + output_decl_code += VECTOR_DIST_META_OUT_DECL_TEMPLATE.format( + name=out_name + ) + if len(self.dense_output_args) == 1: + output_args_code += f"{out_name}_meta_ptr_vec, " + else: + output_args_code += ( + f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, " + ) else: output_decl_code += SINGLE_DIST_META_OUT_DECL_TEMPLATE.format( out_name, out_name @@ -374,8 +451,12 @@ class DistForwardAPI(ForwardAPI): ) output_args_code = output_args_code[:-2] - return output_decl_code + INFER_SPMD_TEMPLATE.format( - infer_meta_func_code, input_args_code, output_args_code + return ( + output_decl_code + + input_meta_code + + INFER_SPMD_TEMPLATE.format( + infer_meta_func_code, input_args_code, output_args_code + ) ) def generate_kernel_selection_code(self) -> str: @@ -386,8 +467,7 @@ class DistForwardAPI(ForwardAPI): def generate_reshard_input_code(self) -> str: return INPUT_RESHARD_TEMPLATE.format() - # override BaseAPI's method - def generate_dense_input( + def generate_single_dense_input( self, input_name, ): @@ -410,6 +490,26 @@ class DistForwardAPI(ForwardAPI): return input_tensor_code + def generate_vector_dense_input( + self, + input_name, + ): + input_tensor_code = "" + trans_flag = self.gene_trans_flag(input_name) + input_names = self.inputs['names'] + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + + input_tensor_code += VECTOR_PREPARE_DATA_TEMPLATE.format( + name=input_name, + index=kernel_param.index(input_name), + trans_flag=trans_flag, + ) + + return input_tensor_code + def generate_prepare_data_code(self) -> str: input_names = self.inputs['names'] attr_names = self.attrs['names'] @@ -420,7 +520,7 @@ class DistForwardAPI(ForwardAPI): for i, input_name in enumerate(input_names): # set input code if input_name in kernel_param: - # onlu support dense tensor + # only support dense tensor api_tensor_type = self.inputs['input_info'][input_name] phi_tensor_type = 'dense' if api_tensor_type in self.gene_dist_input_func.keys(): @@ -480,6 +580,11 @@ class DistForwardAPI(ForwardAPI): if param in input_names: if self.inputs['input_info'][param] == "const Tensor&": input_args_code += SINGLE_META_IN_TEMPLATE.format(param) + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + input_args_code += VECTOR_META_IN_TEMPLATE.format(param) else: raise ValueError( f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported." @@ -498,8 +603,15 @@ class DistForwardAPI(ForwardAPI): output_args_code = "" for i, out_name in enumerate(self.dense_output_args): if self.outputs['types'][i] == 'std::vector': - # TODO(chenweihang): support vector output later - pass + output_decl_code += VECTOR_META_OUT_DECL_TEMPLATE.format( + name=out_name + ) + if len(self.dense_output_args) == 1: + output_args_code += f"{out_name}_meta_ptr_vec, " + else: + output_args_code += ( + f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, " + ) else: output_decl_code += SINGLE_META_OUT_DECL_TEMPLATE.format( out_name, out_name @@ -548,7 +660,11 @@ class DistForwardAPI(ForwardAPI): if input_infos[arg] == "const Tensor&": input_args.append("*" + PREFIX_TENSOR_NAME + arg) elif input_infos[arg] == "const std::vector&": - input_args.append(PREFIX_TENSOR_NAME + arg) + input_args.append( + PREFIX_VECTOR_TENSOR_NAME + + arg + + SUFFIX_VECTOR_TENSOR_NAME + ) else: # do nothing pass @@ -614,12 +730,19 @@ class DistForwardAPI(ForwardAPI): ) def check_argument_whether_support_auto_parallel(self): + global skip_op_lists for name in self.inputs['names']: - if self.inputs['input_info'][name] != "const Tensor&": + if self.inputs['input_info'][name] not in [ + "const Tensor&", + "const std::vector&", + ]: return False for out_type in self.outputs['types']: - if out_type != "Tensor": + if out_type not in ["Tensor", "std::vector"]: return False + + if self.kernel['func'][0] in skip_op_lists: + return False return True # override BaseAPI's method @@ -661,7 +784,6 @@ class DistForwardAPI(ForwardAPI): and self.check_argument_whether_support_auto_parallel() ): dist_branch_code = self.generate_auto_paralel_branch() - return API_IMPL_TEMPLATE.format( self.get_return_type(inplace_flag), api_func_name, diff --git a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py index 95a6f94706eeeec3d63dc5f7ad472080d4b9fc83..487d6e3a257200f0c653b476b99b8dc2653aee87 100644 --- a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py @@ -27,6 +27,13 @@ SINGLE_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput({}); auto dense_out = const_cast(&dist_out->value()); """ +VECTOR_OUT_CREATION_TEMPLATE = """ + auto dist_out = SetKernelDistOutput({name}); + std::vector dense_out(dist_out.size()); + for (size_t i=0; i(&dist_out[i]->value()); + }} +""" INPLACE_OUT_CREATION_TEMPLATE = """ *{} = {}; """ @@ -53,6 +60,10 @@ class DistBackwardAPI(DistForwardAPI, BackwardAPI): output_creation_code += SINGLE_OUT_CREATION_TEMPLATE.format( self.outputs['names'][0] ) + elif self.outputs['types'][0] == 'std::vector': + output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format( + name=self.outputs['names'][0] + ) else: self.vector_output_size_assertion_check() elif output_num > 1: diff --git a/test/auto_parallel/test_dist_tensor.py b/test/auto_parallel/test_dist_tensor.py index 0bf2d88db42370c0bcc2befd815778e2ad1826ed..e82a3cc2fe6b2b57c0a8e852f9056d1b30ebd360 100644 --- a/test/auto_parallel/test_dist_tensor.py +++ b/test/auto_parallel/test_dist_tensor.py @@ -83,6 +83,61 @@ class TestDistTensorForDygraphAPI(unittest.TestCase): dist_out.backward() self.check_tensor_eq(local_in.grad, dist_in.grad) + # input: std::vector, output: phi::Tensor + def test_concat_for_dist_tensor(self): + x1 = np.random.random(size=[4, 4]).astype("float32") + x2 = np.random.random(size=[4, 4]).astype("float32") + x3 = np.random.random(size=[4, 4]).astype("float32") + local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) + local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) + local_in3, dist_in3 = self.create_local_and_dist_tensor_pair(x3) + local_out = paddle.concat([local_in1, local_in2, local_in3]) + dist_out = paddle.concat([dist_in1, dist_in2, dist_in3]) + self.check_tensor_eq(local_out, dist_out) + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_in1.grad, dist_in1.grad) + self.check_tensor_eq(local_in2.grad, dist_in2.grad) + self.check_tensor_eq(local_in3.grad, dist_in3.grad) + + # input: std::vector, output: std::vector + def test_broadcast_tensors_for_dist_tensor(self): + x1 = np.random.random(size=[4, 4]).astype("float32") + x2 = np.random.random(size=[4, 4]).astype("float32") + local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) + local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) + + local_out1, local_out2 = paddle.broadcast_tensors( + [local_in1, local_in2] + ) + dist_out1, dist_out2 = paddle.broadcast_tensors([dist_in1, dist_in2]) + self.check_tensor_eq(local_out1, dist_out1) + self.check_tensor_eq(local_out2, dist_out2) + + local_out = local_out1 + local_out2 + dist_out = dist_out1 + dist_out2 + + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_in1.grad, dist_in1.grad) + self.check_tensor_eq(local_in2.grad, dist_in2.grad) + + # input: phi::Tensor, output: std::vector + def test_unbind_api_for_dist_tensor(self): + x = np.random.random(size=[2, 8]).astype("float32") + local_in, dist_in = self.create_local_and_dist_tensor_pair(x) + local_out1, local_out2 = paddle.unbind(local_in, axis=0) + dist_out1, dist_out2 = paddle.unbind(dist_in, axis=0) + self.check_tensor_eq(local_out1, dist_out1) + self.check_tensor_eq(local_out2, dist_out2) + + local_out = local_out1 + local_out2 + dist_out = dist_out1 + dist_out2 + + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_in.grad, dist_in.grad) + def test_matmul_api_for_dist_tensor(self): x = np.random.random(size=[4, 4]).astype("float32") y = np.random.random(size=[4, 4]).astype("float32")