Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0bc369ef
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0bc369ef
编写于
8月 31, 2023
作者:
Z
Zero Rains
提交者:
GitHub
8月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fluid] Move distributed_fused_lamb_init to phi (#55993)
上级
e358ddac
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
997 addition
and
25 deletion
+997
-25
paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc
...id/operators/optimizers/distributed_fused_lamb_init_op.cc
+2
-7
paddle/phi/kernels/distributed_fused_lamb_init_kernel.h
paddle/phi/kernels/distributed_fused_lamb_init_kernel.h
+52
-0
paddle/phi/kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc
.../kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc
+80
-0
paddle/phi/kernels/fusion/gpu/cast_with_ptr.h
paddle/phi/kernels/fusion/gpu/cast_with_ptr.h
+11
-18
paddle/phi/kernels/fusion/gpu/distributed_fused_lamb_init_kernel.cu
.../kernels/fusion/gpu/distributed_fused_lamb_init_kernel.cu
+804
-0
paddle/phi/ops/compat/distributed_fused_lamb_init_sig.cc
paddle/phi/ops/compat/distributed_fused_lamb_init_sig.cc
+48
-0
未找到文件。
paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc
浏览文件 @
0bc369ef
...
@@ -12,7 +12,8 @@
...
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -116,9 +117,3 @@ namespace ops = paddle::operators;
...
@@ -116,9 +117,3 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT
(
distributed_fused_lamb_init
,
REGISTER_OP_WITHOUT_GRADIENT
(
distributed_fused_lamb_init
,
ops
::
DistributedFusedLambInitOp
,
ops
::
DistributedFusedLambInitOp
,
ops
::
DistributedFusedLambInitOpMaker
);
ops
::
DistributedFusedLambInitOpMaker
);
PD_REGISTER_STRUCT_KERNEL
(
distributed_fused_lamb_init
,
CPU
,
ALL_LAYOUT
,
ops
::
DistributedFusedLambInitOpKernel
,
float
)
{}
paddle/phi/kernels/distributed_fused_lamb_init_kernel.h
0 → 100644
浏览文件 @
0bc369ef
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DistributedFusedLambInitOpKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
float
beta1
,
float
beta2
,
const
std
::
vector
<
int
>&
apply_weight_decay
,
int
alignment
,
int
rank
,
int
nranks
,
DenseTensor
*
fp32_fused_param
,
DenseTensor
*
fp32_fused_grad
,
DenseTensor
*
fp16_fused_param
,
DenseTensor
*
fp16_fused_grad
,
DenseTensor
*
moment1
,
DenseTensor
*
moment2
,
DenseTensor
*
beta1_pow
,
DenseTensor
*
beta2_pow
,
DenseTensor
*
fused_param_offsets
,
DenseTensor
*
fp32_shard_fused_param_offsets
,
DenseTensor
*
fp16_shard_fused_param_offsets
,
DenseTensor
*
param_info
,
DenseTensor
*
param_order
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
,
std
::
vector
<
DenseTensor
*>
grad_out
,
DenseTensor
*
global_scale
,
DenseTensor
*
step
);
}
// namespace phi
paddle/phi/kernels/fusion/cpu/distributed_fused_lamb_init_kernel.cc
0 → 100644
浏览文件 @
0bc369ef
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/distributed_fused_lamb_init_kernel.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
DistributedFusedLambInitOpKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
float
beta1
,
float
beta2
,
const
std
::
vector
<
int
>&
apply_weight_decay
,
int
alignment
,
int
rank
,
int
nranks
,
DenseTensor
*
fp32_fused_param
,
DenseTensor
*
fp32_fused_grad
,
DenseTensor
*
fp16_fused_param
,
DenseTensor
*
fp16_fused_grad
,
DenseTensor
*
moment1
,
DenseTensor
*
moment2
,
DenseTensor
*
beta1_pow
,
DenseTensor
*
beta2_pow
,
DenseTensor
*
fused_param_offsets
,
DenseTensor
*
fp32_shard_fused_param_offsets
,
DenseTensor
*
fp16_shard_fused_param_offsets
,
DenseTensor
*
param_info
,
DenseTensor
*
param_order
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
,
std
::
vector
<
DenseTensor
*>
grad_out
,
DenseTensor
*
global_scale
,
DenseTensor
*
step
)
{
PADDLE_THROW
(
phi
::
errors
::
Unavailable
(
"Do not support expert count op for cpu kernel now."
));
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
distributed_fused_lamb_init
,
CPU
,
ALL_LAYOUT
,
phi
::
fusion
::
DistributedFusedLambInitOpKernel
,
float
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
2
).
SetDataType
(
phi
::
DataType
::
FLOAT16
);
kernel
->
OutputAt
(
3
).
SetDataType
(
phi
::
DataType
::
FLOAT16
);
kernel
->
OutputAt
(
4
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
5
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
6
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
7
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
8
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
9
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
10
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
11
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
12
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
13
).
SetDataType
(
kernel_key
.
dtype
());
kernel
->
OutputAt
(
14
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
15
).
SetDataType
(
kernel_key
.
dtype
());
kernel
->
OutputAt
(
16
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
17
).
SetDataType
(
phi
::
DataType
::
INT64
);
}
paddle/
fluid/operators/optimizers
/cast_with_ptr.h
→
paddle/
phi/kernels/fusion/gpu
/cast_with_ptr.h
浏览文件 @
0bc369ef
...
@@ -14,28 +14,24 @@
...
@@ -14,28 +14,24 @@
#pragma once
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace
paddle
{
namespace
phi
{
namespace
operators
{
namespace
details
{
template
<
typename
InT
,
typename
OutT
>
template
<
typename
InT
,
typename
OutT
>
struct
CastFunctor
{
struct
CastFunctor
{
HOSTDEVICE
OutT
operator
()(
InT
x
)
const
{
return
static_cast
<
OutT
>
(
x
);
}
HOSTDEVICE
OutT
operator
()(
InT
x
)
const
{
return
static_cast
<
OutT
>
(
x
);
}
};
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
>
template
<
typename
InT
,
typename
OutT
,
int
VecSize
>
static
void
VecCastKernel
(
const
phi
::
GPUContext
&
ctx
,
static
void
VecCastKernel
(
const
phi
::
GPUContext
&
ctx
,
const
InT
*
x
,
const
InT
*
x
,
OutT
*
y
,
OutT
*
y
,
size_t
n
)
{
size_t
n
)
{
auto
config
=
p
latform
::
GetGpuLaunchConfig1D
(
ctx
,
n
,
VecSize
);
auto
config
=
p
hi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
n
,
VecSize
);
auto
block
=
config
.
GetGridSize
();
auto
block
=
config
.
GetGridSize
();
auto
thread
=
config
.
GetBlockSize
();
auto
thread
=
config
.
GetBlockSize
();
auto
main_offset
=
n
/
(
VecSize
*
thread
)
*
VecSize
*
thread
;
auto
main_offset
=
n
/
(
VecSize
*
thread
)
*
VecSize
*
thread
;
...
@@ -50,8 +46,6 @@ static void VecCastKernel(const phi::GPUContext &ctx,
...
@@ -50,8 +46,6 @@ static void VecCastKernel(const phi::GPUContext &ctx,
in_arr
,
out_arr
,
n
,
main_offset
,
VecSize
,
FunctorT
());
in_arr
,
out_arr
,
n
,
main_offset
,
VecSize
,
FunctorT
());
}
}
}
// namespace details
template
<
typename
InT
,
typename
OutT
>
template
<
typename
InT
,
typename
OutT
>
static
void
LaunchCastKernel
(
const
phi
::
GPUContext
&
ctx
,
static
void
LaunchCastKernel
(
const
phi
::
GPUContext
&
ctx
,
const
InT
*
x
,
const
InT
*
x
,
...
@@ -61,20 +55,19 @@ static void LaunchCastKernel(const phi::GPUContext &ctx,
...
@@ -61,20 +55,19 @@ static void LaunchCastKernel(const phi::GPUContext &ctx,
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
static_cast
<
const
void
*>
(
x
),
static_cast
<
const
void
*>
(
x
),
static_cast
<
void
*>
(
y
),
static_cast
<
void
*>
(
y
),
platform
::
errors
::
InvalidArgument
(
"Inplace cast is not supported yet."
));
errors
::
InvalidArgument
(
"Inplace cast is not supported yet."
));
int
vec_size
=
std
::
min
(
phi
::
GetVectorizedSize
(
x
),
phi
::
GetVectorizedSize
(
y
));
int
vec_size
=
std
::
min
(
phi
::
GetVectorizedSize
(
x
),
phi
::
GetVectorizedSize
(
y
));
switch
(
vec_size
)
{
switch
(
vec_size
)
{
case
4
:
case
4
:
return
details
::
VecCastKernel
<
InT
,
OutT
,
4
>
(
ctx
,
x
,
y
,
n
);
return
VecCastKernel
<
InT
,
OutT
,
4
>
(
ctx
,
x
,
y
,
n
);
case
2
:
case
2
:
return
details
::
VecCastKernel
<
InT
,
OutT
,
2
>
(
ctx
,
x
,
y
,
n
);
return
VecCastKernel
<
InT
,
OutT
,
2
>
(
ctx
,
x
,
y
,
n
);
case
1
:
case
1
:
return
details
::
VecCastKernel
<
InT
,
OutT
,
1
>
(
ctx
,
x
,
y
,
n
);
return
VecCastKernel
<
InT
,
OutT
,
1
>
(
ctx
,
x
,
y
,
n
);
default:
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
"The vectorized size must be 1, 2 or 4."
));
errors
::
InvalidArgument
(
"The vectorized size must be 1, 2 or 4."
));
}
}
}
}
}
// namespace operators
}
// namespace phi
}
// namespace paddle
paddle/
fluid/operators/optimizers/distributed_fused_lamb_init_op
.cu
→
paddle/
phi/kernels/fusion/gpu/distributed_fused_lamb_init_kernel
.cu
浏览文件 @
0bc369ef
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -12,24 +12,24 @@
...
@@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h"
#include "paddle/phi/kernels/distributed_fused_lamb_init_kernel.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h"
#include "paddle/phi/kernels/fusion/gpu/cast_with_ptr.h"
namespace
p
addle
{
namespace
p
hi
{
namespace
operators
{
namespace
fusion
{
using
phi
::
funcs
::
FlattenToString
;
using
phi
::
funcs
::
FlattenToString
;
using
phi
::
funcs
::
ToVector
;
using
phi
::
funcs
::
ToVector
;
struct
ParamGradInfo
{
struct
ParamGradInfo
{
phi
::
DenseTensor
*
param_t
{
nullptr
};
DenseTensor
*
param_t
{
nullptr
};
phi
::
DenseTensor
*
grad_t
{
nullptr
};
DenseTensor
*
grad_t
{
nullptr
};
size_t
idx
{
0
};
size_t
idx
{
0
};
size_t
numel
{
0
};
size_t
numel
{
0
};
size_t
numel_with_padding
{
0
};
size_t
numel_with_padding
{
0
};
...
@@ -82,20 +82,17 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
...
@@ -82,20 +82,17 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
size_t
*
start_numel_offset
,
size_t
*
start_numel_offset
,
size_t
*
end_numel_offset
)
{
size_t
*
end_numel_offset
)
{
VLOG
(
10
)
<<
"NumelOffset: "
VLOG
(
10
)
<<
"NumelOffset: "
<<
string
::
join_strings
(
infos
,
","
,
[](
const
ParamGradInfo
&
info
)
{
<<
paddle
::
string
::
join_strings
(
infos
,
","
,
[](
const
ParamGradInfo
&
info
)
{
return
info
.
numel_offset
;
return
info
.
numel_offset
;
});
});
VLOG
(
10
)
<<
"start_size = "
<<
start_size
<<
" , end_size = "
<<
end_size
;
VLOG
(
10
)
<<
"start_size = "
<<
start_size
<<
" , end_size = "
<<
end_size
;
if
(
infos
.
empty
())
{
if
(
infos
.
empty
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
start_size
,
start_size
,
0
,
errors
::
InvalidArgument
(
"start_size should be 0."
));
0
,
platform
::
errors
::
InvalidArgument
(
"start_size should be 0."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
end_size
,
end_size
,
0
,
errors
::
InvalidArgument
(
"end_size should be 0."
));
0
,
platform
::
errors
::
InvalidArgument
(
"end_size should be 0."
));
*
start_idx
=
0
;
*
start_idx
=
0
;
*
end_idx
=
0
;
*
end_idx
=
0
;
*
start_numel_offset
=
0
;
*
start_numel_offset
=
0
;
...
@@ -103,10 +100,10 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
...
@@ -103,10 +100,10 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
return
;
return
;
}
}
PADDLE_ENFORCE_LT
(
start_size
,
PADDLE_ENFORCE_LT
(
start_size
,
end_size
,
end_size
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"start_size should be less than end_size."
));
"start_size should be less than end_size."
));
size_t
n
=
infos
.
size
();
size_t
n
=
infos
.
size
();
ParamGradInfoNumelOffsetCompFunctor
comp
;
ParamGradInfoNumelOffsetCompFunctor
comp
;
auto
i
=
static_cast
<
size_t
>
(
auto
i
=
static_cast
<
size_t
>
(
...
@@ -116,7 +113,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
...
@@ -116,7 +113,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
i
,
i
,
0
,
0
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Cannot find suitable sharding which is between [%d, %d)"
,
"Cannot find suitable sharding which is between [%d, %d)"
,
start_size
,
start_size
,
end_size
));
end_size
));
...
@@ -125,7 +122,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
...
@@ -125,7 +122,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
i
,
i
,
n
,
n
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Cannot find suitable sharding which is between [%d, %d)"
,
"Cannot find suitable sharding which is between [%d, %d)"
,
start_size
,
start_size
,
end_size
));
end_size
));
...
@@ -136,10 +133,10 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
...
@@ -136,10 +133,10 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
infos
.
begin
());
infos
.
begin
());
*
end_idx
=
j
-
1
;
*
end_idx
=
j
-
1
;
*
end_numel_offset
=
end_size
-
infos
[
j
-
1
].
numel_offset
;
*
end_numel_offset
=
end_size
-
infos
[
j
-
1
].
numel_offset
;
PADDLE_ENFORCE_GT
(
*
end_numel_offset
,
PADDLE_ENFORCE_GT
(
*
end_numel_offset
,
0
,
0
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Internal error when sharding, this may be a bug "
"Internal error when sharding, this may be a bug "
"caused by empty parameter."
));
"caused by empty parameter."
));
VLOG
(
10
)
<<
"Sharding [start_size="
<<
start_size
<<
", end_size="
<<
end_size
VLOG
(
10
)
<<
"Sharding [start_size="
<<
start_size
<<
", end_size="
<<
end_size
<<
"): "
<<
(
*
start_idx
)
<<
":"
<<
(
*
start_numel_offset
)
<<
" -> "
<<
"): "
<<
(
*
start_idx
)
<<
":"
<<
(
*
start_numel_offset
)
<<
" -> "
...
@@ -154,7 +151,7 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,
...
@@ -154,7 +151,7 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
alignment
%
sizeof_dtype
,
alignment
%
sizeof_dtype
,
0
,
0
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The attr(alignment) should be exactly divided by sizeof(T) %d."
,
"The attr(alignment) should be exactly divided by sizeof(T) %d."
,
sizeof_dtype
));
sizeof_dtype
));
alignment
/=
sizeof_dtype
;
alignment
/=
sizeof_dtype
;
...
@@ -182,41 +179,41 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,
...
@@ -182,41 +179,41 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,
template
<
typename
T
>
template
<
typename
T
>
static
T
*
TensorFillConstant
(
const
phi
::
GPUContext
&
dev_ctx
,
static
T
*
TensorFillConstant
(
const
phi
::
GPUContext
&
dev_ctx
,
phi
::
DenseTensor
*
tensor
,
DenseTensor
*
tensor
,
const
framework
::
DDim
&
dims
,
const
DDim
&
dims
,
T
value
)
{
T
value
)
{
tensor
->
Resize
(
dims
);
tensor
->
Resize
(
dims
);
auto
*
ptr
=
tensor
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
()
);
auto
*
ptr
=
dev_ctx
.
template
Alloc
<
T
>(
tensor
);
phi
::
funcs
::
SetConstant
<
phi
::
GPUContext
,
T
>
set_constant
;
phi
::
funcs
::
SetConstant
<
phi
::
GPUContext
,
T
>
set_constant
;
set_constant
(
dev_ctx
,
tensor
,
value
);
set_constant
(
dev_ctx
,
tensor
,
value
);
return
ptr
;
return
ptr
;
}
}
static
phi
::
DenseTensor
CastDataForInitedTensor
(
const
phi
::
GPUContext
&
dev_ctx
,
static
DenseTensor
CastDataForInitedTensor
(
const
phi
::
GPUContext
&
dev_ctx
,
phi
::
DenseTensor
*
origin
,
DenseTensor
*
origin
,
phi
::
DenseTensor
*
fused_out
,
DenseTensor
*
fused_out
,
size_t
numel_offset
)
{
size_t
numel_offset
)
{
PADDLE_ENFORCE_EQ
(
origin
->
IsInitialized
(),
PADDLE_ENFORCE_EQ
(
origin
->
IsInitialized
(),
true
,
true
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The tensor to be cast should be initialized."
));
"The tensor to be cast should be initialized."
));
PADDLE_ENFORCE_EQ
(
fused_out
->
dtype
(),
PADDLE_ENFORCE_EQ
(
fused_out
->
dtype
(),
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT32
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The dst tensor to be cast should be FP32 tensor."
));
"The dst tensor to be cast should be FP32 tensor."
));
PADDLE_ENFORCE_EQ
(
origin
->
dtype
(),
PADDLE_ENFORCE_EQ
(
origin
->
dtype
(),
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
FLOAT16
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The src tensor to be cast should be FP16 tensor."
));
"The src tensor to be cast should be FP16 tensor."
));
auto
*
dst
=
fused_out
->
data
<
float
>
()
+
numel_offset
;
auto
*
dst
=
fused_out
->
data
<
float
>
()
+
numel_offset
;
auto
*
src
=
origin
->
data
<
platform
::
float16
>
();
auto
*
src
=
origin
->
data
<
dtype
::
float16
>
();
auto
numel
=
origin
->
numel
();
auto
numel
=
origin
->
numel
();
LaunchCastKernel
(
dev_ctx
,
src
,
dst
,
numel
);
LaunchCastKernel
(
dev_ctx
,
src
,
dst
,
numel
);
VLOG
(
10
)
<<
"Cast from FP32 -> FP16, range: ["
<<
numel_offset
<<
", "
VLOG
(
10
)
<<
"Cast from FP32 -> FP16, range: ["
<<
numel_offset
<<
", "
<<
numel_offset
+
numel
<<
")"
<<
numel_offset
+
numel
<<
")"
<<
" , total: [0, "
<<
fused_out
->
numel
()
<<
")"
;
<<
" , total: [0, "
<<
fused_out
->
numel
()
<<
")"
;
framework
::
DDim
fused_out_dim
=
fused_out
->
dims
();
DDim
fused_out_dim
=
fused_out
->
dims
();
auto
fused_out_numel
=
fused_out
->
numel
();
auto
fused_out_numel
=
fused_out
->
numel
();
fused_out
->
Resize
({
fused_out_numel
});
fused_out
->
Resize
({
fused_out_numel
});
auto
sliced_tensor
=
fused_out
->
Slice
(
numel_offset
,
numel
+
numel_offset
);
auto
sliced_tensor
=
fused_out
->
Slice
(
numel_offset
,
numel
+
numel_offset
);
...
@@ -224,45 +221,40 @@ static phi::DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx,
...
@@ -224,45 +221,40 @@ static phi::DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx,
return
sliced_tensor
;
return
sliced_tensor
;
}
}
static
phi
::
DenseTensor
CopyAndShareBufferForInitedTensor
(
static
DenseTensor
CopyAndShareBufferForInitedTensor
(
phi
::
DenseTensor
*
origin
,
const
phi
::
GPUContext
&
dev_ctx
,
phi
::
DenseTensor
*
fused_out
,
DenseTensor
*
origin
,
size_t
numel_offse
t
,
DenseTensor
*
fused_ou
t
,
gpuStream_t
stream
)
{
size_t
numel_offset
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
origin
->
IsInitialized
(),
origin
->
IsInitialized
(),
true
,
true
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The tensor to be copied and shared data should be initialized."
));
"The tensor to be copied and shared data should be initialized."
));
auto
dtype
=
fused_out
->
type
();
auto
dtype
=
fused_out
->
type
();
PADDLE_ENFORCE_EQ
(
origin
->
type
(),
PADDLE_ENFORCE_EQ
(
origin
->
type
(),
dtype
,
dtype
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The tensor to be copied and shared data should be "
"The tensor to be copied and shared data should be "
"have the same data type."
));
"have the same data type."
));
auto
place
=
fused_out
->
place
();
auto
place
=
fused_out
->
place
();
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
origin
->
place
(),
origin
->
place
(),
place
,
place
,
platform
::
errors
::
InvalidArgument
(
"The tensor to be copied and shared "
errors
::
InvalidArgument
(
"The tensor to be copied and shared "
"data should be have the same place."
));
"data should be have the same place."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
place
)
,
dev_ctx
.
GetPlace
().
GetType
()
==
phi
::
AllocationType
::
GPU
,
true
,
true
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The tensor to be copied and shared data should be on GPU place."
));
"The tensor to be copied and shared data should be on GPU place."
));
auto
numel
=
origin
->
numel
();
auto
numel
=
origin
->
numel
();
framework
::
DDim
fused_out_dim
=
fused_out
->
dims
();
DDim
fused_out_dim
=
fused_out
->
dims
();
auto
fused_out_numel
=
fused_out
->
numel
();
auto
fused_out_numel
=
fused_out
->
numel
();
auto
sliced_tensor
=
fused_out
->
Resize
({
fused_out_numel
})
auto
sliced_tensor
=
fused_out
->
Resize
({
fused_out_numel
})
.
Slice
(
numel_offset
,
numel
+
numel_offset
);
.
Slice
(
numel_offset
,
numel
+
numel_offset
);
memory
::
Copy
(
place
,
phi
::
Copy
(
dev_ctx
,
*
origin
,
dev_ctx
.
GetPlace
(),
false
,
&
sliced_tensor
);
sliced_tensor
.
data
(),
place
,
origin
->
data
(),
numel
*
phi
::
SizeOf
(
dtype
),
stream
);
origin
->
ShareBufferWith
(
sliced_tensor
);
origin
->
ShareBufferWith
(
sliced_tensor
);
fused_out
->
Resize
(
fused_out_dim
);
fused_out
->
Resize
(
fused_out_dim
);
VLOG
(
10
)
<<
"Copy and share buffer, range: ["
<<
numel_offset
<<
", "
VLOG
(
10
)
<<
"Copy and share buffer, range: ["
<<
numel_offset
<<
", "
...
@@ -271,17 +263,17 @@ static phi::DenseTensor CopyAndShareBufferForInitedTensor(
...
@@ -271,17 +263,17 @@ static phi::DenseTensor CopyAndShareBufferForInitedTensor(
return
sliced_tensor
;
return
sliced_tensor
;
}
}
static
void
ShareBufferForNonInitedTensor
(
phi
::
DenseTensor
*
origin
,
static
void
ShareBufferForNonInitedTensor
(
DenseTensor
*
origin
,
phi
::
DenseTensor
*
fused_out
,
DenseTensor
*
fused_out
,
size_t
numel_offset
,
size_t
numel_offset
,
const
framework
::
DDim
&
dims
)
{
const
DDim
&
dims
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
origin
->
IsInitialized
(),
origin
->
IsInitialized
(),
false
,
false
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The tensor to be shared data should not be initialized."
));
"The tensor to be shared data should not be initialized."
));
framework
::
DDim
fused_out_dim
=
fused_out
->
dims
();
DDim
fused_out_dim
=
fused_out
->
dims
();
auto
fused_out_numel
=
fused_out
->
numel
();
auto
fused_out_numel
=
fused_out
->
numel
();
auto
numel
=
phi
::
product
(
dims
);
auto
numel
=
phi
::
product
(
dims
);
*
origin
=
fused_out
->
Resize
({
fused_out_numel
})
*
origin
=
fused_out
->
Resize
({
fused_out_numel
})
...
@@ -294,10 +286,11 @@ static void ShareBufferForNonInitedTensor(phi::DenseTensor *origin,
...
@@ -294,10 +286,11 @@ static void ShareBufferForNonInitedTensor(phi::DenseTensor *origin,
}
}
template
<
typename
T
>
template
<
typename
T
>
static
void
CopyVectorToCPUTensor
(
const
std
::
vector
<
T
>
&
src
,
static
void
CopyVectorToCPUTensor
(
const
phi
::
GPUContext
&
dev_ctx
,
phi
::
DenseTensor
*
dst
)
{
const
std
::
vector
<
T
>
&
src
,
DenseTensor
*
dst
)
{
dst
->
Resize
({
static_cast
<
int64_t
>
(
src
.
size
())});
dst
->
Resize
({
static_cast
<
int64_t
>
(
src
.
size
())});
T
*
dst_ptr
=
d
st
->
mutable_data
<
T
>
(
platform
::
CPUPlace
()
);
T
*
dst_ptr
=
d
ev_ctx
.
template
HostAlloc
<
T
>(
dst
);
const
T
*
src_ptr
=
src
.
data
();
const
T
*
src_ptr
=
src
.
data
();
auto
nbytes
=
src
.
size
()
*
sizeof
(
T
);
auto
nbytes
=
src
.
size
()
*
sizeof
(
T
);
std
::
memcpy
(
dst_ptr
,
src_ptr
,
nbytes
);
std
::
memcpy
(
dst_ptr
,
src_ptr
,
nbytes
);
...
@@ -339,98 +332,115 @@ static T ClipByBound(T x, T low_value, T high_value) {
...
@@ -339,98 +332,115 @@ static T ClipByBound(T x, T low_value, T high_value) {
return
x
;
return
x
;
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
Context
>
class
DistributedFusedLambInitOpKernel
<
T
,
phi
::
GPUContext
>
void
DistributedFusedLambInitOpKernel
(
:
public
framework
::
OpKernel
<
T
>
{
const
Context
&
dev_ctx
,
public:
const
std
::
vector
<
const
DenseTensor
*>
&
param
,
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
std
::
vector
<
const
DenseTensor
*>
&
grad
,
float
beta1
,
float
beta2
,
const
std
::
vector
<
int
>
&
apply_weight_decay
,
int
alignment
,
int
rank
,
int
nranks
,
DenseTensor
*
fp32_fused_param
,
DenseTensor
*
fp32_fused_grad
,
DenseTensor
*
fp16_fused_param
,
DenseTensor
*
fp16_fused_grad
,
DenseTensor
*
moment1
,
DenseTensor
*
moment2
,
DenseTensor
*
beta1_pow
,
DenseTensor
*
beta2_pow
,
DenseTensor
*
fused_param_offsets
,
DenseTensor
*
fp32_shard_fused_param_offsets
,
DenseTensor
*
fp16_shard_fused_param_offsets
,
DenseTensor
*
param_info
,
DenseTensor
*
param_order
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
,
std
::
vector
<
DenseTensor
*>
grad_out
,
DenseTensor
*
global_scale
,
DenseTensor
*
step
)
{
VLOG
(
10
)
<<
"starts to run DistributedFusedLambInitOp"
;
VLOG
(
10
)
<<
"starts to run DistributedFusedLambInitOp"
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
phi
::
GPUContext
>();
auto
place
=
dev_ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
auto
stream
=
dev_ctx
.
stream
();
auto
stream
=
dev_ctx
.
stream
();
// Step 1: Check Input(Param) and Output(ParamOut), Input(Grad) and
// Step 1: Check Input(Param) and Output(ParamOut), Input(Grad) and
// Output(GradOut)
// Output(GradOut)
auto
params
=
ctx
.
MultiInput
<
phi
::
DenseTensor
>
(
"Param"
);
auto
grads
=
ctx
.
MultiInput
<
phi
::
DenseTensor
>
(
"Grad"
);
auto
master_params
=
ctx
.
MultiOutput
<
phi
::
DenseTensor
>
(
"MasterParamOut"
);
std
::
vector
<
ParamGradInfo
>
fp32_infos
,
fp16_infos
;
std
::
vector
<
ParamGradInfo
>
fp32_infos
,
fp16_infos
;
{
{
PADDLE_ENFORCE_EQ
(
params
.
size
(),
PADDLE_ENFORCE_EQ
(
grads
.
size
(),
param
.
size
(),
platform
::
errors
::
InvalidArgument
(
grad
.
size
(),
"The parameter number and parameter gradient "
errors
::
InvalidArgument
(
"The parameter number and parameter gradient "
"number should be the same."
));
"number should be the same."
));
auto
params_out
=
ctx
.
MultiOutput
<
phi
::
DenseTensor
>
(
"ParamOut"
);
auto
grads_out
=
ctx
.
MultiOutput
<
phi
::
DenseTensor
>
(
"GradOut"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
params
.
size
(),
param
.
size
(),
params
_out
.
size
(),
param
_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Output(ParamOut) "
errors
::
InvalidArgument
(
"Input(Param) and Output(ParamOut) "
"should have the same number."
));
"should have the same number."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
grads
.
size
(),
grad
.
size
(),
grads
_out
.
size
(),
grad
_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Input(Grad) and Output(GradOut) should have the same number."
));
"Input(Grad) and Output(GradOut) should have the same number."
));
size_t
n
=
params
.
size
();
size_t
n
=
param
.
size
();
VLOG
(
10
)
<<
"parameter number: "
<<
n
;
VLOG
(
10
)
<<
"parameter number: "
<<
n
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
auto
*
p
=
params
[
i
];
auto
*
p
=
param
[
i
];
auto
*
g
=
grads
[
i
];
auto
*
g
=
grad
[
i
];
auto
*
p_out
=
params
_out
[
i
];
auto
*
p_out
=
param
_out
[
i
];
auto
*
g_out
=
grads
_out
[
i
];
auto
*
g_out
=
grad
_out
[
i
];
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
p
,
p
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th parameter should not be nullptr."
,
"The %d-th parameter should not be nullptr."
,
i
));
i
));
PADDLE_ENFORCE_EQ
(
p
->
IsInitialized
(),
PADDLE_ENFORCE_EQ
(
p
->
IsInitialized
(),
true
,
true
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th parameter should be initialized."
,
i
));
"The %d-th parameter should be initialized."
,
i
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
p
->
place
(),
p
->
place
(),
place
,
place
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th parameter is not initialized on the right place."
,
"The %d-th parameter is not initialized on the right place."
,
i
));
i
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
p
,
p
,
p_out
,
p_out
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th Input(Param) and Output(ParamOut) "
"The %d-th Input(Param) and Output(ParamOut) "
"should be the same tensor."
,
"should be the same tensor."
,
i
));
i
));
auto
dtype
=
p
->
dtype
();
auto
dtype
=
p
->
dtype
();
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
g
,
g
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th gradient should not be nullptr."
,
"The %d-th gradient should not be nullptr."
,
i
));
i
));
PADDLE_ENFORCE_EQ
(
g
,
PADDLE_ENFORCE_EQ
(
g
,
g_out
,
g_out
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th Input(Grad) and Output(Grad) should "
"The %d-th Input(Grad) and Output(Grad) should "
"be the same tensor."
));
"be the same tensor."
));
auto
numel
=
p
->
numel
();
auto
numel
=
p
->
numel
();
PADDLE_ENFORCE_GT
(
numel
,
PADDLE_ENFORCE_GT
(
numel
,
0
,
0
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th Input(Param) have no elements."
));
"The %d-th Input(Param) have no elements."
));
void
*
g_data
=
nullptr
;
void
*
g_data
=
nullptr
;
if
(
g
->
IsInitialized
())
{
if
(
g
->
IsInitialized
())
{
PADDLE_ENFORCE_EQ
(
g
->
dtype
(),
PADDLE_ENFORCE_EQ
(
g
->
dtype
(),
dtype
,
dtype
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th Input(Param) and Input(Grad) should "
"The %d-th Input(Param) and Input(Grad) should "
"have the same data type %s."
,
"have the same data type %s."
,
i
,
i
,
dtype
));
dtype
));
PADDLE_ENFORCE_EQ
(
g
->
dims
(),
PADDLE_ENFORCE_EQ
(
g
->
dims
(),
p
->
dims
(),
p
->
dims
(),
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th Input(Param) and Input(Grad) should "
"The %d-th Input(Param) and Input(Grad) should "
"have the same shape."
,
"have the same shape."
,
i
));
i
));
...
@@ -445,8 +455,8 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -445,8 +455,8 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
fp16_infos
.
emplace_back
();
fp16_infos
.
emplace_back
();
info
=
&
fp16_infos
.
back
();
info
=
&
fp16_infos
.
back
();
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
"Unsupported data type %s."
,
dtype
));
errors
::
InvalidArgument
(
"Unsupported data type %s."
,
dtype
));
}
}
VLOG
(
10
)
<<
"Found "
<<
dtype
<<
" parameter "
<<
i
<<
" shape=["
VLOG
(
10
)
<<
"Found "
<<
dtype
<<
" parameter "
<<
i
<<
" shape=["
...
@@ -462,108 +472,102 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -462,108 +472,102 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
info
->
numel_offset
=
0
;
// not determined yet
info
->
numel_offset
=
0
;
// not determined yet
}
}
}
}
const
auto
&
apply_weight_decay
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"apply_weight_decay"
);
size_t
fp32_wd_end_idx
=
size_t
fp32_wd_end_idx
=
ReorderParamGradInfoList
(
apply_weight_decay
,
&
fp32_infos
);
ReorderParamGradInfoList
(
apply_weight_decay
,
&
fp32_infos
);
size_t
fp16_wd_end_idx
=
size_t
fp16_wd_end_idx
=
ReorderParamGradInfoList
(
apply_weight_decay
,
&
fp16_infos
);
ReorderParamGradInfoList
(
apply_weight_decay
,
&
fp16_infos
);
auto
*
param_order_t
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"ParamOrder"
);
auto
param_num
=
fp32_infos
.
size
()
+
fp16_infos
.
size
();
auto
param_num
=
fp32_infos
.
size
()
+
fp16_infos
.
size
();
param_order_t
->
Resize
({
static_cast
<
int16_t
>
(
param_num
)});
param_order
->
Resize
({
static_cast
<
int16_t
>
(
param_num
)});
auto
*
param_order
=
param_order_t
->
mutable_data
<
int
>
(
platform
::
CPUPlace
()
);
auto
*
param_order_t
=
dev_ctx
.
template
HostAlloc
<
int
>(
param_order
);
for
(
size_t
i
=
0
;
i
<
fp32_infos
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
fp32_infos
.
size
();
++
i
)
{
param_order
[
i
]
=
static_cast
<
int
>
(
fp32_infos
[
i
].
idx
);
param_order_t
[
i
]
=
static_cast
<
int
>
(
fp32_infos
[
i
].
idx
);
}
}
for
(
size_t
i
=
0
;
i
<
fp16_infos
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
fp16_infos
.
size
();
++
i
)
{
param_order
[
i
+
fp32_infos
.
size
()]
=
static_cast
<
int
>
(
fp16_infos
[
i
].
idx
);
param_order_t
[
i
+
fp32_infos
.
size
()]
=
static_cast
<
int
>
(
fp16_infos
[
i
].
idx
);
}
}
VLOG
(
10
)
<<
"Fill ParamGradInfo ends"
;
VLOG
(
10
)
<<
"Fill ParamGradInfo ends"
;
// Step 2: determine the numel_with_padding and numel_offset
// Step 2: determine the numel_with_padding and numel_offset
auto
rank
=
ctx
.
Attr
<
int
>
(
"rank"
);
auto
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
auto
alignment
=
ctx
.
Attr
<
int
>
(
"alignment"
);
VLOG
(
10
)
<<
"rank = "
<<
rank
<<
", nranks = "
<<
nranks
VLOG
(
10
)
<<
"rank = "
<<
rank
<<
", nranks = "
<<
nranks
<<
" , alignment = "
<<
alignment
;
<<
" , alignment = "
<<
alignment
;
if
(
alignment
<=
0
)
{
if
(
alignment
<=
0
)
{
alignment
=
platform
::
GpuMinChunkSize
();
alignment
=
phi
::
backends
::
gpu
::
GpuMinChunkSize
();
}
}
PADDLE_ENFORCE_GE
(
alignment
,
PADDLE_ENFORCE_GE
(
alignment
,
1
,
1
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The attr(alignment) should be larger than 0."
));
"The attr(alignment) should be larger than 0."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
alignment
&
(
alignment
-
1
),
alignment
&
(
alignment
-
1
),
0
,
0
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The attr(alignment) should be the power of 2."
));
"The attr(alignment) should be the power of 2."
));
PADDLE_ENFORCE_GE
(
rank
,
PADDLE_ENFORCE_GE
(
rank
,
0
,
0
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The attr(rank) should be equal to or larger than 0."
));
"The attr(rank) should be equal to or larger than 0."
));
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
rank
,
rank
,
nranks
,
nranks
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The attr(rank) should be less than the attr(nranks)."
));
"The attr(rank) should be less than the attr(nranks)."
));
// NOTE: We guarantee that both fp32_numel and fp16_numel can be exactly
// NOTE: We guarantee that both fp32_numel and fp16_numel can be exactly
// divided by alignment and nranks.
// divided by alignment and nranks.
auto
fp32_numel
=
FillAlignmentPaddingInfo
(
auto
fp32_numel
=
FillAlignmentPaddingInfo
(
&
fp32_infos
,
alignment
,
nranks
,
phi
::
DataType
::
FLOAT32
);
&
fp32_infos
,
alignment
,
nranks
,
phi
::
DataType
::
FLOAT32
);
VLOG
(
10
)
<<
"FP32 ParamGradInfo: "
<<
string
::
join_strings
(
fp32_infos
,
" "
);
VLOG
(
10
)
<<
"FP32 ParamGradInfo: "
<<
paddle
::
string
::
join_strings
(
fp32_infos
,
" "
);
auto
fp16_numel
=
FillAlignmentPaddingInfo
(
auto
fp16_numel
=
FillAlignmentPaddingInfo
(
&
fp16_infos
,
alignment
,
nranks
,
phi
::
DataType
::
FLOAT16
);
&
fp16_infos
,
alignment
,
nranks
,
phi
::
DataType
::
FLOAT16
);
VLOG
(
10
)
<<
"FP16 ParamGradInfo: "
<<
string
::
join_strings
(
fp16_infos
,
" "
);
VLOG
(
10
)
<<
"FP16 ParamGradInfo: "
<<
paddle
::
string
::
join_strings
(
fp16_infos
,
" "
);
auto
total_numel
=
fp32_numel
+
fp16_numel
;
auto
total_numel
=
fp32_numel
+
fp16_numel
;
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
total_numel
,
total_numel
,
std
::
numeric_limits
<
int
>::
max
(),
std
::
numeric_limits
<
int
>::
max
(),
platform
::
errors
::
InvalidArgument
(
"Too many parameter number."
));
errors
::
InvalidArgument
(
"Too many parameter number."
));
auto
fp32_numel_each_device
=
fp32_numel
/
nranks
;
auto
fp32_numel_each_device
=
fp32_numel
/
nranks
;
auto
fp16_numel_each_device
=
fp16_numel
/
nranks
;
auto
fp16_numel_each_device
=
fp16_numel
/
nranks
;
auto
numel_each_device
=
fp32_numel_each_device
+
fp16_numel_each_device
;
auto
numel_each_device
=
fp32_numel_each_device
+
fp16_numel_each_device
;
VLOG
(
10
)
<<
"Fill padding ends. total_numel = "
<<
total_numel
VLOG
(
10
)
<<
"Fill padding ends. total_numel = "
<<
total_numel
<<
", fp32_numel = "
<<
fp32_numel
<<
", fp32_numel = "
<<
fp32_numel
<<
", fp16_numel = "
<<
fp16_numel
<<
", fp16_numel = "
<<
fp16_numel
<<
", fp32_numel_each_device = "
<<
fp32_numel_each_device
<<
", fp32_numel_each_device = "
<<
fp32_numel_each_device
<<
", fp16_numel_each_device = "
<<
fp16_numel_each_device
;
<<
", fp16_numel_each_device = "
<<
fp16_numel_each_device
;
// Step 3: allocate output tensor and do initialization
// Step 3: allocate output tensor and do initialization
float
*
fused_fp32_param
=
nullptr
,
*
fused_fp32_grad
=
nullptr
;
float
*
fused_fp32_param
=
nullptr
,
*
fused_fp32_grad
=
nullptr
;
platform
::
float16
*
fused_fp16_param
=
nullptr
,
*
fused_fp16_grad
=
nullptr
;
dtype
::
float16
*
fused_fp16_param
=
nullptr
,
*
fused_fp16_grad
=
nullptr
;
phi
::
DenseTensor
*
fp32_p_t
=
nullptr
,
*
fp16_p
_t
=
nullptr
,
DenseTensor
*
fp32_p_t
=
nullptr
,
*
fp16_p_t
=
nullptr
,
*
fp32_g
_t
=
nullptr
,
*
fp32_g_t
=
nullptr
,
*
fp16_g_t
=
nullptr
;
*
fp16_g_t
=
nullptr
;
std
::
vector
<
phi
::
DenseTensor
*>
fp16_master_params
;
std
::
vector
<
DenseTensor
*>
fp16_master_params
;
if
(
total_numel
>
0
)
{
if
(
total_numel
>
0
)
{
fp32_p_t
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"FP32FusedParam"
)
;
fp32_p_t
=
fp32_fused_param
;
fused_fp32_param
=
TensorFillConstant
<
float
>
(
fused_fp32_param
=
TensorFillConstant
<
float
>
(
dev_ctx
,
fp32_p_t
,
{
static_cast
<
int64_t
>
(
total_numel
)},
0.0
f
);
dev_ctx
,
fp32_p_t
,
{
static_cast
<
int64_t
>
(
total_numel
)},
0.0
f
);
}
}
if
(
fp32_numel
>
0
)
{
if
(
fp32_numel
>
0
)
{
fp32_g_t
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"FP32FusedGrad"
)
;
fp32_g_t
=
fp32_fused_grad
;
fused_fp32_grad
=
TensorFillConstant
<
float
>
(
fused_fp32_grad
=
TensorFillConstant
<
float
>
(
dev_ctx
,
fp32_g_t
,
{
static_cast
<
int64_t
>
(
fp32_numel
)},
0.0
f
);
dev_ctx
,
fp32_g_t
,
{
static_cast
<
int64_t
>
(
fp32_numel
)},
0.0
f
);
}
}
if
(
fp16_numel
>
0
)
{
if
(
fp16_numel
>
0
)
{
fp16_p_t
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"FP16FusedParam"
)
;
fp16_p_t
=
fp16_fused_param
;
fused_fp16_param
=
TensorFillConstant
<
platform
::
float16
>
(
fused_fp16_param
=
dev_ctx
,
TensorFillConstant
<
dtype
::
float16
>
(
dev_ctx
,
fp16_p_t
,
fp16_p_t
,
{
static_cast
<
int64_t
>
(
fp16_numel
)},
{
static_cast
<
int64_t
>
(
fp16_numel
)},
static_cast
<
platform
::
float16
>
(
0
));
static_cast
<
dtype
::
float16
>
(
0
));
fp16_g_t
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"FP16FusedGrad"
)
;
fp16_g_t
=
fp16_fused_grad
;
fused_fp16_grad
=
TensorFillConstant
<
platform
::
float16
>
(
fused_fp16_grad
=
dev_ctx
,
TensorFillConstant
<
dtype
::
float16
>
(
dev_ctx
,
fp16_g_t
,
fp16_g_t
,
{
static_cast
<
int64_t
>
(
fp16_numel
)},
{
static_cast
<
int64_t
>
(
fp16_numel
)},
static_cast
<
platform
::
float16
>
(
0
));
static_cast
<
dtype
::
float16
>
(
0
));
}
}
VLOG
(
10
)
<<
"Allocate FP32FusedParam/Grad, FP16FusedParam/Grad ends"
;
VLOG
(
10
)
<<
"Allocate FP32FusedParam/Grad, FP16FusedParam/Grad ends"
;
...
@@ -573,16 +577,20 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -573,16 +577,20 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
// (3) For FP32FusedGrad/FP16FusedGrad, memcpy if gradient has been inited
// (3) For FP32FusedGrad/FP16FusedGrad, memcpy if gradient has been inited
for
(
const
auto
&
info
:
fp32_infos
)
{
for
(
const
auto
&
info
:
fp32_infos
)
{
auto
sliced_tensor
=
CopyAndShareBufferForInitedTensor
(
auto
sliced_tensor
=
CopyAndShareBufferForInitedTensor
(
info
.
param_t
,
fp32_p_t
,
info
.
numel_offset
,
stream
);
dev_ctx
,
info
.
param_t
,
fp32_p_t
,
info
.
numel_offset
);
master_params
[
info
.
idx
]
->
Resize
(
info
.
param_t
->
dims
());
master_param_out
[
info
.
idx
]
->
Resize
(
info
.
param_t
->
dims
());
master_params
[
info
.
idx
]
->
ShareBufferWith
(
sliced_tensor
);
master_param_out
[
info
.
idx
]
->
ShareBufferWith
(
sliced_tensor
);
PADDLE_ENFORCE_EQ
(
master_params
[
info
.
idx
]
->
mutable_data
<
float
>
(
place
),
float
*
master_param_tmp
=
sliced_tensor
.
data
<
float
>
(),
dev_ctx
.
template
Alloc
<
float
>(
master_param_out
[
info
.
idx
]);
platform
::
errors
::
InvalidArgument
(
float
*
sliced_tensor_tmp
=
reinterpret_cast
<
float
*>
(
sliced_tensor
.
data
());
"Invalid master weight tensor pointer."
));
PADDLE_ENFORCE_EQ
(
master_param_tmp
,
sliced_tensor_tmp
,
errors
::
InvalidArgument
(
"Invalid master weight tensor pointer."
));
if
(
info
.
grad_t
->
IsInitialized
())
{
if
(
info
.
grad_t
->
IsInitialized
())
{
CopyAndShareBufferForInitedTensor
(
CopyAndShareBufferForInitedTensor
(
info
.
grad_t
,
fp32_g_t
,
info
.
numel_offset
,
stream
);
dev_ctx
,
info
.
grad_t
,
fp32_g_t
,
info
.
numel_offset
);
}
else
{
}
else
{
ShareBufferForNonInitedTensor
(
ShareBufferForNonInitedTensor
(
info
.
grad_t
,
fp32_g_t
,
info
.
numel_offset
,
info
.
param_t
->
dims
());
info
.
grad_t
,
fp32_g_t
,
info
.
numel_offset
,
info
.
param_t
->
dims
());
...
@@ -600,19 +608,22 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -600,19 +608,22 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
auto
master_weight_offset
=
info
.
numel_offset
+
fp16_numel_offset
;
auto
master_weight_offset
=
info
.
numel_offset
+
fp16_numel_offset
;
auto
sliced_tensor
=
CastDataForInitedTensor
(
auto
sliced_tensor
=
CastDataForInitedTensor
(
dev_ctx
,
info
.
param_t
,
fp32_p_t
,
master_weight_offset
);
dev_ctx
,
info
.
param_t
,
fp32_p_t
,
master_weight_offset
);
master_params
[
info
.
idx
]
->
Resize
(
info
.
param_t
->
dims
());
master_param_out
[
info
.
idx
]
->
Resize
(
info
.
param_t
->
dims
());
master_params
[
info
.
idx
]
->
ShareBufferWith
(
sliced_tensor
);
master_param_out
[
info
.
idx
]
->
ShareBufferWith
(
sliced_tensor
);
CopyAndShareBufferForInitedTensor
(
CopyAndShareBufferForInitedTensor
(
info
.
param_t
,
fp16_p_t
,
info
.
numel_offset
,
stream
);
dev_ctx
,
info
.
param_t
,
fp16_p_t
,
info
.
numel_offset
);
PADDLE_ENFORCE_EQ
(
master_params
[
info
.
idx
]
->
mutable_data
<
float
>
(
place
),
float
*
master_param_tmp
=
sliced_tensor
.
data
<
float
>
(),
dev_ctx
.
template
Alloc
<
float
>(
master_param_out
[
info
.
idx
]);
platform
::
errors
::
InvalidArgument
(
float
*
sliced_tensor_tmp
=
reinterpret_cast
<
float
*>
(
sliced_tensor
.
data
());
"Invalid master weight tensor pointer."
));
PADDLE_ENFORCE_EQ
(
master_param_tmp
,
sliced_tensor_tmp
,
errors
::
InvalidArgument
(
"Invalid master weight tensor pointer."
));
if
(
info
.
grad_t
->
IsInitialized
())
{
if
(
info
.
grad_t
->
IsInitialized
())
{
CopyAndShareBufferForInitedTensor
(
CopyAndShareBufferForInitedTensor
(
info
.
grad_t
,
fp16_g_t
,
info
.
numel_offset
,
stream
);
dev_ctx
,
info
.
grad_t
,
fp16_g_t
,
info
.
numel_offset
);
}
else
{
}
else
{
ShareBufferForNonInitedTensor
(
ShareBufferForNonInitedTensor
(
info
.
grad_t
,
fp16_g_t
,
info
.
numel_offset
,
info
.
param_t
->
dims
());
info
.
grad_t
,
fp16_g_t
,
info
.
numel_offset
,
info
.
param_t
->
dims
());
...
@@ -621,22 +632,12 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -621,22 +632,12 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
VLOG
(
10
)
<<
"Copy/share data for Param/Grad ends"
;
VLOG
(
10
)
<<
"Copy/share data for Param/Grad ends"
;
// Step 4: For Moment1, Moment2, Beta1Pow, Beta2Pow, just fill constant
// Step 4: For Moment1, Moment2, Beta1Pow, Beta2Pow, just fill constant
TensorFillConstant
<
float
>
(
dev_ctx
,
TensorFillConstant
<
float
>
(
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Moment1"
),
dev_ctx
,
moment1
,
{
static_cast
<
int64_t
>
(
numel_each_device
)},
0.0
f
);
{
static_cast
<
int64_t
>
(
numel_each_device
)},
TensorFillConstant
<
float
>
(
0.0
f
);
dev_ctx
,
moment2
,
{
static_cast
<
int64_t
>
(
numel_each_device
)},
0.0
f
);
TensorFillConstant
<
float
>
(
dev_ctx
,
TensorFillConstant
<
float
>
(
dev_ctx
,
beta1_pow
,
{
1
},
beta1
);
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Moment2"
),
TensorFillConstant
<
float
>
(
dev_ctx
,
beta2_pow
,
{
1
},
beta2
);
{
static_cast
<
int64_t
>
(
numel_each_device
)},
0.0
f
);
TensorFillConstant
<
float
>
(
dev_ctx
,
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Beta1Pow"
),
{
1
},
ctx
.
Attr
<
float
>
(
"beta1"
));
TensorFillConstant
<
float
>
(
dev_ctx
,
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Beta2Pow"
),
{
1
},
ctx
.
Attr
<
float
>
(
"beta2"
));
VLOG
(
10
)
<<
"Init Moment and BetaPow ends"
;
VLOG
(
10
)
<<
"Init Moment and BetaPow ends"
;
// Step 5: Do sharding
// Step 5: Do sharding
...
@@ -665,34 +666,33 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -665,34 +666,33 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
size_t
total_local_param_num
=
fp32_local_param_num
+
fp16_local_param_num
;
size_t
total_local_param_num
=
fp32_local_param_num
+
fp16_local_param_num
;
VLOG
(
10
)
<<
"Found the sharding arguments"
;
VLOG
(
10
)
<<
"Found the sharding arguments"
;
auto
*
param_info_t
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"ParamInfo"
);
param_info
->
Resize
({
8
});
param_info_t
->
Resize
({
8
});
auto
*
param_info_t
=
dev_ctx
.
template
HostAlloc
<
int
>(
param_info
);
auto
*
param_info
=
param_info_t
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
param_info_t
[
0
]
=
static_cast
<
int
>
(
fp32_start_idx
);
param_info
[
0
]
=
static_cast
<
int
>
(
fp32_start_idx
);
param_info_t
[
1
]
=
static_cast
<
int
>
(
fp32_local_param_num
);
param_info
[
1
]
=
static_cast
<
int
>
(
fp32_local_param_num
);
param_info_t
[
2
]
=
static_cast
<
int
>
(
fp32_infos
.
size
());
param_info
[
2
]
=
static_cast
<
int
>
(
fp32_infos
.
size
());
param_info_t
[
3
]
=
ClipByBound
<
int
>
(
fp32_wd_end_idx
,
param_info
[
3
]
=
ClipByBound
<
int
>
(
fp32_wd_end_idx
,
fp32_start_idx
,
fp32_start_idx
,
fp32_start_idx
+
fp32_local_param_num
)
-
fp32_start_idx
+
fp32_local_param_num
)
-
static_cast
<
int
>
(
fp32_start_idx
);
static_cast
<
int
>
(
fp32_start_idx
);
param_info
[
4
]
=
static_cast
<
int
>
(
fp16_start_idx
+
fp32_infos
.
size
());
param_info_t
[
4
]
=
static_cast
<
int
>
(
fp16_start_idx
+
fp32_infos
.
size
());
param_info
[
5
]
=
static_cast
<
int
>
(
fp16_local_param_num
);
param_info_t
[
5
]
=
static_cast
<
int
>
(
fp16_local_param_num
);
param_info
[
6
]
=
static_cast
<
int
>
(
fp16_infos
.
size
());
param_info_t
[
6
]
=
static_cast
<
int
>
(
fp16_infos
.
size
());
param_info
[
7
]
=
ClipByBound
<
int
>
(
fp16_wd_end_idx
,
param_info_t
[
7
]
=
ClipByBound
<
int
>
(
fp16_wd_end_idx
,
fp16_start_idx
,
fp16_start_idx
,
fp16_start_idx
+
fp16_local_param_num
)
-
fp16_start_idx
+
fp16_local_param_num
)
-
static_cast
<
int
>
(
fp16_start_idx
);
static_cast
<
int
>
(
fp16_start_idx
);
VLOG
(
10
)
<<
"Start FP32 idx: "
<<
param_info
[
0
];
VLOG
(
10
)
<<
"Start FP32 idx: "
<<
param_info_t
[
0
];
VLOG
(
10
)
<<
"Local FP32 param num: "
<<
param_info
[
1
];
VLOG
(
10
)
<<
"Local FP32 param num: "
<<
param_info_t
[
1
];
VLOG
(
10
)
<<
"Global FP32 param num: "
<<
param_info
[
2
];
VLOG
(
10
)
<<
"Global FP32 param num: "
<<
param_info_t
[
2
];
VLOG
(
10
)
<<
"Start FP16 idx: "
<<
param_info
[
4
];
VLOG
(
10
)
<<
"Start FP16 idx: "
<<
param_info_t
[
4
];
VLOG
(
10
)
<<
"Local FP16 param num: "
<<
param_info
[
5
];
VLOG
(
10
)
<<
"Local FP16 param num: "
<<
param_info_t
[
5
];
VLOG
(
10
)
<<
"Global FP16 param num: "
<<
param_info
[
6
];
VLOG
(
10
)
<<
"Global FP16 param num: "
<<
param_info_t
[
6
];
std
::
vector
<
int
>
numel_offsets
;
std
::
vector
<
int
>
numel_offsets
;
numel_offsets
.
reserve
(
params
.
size
()
+
1
);
numel_offsets
.
reserve
(
param
.
size
()
+
1
);
for
(
const
auto
&
info
:
fp32_infos
)
{
for
(
const
auto
&
info
:
fp32_infos
)
{
numel_offsets
.
push_back
(
info
.
numel_offset
);
numel_offsets
.
push_back
(
info
.
numel_offset
);
}
}
...
@@ -701,8 +701,8 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -701,8 +701,8 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
}
}
numel_offsets
.
push_back
(
fp32_numel
+
fp16_numel
);
numel_offsets
.
push_back
(
fp32_numel
+
fp16_numel
);
PADDLE_ENFORCE_EQ
(
numel_offsets
.
size
(),
PADDLE_ENFORCE_EQ
(
numel_offsets
.
size
(),
params
.
size
()
+
1
,
param
.
size
()
+
1
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The numel_offsets number must be one larger than "
"The numel_offsets number must be one larger than "
"the parameter number."
));
"the parameter number."
));
VLOG
(
10
)
<<
"Total numel offset: "
<<
FlattenToString
(
numel_offsets
);
VLOG
(
10
)
<<
"Total numel offset: "
<<
FlattenToString
(
numel_offsets
);
...
@@ -723,13 +723,12 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -723,13 +723,12 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
end_n
=
std
::
min
(
end_n
,
fp32_end_numel_offset
);
end_n
=
std
::
min
(
end_n
,
fp32_end_numel_offset
);
}
}
PADDLE_ENFORCE_NE
(
valid_start_n
,
PADDLE_ENFORCE_NE
(
valid_start_n
,
end_n
,
end_n
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Indices sharding error. This may be a bug."
));
"Indices sharding error. This may be a bug."
));
VLOG
(
10
)
<<
"FP32 Partial numel = ["
<<
valid_start_n
+
fp32_infos
[
i
].
numel
VLOG
(
10
)
<<
"FP32 Partial numel = ["
<<
","
<<
end_n
+
fp32_infos
[
i
].
numel
;
<<
valid_start_n
+
fp32_infos
[
i
].
numel
<<
","
<<
end_n
+
fp32_infos
[
i
].
numel
;
auto
len
=
end_n
-
valid_start_n
;
auto
len
=
end_n
-
valid_start_n
;
fp32_partial_numel_offsets
.
push_back
(
fp32_partial_numel_offsets
.
back
()
+
fp32_partial_numel_offsets
.
push_back
(
fp32_partial_numel_offsets
.
back
()
+
len
);
len
);
...
@@ -750,48 +749,56 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
...
@@ -750,48 +749,56 @@ class DistributedFusedLambInitOpKernel<T, phi::GPUContext>
end_n
=
std
::
min
(
end_n
,
fp16_end_numel_offset
);
end_n
=
std
::
min
(
end_n
,
fp16_end_numel_offset
);
}
}
PADDLE_ENFORCE_NE
(
valid_start_n
,
PADDLE_ENFORCE_NE
(
valid_start_n
,
end_n
,
end_n
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Indices sharding error. This may be a bug."
));
"Indices sharding error. This may be a bug."
));
auto
len
=
end_n
-
valid_start_n
;
auto
len
=
end_n
-
valid_start_n
;
fp16_partial_numel_offsets
.
push_back
(
fp16_partial_numel_offsets
.
back
()
+
fp16_partial_numel_offsets
.
push_back
(
fp16_partial_numel_offsets
.
back
()
+
len
);
len
);
}
}
CopyVectorToCPUTensor
(
numel_offsets
,
CopyVectorToCPUTensor
(
dev_ctx
,
numel_offsets
,
fused_param_offsets
);
ctx
.
Output
<
phi
::
DenseTensor
>
(
"FusedParamOffsets"
));
CopyVectorToCPUTensor
(
CopyVectorToCPUTensor
(
fp32_partial_numel_offsets
,
dev_ctx
,
fp32_partial_numel_offsets
,
fp32_shard_fused_param_offsets
);
ctx
.
Output
<
phi
::
DenseTensor
>
(
"FP32ShardFusedParamOffsets"
));
CopyVectorToCPUTensor
(
CopyVectorToCPUTensor
(
fp16_partial_numel_offsets
,
dev_ctx
,
fp16_partial_numel_offsets
,
fp16_shard_fused_param_offsets
);
ctx
.
Output
<
phi
::
DenseTensor
>
(
"FP16ShardFusedParamOffsets"
));
auto
*
global_scale
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"GlobalScale"
);
if
(
!
global_scale
->
IsInitialized
())
{
if
(
!
global_scale
->
IsInitialized
())
{
TensorFillConstant
<
float
>
(
dev_ctx
,
global_scale
,
{
1
},
1.0
f
);
TensorFillConstant
<
float
>
(
dev_ctx
,
global_scale
,
{
1
},
1.0
f
);
}
}
VLOG
(
10
)
<<
"Init global scale ends"
;
VLOG
(
10
)
<<
"Init global scale ends"
;
TensorFillConstant
<
int64_t
>
(
dev_ctx
,
TensorFillConstant
<
int64_t
>
(
dev_ctx
,
step
,
{
1
},
static_cast
<
int64_t
>
(
0
));
ctx
.
Output
<
phi
::
DenseTensor
>
(
"Step"
),
{
1
},
static_cast
<
int64_t
>
(
0
));
dev_ctx
.
Wait
();
dev_ctx
.
Wait
();
VLOG
(
10
)
<<
"Wait for H2D copy"
;
VLOG
(
10
)
<<
"Wait for H2D copy"
;
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
}
// namespace fusion
namespace
plat
=
paddle
::
platform
;
}
// namespace phi
PD_REGISTER_
STRUCT_
KERNEL
(
distributed_fused_lamb_init
,
PD_REGISTER_KERNEL
(
distributed_fused_lamb_init
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
ops
::
DistributedFusedLambInitOpKernel
,
phi
::
fusion
::
DistributedFusedLambInitOpKernel
,
float
)
{}
float
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
1
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
2
).
SetDataType
(
phi
::
DataType
::
FLOAT16
);
kernel
->
OutputAt
(
3
).
SetDataType
(
phi
::
DataType
::
FLOAT16
);
kernel
->
OutputAt
(
4
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
5
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
6
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
7
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
8
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
9
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
10
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
11
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
12
).
SetDataType
(
phi
::
DataType
::
INT32
);
kernel
->
OutputAt
(
13
).
SetDataType
(
kernel_key
.
dtype
());
kernel
->
OutputAt
(
14
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
15
).
SetDataType
(
kernel_key
.
dtype
());
kernel
->
OutputAt
(
16
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
kernel
->
OutputAt
(
17
).
SetDataType
(
phi
::
DataType
::
INT64
);
}
paddle/
fluid/operators/optimizers/distributed_fused_lamb_init_op.h
→
paddle/
phi/ops/compat/distributed_fused_lamb_init_sig.cc
浏览文件 @
0bc369ef
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -12,22 +12,37 @@
...
@@ -12,22 +12,37 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#
pragma once
#
include "paddle/phi/core/compat/op_utils.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
phi
{
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
KernelSignature
DistributedFusedLambInitOpArgumentMapping
(
namespace
operators
{
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"distributed_fused_lamb_init"
,
{
"Param"
,
"Grad"
},
{
"beta1"
,
"beta2"
,
"apply_weight_decay"
,
"alignment"
,
"rank"
,
"nranks"
},
{
"FP32FusedParam"
,
"FP32FusedGrad"
,
"FP16FusedParam"
,
"FP16FusedGrad"
,
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
,
"FusedParamOffsets"
,
"FP32ShardFusedParamOffsets"
,
"FP16ShardFusedParamOffsets"
,
"ParamInfo"
,
"ParamOrder"
,
"ParamOut"
,
"MasterParamOut"
,
"GradOut"
,
"GlobalScale"
,
"Step"
});
}
template
<
typename
T
,
typename
DevCtx
>
}
// namespace phi
class
DistributedFusedLambInitOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"The distributed_fused_lamb_init operator does not support CPU yet."
));
}
};
}
// namespace operators
PD_REGISTER_ARG_MAPPING_FN
(
distributed_fused_lamb_init
,
}
// namespace paddle
phi
::
DistributedFusedLambInitOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录