diff --git a/paddle/fluid/operators/collective/c_split_op.cc b/paddle/fluid/operators/collective/c_split_op.cc index 1bca682dbab8f7041dbcde240da4885da5658729..dd65b99e3b7ee3baf077443628a4e96697a99afc 100644 --- a/paddle/fluid/operators/collective/c_split_op.cc +++ b/paddle/fluid/operators/collective/c_split_op.cc @@ -120,13 +120,3 @@ REGISTER_OPERATOR(c_split, ops::CSplitOpGradMaker, ops::CSplitOpGradMaker, ops::CSplitOpMaker); - -PD_REGISTER_STRUCT_KERNEL(c_split, - CPU, - ALL_LAYOUT, - ops::CSplitOpCPUKernel, - float, - double, - int, - int64_t, - plat::float16) {} diff --git a/paddle/fluid/operators/collective/c_split_op.cu b/paddle/fluid/operators/collective/c_split_op.cu deleted file mode 100644 index 0b3e2aaf781dbe227c646c2c2161d49b954d6829..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/collective/c_split_op.cu +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/operators/collective/c_split_op.h" -#include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/phi/backends/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -static constexpr int64_t kNumCUDAThreads = 512; -static constexpr int64_t kNumMaxinumNumBlocks = 4096; - -static inline int64_t NumBlocks(const int64_t N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -template -__global__ void SplitFromRank(const T* input, - T* output, - const int64_t rows, - const int64_t columns, - const int rank, - const int nranks, - const int64_t limit) { - CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) { - int64_t row = i / columns; - int64_t col = i % columns; - - int64_t block = columns / nranks; - int64_t start = block * rank; - int64_t end = start + block; - - if (col >= start && col < end) { - int64_t idx = block * row + col % block; - output[idx] = input[i]; - } - } -} - -template -class CSplitOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto x = ctx.Input("X"); - auto out = ctx.Output("Out"); - - int nranks = ctx.Attr("nranks"); - int rank = ctx.Attr("rank"); - auto place = ctx.GetPlace(); - - PADDLE_ENFORCE_GE(rank, - 0, - platform::errors::PreconditionNotMet( - "The value of rank (%d) for c_split must be " - "greater than or equal to 0.", - rank)); - PADDLE_ENFORCE_GE(nranks, - 2, - platform::errors::PreconditionNotMet( - "The value of nranks (%d) for c_split must be " - "greater than or equal to 2.", - nranks)); - PADDLE_ENFORCE_LT(rank, - nranks, - platform::errors::PreconditionNotMet( - "The value of rank (%d) for c_split must be " - "less than that of nranks (%d).", - rank, - nranks)); - - auto& dev_ctx = ctx.template device_context(); - auto dims = x->dims(); - auto dims_size = dims.size(); - // final dim - int64_t end_size = dims[dims_size - 1]; - - // remain dim - auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1); - int64_t remain_numel = phi::product(remain_ddim); - - int64_t limit = x->numel(); - int64_t blocks = NumBlocks(limit); - int64_t threads = kNumCUDAThreads; - - dims[dims_size - 1] /= nranks; - out->mutable_data(dims, place); - - SplitFromRank<<>>(x->data(), - out->data(), - remain_numel, - end_size, - rank, - nranks, - limit); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(c_split, - GPU, - ALL_LAYOUT, - ops::CSplitOpCUDAKernel, - float, - double, - int, - int64_t, -#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 - plat::bfloat16, -#endif - plat::float16) { -} diff --git a/paddle/fluid/operators/collective/c_split_op.h b/paddle/fluid/operators/collective/c_split_op.h index bd120af04375fd958026360efde5a2002314179d..4abcbf058ded83e008594c181c93fc9683eb9f3b 100644 --- a/paddle/fluid/operators/collective/c_split_op.h +++ b/paddle/fluid/operators/collective/c_split_op.h @@ -28,10 +28,7 @@ namespace operators { template class CSplitOpCPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx UNUSED) const override { - PADDLE_THROW(platform::errors::Unavailable( - "Do not support c_split for cpu kernel now.")); - } + void Compute(const framework::ExecutionContext& ctx UNUSED) const override {} }; } // namespace operators diff --git a/paddle/fluid/operators/collective/c_split_op_xpu.cc b/paddle/fluid/operators/collective/c_split_op_xpu.cc deleted file mode 100644 index d573a83d708c4f79a3263bba714de4090f2574dc..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/collective/c_split_op_xpu.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/operators/collective/c_split_op.h" -#if defined(PADDLE_WITH_XPU) -#include "paddle/phi/backends/xpu/enforce_xpu.h" -#endif - -namespace paddle { -namespace operators { - -template -class CSplitOpXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using XPUType = typename XPUTypeTrait::Type; - auto x = ctx.Input("X"); - auto out = ctx.Output("Out"); - - int nranks = ctx.Attr("nranks"); - int rank = ctx.Attr("rank"); - - PADDLE_ENFORCE_GE(rank, - 0, - platform::errors::PreconditionNotMet( - "The value of rank (%d) for c_split must be " - "greater than or equal to 0.", - rank)); - PADDLE_ENFORCE_GE(nranks, - 2, - platform::errors::PreconditionNotMet( - "The value of nranks (%d) for c_split must be " - "greater than or equal to 2.", - nranks)); - PADDLE_ENFORCE_LT(rank, - nranks, - platform::errors::PreconditionNotMet( - "The value of rank (%d) for c_split must be " - "less than that of nranks (%d).", - rank, - nranks)); - - auto& dev_ctx = ctx.template device_context(); - auto dims = x->dims(); - auto dims_size = dims.size(); - // final dim - int64_t end_size = dims[dims_size - 1]; - - // remain dim - auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1); - int64_t remain_numel = phi::product(remain_ddim); - - dims[dims_size - 1] /= nranks; - out->Resize(dims); - dev_ctx.template Alloc(out, x->dtype()); - - std::vector output_list(nranks, nullptr); - output_list.at(rank) = reinterpret_cast(out->data()); - std::vector split_list(nranks, dims[dims_size - 1]); - int axis = 1; - - auto ret = xpu::split(dev_ctx.x_context(), - reinterpret_cast(x->data()), - output_list, - {remain_numel, end_size}, - split_list, - axis); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "split"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(c_split, - XPU, - ALL_LAYOUT, - ops::CSplitOpXPUKernel, - float, - int, - plat::float16) {} diff --git a/paddle/phi/kernels/c_split_kernel.h b/paddle/phi/kernels/c_split_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6ec945e123a8144a0d42dc0afbac42092c8feeb2 --- /dev/null +++ b/paddle/phi/kernels/c_split_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CSplitKernel(const Context& ctx, + const DenseTensor& x, + int rank, + int nranks, + int ring_id, + bool use_calc_stream, + bool use_model_parallel, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/c_split_kernel.cc b/paddle/phi/kernels/cpu/c_split_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c6d22ec8cf908ed046eff76205d6e50b8470968e --- /dev/null +++ b/paddle/phi/kernels/cpu/c_split_kernel.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_split_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CSplitKernel(const Context& ctx, + const DenseTensor& x, + int rank, + int nranks, + int ring_id, + bool use_calc_stream, + bool use_model_parallel, + DenseTensor* out) { + PADDLE_THROW( + phi::errors::Unavailable("Do not support c_split for cpu kernel now.")); +} +} // namespace phi + +PD_REGISTER_KERNEL(c_split, + CPU, + ALL_LAYOUT, + phi::CSplitKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/c_split_kernel.cu b/paddle/phi/kernels/gpu/c_split_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2fda7d3cf37f0dda03d26b2e14387b85475d8e85 --- /dev/null +++ b/paddle/phi/kernels/gpu/c_split_kernel.cu @@ -0,0 +1,127 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_split_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +static constexpr int64_t kNumCUDAThreads = 512; +static constexpr int64_t kNumMaxinumNumBlocks = 4096; + +static inline int64_t NumBlocks(const int64_t N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void SplitFromRank(const T* input, + T* output, + const int64_t rows, + const int64_t columns, + const int rank, + const int nranks, + const int64_t limit) { + CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) { + int64_t row = i / columns; + int64_t col = i % columns; + + int64_t block = columns / nranks; + int64_t start = block * rank; + int64_t end = start + block; + + if (col >= start && col < end) { + int64_t idx = block * row + col % block; + output[idx] = input[i]; + } + } +} + +template +void CSplitKernel(const Context& ctx, + const DenseTensor& x, + int rank, + int nranks, + int ring_id, + bool use_calc_stream, + bool use_model_parallel, + DenseTensor* out) { + auto place = ctx.GetPlace(); + + PADDLE_ENFORCE_GE(rank, + 0, + phi::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "greater than or equal to 0.", + rank)); + PADDLE_ENFORCE_GE(nranks, + 2, + phi::errors::PreconditionNotMet( + "The value of nranks (%d) for c_split must be " + "greater than or equal to 2.", + nranks)); + PADDLE_ENFORCE_LT(rank, + nranks, + phi::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "less than that of nranks (%d).", + rank, + nranks)); + + auto dims = x.dims(); + auto dims_size = dims.size(); + // final dim + int64_t end_size = dims[dims_size - 1]; + + // remain dim + auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1); + int64_t remain_numel = phi::product(remain_ddim); + + int64_t limit = x.numel(); + int64_t blocks = NumBlocks(limit); + int64_t threads = kNumCUDAThreads; + + dims[dims_size - 1] /= nranks; + out->Resize(dims); + ctx.template Alloc(out); + + SplitFromRank<<>>( + x.data(), out->data(), remain_numel, end_size, rank, nranks, limit); +} + +} // namespace phi + +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 +PD_REGISTER_KERNEL(c_split, + GPU, + ALL_LAYOUT, + phi::CSplitKernel, + float, + double, + int, + int64_t, + phi::dtype::bfloat16, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(c_split, + GPU, + ALL_LAYOUT, + phi::CSplitKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/xpu/c_split_kernel.cc b/paddle/phi/kernels/xpu/c_split_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f330323059e2b72a7213fad4383bf25134b73280 --- /dev/null +++ b/paddle/phi/kernels/xpu/c_split_kernel.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_split_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CSplitKernel(const Context& dev_ctx, + const DenseTensor& x, + int rank, + int nranks, + int ring_id, + bool use_calc_stream, + bool use_model_parallel, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + PADDLE_ENFORCE_GE(rank, + 0, + phi::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "greater than or equal to 0.", + rank)); + PADDLE_ENFORCE_GE(nranks, + 2, + phi::errors::PreconditionNotMet( + "The value of nranks (%d) for c_split must be " + "greater than or equal to 2.", + nranks)); + PADDLE_ENFORCE_LT(rank, + nranks, + phi::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "less than that of nranks (%d).", + rank, + nranks)); + + auto dims = x.dims(); + auto dims_size = dims.size(); + // final dim + int64_t end_size = dims[dims_size - 1]; + + // remain dim + auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1); + int64_t remain_numel = phi::product(remain_ddim); + + dims[dims_size - 1] /= nranks; + out->Resize(dims); + dev_ctx.template Alloc(out, x.dtype()); + + std::vector output_list(nranks, nullptr); + output_list.at(rank) = reinterpret_cast(out->data()); + std::vector split_list(nranks, dims[dims_size - 1]); + int axis = 1; + + auto ret = xpu::split(dev_ctx.x_context(), + reinterpret_cast(x.data()), + output_list, + {remain_numel, end_size}, + split_list, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "split"); +} +} // namespace phi + +PD_REGISTER_KERNEL(c_split, + XPU, + ALL_LAYOUT, + phi::CSplitKernel, + float, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/compat/c_split_sig.cc b/paddle/phi/ops/compat/c_split_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..53cd79c755ab583ede0887925f91d788926dfc22 --- /dev/null +++ b/paddle/phi/ops/compat/c_split_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CSplitOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature( + "c_split", + {"X"}, + {"rank", "nranks", "ring_id", "use_calc_stream", "use_model_parallel"}, + {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(c_split, phi::CSplitOpArgumentMapping);