diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d5336af8f05ef7fce1d5b1a2153cb8928772e232..d7d1093b9b3bf2f9f605c7c45c6d5f8a4e52bb6a 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -6,9 +6,9 @@ endif() # please add new math_library in alphabetical order if (WITH_ASCEND_CL) -math_library(concat_and_split DEPS npu_op_runner) +math_library(concat_and_split DEPS concat_and_split_functor npu_op_runner) else() -math_library(concat_and_split) +math_library(concat_and_split DEPS concat_and_split_functor) endif() math_library(context_project DEPS im2col math_function) math_library(cross_entropy) diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index 8ec89f1b60acebdb0d1da8b6a07113b1f4c23ef0..46126ac59c892787d2f63956983404843e518ae7 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/phi/kernels/cpu/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #ifdef PADDLE_WITH_ASCEND_CL #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #endif @@ -46,9 +46,8 @@ class ConcatFunctor { void operator()(const platform::CPUDeviceContext& context, const std::vector& input, int axis, framework::Tensor* output) { - std::vector pt_input{input.begin(), input.end()}; - phi::ConcatImpl(context, pt_input, axis, - output); + phi::funcs::ConcatFunctor functor; + functor(context, input, axis, output); } }; @@ -63,11 +62,8 @@ class SplitFunctor { const framework::Tensor& input, const std::vector& ref_inputs, const int axis, std::vector* outputs) { - std::vector pt_ref_inputs{ref_inputs.begin(), - ref_inputs.end()}; - std::vector pt_outputs{outputs->begin(), outputs->end()}; - phi::SplitImpl(context, input, pt_ref_inputs, - axis, &pt_outputs); + phi::funcs::SplitFunctor functor; + functor(context, input, ref_inputs, axis, outputs); } }; diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index 51f94afcfc1b99755d5f9dca8460a56fc76cf543..e51631385eb75a63083e0cbbd2a8632d689be8f1 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/phi/kernels/gpu/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" namespace paddle { namespace operators { namespace math { @@ -29,10 +29,8 @@ class ConcatFunctor { void operator()(const platform::CUDADeviceContext& context, const std::vector& input, int axis, framework::Tensor* output) { - std::vector pt_input{input.begin(), input.end()}; - - phi::ConcatImpl(context, pt_input, axis, - output); + phi::funcs::ConcatFunctor functor; + functor(context, input, axis, output); } }; @@ -43,16 +41,12 @@ class ConcatFunctor { template class SplitFunctor { public: - SplitFunctor(); void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const std::vector& ref_inputs, int axis, std::vector* outputs) { - std::vector pt_ref_inputs{ref_inputs.begin(), - ref_inputs.end()}; - std::vector pt_outputs{outputs->begin(), outputs->end()}; - phi::SplitImpl( - context, input, pt_ref_inputs, axis, &pt_outputs); + phi::funcs::SplitFunctor functor; + functor(context, input, ref_inputs, axis, outputs); } }; diff --git a/paddle/fluid/operators/math/concat_and_split.h b/paddle/fluid/operators/math/concat_and_split.h index 65d2ca79e60c2ec90d879ce9818c398adc93c73c..b5b0aae23ac875c7afeb4148309138aae49e5b4a 100644 --- a/paddle/fluid/operators/math/concat_and_split.h +++ b/paddle/fluid/operators/math/concat_and_split.h @@ -64,17 +64,3 @@ class SplitFunctor { } // namespace math } // namespace operators } // namespace paddle - -#define FOR_ALL_TYPES(macro) \ - macro(int); \ - macro(float); \ - macro(double); \ - macro(bool); \ - macro(int64_t); \ - macro(int16_t); \ - macro(uint8_t); \ - macro(int8_t); \ - macro(::paddle::platform::float16); \ - macro(::paddle::platform::bfloat16); \ - macro(::paddle::platform::complex); \ - macro(::paddle::platform::complex); diff --git a/paddle/fluid/operators/unbind_op.cc b/paddle/fluid/operators/unbind_op.cc index 3fce0f8f47d32a602d56e88b43ddb9bf3d4b15f8..f2fc08308c6b32868adc8057c9bc2a92c4247c60 100644 --- a/paddle/fluid/operators/unbind_op.cc +++ b/paddle/fluid/operators/unbind_op.cc @@ -14,6 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/unbind_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -79,11 +82,3 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(unbind, ops::UnbindOp, ops::UnbindOpMaker, ops::UnbindGradMaker, ops::UnbindGradMaker); -namespace plat = paddle::platform; -REGISTER_OP_CPU_KERNEL( - unbind, ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel); diff --git a/paddle/fluid/operators/unbind_op.cu.cc b/paddle/fluid/operators/unbind_op.cu.cc deleted file mode 100644 index cec7058d3cf52eff55eb88afaa217204a72e4566..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/unbind_op.cu.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/unbind_op.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - unbind, ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel, - ops::UnbindOpKernel); diff --git a/paddle/fluid/operators/unbind_op.h b/paddle/fluid/operators/unbind_op.h index 69808e3f9fe9ed4a92152fc89532a7470bf85f6f..6e35f262de420744b5299fbf1ab540e34c711d92 100644 --- a/paddle/fluid/operators/unbind_op.h +++ b/paddle/fluid/operators/unbind_op.h @@ -34,27 +34,6 @@ static inline framework::DDim UnbindOutsDims(const framework::DDim in_dims, } return phi::make_ddim(out_dims); } -template -class UnbindOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto outs = ctx.MultiOutput("Out"); - int axis = ctx.Attr("axis"); - - auto in_dims = in->dims(); - axis = axis < 0 ? in_dims.size() + axis : axis; - std::vector shape_refer; - for (size_t j = 0; j < outs.size(); ++j) { - outs[j]->mutable_data(ctx.GetPlace()); - shape_refer.emplace_back(outs[j]); - } - - auto& dev_ctx = ctx.template device_context(); - math::SplitFunctor functor; - functor(dev_ctx, *in, shape_refer, axis, &outs); - } -}; template class UnbindGradMaker : public framework::SingleGradOpMaker { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1fbd6c2b6c2f5f5b3a86917c9ff35031da9b6b93..ca71d6a56d8e785ab18e047e6ae552f5994cc0f0 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -485,6 +485,25 @@ void SplitInferMeta(const MetaTensor& x, } } +void UnbindInferMeta(const MetaTensor& x, + int axis, + std::vector* outs) { + auto in_dims = x.dims(); + std::vector out_dim; + axis = axis < 0 ? in_dims.size() + axis : axis; + for (int i = 0; i < in_dims.size(); ++i) { + if (i != axis) out_dim.push_back(in_dims[i]); + } + auto out_dims = phi::make_ddim(out_dim); + + for (size_t i = 0; i < outs->size(); ++i) { + (*outs)[i].set_dtype(x.dtype()); + (*outs)[i].set_dims(out_dims); + (*outs)[i].set_layout(x.layout()); + (*outs)[i].share_lod(x); + } +} + void TraceInferMeta( const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out) { int dim1 = axis1; diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c6d5d250d98aa7b490806bc20a38944589e19b9d..7d15f497ead146d2e081146d75984cb84b121cdb 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -90,6 +90,9 @@ void SplitInferMeta(const MetaTensor& x_meta, std::vector* out, MetaConfig config = MetaConfig()); +void UnbindInferMeta(const MetaTensor& x, + int axis, + std::vector* outs); void TraceInferMeta( const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out); diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index f819eb3de3ef7891ab3b21242d4a7bbcf7210cb9..ef085e71f5dcc295a417f0c6aa83fc7cdfc20a8d 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -10,7 +10,7 @@ add_subdirectory(funcs) set_property(GLOBAL PROPERTY PTEN_KERNELS "") set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col concat_and_split_functor) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) diff --git a/paddle/phi/kernels/cpu/concat_and_split.h b/paddle/phi/kernels/cpu/concat_and_split.h deleted file mode 100644 index 88cfc5db8f2e852ee26f2300afb5a93cf06274c1..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/cpu/concat_and_split.h +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -/* - * \brief Concatenate the input tensors along the dimension axis. - * TODO(zcd): maybe it needs to be more detailed. - * Examples: - * Input[0] = [[1,2],[3,4]] - * Input[1] = [[5,6]] - * axis = 0 - * - * Output = [[1,2], - * [3,4], - * [5,6]] - */ - -template -void ConcatImpl(const Context& context, - const std::vector& input, - int axis, - DenseTensor* output) { - // TODO(zcd): Add input data validity checking - size_t num = input.size(); - - int64_t rows = 1; - auto dim_0 = input[0].dims(); - for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; - } - int64_t out_rows = rows, out_cols = 0; - - std::vector input_cols(input.size()); - for (size_t i = 0; i < num; ++i) { - int64_t t_cols = input[i].numel() / rows; - out_cols += t_cols; - input_cols[i] = t_cols; - } - auto cpu_place = context.GetPlace(); - - // computation - auto output_data = output->data(); - int64_t col_idx = 0; - for (size_t j = 0; j < num; ++j) { - int64_t col_len = input_cols[j]; - auto input_data = input[j].data(); - for (int64_t k = 0; k < out_rows; ++k) { - paddle::memory::Copy(cpu_place, - output_data + k * out_cols + col_idx, - cpu_place, - input_data + k * col_len, - sizeof(T) * col_len); - } - col_idx += col_len; - } -} - -/* - * \brief Split the input tensors along the dimension axis into outputs. - * TODO(zcd): maybe it needs to be more detailed. - * Examples: - * Input = [[1,2], - * [3,4], - * [5,6]] - * axis = 0 - * - * Output[0] = [[1,2],[3,4]] - * Output[1] = [[5,6]] - */ -template -void SplitImpl(const Context& context, - const DenseTensor& input, - const std::vector& ref_inputs, - const int axis, - std::vector* outputs) { - // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 - // tensors of shape [0,1,4] - if (input.numel() == 0) { - return; - } - - // TODO(zcd): Add input data validity checking - size_t num = outputs->size(); - - int input_rows = 1; - auto dim_0 = ref_inputs[0]->dims(); - for (int i = 0; i < axis; ++i) { - input_rows *= dim_0[i]; - } - - int input_cols = 0; - - std::vector output_cols(outputs->size()); - for (size_t i = 0; i < num; ++i) { - int t_cols = ref_inputs[i]->numel() / input_rows; - input_cols += t_cols; - output_cols[i] = t_cols; - } - auto cpu_place = context.GetPlace(); - - // computation - for (int k = 0; k < input_rows; ++k) { - const T* src_ptr = input.data() + k * input_cols; - int col_idx = 0; - for (size_t j = 0; j < num; ++j) { - int col_len = output_cols[j]; - auto* out_tensor = outputs->at(j); - if (out_tensor != nullptr) { - T* dst_ptr = out_tensor->data() + k * col_len; - paddle::memory::Copy(cpu_place, - dst_ptr, - cpu_place, - src_ptr + col_idx, - sizeof(T) * col_len); - } - col_idx += col_len; - } - } -} - -} // namespace phi diff --git a/paddle/phi/kernels/cpu/concat_kernel.cc b/paddle/phi/kernels/cpu/concat_kernel.cc index 3b74951a5041cd303c85c6a57766f5a06412f71b..18bb8837b105d91e3e13a0a7519b08c9c47202c4 100644 --- a/paddle/phi/kernels/cpu/concat_kernel.cc +++ b/paddle/phi/kernels/cpu/concat_kernel.cc @@ -22,7 +22,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/lod_utils.h" -#include "paddle/phi/kernels/cpu/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" namespace phi { @@ -104,7 +104,8 @@ void ConcatKernel(const Context& dev_ctx, continue; } } - ConcatImpl(dev_ctx, inputs, axis, out); + phi::funcs::ConcatFunctor functor; + functor(dev_ctx, inputs, axis, out); } } diff --git a/paddle/phi/kernels/cpu/split_kernel.cc b/paddle/phi/kernels/cpu/split_kernel.cc index 259bf9e388c2c1a88400d13086bf9df23df21044..7b2166eaf11f90b653f5f3c57a278b24c2aa1af4 100644 --- a/paddle/phi/kernels/cpu/split_kernel.cc +++ b/paddle/phi/kernels/cpu/split_kernel.cc @@ -19,7 +19,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/infermeta/unary.h" -#include "paddle/phi/kernels/cpu/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" namespace phi { template @@ -54,7 +54,8 @@ void SplitKernel(const Context& dev_ctx, paddle::operators::StridedMemcpyWithAxis0( dev_ctx, x, shape_refer, &outs); } else { - SplitImpl(dev_ctx, x, shape_refer, axis, &outs); + phi::funcs::SplitFunctor functor; + functor(dev_ctx, x, shape_refer, axis, &outs); } } diff --git a/paddle/phi/kernels/cpu/unbind_kernel.cc b/paddle/phi/kernels/cpu/unbind_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..655f8c8aafbf201dc07db0fa1af79605c2a76763 --- /dev/null +++ b/paddle/phi/kernels/cpu/unbind_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/unbind_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/unbind_kernel_impl.h" + +PD_REGISTER_KERNEL(unbind, + CPU, + ALL_LAYOUT, + phi::UnbindKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index ba0c848df434ed403c29a5754043784066f7ef2a..aa4fac169200753639c48f5e9b5fa8c3bbfbd33c 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(blas) add_subdirectory(lapack) math_library(math_function DEPS blas dense_tensor tensor) +math_library(concat_and_split_functor DEPS dense_tensor) diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cc b/paddle/phi/kernels/funcs/concat_and_split_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8405703a5c16ae9eae583638d1c89c22a736531 --- /dev/null +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cc @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/utils/data_type.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" + +namespace phi { +namespace funcs { + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +struct ConcatFunctor { + void operator()(const phi::CPUContext& context, + const std::vector& input, + int axis, + phi::DenseTensor* output) { + // TODO(zcd): Add input data validity checking + size_t num = input.size(); + + int64_t rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int64_t out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (size_t i = 0; i < num; ++i) { + int64_t t_cols = input[i].numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + auto cpu_place = context.GetPlace(); + + // computation + auto output_data = output->data(); + int64_t col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int64_t col_len = input_cols[j]; + auto input_data = input[j].data(); + for (int64_t k = 0; k < out_rows; ++k) { + paddle::memory::Copy(cpu_place, + output_data + k * out_cols + col_idx, + cpu_place, + input_data + k * col_len, + sizeof(T) * col_len); + } + col_idx += col_len; + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +struct SplitFunctor { + public: + void operator()(const phi::CPUContext& context, + const phi::DenseTensor& input, + const std::vector& ref_inputs, + int axis, + std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + + // TODO(zcd): Add input data validity checking + size_t num = outputs->size(); + + int input_rows = 1; + auto dim_0 = ref_inputs[0]->dims(); + for (int i = 0; i < axis; ++i) { + input_rows *= dim_0[i]; + } + + int input_cols = 0; + + std::vector output_cols(outputs->size()); + for (size_t i = 0; i < num; ++i) { + int t_cols = ref_inputs[i]->numel() / input_rows; + input_cols += t_cols; + output_cols[i] = t_cols; + } + auto cpu_place = context.GetPlace(); + + // computation + for (int k = 0; k < input_rows; ++k) { + const T* src_ptr = input.data() + k * input_cols; + int col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int col_len = output_cols[j]; + auto* out_tensor = outputs->at(j); + if (out_tensor != nullptr) { + T* dst_ptr = out_tensor->data() + k * col_len; + paddle::memory::Copy(cpu_place, + dst_ptr, + cpu_place, + src_ptr + col_idx, + sizeof(T) * col_len); + } + col_idx += col_len; + } + } + } +}; + +#define DEFINE_FUNCTOR(type) \ + template class ConcatFunctor; \ + template class SplitFunctor; + +FOR_ALL_TYPES(DEFINE_FUNCTOR); + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu new file mode 100644 index 0000000000000000000000000000000000000000..2abfdb606e7e6c410f6f9deb45aed536bea88207 --- /dev/null +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -0,0 +1,584 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" + +namespace phi { +namespace funcs { + +template +__global__ void ConcatKernel_(const T** inputs, + const int64_t* input_cols, + int col_size, + const int64_t output_rows, + const int64_t output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int curr_segment = 0; + int curr_offset = input_cols[0]; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int curr_col_offset = input_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + curr_col_offset = input_cols[curr_segment + 1]; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + + const T* input_ptr = inputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +template +__device__ void ConcatKernelDetail(const T** inputs_data, + const int fixed_in_col, + const int out_rows, + const int out_cols, + T* output_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * 1.0 / fixed_in_col; + int in_offset = tid_x - split * fixed_in_col; + const T* input_ptr = inputs_data[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { + output_data[tid_y * out_cols + tid_x] = + input_ptr[tid_y * fixed_in_col + in_offset]; + } + } +} + +template +__global__ void ConcatKernel_(const T* input_addr0, + const T* input_addr1, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + const T* inputs_data[2]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel_(const T* input_addr0, + const T* input_addr1, + const T* input_addr2, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + const T* inputs_data[3]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel_(const T* input_addr0, + const T* input_addr1, + const T* input_addr2, + const T* input_addr3, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + const T* inputs_data[4]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + inputs_data[3] = input_addr3; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel_(const T** inputs_data, + const int in_num, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, + T* output_data) { + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t* out_cols, + int out_cols_size, + T** outputs_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int curr_segment = 0; + int curr_offset = out_cols[0]; + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { + int curr_col_offset = out_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + curr_col_offset = out_cols[curr_segment + 1]; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs_data[curr_segment]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input_data[tid_y * in_col + tid_x]; + } + } +} + +template +__device__ void SplitKernelDetail(const T* input_data, + const int in_row, + const int in_col, + const int fixed_out_col, + T** outputs_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x / fixed_out_col; + int in_offset = tid_x - split * fixed_out_col; + T* output_ptr = outputs_data[split]; + if (output_ptr != nullptr) { + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * fixed_out_col + in_offset] = + input_data[tid_y * in_col + tid_x]; + } + } +} + +template +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T** outputs_data) { + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1) { + T* outputs_data[2]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2) { + T* outputs_data[3]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + outputs_data[2] = outputs_addr2; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +template +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2, + T* outputs_addr3) { + T* outputs_data[4]; + outputs_data[0] = outputs_addr0; + outputs_data[1] = outputs_addr1; + outputs_data[2] = outputs_addr2; + outputs_data[3] = outputs_addr3; + SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); +} + +static inline void GetBlockDims(const phi::GPUContext& context, + int64_t num_rows, + int64_t num_cols, + dim3* block_dims, + dim3* grid_dims) { + // Set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((num_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + *block_dims = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((num_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = std::min(max_blocks / grid_cols, + std::max(num_rows / block_rows, (int64_t)1)); + *grid_dims = dim3(grid_cols, grid_rows, 1); +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ + +template +struct ConcatFunctor { + void operator()(const phi::GPUContext& context, + const std::vector& input, + int axis, + phi::DenseTensor* output) { + // TODO(zcd): Add input data validity checking + int in_num = input.size(); + int64_t in_row = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + in_row *= dim_0[i]; + } + int64_t in_col = input[0].numel() / in_row; + int64_t out_row = in_row, out_col = 0; + + int inputs_col_num = in_num + 1; + std::vector inputs_data_vec(in_num); + std::vector inputs_col_vec(inputs_col_num); + const T** inputs_data = inputs_data_vec.data(); + int64_t* inputs_col = inputs_col_vec.data(); + +// There are some differences between hip runtime and NV runtime. +// In NV, when the pageable memory data less than 64K is transferred from +// hosttodevice, it will be automatically asynchronous. +// However, only pinned memory in hip can copy asynchronously +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device +// 3.2.6.1. Concurrent Execution between Host and Device +// Memory copies from host to device of a memory block of 64 KB or less +#ifdef PADDLE_WITH_HIP + paddle::memory::AllocationPtr data_alloc, col_alloc; + // TODO(chentianyu03): try to find a method to remove the Alloc function + data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + in_num * sizeof(T*)); + inputs_data = reinterpret_cast(data_alloc->ptr()); + // TODO(chentianyu03): try to find a method to remove the Alloc function + col_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + inputs_col_num * sizeof(int)); + inputs_col = reinterpret_cast(col_alloc->ptr()); +#endif + + inputs_col[0] = 0; + bool has_same_shape = true; + for (int i = 0; i < in_num; ++i) { + int64_t t_cols = input[i].numel() / in_row; + if (has_same_shape) { + if (t_cols != in_col) has_same_shape = false; + } + out_col += t_cols; + inputs_col[i + 1] = out_col; + inputs_data[i] = input[i].data(); + } + + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims); + + paddle::memory::allocation::AllocationPtr tmp_dev_ins_data; + const T** dev_ins_data = nullptr; + if (!has_same_shape || in_num < 2 || in_num > 4) { + tmp_dev_ins_data = paddle::memory::Alloc(context, in_num * sizeof(T*)); + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + inputs_data, in_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_ins_data->ptr(), + paddle::platform::CPUPlace(), + restored, + in_num * sizeof(T*), + context.stream()); + dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); + } + + if (has_same_shape) { + if (in_num == 2) { + ConcatKernel_<<>>( + inputs_data[0], + inputs_data[1], + in_col, + out_row, + out_col, + output->data()); + } else if (in_num == 3) { + ConcatKernel_<<>>( + inputs_data[0], + inputs_data[1], + inputs_data[2], + in_col, + out_row, + out_col, + output->data()); + } else if (in_num == 4) { + ConcatKernel_<<>>( + inputs_data[0], + inputs_data[1], + inputs_data[2], + inputs_data[3], + in_col, + out_row, + out_col, + output->data()); + } else { + ConcatKernel_<<>>( + dev_ins_data, in_num, in_col, out_row, out_col, output->data()); + } + } else { + auto tmp_dev_ins_col_data = + paddle::memory::Alloc(context, inputs_col_num * sizeof(int64_t)); + + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + inputs_col, inputs_col_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_ins_col_data->ptr(), + paddle::platform::CPUPlace(), + restored, + inputs_col_num * sizeof(int64_t), + context.stream()); + int64_t* dev_ins_col_data = + static_cast(tmp_dev_ins_col_data->ptr()); + + ConcatKernel_<<>>( + dev_ins_data, + dev_ins_col_data, + static_cast(inputs_col_num), + out_row, + out_col, + output->data()); + } + +#ifdef PADDLE_WITH_HIP + // Prevent the pinned memory value from being covered and release the memory + // after the launch kernel of the stream is executed (reapply pinned memory + // next time) + auto* data_alloc_released = data_alloc.release(); + auto* col_alloc_released = col_alloc.release(); + context.AddStreamCallback([data_alloc_released, col_alloc_released] { + paddle::memory::allocation::Allocator::AllocationDeleter( + data_alloc_released); + paddle::memory::allocation::Allocator::AllocationDeleter( + col_alloc_released); + }); +#endif + } +}; + +template +class SplitFunctor { + public: + void operator()(const phi::GPUContext& context, + const phi::DenseTensor& input, + const std::vector& ref_inputs, + int axis, + std::vector* outputs) { + // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 + // tensors of shape [0,1,4] + if (input.numel() == 0) { + return; + } + + // TODO(zcd): Add input data validity checking + int o_num = outputs->size(); + int64_t out_row = 1; + auto dim_0 = ref_inputs[0]->dims(); + for (int i = 0; i < axis; ++i) { + out_row *= dim_0[i]; + } + + int64_t out0_col = ref_inputs[0]->numel() / out_row; + int64_t in_col = 0, in_row = out_row; + bool has_same_shape = true; + + int outputs_cols_num = o_num + 1; + std::vector outputs_data_vec(o_num); + std::vector outputs_cols_vec(outputs_cols_num); + T** outputs_data = outputs_data_vec.data(); + int64_t* outputs_cols = outputs_cols_vec.data(); + +// There are some differences between hip runtime and NV runtime. +// In NV, when the pageable memory data less than 64K is transferred from +// hosttodevice, it will be automatically asynchronous. +// However, only pinned memory in hip can copy asynchronously +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device +// 3.2.6.1. Concurrent Execution between Host and Device +// Memory copies from host to device of a memory block of 64 KB or less +#ifdef PADDLE_WITH_HIP + paddle::memory::AllocationPtr data_alloc, cols_alloc; + // TODO(chentianyu03): try to find a method to remove the Alloc function + data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + o_num * sizeof(T*)); + outputs_data = reinterpret_cast(data_alloc->ptr()); + // TODO(chentianyu03): try to find a method to remove the Alloc function + cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), + (outputs_cols_num) * sizeof(int64_t)); + outputs_cols = reinterpret_cast(cols_alloc->ptr()); +#endif + + outputs_cols[0] = 0; + for (int i = 0; i < o_num; ++i) { + int64_t t_col = ref_inputs.at(i)->numel() / out_row; + if (has_same_shape) { + if (t_col != out0_col) has_same_shape = false; + } + in_col += t_col; + outputs_cols[i + 1] = in_col; + if (outputs->at(i) != nullptr) { + outputs_data[i] = outputs->at(i)->data(); + } else { + outputs_data[i] = nullptr; + } + } + + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims); + + paddle::memory::allocation::AllocationPtr tmp_dev_outs_data; + T** dev_out_gpu_data = nullptr; + if (!has_same_shape || o_num < 2 || o_num > 4) { + // TODO(chentianyu03): try to find a method to remove the Alloc function + tmp_dev_outs_data = paddle::memory::Alloc(context, o_num * sizeof(T*)); + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + outputs_data, o_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_outs_data->ptr(), + paddle::platform::CPUPlace(), + restored, + o_num * sizeof(T*), + context.stream()); + dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); + } + + if (has_same_shape) { + if (o_num == 2) { + SplitKernel_<<>>( + input.data(), + in_row, + in_col, + out0_col, + outputs_data[0], + outputs_data[1]); + } else if (o_num == 3) { + SplitKernel_<<>>( + input.data(), + in_row, + in_col, + out0_col, + outputs_data[0], + outputs_data[1], + outputs_data[2]); + } else if (o_num == 4) { + SplitKernel_<<>>( + input.data(), + in_row, + in_col, + out0_col, + outputs_data[0], + outputs_data[1], + outputs_data[2], + outputs_data[3]); + } else { + SplitKernel_<<>>( + input.data(), in_row, in_col, out0_col, dev_out_gpu_data); + } + } else { + auto tmp_dev_ins_col_data = + // TODO(chentianyu03): try to find a method to remove the Alloc + // function + paddle::memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); + auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + outputs_cols, outputs_cols_num); + paddle::memory::Copy(context.GetPlace(), + tmp_dev_ins_col_data->ptr(), + paddle::platform::CPUPlace(), + restored, + outputs_cols_num * sizeof(int64_t), + context.stream()); + int64_t* dev_outs_col_data = + reinterpret_cast(tmp_dev_ins_col_data->ptr()); + + SplitKernel_<<>>( + input.data(), + in_row, + in_col, + dev_outs_col_data, + static_cast(outputs_cols_num), + dev_out_gpu_data); + } +#ifdef PADDLE_WITH_HIP + // Prevent the pinned memory value from being covered and release the memory + // after the launch kernel of the stream is executed (reapply pinned memory + // next time) + auto* data_alloc_released = data_alloc.release(); + auto* cols_alloc_released = cols_alloc.release(); + context.AddStreamCallback([data_alloc_released, cols_alloc_released] { + paddle::memory::allocation::Allocator::AllocationDeleter( + data_alloc_released); + paddle::memory::allocation::Allocator::AllocationDeleter( + cols_alloc_released); + }); +#endif + } +}; + +#define DEFINE_FUNCTOR(type) \ + template class ConcatFunctor; \ + template class SplitFunctor + +FOR_ALL_TYPES(DEFINE_FUNCTOR); + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.h b/paddle/phi/kernels/funcs/concat_and_split_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..3af4d878d3cab03eb80a6ba878cc4fa5a62103c9 --- /dev/null +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { +namespace funcs { + +/* + * \brief Concatenate the input tensors along the dimension axis. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input[0] = [[1,2],[3,4]] + * Input[1] = [[5,6]] + * axis = 0 + * + * Output = [[1,2], + * [3,4], + * [5,6]] + */ +template +struct ConcatFunctor { + void operator()(const Context& context, + const std::vector& input, + int axis, + phi::DenseTensor* output); +}; + +/* + * \brief Split the input tensors along the dimension axis into outputs. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input = [[1,2], + * [3,4], + * [5,6]] + * axis = 0 + * + * Output[0] = [[1,2],[3,4]] + * Output[1] = [[5,6]] + */ +template +class SplitFunctor { + public: + void operator()(const Context& context, + const phi::DenseTensor& input, + const std::vector& ref_inputs, + int axis, + std::vector* outputs); +}; + +} // namespace funcs +} // namespace phi + +#define FOR_ALL_TYPES(macro) \ + macro(int); \ + macro(float); \ + macro(double); \ + macro(bool); \ + macro(int64_t); \ + macro(int16_t); \ + macro(uint8_t); \ + macro(int8_t); \ + macro(phi::dtype::float16); \ + macro(phi::dtype::bfloat16); \ + macro(phi::dtype::complex); \ + macro(phi::dtype::complex); diff --git a/paddle/phi/kernels/gpu/concat_and_split.h b/paddle/phi/kernels/gpu/concat_and_split.h deleted file mode 100644 index ced48ece979f06fbf2bd3f9fd8b7e07cc2954fbf..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/gpu/concat_and_split.h +++ /dev/null @@ -1,567 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include "gflags/gflags.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" - -namespace phi { - -template -__global__ void ConcatKernel_(const T** inputs, - const int64_t* input_cols, - int col_size, - const int64_t output_rows, - const int64_t output_cols, - T* output) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = input_cols[0]; - for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = input_cols[curr_segment + 1]; - while (curr_col_offset <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - curr_col_offset = input_cols[curr_segment + 1]; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - - const T* input_ptr = inputs[curr_segment]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) - output[tid_y * output_cols + tid_x] = - input_ptr[tid_y * segment_width + local_col]; - } -} - -template -__device__ void ConcatKernelDetail(const T** inputs_data, - const int fixed_in_col, - const int out_rows, - const int out_cols, - T* output_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * 1.0 / fixed_in_col; - int in_offset = tid_x - split * fixed_in_col; - const T* input_ptr = inputs_data[split]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { - output_data[tid_y * out_cols + tid_x] = - input_ptr[tid_y * fixed_in_col + in_offset]; - } - } -} - -template -__global__ void ConcatKernel_(const T* input_addr0, - const T* input_addr1, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - const T* inputs_data[2]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel_(const T* input_addr0, - const T* input_addr1, - const T* input_addr2, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - const T* inputs_data[3]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel_(const T* input_addr0, - const T* input_addr1, - const T* input_addr2, - const T* input_addr3, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - const T* inputs_data[4]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - inputs_data[3] = input_addr3; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel_(const T** inputs_data, - const int in_num, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t* out_cols, - int out_cols_size, - T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = out_cols[0]; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = out_cols[curr_segment + 1]; - while (curr_col_offset <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - curr_col_offset = out_cols[curr_segment + 1]; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - T* output_ptr = outputs_data[curr_segment]; - if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * segment_width + local_col] = - input_data[tid_y * in_col + tid_x]; - } - } -} - -template -__device__ void SplitKernelDetail(const T* input_data, - const int in_row, - const int in_col, - const int fixed_out_col, - T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int split = tid_x / fixed_out_col; - int in_offset = tid_x - split * fixed_out_col; - T* output_ptr = outputs_data[split]; - if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) - output_ptr[tid_y * fixed_out_col + in_offset] = - input_data[tid_y * in_col + tid_x]; - } - } -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T** outputs_data) { - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1) { - T* outputs_data[2]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2) { - T* outputs_data[3]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - outputs_data[2] = outputs_addr2; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -template -__global__ void SplitKernel_(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2, - T* outputs_addr3) { - T* outputs_data[4]; - outputs_data[0] = outputs_addr0; - outputs_data[1] = outputs_addr1; - outputs_data[2] = outputs_addr2; - outputs_data[3] = outputs_addr3; - SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); -} - -static inline void GetBlockDims(const phi::GPUContext& context, - int64_t num_rows, - int64_t num_cols, - dim3* block_dims, - dim3* grid_dims) { - // Set the thread block and grid according to CurrentDeviceId - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((num_cols + 31) >> 5) << 5; - } - int block_rows = kThreadsPerBlock / block_cols; - *block_dims = dim3(block_cols, block_rows, 1); - - int max_threads = context.GetMaxPhysicalThreadCount(); - int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - int grid_cols = - std::min((num_cols + block_cols - 1) / block_cols, max_blocks); - int grid_rows = std::min(max_blocks / grid_cols, - std::max(num_rows / block_rows, (int64_t)1)); - *grid_dims = dim3(grid_cols, grid_rows, 1); -} - -/* - * All tensors' dimension should be the same and the values of - * each dimension must be the same, except the axis dimension. - */ -template -void ConcatImpl(const Context& context, - const std::vector& input, - int axis, - phi::DenseTensor* output) { - // TODO(zcd): Add input data validity checking - int in_num = input.size(); - int64_t in_row = 1; - auto dim_0 = input[0].dims(); - for (int i = 0; i < axis; ++i) { - in_row *= dim_0[i]; - } - int64_t in_col = input[0].numel() / in_row; - int64_t out_row = in_row, out_col = 0; - - int inputs_col_num = in_num + 1; - std::vector inputs_data_vec(in_num); - std::vector inputs_col_vec(inputs_col_num); - const T** inputs_data = inputs_data_vec.data(); - int64_t* inputs_col = inputs_col_vec.data(); - -// There are some differences between hip runtime and NV runtime. -// In NV, when the pageable memory data less than 64K is transferred from -// hosttodevice, it will be automatically asynchronous. -// However, only pinned memory in hip can copy asynchronously -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device -// 3.2.6.1. Concurrent Execution between Host and Device -// Memory copies from host to device of a memory block of 64 KB or less -#ifdef PADDLE_WITH_HIP - paddle::memory::AllocationPtr data_alloc, col_alloc; - // TODO(chentianyu03): try to find a method to remove the Alloc function - data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - in_num * sizeof(T*)); - inputs_data = reinterpret_cast(data_alloc->ptr()); - // TODO(chentianyu03): try to find a method to remove the Alloc function - col_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - inputs_col_num * sizeof(int)); - inputs_col = reinterpret_cast(col_alloc->ptr()); -#endif - - inputs_col[0] = 0; - bool has_same_shape = true; - for (int i = 0; i < in_num; ++i) { - int64_t t_cols = input[i].numel() / in_row; - if (has_same_shape) { - if (t_cols != in_col) has_same_shape = false; - } - out_col += t_cols; - inputs_col[i + 1] = out_col; - inputs_data[i] = input[i].data(); - } - - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims); - - paddle::memory::allocation::AllocationPtr tmp_dev_ins_data; - const T** dev_ins_data = nullptr; - if (!has_same_shape || in_num < 2 || in_num > 4) { - tmp_dev_ins_data = paddle::memory::Alloc(context, in_num * sizeof(T*)); - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( - inputs_data, in_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_ins_data->ptr(), - phi::CPUPlace(), - restored, - in_num * sizeof(T*), - context.stream()); - dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); - } - - if (has_same_shape) { - if (in_num == 2) { - ConcatKernel_<<>>( - inputs_data[0], - inputs_data[1], - in_col, - out_row, - out_col, - output->data()); - } else if (in_num == 3) { - ConcatKernel_<<>>( - inputs_data[0], - inputs_data[1], - inputs_data[2], - in_col, - out_row, - out_col, - output->data()); - } else if (in_num == 4) { - ConcatKernel_<<>>( - inputs_data[0], - inputs_data[1], - inputs_data[2], - inputs_data[3], - in_col, - out_row, - out_col, - output->data()); - } else { - ConcatKernel_<<>>( - dev_ins_data, in_num, in_col, out_row, out_col, output->data()); - } - } else { - auto tmp_dev_ins_col_data = - paddle::memory::Alloc(context, inputs_col_num * sizeof(int64_t)); - - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( - inputs_col, inputs_col_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_ins_col_data->ptr(), - phi::CPUPlace(), - restored, - inputs_col_num * sizeof(int64_t), - context.stream()); - int64_t* dev_ins_col_data = - static_cast(tmp_dev_ins_col_data->ptr()); - - ConcatKernel_<<>>( - dev_ins_data, - dev_ins_col_data, - static_cast(inputs_col_num), - out_row, - out_col, - output->data()); - } - -#ifdef PADDLE_WITH_HIP - // Prevent the pinned memory value from being covered and release the memory - // after the launch kernel of the stream is executed (reapply pinned memory - // next time) - auto* data_alloc_released = data_alloc.release(); - auto* col_alloc_released = col_alloc.release(); - context.AddStreamCallback([data_alloc_released, col_alloc_released] { - paddle::memory::allocation::Allocator::AllocationDeleter( - data_alloc_released); - paddle::memory::allocation::Allocator::AllocationDeleter( - col_alloc_released); - }); -#endif -} - -/* - * All tensors' dimension should be the same and the values of - * each dimension must be the same, except the axis dimension. - */ -template -void SplitImpl(const Context& context, - const phi::DenseTensor& input, - const std::vector& ref_inputs, - int axis, - std::vector* outputs) { - // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 - // tensors of shape [0,1,4] - if (input.numel() == 0) { - return; - } - - // TODO(zcd): Add input data validity checking - int o_num = outputs->size(); - int64_t out_row = 1; - auto dim_0 = ref_inputs[0]->dims(); - for (int i = 0; i < axis; ++i) { - out_row *= dim_0[i]; - } - - int64_t out0_col = ref_inputs[0]->numel() / out_row; - int64_t in_col = 0, in_row = out_row; - bool has_same_shape = true; - - int outputs_cols_num = o_num + 1; - std::vector outputs_data_vec(o_num); - std::vector outputs_cols_vec(outputs_cols_num); - T** outputs_data = outputs_data_vec.data(); - int64_t* outputs_cols = outputs_cols_vec.data(); - -// There are some differences between hip runtime and NV runtime. -// In NV, when the pageable memory data less than 64K is transferred from -// hosttodevice, it will be automatically asynchronous. -// However, only pinned memory in hip can copy asynchronously -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device -// 3.2.6.1. Concurrent Execution between Host and Device -// Memory copies from host to device of a memory block of 64 KB or less -#ifdef PADDLE_WITH_HIP - paddle::memory::AllocationPtr data_alloc, cols_alloc; - // TODO(chentianyu03): try to find a method to remove the Alloc function - data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - o_num * sizeof(T*)); - outputs_data = reinterpret_cast(data_alloc->ptr()); - // TODO(chentianyu03): try to find a method to remove the Alloc function - cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - (outputs_cols_num) * sizeof(int64_t)); - outputs_cols = reinterpret_cast(cols_alloc->ptr()); -#endif - - outputs_cols[0] = 0; - for (int i = 0; i < o_num; ++i) { - int64_t t_col = ref_inputs.at(i)->numel() / out_row; - if (has_same_shape) { - if (t_col != out0_col) has_same_shape = false; - } - in_col += t_col; - outputs_cols[i + 1] = in_col; - if (outputs->at(i) != nullptr) { - outputs_data[i] = outputs->at(i)->data(); - } else { - outputs_data[i] = nullptr; - } - } - - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims); - - paddle::memory::allocation::AllocationPtr tmp_dev_outs_data; - T** dev_out_gpu_data = nullptr; - if (!has_same_shape || o_num < 2 || o_num > 4) { - // TODO(chentianyu03): try to find a method to remove the Alloc function - tmp_dev_outs_data = paddle::memory::Alloc(context, o_num * sizeof(T*)); - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( - outputs_data, o_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_outs_data->ptr(), - phi::CPUPlace(), - restored, - o_num * sizeof(T*), - context.stream()); - dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); - } - - if (has_same_shape) { - if (o_num == 2) { - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - out0_col, - outputs_data[0], - outputs_data[1]); - } else if (o_num == 3) { - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - out0_col, - outputs_data[0], - outputs_data[1], - outputs_data[2]); - } else if (o_num == 4) { - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - out0_col, - outputs_data[0], - outputs_data[1], - outputs_data[2], - outputs_data[3]); - } else { - SplitKernel_<<>>( - input.data(), in_row, in_col, out0_col, dev_out_gpu_data); - } - } else { - auto tmp_dev_ins_col_data = - // TODO(chentianyu03): try to find a method to remove the Alloc function - paddle::memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( - outputs_cols, outputs_cols_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_ins_col_data->ptr(), - phi::CPUPlace(), - restored, - outputs_cols_num * sizeof(int64_t), - context.stream()); - int64_t* dev_outs_col_data = - reinterpret_cast(tmp_dev_ins_col_data->ptr()); - - SplitKernel_<<>>( - input.data(), - in_row, - in_col, - dev_outs_col_data, - static_cast(outputs_cols_num), - dev_out_gpu_data); - } -#ifdef PADDLE_WITH_HIP - // Prevent the pinned memory value from being covered and release the memory - // after the launch kernel of the stream is executed (reapply pinned memory - // next time) - auto* data_alloc_released = data_alloc.release(); - auto* cols_alloc_released = cols_alloc.release(); - context.AddStreamCallback([data_alloc_released, cols_alloc_released] { - paddle::memory::allocation::Allocator::AllocationDeleter( - data_alloc_released); - paddle::memory::allocation::Allocator::AllocationDeleter( - cols_alloc_released); - }); -#endif -} - -} // namespace phi diff --git a/paddle/phi/kernels/gpu/concat_kernel.cu b/paddle/phi/kernels/gpu/concat_kernel.cu index b787b80c7e4ed9c10fafb139648e17fd91ca7529..2b04b979c20aa71cc723610d013cd12fb5537a29 100644 --- a/paddle/phi/kernels/gpu/concat_kernel.cu +++ b/paddle/phi/kernels/gpu/concat_kernel.cu @@ -22,8 +22,8 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/lod_utils.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" -#include "paddle/phi/kernels/gpu/concat_and_split.h" namespace phi { @@ -104,7 +104,8 @@ void ConcatKernel(const Context& dev_ctx, continue; } } - ConcatImpl(dev_ctx, inputs, axis, out); + phi::funcs::ConcatFunctor functor; + functor(dev_ctx, inputs, axis, out); } } diff --git a/paddle/phi/kernels/gpu/split_kernel.cu b/paddle/phi/kernels/gpu/split_kernel.cu index 5222fce03ace6fe30fce4aa9908794e348b79ad3..a698b9e716140b59b10a5799647e0a1aa7a8261d 100644 --- a/paddle/phi/kernels/gpu/split_kernel.cu +++ b/paddle/phi/kernels/gpu/split_kernel.cu @@ -18,7 +18,7 @@ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/concat_and_split.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" namespace phi { template @@ -53,7 +53,8 @@ void SplitKernel(const Context& dev_ctx, paddle::operators::StridedMemcpyWithAxis0( dev_ctx, x, shape_refer, &outs); } else { - SplitImpl(dev_ctx, x, shape_refer, axis, &outs); + phi::funcs::SplitFunctor functor; + functor(dev_ctx, x, shape_refer, axis, &outs); } } diff --git a/paddle/phi/kernels/gpu/unbind_kernel.cu b/paddle/phi/kernels/gpu/unbind_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..1efc3a1094da253c27fec5108b536837d868425e --- /dev/null +++ b/paddle/phi/kernels/gpu/unbind_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/unbind_kernel_impl.h" +#include "paddle/phi/kernels/unbind_kernel.h" + +PD_REGISTER_KERNEL(unbind, + GPU, + ALL_LAYOUT, + phi::UnbindKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/impl/unbind_kernel_impl.h b/paddle/phi/kernels/impl/unbind_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..8a1342559bd908bf197e4949ce66f9b3e504b499 --- /dev/null +++ b/paddle/phi/kernels/impl/unbind_kernel_impl.h @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/unbind_kernel.h" + +namespace phi { + +template +void UnbindKernel(const Context& ctx, + const DenseTensor& x, + int axis, + std::vector outs) { + auto x_dims = x.dims(); + axis = axis < 0 ? x_dims.size() + axis : axis; + + std::vector shape_refer; + for (size_t j = 0; j < outs.size(); ++j) { + ctx.template Alloc(outs[j]); + shape_refer.emplace_back(outs[j]); + } + + phi::funcs::SplitFunctor functor; + functor(ctx, x, shape_refer, axis, &outs); +} + +} // namespace phi diff --git a/paddle/phi/kernels/unbind_kernel.h b/paddle/phi/kernels/unbind_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..30ee9a15d084e7d33ab6b392592be0c9b8f3789a --- /dev/null +++ b/paddle/phi/kernels/unbind_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +void UnbindKernel(const Context& ctx, + const DenseTensor& x, + int axis, + std::vector outs); + +} // namespace phi