Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
230c6ce1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
230c6ce1
编写于
8月 03, 2023
作者:
Y
yangguohao
提交者:
GitHub
8月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
FLUID: move limit_by_capacity to PHI (#55948)
上级
81ccd99e
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
171 addition
and
8 deletion
+171
-8
paddle/fluid/operators/limit_by_capacity_op.cc
paddle/fluid/operators/limit_by_capacity_op.cc
+7
-8
paddle/phi/kernels/cpu/limit_by_capacity_kernel.cc
paddle/phi/kernels/cpu/limit_by_capacity_kernel.cc
+41
-0
paddle/phi/kernels/gpu/limit_by_capacity_kernel.cu
paddle/phi/kernels/gpu/limit_by_capacity_kernel.cu
+67
-0
paddle/phi/kernels/limit_by_capacity_kernel.h
paddle/phi/kernels/limit_by_capacity_kernel.h
+28
-0
paddle/phi/ops/compat/limit_by_capacity_sig.cc
paddle/phi/ops/compat/limit_by_capacity_sig.cc
+28
-0
未找到文件。
paddle/fluid/operators/limit_by_capacity_op.cc
浏览文件 @
230c6ce1
...
@@ -12,7 +12,13 @@
...
@@ -12,7 +12,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/limit_by_capacity_op.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -80,10 +86,3 @@ namespace plat = paddle::platform;
...
@@ -80,10 +86,3 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT
(
limit_by_capacity
,
REGISTER_OP_WITHOUT_GRADIENT
(
limit_by_capacity
,
ops
::
LimitByCapacityOp
,
ops
::
LimitByCapacityOp
,
ops
::
LimitByCapacityOpMaker
);
ops
::
LimitByCapacityOpMaker
);
PD_REGISTER_STRUCT_KERNEL
(
limit_by_capacity
,
CPU
,
ALL_LAYOUT
,
ops
::
LimitByCapacityOpCPUKernel
,
int
,
int64_t
)
{}
paddle/phi/kernels/cpu/limit_by_capacity_kernel.cc
0 → 100644
浏览文件 @
230c6ce1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/limit_by_capacity_kernel.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
LimitByCapacityKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
expert_count
,
const
DenseTensor
&
capacity
,
int
n_worker
,
DenseTensor
*
Out
)
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"limit_by_capacity is not supported on CPU."
));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
limit_by_capacity
,
CPU
,
ALL_LAYOUT
,
phi
::
LimitByCapacityKernel
,
int
,
int64_t
)
{}
paddle/
fluid/operators/limit_by_capacity_op
.cu
→
paddle/
phi/kernels/gpu/limit_by_capacity_kernel
.cu
浏览文件 @
230c6ce1
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -11,22 +11,14 @@
...
@@ -11,22 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//
// The file has been adapted from the two files:
// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cu
// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cuh
// Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4
// We retain the following license from the original files:
// Copyright 2021, Jiaao He. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License").
#include "paddle/fluid/operators/limit_by_capacity_op.h"
#include "paddle/phi/kernels/limit_by_capacity_kernel.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
namespace
paddle
{
namespace
phi
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
__global__
void
limit_by_capacity_impl
(
__global__
void
limit_by_capacity_impl
(
...
@@ -47,39 +39,29 @@ __global__ void limit_by_capacity_impl(
...
@@ -47,39 +39,29 @@ __global__ void limit_by_capacity_impl(
}
}
}
}
template
<
typename
T
,
typename
DeviceContext
>
template
<
typename
T
,
typename
Context
>
class
LimitByCapacityOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
void
LimitByCapacityKernel
(
const
Context
&
dev_ctx
,
public:
const
DenseTensor
&
expert_count
,
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
DenseTensor
&
capacity
,
auto
expert_count
=
context
.
Input
<
phi
::
DenseTensor
>
(
"expert_count"
);
int
n_worker
,
auto
capacity
=
context
.
Input
<
phi
::
DenseTensor
>
(
"capacity"
);
DenseTensor
*
Out
)
{
auto
n_worker
=
context
.
Attr
<
int
>
(
"n_worker"
);
auto
expert_count_ptr
=
&
expert_count
;
auto
out
=
context
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
auto
n_expert
=
expert_count_ptr
->
numel
()
/
n_worker
;
auto
n_expert
=
expert_count
->
numel
()
/
n_worker
;
const
auto
place
=
context
.
GetPlace
();
const
auto
&
dev_ctx
=
context
.
template
device_context
<
phi
::
GPUContext
>();
dim3
grid_dim
(
256
);
dim3
grid_dim
(
256
);
dim3
block_dim
(
1024
);
dim3
block_dim
(
1024
);
auto
out_data
=
out
->
mutable_data
<
T
>
(
place
);
auto
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
Out
);
const
T
*
ec_data
=
expert_count
->
data
<
T
>
();
const
T
*
ec_data
=
expert_count_ptr
->
data
<
T
>
();
phi
::
DenseTensor
capacity_copy
;
phi
::
DenseTensor
capacity_copy
;
framework
::
TensorCopy
(
*
capacity
,
place
,
dev_ctx
,
&
capacity_copy
);
phi
::
Copy
(
dev_ctx
,
capacity
,
dev_ctx
.
GetPlace
(),
false
,
&
capacity_copy
);
T
*
cap_data
=
capacity_copy
.
mutable_data
<
T
>
(
place
);
T
*
cap_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
capacity_copy
);
limit_by_capacity_impl
<
T
><<<
grid_dim
,
block_dim
,
0
,
dev_ctx
.
stream
()
>>>
(
limit_by_capacity_impl
<
T
><<<
grid_dim
,
block_dim
,
0
,
dev_ctx
.
stream
()
>>>
(
ec_data
,
cap_data
,
out_data
,
n_expert
,
n_worker
);
ec_data
,
cap_data
,
out_data
,
n_expert
,
n_worker
);
}
}
};
}
// namespace operators
}
// namespace phi
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
PD_REGISTER_KERNEL
(
PD_REGISTER_STRUCT_KERNEL
(
limit_by_capacity
,
limit_by_capacity
,
GPU
,
ALL_LAYOUT
,
phi
::
LimitByCapacityKernel
,
int64_t
)
{}
GPU
,
ALL_LAYOUT
,
ops
::
LimitByCapacityOpCUDAKernel
,
int64_t
)
{}
paddle/
fluid/operators/limit_by_capacity_op
.h
→
paddle/
phi/kernels/limit_by_capacity_kernel
.h
浏览文件 @
230c6ce1
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -13,25 +13,16 @@
...
@@ -13,25 +13,16 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace
paddle
{
namespace
phi
{
namespace
operators
{
template
<
typename
T
,
typename
DeviceContext
>
template
<
typename
T
,
typename
Context
>
class
LimitByCapacityOpCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
void
LimitByCapacityKernel
(
const
Context
&
dev_ctx
,
public:
const
DenseTensor
&
expert_count
,
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
DenseTensor
&
capacity
,
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
int
n_worker
,
"Do not support limit by capacity op for cpu kernel now."
));
DenseTensor
*
Out
);
}
};
}
// namespace operators
}
// namespace phi
}
// namespace paddle
paddle/phi/ops/compat/limit_by_capacity_sig.cc
0 → 100644
浏览文件 @
230c6ce1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
LimitByCapacityOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"limit_by_capacity"
,
{
"expert_count"
,
"capacity"
},
{
"n_worker"
},
{
"Out"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
limit_by_capacity
,
phi
::
LimitByCapacityOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录