Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0bc369ef
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0bc369ef
编写于
8月 31, 2023
作者:
Z
Zero Rains
提交者:
GitHub
8月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fluid] Move distributed_fused_lamb_init to phi (#55993)
上级
e358ddac
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
997 addition
and
25 deletion
+997
-25
paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc
...id/operators/optimizers/distributed_fused_lamb_init_op.cc
+2
-7
paddle/phi/kernels/distributed_fused_lamb_init_kernel.h
paddle/phi/kernels/distributed_fused_lamb_init_kernel.h
+52
-0
paddle/phi/kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc
.../kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc
+80
-0
paddle/phi/kernels/fusion/gpu/cast_with_ptr.h
paddle/phi/kernels/fusion/gpu/cast_with_ptr.h
+11
-18
paddle/phi/kernels/fusion/gpu/distributed_fused_lamb_init_kernel.cu
.../kernels/fusion/gpu/distributed_fused_lamb_init_kernel.cu
+804
-0
paddle/phi/ops/compat/distributed_fused_lamb_init_sig.cc
paddle/phi/ops/compat/distributed_fused_lamb_init_sig.cc
+48
-0
未找到文件。
paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc
浏览文件 @
0bc369ef
...
...
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -116,9 +117,3 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT
(
distributed_fused_lamb_init
,
ops
::
DistributedFusedLambInitOp
,
ops
::
DistributedFusedLambInitOpMaker
);
PD_REGISTER_STRUCT_KERNEL
(
distributed_fused_lamb_init
,
CPU
,
ALL_LAYOUT
,
ops
::
DistributedFusedLambInitOpKernel
,
float
)
{}
paddle/phi/kernels/distributed_fused_lamb_init_kernel.h
0 → 100644
浏览文件 @
0bc369ef
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DistributedFusedLambInitOpKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
float
beta1
,
float
beta2
,
const
std
::
vector
<
int
>&
apply_weight_decay
,
int
alignment
,
int
rank
,
int
nranks
,
DenseTensor
*
fp32_fused_param
,
DenseTensor
*
fp32_fused_grad
,
DenseTensor
*
fp16_fused_param
,
DenseTensor
*
fp16_fused_grad
,
DenseTensor
*
moment1
,
DenseTensor
*
moment2
,
DenseTensor
*
beta1_pow
,
DenseTensor
*
beta2_pow
,
DenseTensor
*
fused_param_offsets
,
DenseTensor
*
fp32_shard_fused_param_offsets
,
DenseTensor
*
fp16_shard_fused_param_offsets
,
DenseTensor
*
param_info
,
DenseTensor
*
param_order
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
,
std
::
vector
<
DenseTensor
*>
grad_out
,
DenseTensor
*
global_scale
,
DenseTensor
*
step
);
}
// namespace phi
paddle/phi/kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc
0 → 100644
浏览文件 @
0bc369ef
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/distributed_fused_lamb_init_kernel.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
DistributedFusedLambInitOpKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
float
beta1
,
float
beta2
,
const
std
::
vector
<
int
>&
apply_weight_decay
,
int
alignment
,
int
rank
,
int
nranks
,
DenseTensor
*
fp32_fused_param
,
DenseTensor
*
fp32_fused_grad
,
DenseTensor
*
fp16_fused_param
,
DenseTensor
*
fp16_fused_grad
,
DenseTensor
*
moment1
,
DenseTensor
*
moment2
,
DenseTensor
*
beta1_pow
,
DenseTensor
*
beta2_pow
,
DenseTensor
*
fused_param_offsets
,
DenseTensor
*
fp32_shard_fused_param_offsets
,
DenseTensor
*
fp16_shard_fused_param_offsets
,
DenseTensor
*
param_info
,
DenseTensor
*
param_order
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
,
std
::
vector
<
DenseTensor
*>
grad_out
,
DenseTensor
*
global_scale
,
DenseTensor
*
step
)
{
PADDLE_THROW
(
phi
::
errors
::
Unavailable
(
"Do not support expert count op for cpu kernel now."
));
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
distributed_fused_lamb_init
,
CPU
,
ALL_LAYOUT
,
phi
::
fusion
::
DistributedFusedLambInitOpKernel
,
float
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
2
).
SetDataType
(
phi
::
DataType
::
FLOAT16
);
kernel
->
OutputAt
(
3
).
SetDataType
(
phi
::
DataType
::
FLOAT16
);
kernel
->
OutputAt
(
4
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
5
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
6
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
7
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
8
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
9
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
10
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
11
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
12
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
13
).
SetDataType
(
kernel_key
.
dtype
());
kernel
->
OutputAt
(
14
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
15
).
SetDataType
(
kernel_key
.
dtype
());
kernel
->
OutputAt
(
16
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
17
).
SetDataType
(
phi
::
DataType
::
INT64
);
}
paddle/
fluid/operators/optimizers
/cast_with_ptr.h
→
paddle/
phi/kernels/fusion/gpu
/cast_with_ptr.h
浏览文件 @
0bc369ef
...
...
@@ -14,28 +14,24 @@
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace
paddle
{
namespace
operators
{
namespace
details
{
namespace
phi
{
template
<
typename
InT
,
typename
OutT
>
struct
CastFunctor
{
HOSTDEVICE
OutT
operator
()(
InT
x
)
const
{
return
static_cast
<
OutT
>
(
x
);
}
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
>
static
void
VecCastKernel
(
const
phi
::
GPUContext
&
ctx
,
const
InT
*
x
,
OutT
*
y
,
size_t
n
)
{
auto
config
=
p
latform
::
GetGpuLaunchConfig1D
(
ctx
,
n
,
VecSize
);
auto
config
=
p
hi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
n
,
VecSize
);
auto
block
=
config
.
GetGridSize
();
auto
thread
=
config
.
GetBlockSize
();
auto
main_offset
=
n
/
(
VecSize
*
thread
)
*
VecSize
*
thread
;
...
...
@@ -50,8 +46,6 @@ static void VecCastKernel(const phi::GPUContext &ctx,
in_arr
,
out_arr
,
n
,
main_offset
,
VecSize
,
FunctorT
());
}
}
// namespace details
template
<
typename
InT
,
typename
OutT
>
static
void
LaunchCastKernel
(
const
phi
::
GPUContext
&
ctx
,
const
InT
*
x
,
...
...
@@ -61,20 +55,19 @@ static void LaunchCastKernel(const phi::GPUContext &ctx,
PADDLE_ENFORCE_NE
(
static_cast
<
const
void
*>
(
x
),
static_cast
<
void
*>
(
y
),
platform
::
errors
::
InvalidArgument
(
"Inplace cast is not supported yet."
));
errors
::
InvalidArgument
(
"Inplace cast is not supported yet."
));
int
vec_size
=
std
::
min
(
phi
::
GetVectorizedSize
(
x
),
phi
::
GetVectorizedSize
(
y
));
switch
(
vec_size
)
{
case
4
:
return
details
::
VecCastKernel
<
InT
,
OutT
,
4
>
(
ctx
,
x
,
y
,
n
);
return
VecCastKernel
<
InT
,
OutT
,
4
>
(
ctx
,
x
,
y
,
n
);
case
2
:
return
details
::
VecCastKernel
<
InT
,
OutT
,
2
>
(
ctx
,
x
,
y
,
n
);
return
VecCastKernel
<
InT
,
OutT
,
2
>
(
ctx
,
x
,
y
,
n
);
case
1
:
return
details
::
VecCastKernel
<
InT
,
OutT
,
1
>
(
ctx
,
x
,
y
,
n
);
return
VecCastKernel
<
InT
,
OutT
,
1
>
(
ctx
,
x
,
y
,
n
);
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The vectorized size must be 1, 2 or 4."
));
PADDLE_THROW
(
errors
::
InvalidArgument
(
"The vectorized size must be 1, 2 or 4."
));
}
}
}
// namespace operators
}
// namespace paddle
}
// namespace phi
paddle/
fluid/operators/optimizers/distributed_fused_lamb_init_op
.cu
→
paddle/
phi/kernels/fusion/gpu/distributed_fused_lamb_init_kernel
.cu
浏览文件 @
0bc369ef
此差异已折叠。
点击以展开。
paddle/
fluid/operators/optimizers/distributed_fused_lamb_init_op.h
→
paddle/
phi/ops/compat/distributed_fused_lamb_init_sig.cc
浏览文件 @
0bc369ef
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,22 +12,37 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#
pragma once
#
include "paddle/phi/core/compat/op_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
phi
{
namespace
paddle
{
namespace
operators
{
KernelSignature
DistributedFusedLambInitOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"distributed_fused_lamb_init"
,
{
"Param"
,
"Grad"
},
{
"beta1"
,
"beta2"
,
"apply_weight_decay"
,
"alignment"
,
"rank"
,
"nranks"
},
{
"FP32FusedParam"
,
"FP32FusedGrad"
,
"FP16FusedParam"
,
"FP16FusedGrad"
,
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
,
"FusedParamOffsets"
,
"FP32ShardFusedParamOffsets"
,
"FP16ShardFusedParamOffsets"
,
"ParamInfo"
,
"ParamOrder"
,
"ParamOut"
,
"MasterParamOut"
,
"GradOut"
,
"GlobalScale"
,
"Step"
});
}
template
<
typename
T
,
typename
DevCtx
>
class
DistributedFusedLambInitOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"The distributed_fused_lamb_init operator does not support CPU yet."
));
}
};
}
// namespace phi
}
// namespace operators
}
// namespace paddle
PD_REGISTER_ARG_MAPPING_FN
(
distributed_fused_lamb_init
,
phi
::
DistributedFusedLambInitOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录