Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5dc7ff04
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5dc7ff04
编写于
8月 22, 2023
作者:
R
Ruibin Cheung
提交者:
GitHub
8月 22, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fluid] NO.4 Migrate c_split to PHI (#56327)
上级
332a73b1
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
319 addition
and
240 deletion
+319
-240
paddle/fluid/operators/collective/c_split_op.cc
paddle/fluid/operators/collective/c_split_op.cc
+0
-10
paddle/fluid/operators/collective/c_split_op.cu
paddle/fluid/operators/collective/c_split_op.cu
+0
-130
paddle/fluid/operators/collective/c_split_op.h
paddle/fluid/operators/collective/c_split_op.h
+1
-4
paddle/fluid/operators/collective/c_split_op_xpu.cc
paddle/fluid/operators/collective/c_split_op_xpu.cc
+0
-96
paddle/phi/kernels/c_split_kernel.h
paddle/phi/kernels/c_split_kernel.h
+31
-0
paddle/phi/kernels/cpu/c_split_kernel.cc
paddle/phi/kernels/cpu/c_split_kernel.cc
+43
-0
paddle/phi/kernels/gpu/c_split_kernel.cu
paddle/phi/kernels/gpu/c_split_kernel.cu
+127
-0
paddle/phi/kernels/xpu/c_split_kernel.cc
paddle/phi/kernels/xpu/c_split_kernel.cc
+87
-0
paddle/phi/ops/compat/c_split_sig.cc
paddle/phi/ops/compat/c_split_sig.cc
+30
-0
未找到文件。
paddle/fluid/operators/collective/c_split_op.cc
浏览文件 @
5dc7ff04
...
...
@@ -120,13 +120,3 @@ REGISTER_OPERATOR(c_split,
ops
::
CSplitOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
CSplitOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
CSplitOpMaker
);
PD_REGISTER_STRUCT_KERNEL
(
c_split
,
CPU
,
ALL_LAYOUT
,
ops
::
CSplitOpCPUKernel
,
float
,
double
,
int
,
int64_t
,
plat
::
float16
)
{}
paddle/fluid/operators/collective/c_split_op.cu
已删除
100644 → 0
浏览文件 @
332a73b1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/collective/c_split_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace
paddle
{
namespace
operators
{
static
constexpr
int64_t
kNumCUDAThreads
=
512
;
static
constexpr
int64_t
kNumMaxinumNumBlocks
=
4096
;
static
inline
int64_t
NumBlocks
(
const
int64_t
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
template
<
typename
T
>
__global__
void
SplitFromRank
(
const
T
*
input
,
T
*
output
,
const
int64_t
rows
,
const
int64_t
columns
,
const
int
rank
,
const
int
nranks
,
const
int64_t
limit
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
limit
,
int64_t
)
{
int64_t
row
=
i
/
columns
;
int64_t
col
=
i
%
columns
;
int64_t
block
=
columns
/
nranks
;
int64_t
start
=
block
*
rank
;
int64_t
end
=
start
+
block
;
if
(
col
>=
start
&&
col
<
end
)
{
int64_t
idx
=
block
*
row
+
col
%
block
;
output
[
idx
]
=
input
[
i
];
}
}
}
template
<
typename
T
,
typename
DeviceContext
>
class
CSplitOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
rank
=
ctx
.
Attr
<
int
>
(
"rank"
);
auto
place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE_GE
(
rank
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0."
,
rank
));
PADDLE_ENFORCE_GE
(
nranks
,
2
,
platform
::
errors
::
PreconditionNotMet
(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2."
,
nranks
));
PADDLE_ENFORCE_LT
(
rank
,
nranks
,
platform
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d)."
,
rank
,
nranks
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
phi
::
GPUContext
>();
auto
dims
=
x
->
dims
();
auto
dims_size
=
dims
.
size
();
// final dim
int64_t
end_size
=
dims
[
dims_size
-
1
];
// remain dim
auto
remain_ddim
=
phi
::
slice_ddim
(
dims
,
0
,
dims_size
-
1
);
int64_t
remain_numel
=
phi
::
product
(
remain_ddim
);
int64_t
limit
=
x
->
numel
();
int64_t
blocks
=
NumBlocks
(
limit
);
int64_t
threads
=
kNumCUDAThreads
;
dims
[
dims_size
-
1
]
/=
nranks
;
out
->
mutable_data
<
T
>
(
dims
,
place
);
SplitFromRank
<
T
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
x
->
data
<
T
>
(),
out
->
data
<
T
>
(),
remain_numel
,
end_size
,
rank
,
nranks
,
limit
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
PD_REGISTER_STRUCT_KERNEL
(
c_split
,
GPU
,
ALL_LAYOUT
,
ops
::
CSplitOpCUDAKernel
,
float
,
double
,
int
,
int64_t
,
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat
::
bfloat16
,
#endif
plat
::
float16
)
{
}
paddle/fluid/operators/collective/c_split_op.h
浏览文件 @
5dc7ff04
...
...
@@ -28,10 +28,7 @@ namespace operators {
template
<
typename
T
,
typename
DeviceContext
>
class
CSplitOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
UNUSED
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Do not support c_split for cpu kernel now."
));
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
UNUSED
)
const
override
{}
};
}
// namespace operators
...
...
paddle/fluid/operators/collective/c_split_op_xpu.cc
已删除
100644 → 0
浏览文件 @
332a73b1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/collective/c_split_op.h"
#if defined(PADDLE_WITH_XPU)
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
DeviceContext
>
class
CSplitOpXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
x
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
rank
=
ctx
.
Attr
<
int
>
(
"rank"
);
PADDLE_ENFORCE_GE
(
rank
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0."
,
rank
));
PADDLE_ENFORCE_GE
(
nranks
,
2
,
platform
::
errors
::
PreconditionNotMet
(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2."
,
nranks
));
PADDLE_ENFORCE_LT
(
rank
,
nranks
,
platform
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d)."
,
rank
,
nranks
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
phi
::
XPUContext
>();
auto
dims
=
x
->
dims
();
auto
dims_size
=
dims
.
size
();
// final dim
int64_t
end_size
=
dims
[
dims_size
-
1
];
// remain dim
auto
remain_ddim
=
phi
::
slice_ddim
(
dims
,
0
,
dims_size
-
1
);
int64_t
remain_numel
=
phi
::
product
(
remain_ddim
);
dims
[
dims_size
-
1
]
/=
nranks
;
out
->
Resize
(
dims
);
dev_ctx
.
template
Alloc
(
out
,
x
-
>
dtype
());
std
::
vector
<
XPUType
*>
output_list
(
nranks
,
nullptr
);
output_list
.
at
(
rank
)
=
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
());
std
::
vector
<
int64_t
>
split_list
(
nranks
,
dims
[
dims_size
-
1
]);
int
axis
=
1
;
auto
ret
=
xpu
::
split
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
output_list
,
{
remain_numel
,
end_size
},
split_list
,
axis
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
ret
,
"split"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
PD_REGISTER_STRUCT_KERNEL
(
c_split
,
XPU
,
ALL_LAYOUT
,
ops
::
CSplitOpXPUKernel
,
float
,
int
,
plat
::
float16
)
{}
paddle/phi/kernels/c_split_kernel.h
0 → 100644
浏览文件 @
5dc7ff04
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
CSplitKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
int
rank
,
int
nranks
,
int
ring_id
,
bool
use_calc_stream
,
bool
use_model_parallel
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/c_split_kernel.cc
0 → 100644
浏览文件 @
5dc7ff04
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_split_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
CSplitKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
int
rank
,
int
nranks
,
int
ring_id
,
bool
use_calc_stream
,
bool
use_model_parallel
,
DenseTensor
*
out
)
{
PADDLE_THROW
(
phi
::
errors
::
Unavailable
(
"Do not support c_split for cpu kernel now."
));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
c_split
,
CPU
,
ALL_LAYOUT
,
phi
::
CSplitKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/c_split_kernel.cu
0 → 100644
浏览文件 @
5dc7ff04
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_split_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
static
constexpr
int64_t
kNumCUDAThreads
=
512
;
static
constexpr
int64_t
kNumMaxinumNumBlocks
=
4096
;
static
inline
int64_t
NumBlocks
(
const
int64_t
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
template
<
typename
T
>
__global__
void
SplitFromRank
(
const
T
*
input
,
T
*
output
,
const
int64_t
rows
,
const
int64_t
columns
,
const
int
rank
,
const
int
nranks
,
const
int64_t
limit
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
limit
,
int64_t
)
{
int64_t
row
=
i
/
columns
;
int64_t
col
=
i
%
columns
;
int64_t
block
=
columns
/
nranks
;
int64_t
start
=
block
*
rank
;
int64_t
end
=
start
+
block
;
if
(
col
>=
start
&&
col
<
end
)
{
int64_t
idx
=
block
*
row
+
col
%
block
;
output
[
idx
]
=
input
[
i
];
}
}
}
template
<
typename
T
,
typename
Context
>
void
CSplitKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
int
rank
,
int
nranks
,
int
ring_id
,
bool
use_calc_stream
,
bool
use_model_parallel
,
DenseTensor
*
out
)
{
auto
place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE_GE
(
rank
,
0
,
phi
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0."
,
rank
));
PADDLE_ENFORCE_GE
(
nranks
,
2
,
phi
::
errors
::
PreconditionNotMet
(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2."
,
nranks
));
PADDLE_ENFORCE_LT
(
rank
,
nranks
,
phi
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d)."
,
rank
,
nranks
));
auto
dims
=
x
.
dims
();
auto
dims_size
=
dims
.
size
();
// final dim
int64_t
end_size
=
dims
[
dims_size
-
1
];
// remain dim
auto
remain_ddim
=
phi
::
slice_ddim
(
dims
,
0
,
dims_size
-
1
);
int64_t
remain_numel
=
phi
::
product
(
remain_ddim
);
int64_t
limit
=
x
.
numel
();
int64_t
blocks
=
NumBlocks
(
limit
);
int64_t
threads
=
kNumCUDAThreads
;
dims
[
dims_size
-
1
]
/=
nranks
;
out
->
Resize
(
dims
);
ctx
.
template
Alloc
<
T
>(
out
);
SplitFromRank
<
T
><<<
blocks
,
threads
,
0
,
ctx
.
stream
()
>>>
(
x
.
data
<
T
>
(),
out
->
data
<
T
>
(),
remain_numel
,
end_size
,
rank
,
nranks
,
limit
);
}
}
// namespace phi
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
PD_REGISTER_KERNEL
(
c_split
,
GPU
,
ALL_LAYOUT
,
phi
::
CSplitKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
)
{}
#else
PD_REGISTER_KERNEL
(
c_split
,
GPU
,
ALL_LAYOUT
,
phi
::
CSplitKernel
,
float
,
double
,
int
,
int64_t
,
phi
::
dtype
::
float16
)
{}
#endif
paddle/phi/kernels/xpu/c_split_kernel.cc
0 → 100644
浏览文件 @
5dc7ff04
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_split_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
CSplitKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
rank
,
int
nranks
,
int
ring_id
,
bool
use_calc_stream
,
bool
use_model_parallel
,
DenseTensor
*
out
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
PADDLE_ENFORCE_GE
(
rank
,
0
,
phi
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"greater than or equal to 0."
,
rank
));
PADDLE_ENFORCE_GE
(
nranks
,
2
,
phi
::
errors
::
PreconditionNotMet
(
"The value of nranks (%d) for c_split must be "
"greater than or equal to 2."
,
nranks
));
PADDLE_ENFORCE_LT
(
rank
,
nranks
,
phi
::
errors
::
PreconditionNotMet
(
"The value of rank (%d) for c_split must be "
"less than that of nranks (%d)."
,
rank
,
nranks
));
auto
dims
=
x
.
dims
();
auto
dims_size
=
dims
.
size
();
// final dim
int64_t
end_size
=
dims
[
dims_size
-
1
];
// remain dim
auto
remain_ddim
=
phi
::
slice_ddim
(
dims
,
0
,
dims_size
-
1
);
int64_t
remain_numel
=
phi
::
product
(
remain_ddim
);
dims
[
dims_size
-
1
]
/=
nranks
;
out
->
Resize
(
dims
);
dev_ctx
.
template
Alloc
(
out
,
x
.
dtype
());
std
::
vector
<
XPUType
*>
output_list
(
nranks
,
nullptr
);
output_list
.
at
(
rank
)
=
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
());
std
::
vector
<
int64_t
>
split_list
(
nranks
,
dims
[
dims_size
-
1
]);
int
axis
=
1
;
auto
ret
=
xpu
::
split
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
output_list
,
{
remain_numel
,
end_size
},
split_list
,
axis
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
ret
,
"split"
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
c_split
,
XPU
,
ALL_LAYOUT
,
phi
::
CSplitKernel
,
float
,
int
,
phi
::
dtype
::
float16
)
{}
paddle/phi/ops/compat/c_split_sig.cc
0 → 100644
浏览文件 @
5dc7ff04
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
CSplitOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"c_split"
,
{
"X"
},
{
"rank"
,
"nranks"
,
"ring_id"
,
"use_calc_stream"
,
"use_model_parallel"
},
{
"Out"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
c_split
,
phi
::
CSplitOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录