未验证 提交 5dc7ff04 编写于 作者: R Ruibin Cheung 提交者: GitHub

[Fluid] NO.4 Migrate c_split to PHI (#56327)

上级 332a73b1
......@@ -120,13 +120,3 @@ REGISTER_OPERATOR(c_split,
ops::CSplitOpGradMaker<paddle::framework::OpDesc>,
ops::CSplitOpGradMaker<paddle::imperative::OpBase>,
ops::CSplitOpMaker);
PD_REGISTER_STRUCT_KERNEL(c_split,
CPU,
ALL_LAYOUT,
ops::CSplitOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/collective/c_split_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
static constexpr int64_t kNumCUDAThreads = 512;
static constexpr int64_t kNumMaxinumNumBlocks = 4096;
static inline int64_t NumBlocks(const int64_t N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void SplitFromRank(const T* input,
T* output,
const int64_t rows,
const int64_t columns,
const int rank,
const int nranks,
const int64_t limit) {
CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) {
int64_t row = i / columns;
int64_t col = i % columns;
int64_t block = columns / nranks;
int64_t start = block * rank;
int64_t end = start + block;
if (col >= start && col < end) {
int64_t idx = block * row + col % block;
output[idx] = input[i];
}
}
}
template <typename T, typename DeviceContext>
class CSplitOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
auto place = ctx.GetPlace();
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks,
2,
platform::errors::PreconditionNotMet(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank,
nranks,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d).",
rank,
nranks));
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
auto dims = x->dims();
auto dims_size = dims.size();
// final dim
int64_t end_size = dims[dims_size - 1];
// remain dim
auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
int64_t limit = x->numel();
int64_t blocks = NumBlocks(limit);
int64_t threads = kNumCUDAThreads;
dims[dims_size - 1] /= nranks;
out->mutable_data<T>(dims, place);
SplitFromRank<T><<<blocks, threads, 0, dev_ctx.stream()>>>(x->data<T>(),
out->data<T>(),
remain_numel,
end_size,
rank,
nranks,
limit);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_split,
GPU,
ALL_LAYOUT,
ops::CSplitOpCUDAKernel,
float,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
}
......@@ -28,10 +28,7 @@ namespace operators {
template <typename T, typename DeviceContext>
class CSplitOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx UNUSED) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support c_split for cpu kernel now."));
}
void Compute(const framework::ExecutionContext& ctx UNUSED) const override {}
};
} // namespace operators
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/collective/c_split_op.h"
#if defined(PADDLE_WITH_XPU)
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#endif
namespace paddle {
namespace operators {
template <typename T, typename DeviceContext>
class CSplitOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using XPUType = typename XPUTypeTrait<T>::Type;
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks,
2,
platform::errors::PreconditionNotMet(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank,
nranks,
platform::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d).",
rank,
nranks));
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
auto dims = x->dims();
auto dims_size = dims.size();
// final dim
int64_t end_size = dims[dims_size - 1];
// remain dim
auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
dims[dims_size - 1] /= nranks;
out->Resize(dims);
dev_ctx.template Alloc(out, x->dtype());
std::vector<XPUType*> output_list(nranks, nullptr);
output_list.at(rank) = reinterpret_cast<XPUType*>(out->data<T>());
std::vector<int64_t> split_list(nranks, dims[dims_size - 1]);
int axis = 1;
auto ret = xpu::split(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x->data<T>()),
output_list,
{remain_numel, end_size},
split_list,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "split");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_split,
XPU,
ALL_LAYOUT,
ops::CSplitOpXPUKernel,
float,
int,
plat::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CSplitKernel(const Context& ctx,
const DenseTensor& x,
int rank,
int nranks,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_split_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CSplitKernel(const Context& ctx,
const DenseTensor& x,
int rank,
int nranks,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
DenseTensor* out) {
PADDLE_THROW(
phi::errors::Unavailable("Do not support c_split for cpu kernel now."));
}
} // namespace phi
PD_REGISTER_KERNEL(c_split,
CPU,
ALL_LAYOUT,
phi::CSplitKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_split_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
static constexpr int64_t kNumCUDAThreads = 512;
static constexpr int64_t kNumMaxinumNumBlocks = 4096;
static inline int64_t NumBlocks(const int64_t N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void SplitFromRank(const T* input,
T* output,
const int64_t rows,
const int64_t columns,
const int rank,
const int nranks,
const int64_t limit) {
CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) {
int64_t row = i / columns;
int64_t col = i % columns;
int64_t block = columns / nranks;
int64_t start = block * rank;
int64_t end = start + block;
if (col >= start && col < end) {
int64_t idx = block * row + col % block;
output[idx] = input[i];
}
}
}
template <typename T, typename Context>
void CSplitKernel(const Context& ctx,
const DenseTensor& x,
int rank,
int nranks,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
DenseTensor* out) {
auto place = ctx.GetPlace();
PADDLE_ENFORCE_GE(rank,
0,
phi::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks,
2,
phi::errors::PreconditionNotMet(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank,
nranks,
phi::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d).",
rank,
nranks));
auto dims = x.dims();
auto dims_size = dims.size();
// final dim
int64_t end_size = dims[dims_size - 1];
// remain dim
auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
int64_t limit = x.numel();
int64_t blocks = NumBlocks(limit);
int64_t threads = kNumCUDAThreads;
dims[dims_size - 1] /= nranks;
out->Resize(dims);
ctx.template Alloc<T>(out);
SplitFromRank<T><<<blocks, threads, 0, ctx.stream()>>>(
x.data<T>(), out->data<T>(), remain_numel, end_size, rank, nranks, limit);
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
PD_REGISTER_KERNEL(c_split,
GPU,
ALL_LAYOUT,
phi::CSplitKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(c_split,
GPU,
ALL_LAYOUT,
phi::CSplitKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_split_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CSplitKernel(const Context& dev_ctx,
const DenseTensor& x,
int rank,
int nranks,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
PADDLE_ENFORCE_GE(rank,
0,
phi::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0.",
rank));
PADDLE_ENFORCE_GE(nranks,
2,
phi::errors::PreconditionNotMet(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2.",
nranks));
PADDLE_ENFORCE_LT(rank,
nranks,
phi::errors::PreconditionNotMet(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d).",
rank,
nranks));
auto dims = x.dims();
auto dims_size = dims.size();
// final dim
int64_t end_size = dims[dims_size - 1];
// remain dim
auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
dims[dims_size - 1] /= nranks;
out->Resize(dims);
dev_ctx.template Alloc(out, x.dtype());
std::vector<XPUType*> output_list(nranks, nullptr);
output_list.at(rank) = reinterpret_cast<XPUType*>(out->data<T>());
std::vector<int64_t> split_list(nranks, dims[dims_size - 1]);
int axis = 1;
auto ret = xpu::split(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
output_list,
{remain_numel, end_size},
split_list,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "split");
}
} // namespace phi
PD_REGISTER_KERNEL(c_split,
XPU,
ALL_LAYOUT,
phi::CSplitKernel,
float,
int,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CSplitOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"c_split",
{"X"},
{"rank", "nranks", "ring_id", "use_calc_stream", "use_model_parallel"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(c_split, phi::CSplitOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册