From 50d5bf7959e660fff3d49d70fa73e1f3b132c0c2 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 4 Mar 2022 10:35:30 +0800 Subject: [PATCH] [Phi] Change input vec tensor to pointer type (#40078) * change input vec tensor to pointer * update input between * fix format error * resolve conflict * resolve conflict --- paddle/infrt/host_context/value.h | 2 +- paddle/phi/api/lib/api_gen_utils.cc | 6 +-- paddle/phi/api/lib/api_gen_utils.h | 2 +- paddle/phi/core/kernel_context.h | 9 ++--- paddle/phi/core/kernel_registry.h | 4 +- paddle/phi/core/kernel_utils.h | 40 +++++++++---------- .../kernels/broadcast_tensors_grad_kernel.h | 2 +- paddle/phi/kernels/broadcast_tensors_kernel.h | 2 +- paddle/phi/kernels/concat_kernel.h | 8 ++-- .../cpu/broadcast_tensors_grad_kernel.cc | 4 +- paddle/phi/kernels/cpu/concat_kernel.cc | 31 +++++++------- .../gpu/broadcast_tensors_grad_kernel.cu | 4 +- paddle/phi/kernels/gpu/concat_kernel.cu | 30 +++++++------- .../impl/broadcast_tensors_kernel_impl.h | 10 ++--- paddle/phi/tests/core/test_custom_kernel.cc | 2 +- .../phi/tests/kernels/test_concat_dev_api.cc | 2 +- python/paddle/utils/code_gen/api_base.py | 28 +++++++++++-- 17 files changed, 103 insertions(+), 83 deletions(-) diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index 7e7d77d3af..0ae482349c 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -70,7 +70,7 @@ using ValueVariantType = backends::CpuPhiAllocator, backends::CpuPhiContext, ::phi::CPUContext, - std::vector, + std::vector, paddle::experimental::ScalarBase, paddle::experimental::ScalarArrayBase, std::vector, diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index f04e74b45f..e1ebe8c646 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -71,11 +71,11 @@ paddle::optional MakeMetaTensor( } std::vector MakeMetaTensor( - const std::vector& tensors) { + const std::vector& tensors) { std::vector meta_tensors; meta_tensors.reserve(tensors.size()); - for (const auto& t : tensors) { - meta_tensors.emplace_back(t); + for (const auto* t : tensors) { + meta_tensors.emplace_back(*t); } return meta_tensors; } diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 109c6e7ab7..01625f651c 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -51,7 +51,7 @@ paddle::optional MakeMetaTensor( const paddle::optional& tensor); std::vector MakeMetaTensor( - const std::vector& tensors); + const std::vector& tensors); phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor); diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index 57e2db60c2..213ac47d30 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -82,12 +82,11 @@ class KernelContext { } template - std::vector MoveInputsBetween(size_t start, size_t end) { - std::vector v; + std::vector InputsBetween(size_t start, size_t end) { + std::vector v; for (size_t i = start; i < end; ++i) { - auto t = static_cast(inputs_.at(i)); - v.emplace_back(*t); - inputs_[i] = nullptr; + auto* t = static_cast(inputs_.at(i)); + v.emplace_back(t); } return v; } diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 2b04d173af..35e170a3fc 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -87,8 +87,8 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == - std::type_index(typeid(const std::vector&))) { + } else if (arg_type == std::type_index(typeid( + const std::vector&))) { args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index b582375155..f7fa27b074 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -102,26 +102,26 @@ namespace phi { } \ } -#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ - template \ - struct KernelCallHelper&, Tail...> { \ - template \ - static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ - static_assert(attr_idx == 0, \ - "Kernel's Input should appear before Attributes."); \ - static_assert(out_idx == 0, \ - "Kernel's Input should appear before Outputs."); \ - const std::pair range = ctx->InputRangeAt(in_idx); \ - std::vector arg = std::move( \ - ctx->MoveInputsBetween(range.first, range.second)); \ - KernelCallHelper:: \ - template Compute( \ - ctx, pargs..., arg); \ - } \ +#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ + template \ + struct KernelCallHelper&, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(attr_idx == 0, \ + "Kernel's Input should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's Input should appear before Outputs."); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + std::vector arg = std::move( \ + ctx->InputsBetween(range.first, range.second)); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ } #define PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \ diff --git a/paddle/phi/kernels/broadcast_tensors_grad_kernel.h b/paddle/phi/kernels/broadcast_tensors_grad_kernel.h index 5ec2e35cc9..5d24f6684a 100644 --- a/paddle/phi/kernels/broadcast_tensors_grad_kernel.h +++ b/paddle/phi/kernels/broadcast_tensors_grad_kernel.h @@ -21,7 +21,7 @@ namespace phi { template void BroadcastTensorsGradKernel(const Context& ctx, - const std::vector& dout, + const std::vector& dout, std::vector dx); } // namespace phi diff --git a/paddle/phi/kernels/broadcast_tensors_kernel.h b/paddle/phi/kernels/broadcast_tensors_kernel.h index fb2a6f1136..22b5201b69 100644 --- a/paddle/phi/kernels/broadcast_tensors_kernel.h +++ b/paddle/phi/kernels/broadcast_tensors_kernel.h @@ -21,7 +21,7 @@ namespace phi { template void BroadcastTensorsKernel(const Context& ctx, - const std::vector& x, + const std::vector& x, std::vector out); } // namespace phi diff --git a/paddle/phi/kernels/concat_kernel.h b/paddle/phi/kernels/concat_kernel.h index f136678814..ed969e963e 100644 --- a/paddle/phi/kernels/concat_kernel.h +++ b/paddle/phi/kernels/concat_kernel.h @@ -22,19 +22,19 @@ namespace phi { template void ConcatKernel(const Context& dev_ctx, - const std::vector& x, + const std::vector& x, const Scalar& axis, DenseTensor* out); template DenseTensor Concat(const Context& dev_ctx, - const std::vector& x, + const std::vector& x, const Scalar& axis) { std::vector meta_x; meta_x.reserve(x.size()); std::vector meta_x_ptr; - for (const auto& t : x) { - meta_x.emplace_back(t); + for (const auto* t : x) { + meta_x.emplace_back(*t); meta_x_ptr.push_back(&meta_x.back()); } diff --git a/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc b/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc index 7a97f8c218..0869cd6202 100644 --- a/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/broadcast_tensors_grad_kernel.cc @@ -59,7 +59,7 @@ namespace phi { template void BroadcastTensorsGradKernel(const Context& ctx, - const std::vector& dout, + const std::vector& dout, std::vector dx) { // Find reduce dimensions const auto& in_tensors = dout; @@ -85,7 +85,7 @@ void BroadcastTensorsGradKernel(const Context& ctx, // For each In-Out tensor pair, // Prepare and apply broadcast dims array for (size_t i = 0; i < num_ins; i++) { - const auto* input_tensor = &in_tensors[i]; + const auto* input_tensor = in_tensors[i]; auto* output_tensor = out_tensors[i]; const auto& input_dims = input_tensor->dims(); diff --git a/paddle/phi/kernels/cpu/concat_kernel.cc b/paddle/phi/kernels/cpu/concat_kernel.cc index 5c4202837c..6be825d4ef 100644 --- a/paddle/phi/kernels/cpu/concat_kernel.cc +++ b/paddle/phi/kernels/cpu/concat_kernel.cc @@ -29,17 +29,17 @@ namespace phi { template void ConcatKernel(const Context& dev_ctx, - const std::vector& x, + const std::vector& x, const Scalar& axis_scalar, DenseTensor* out) { int64_t axis = axis_scalar.to(); - axis = phi::funcs::ComputeAxis(axis, x[0].dims().size()); + axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size()); std::vector x_dims; x_dims.reserve(x.size()); for (size_t i = 0; i < x.size(); ++i) { - x_dims.push_back(x[i].dims()); + x_dims.push_back(x[i]->dims()); } phi::DDim out_dims = phi::funcs::ComputeAndCheckShape(true, x_dims, axis); @@ -47,13 +47,13 @@ void ConcatKernel(const Context& dev_ctx, out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. - if (axis == 0 && x[0].lod().size() > 0) { - size_t lod_size_0 = x[0].lod().size(); + if (axis == 0 && x[0]->lod().size() > 0) { + size_t lod_size_0 = x[0]->lod().size(); size_t lod_size = lod_size_0; for (size_t i = 1; i < x.size(); ++i) { - if (x[i].lod().size() > 0) { + if (x[i]->lod().size() > 0) { PADDLE_ENFORCE_EQ( - x[i].lod().size(), + x[i]->lod().size(), lod_size_0, phi::errors::Unimplemented( "The lod level of all input LoDTensors should be same. " @@ -61,7 +61,7 @@ void ConcatKernel(const Context& dev_ctx, "it is not supported currently. The lod level of %dth input " "is %d and first input is %d.", i, - x[i].lod().size(), + x[i]->lod().size(), lod_size_0)); } else { lod_size = 0; @@ -71,7 +71,7 @@ void ConcatKernel(const Context& dev_ctx, if (lod_size) { auto* out_lod = out->mutable_lod(); for (size_t i = 1; i < x.size(); ++i) { - auto in_lod = phi::ConvertToLengthBasedLoD(x[i].lod()); + auto in_lod = phi::ConvertToLengthBasedLoD(x[i]->lod()); phi::AppendLoD(out_lod, in_lod); } } @@ -80,28 +80,29 @@ void ConcatKernel(const Context& dev_ctx, // Sometimes direct copies will be faster, this maybe need deeply analysis. if (axis == 0 && x.size() < 10) { size_t output_offset = 0; - for (auto& in : x) { - if (in.numel() == 0UL) { + for (const auto* in : x) { + if (in->numel() == 0UL) { continue; } - auto in_stride = phi::stride_numel(in.dims()); + auto in_stride = phi::stride_numel(in->dims()); auto out_stride = phi::stride_numel(out->dims()); paddle::operators::StridedNumelCopyWithAxis( dev_ctx, axis, out->data() + output_offset, out_stride, - in.data(), + in->data(), in_stride, in_stride[axis]); output_offset += in_stride[axis]; } } else { + // TODO(chenweihang): concat functor support vector input std::vector inputs; inputs.reserve(x.size()); for (size_t j = 0; j < x.size(); ++j) { - if (x[j].numel() > 0) { - inputs.emplace_back(x[j]); + if (x[j]->numel() > 0) { + inputs.emplace_back(*x[j]); } else { continue; } diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu index 6fb24d7214..275b8411cc 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu @@ -27,7 +27,7 @@ namespace phi { template void BroadcastTensorsGradKernel(const Context& ctx, - const std::vector& dout, + const std::vector& dout, std::vector dx) { // Find reduce dimensions const auto& in_tensors = dout; @@ -54,7 +54,7 @@ void BroadcastTensorsGradKernel(const Context& ctx, // For each In-Out tensor pair, // Prepare and apply broadcast dims array for (size_t i = 0; i < num_ins; i++) { - auto* input_tensor = &in_tensors[i]; + auto* input_tensor = in_tensors[i]; auto* output_tensor = out_tensors[i]; const DDim& input_dims = input_tensor->dims(); diff --git a/paddle/phi/kernels/gpu/concat_kernel.cu b/paddle/phi/kernels/gpu/concat_kernel.cu index 2b04b979c2..accb1cc3d7 100644 --- a/paddle/phi/kernels/gpu/concat_kernel.cu +++ b/paddle/phi/kernels/gpu/concat_kernel.cu @@ -29,16 +29,16 @@ namespace phi { template void ConcatKernel(const Context& dev_ctx, - const std::vector& x, + const std::vector& x, const Scalar& axis_scalar, DenseTensor* out) { int64_t axis = axis_scalar.to(); - axis = phi::funcs::ComputeAxis(axis, x[0].dims().size()); + axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size()); std::vector x_dims; for (size_t i = 0; i < x.size(); ++i) { - x_dims.push_back(x[i].dims()); + x_dims.push_back(x[i]->dims()); } phi::DDim out_dims = phi::funcs::ComputeAndCheckShape(true, x_dims, axis); @@ -46,13 +46,13 @@ void ConcatKernel(const Context& dev_ctx, out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. - if (axis == 0 && x[0].lod().size() > 0) { - size_t lod_size_0 = x[0].lod().size(); + if (axis == 0 && x[0]->lod().size() > 0) { + size_t lod_size_0 = x[0]->lod().size(); size_t lod_size = lod_size_0; for (size_t i = 1; i < x.size(); ++i) { - if (x[i].lod().size() > 0) { + if (x[i]->lod().size() > 0) { PADDLE_ENFORCE_EQ( - x[i].lod().size(), + x[i]->lod().size(), lod_size_0, phi::errors::Unimplemented( "The lod level of all input LoDTensors should be same. " @@ -60,7 +60,7 @@ void ConcatKernel(const Context& dev_ctx, "it is not supported currently. The lod level of %dth input " "is %d and first input is %d.", i, - x[i].lod().size(), + x[i]->lod().size(), lod_size_0)); } else { lod_size = 0; @@ -70,7 +70,7 @@ void ConcatKernel(const Context& dev_ctx, if (lod_size) { auto* out_lod = out->mutable_lod(); for (size_t i = 1; i < x.size(); ++i) { - auto in_lod = phi::ConvertToLengthBasedLoD(x[i].lod()); + auto in_lod = phi::ConvertToLengthBasedLoD(x[i]->lod()); phi::AppendLoD(out_lod, in_lod); } } @@ -79,18 +79,18 @@ void ConcatKernel(const Context& dev_ctx, // Sometimes direct copies will be faster, this maybe need deeply analysis. if (axis == 0 && x.size() < 10) { size_t output_offset = 0; - for (auto& in : x) { - if (in.numel() == 0UL) { + for (auto* in : x) { + if (in->numel() == 0UL) { continue; } - auto in_stride = phi::stride_numel(in.dims()); + auto in_stride = phi::stride_numel(in->dims()); auto out_stride = phi::stride_numel(out->dims()); paddle::operators::StridedNumelCopyWithAxis( dev_ctx, axis, out->data() + output_offset, out_stride, - in.data(), + in->data(), in_stride, in_stride[axis]); output_offset += in_stride[axis]; @@ -98,8 +98,8 @@ void ConcatKernel(const Context& dev_ctx, } else { std::vector inputs; for (size_t j = 0; j < x.size(); ++j) { - if (x[j].numel() > 0) { - inputs.push_back(x[j]); + if (x[j]->numel() > 0) { + inputs.push_back(*x[j]); } else { continue; } diff --git a/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h b/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h index eb01b83377..d7167704a4 100644 --- a/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h +++ b/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h @@ -23,10 +23,10 @@ #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function.h" -#define SWITCH_OUT_RANK_CASE(n) \ - case n: { \ - ApplyBroadcast(ctx, &in_tensors[i], out_tensors[i]); \ - break; \ +#define SWITCH_OUT_RANK_CASE(n) \ + case n: { \ + ApplyBroadcast(ctx, in_tensors[i], out_tensors[i]); \ + break; \ } namespace phi { @@ -75,7 +75,7 @@ void ApplyBroadcast(const Context& ctx, template void BroadcastTensorsKernel(const Context& ctx, - const std::vector& x, + const std::vector& x, std::vector out) { const auto& in_tensors = x; auto out_tensors = out; diff --git a/paddle/phi/tests/core/test_custom_kernel.cc b/paddle/phi/tests/core/test_custom_kernel.cc index 69922c055c..a4e89231e1 100644 --- a/paddle/phi/tests/core/test_custom_kernel.cc +++ b/paddle/phi/tests/core/test_custom_kernel.cc @@ -43,7 +43,7 @@ template void FakeDot(const Context& dev_ctx, const phi::DenseTensor& x, const phi::DenseTensor& y, - const std::vector& fake_input_vec, + const std::vector& fake_input_vec, bool fake_attr_bool, int fake_attr_int, float fake_attr_float, diff --git a/paddle/phi/tests/kernels/test_concat_dev_api.cc b/paddle/phi/tests/kernels/test_concat_dev_api.cc index 55dd6dce1a..7f954085f6 100644 --- a/paddle/phi/tests/kernels/test_concat_dev_api.cc +++ b/paddle/phi/tests/kernels/test_concat_dev_api.cc @@ -53,7 +53,7 @@ TEST(DEV_API, concat) { } } - std::vector inputs = {dense_x, dense_y}; + std::vector inputs = {&dense_x, &dense_y}; // 2. test API phi::CPUContext dev_ctx; diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 6c07cdec2e..601248a417 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -458,7 +458,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. elif self.inputs['input_info'][ param] == "const std::vector&": meta_tensor_code = meta_tensor_code + f""" -{code_indent} auto {param}_meta_vec = MakeMetaTensor(*{PREFIX_TENSOR_NAME}{param}); +{code_indent} auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param}); {code_indent} std::vector {param}_metas({param}_meta_vec.size()); {code_indent} for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{ {code_indent} {param}_metas[i] = &{param}_meta_vec[i]; @@ -502,7 +502,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. input_trans_map = { 'const Tensor&': 'const phi::DenseTensor&', 'const std::vector&': - 'const std::vector&', + 'const std::vector&', 'const paddle::optional&': 'paddle::optional', 'const paddle::optional>&': @@ -539,9 +539,22 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. {code_indent} }}""" else: - input_tensor_code = input_tensor_code + f""" + if self.inputs['input_info'][input_name] == "const Tensor&": + input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" + elif self.inputs['input_info'][ + input_name] == "const std::vector&": + input_tensor_code = input_tensor_code + f""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag}); +{code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size()); +{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{ +{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); +{code_indent} }}""" + + else: + # do nothing + pass else: if input_name in self.optional_vars: input_tensor_code = input_tensor_code + f""" @@ -561,7 +574,14 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. if param in self.optional_vars: kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", " else: - kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", " + if self.inputs['input_info'][param] == "const Tensor&": + kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", " + elif self.inputs['input_info'][ + input_name] == "const std::vector&": + kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", " + else: + # do nothing + pass kernel_in_type = input_trans_map[input_infos[param]] kernel_args_type_list.append(kernel_in_type) elif param in attr_names: -- GitLab