未验证 提交 0bc369ef 编写于 作者: Z Zero Rains 提交者: GitHub

[Fluid] Move distributed_fused_lamb_init to phi (#55993)

上级 e358ddac
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -116,9 +117,3 @@ namespace ops = paddle::operators; ...@@ -116,9 +117,3 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init, REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init,
ops::DistributedFusedLambInitOp, ops::DistributedFusedLambInitOp,
ops::DistributedFusedLambInitOpMaker); ops::DistributedFusedLambInitOpMaker);
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb_init,
CPU,
ALL_LAYOUT,
ops::DistributedFusedLambInitOpKernel,
float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DistributedFusedLambInitOpKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
float beta1,
float beta2,
const std::vector<int>& apply_weight_decay,
int alignment,
int rank,
int nranks,
DenseTensor* fp32_fused_param,
DenseTensor* fp32_fused_grad,
DenseTensor* fp16_fused_param,
DenseTensor* fp16_fused_grad,
DenseTensor* moment1,
DenseTensor* moment2,
DenseTensor* beta1_pow,
DenseTensor* beta2_pow,
DenseTensor* fused_param_offsets,
DenseTensor* fp32_shard_fused_param_offsets,
DenseTensor* fp16_shard_fused_param_offsets,
DenseTensor* param_info,
DenseTensor* param_order,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> master_param_out,
std::vector<DenseTensor*> grad_out,
DenseTensor* global_scale,
DenseTensor* step);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/distributed_fused_lamb_init_kernel.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void DistributedFusedLambInitOpKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
float beta1,
float beta2,
const std::vector<int>& apply_weight_decay,
int alignment,
int rank,
int nranks,
DenseTensor* fp32_fused_param,
DenseTensor* fp32_fused_grad,
DenseTensor* fp16_fused_param,
DenseTensor* fp16_fused_grad,
DenseTensor* moment1,
DenseTensor* moment2,
DenseTensor* beta1_pow,
DenseTensor* beta2_pow,
DenseTensor* fused_param_offsets,
DenseTensor* fp32_shard_fused_param_offsets,
DenseTensor* fp16_shard_fused_param_offsets,
DenseTensor* param_info,
DenseTensor* param_order,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> master_param_out,
std::vector<DenseTensor*> grad_out,
DenseTensor* global_scale,
DenseTensor* step) {
PADDLE_THROW(phi::errors::Unavailable(
"Do not support expert count op for cpu kernel now."));
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(distributed_fused_lamb_init,
CPU,
ALL_LAYOUT,
phi::fusion::DistributedFusedLambInitOpKernel,
float) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(8).SetDataType(phi::DataType::INT32);
kernel->OutputAt(9).SetDataType(phi::DataType::INT32);
kernel->OutputAt(10).SetDataType(phi::DataType::INT32);
kernel->OutputAt(11).SetDataType(phi::DataType::INT32);
kernel->OutputAt(12).SetDataType(phi::DataType::INT32);
kernel->OutputAt(13).SetDataType(kernel_key.dtype());
kernel->OutputAt(14).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(15).SetDataType(kernel_key.dtype());
kernel->OutputAt(16).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(17).SetDataType(phi::DataType::INT64);
}
...@@ -14,28 +14,24 @@ ...@@ -14,28 +14,24 @@
#pragma once #pragma once
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace paddle { namespace phi {
namespace operators {
namespace details {
template <typename InT, typename OutT> template <typename InT, typename OutT>
struct CastFunctor { struct CastFunctor {
HOSTDEVICE OutT operator()(InT x) const { return static_cast<OutT>(x); } HOSTDEVICE OutT operator()(InT x) const { return static_cast<OutT>(x); }
}; };
template <typename InT, typename OutT, int VecSize> template <typename InT, typename OutT, int VecSize>
static void VecCastKernel(const phi::GPUContext &ctx, static void VecCastKernel(const phi::GPUContext &ctx,
const InT *x, const InT *x,
OutT *y, OutT *y,
size_t n) { size_t n) {
auto config = platform::GetGpuLaunchConfig1D(ctx, n, VecSize); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, n, VecSize);
auto block = config.GetGridSize(); auto block = config.GetGridSize();
auto thread = config.GetBlockSize(); auto thread = config.GetBlockSize();
auto main_offset = n / (VecSize * thread) * VecSize * thread; auto main_offset = n / (VecSize * thread) * VecSize * thread;
...@@ -50,8 +46,6 @@ static void VecCastKernel(const phi::GPUContext &ctx, ...@@ -50,8 +46,6 @@ static void VecCastKernel(const phi::GPUContext &ctx,
in_arr, out_arr, n, main_offset, VecSize, FunctorT()); in_arr, out_arr, n, main_offset, VecSize, FunctorT());
} }
} // namespace details
template <typename InT, typename OutT> template <typename InT, typename OutT>
static void LaunchCastKernel(const phi::GPUContext &ctx, static void LaunchCastKernel(const phi::GPUContext &ctx,
const InT *x, const InT *x,
...@@ -61,20 +55,19 @@ static void LaunchCastKernel(const phi::GPUContext &ctx, ...@@ -61,20 +55,19 @@ static void LaunchCastKernel(const phi::GPUContext &ctx,
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
static_cast<const void *>(x), static_cast<const void *>(x),
static_cast<void *>(y), static_cast<void *>(y),
platform::errors::InvalidArgument("Inplace cast is not supported yet.")); errors::InvalidArgument("Inplace cast is not supported yet."));
int vec_size = std::min(phi::GetVectorizedSize(x), phi::GetVectorizedSize(y)); int vec_size = std::min(phi::GetVectorizedSize(x), phi::GetVectorizedSize(y));
switch (vec_size) { switch (vec_size) {
case 4: case 4:
return details::VecCastKernel<InT, OutT, 4>(ctx, x, y, n); return VecCastKernel<InT, OutT, 4>(ctx, x, y, n);
case 2: case 2:
return details::VecCastKernel<InT, OutT, 2>(ctx, x, y, n); return VecCastKernel<InT, OutT, 2>(ctx, x, y, n);
case 1: case 1:
return details::VecCastKernel<InT, OutT, 1>(ctx, x, y, n); return VecCastKernel<InT, OutT, 1>(ctx, x, y, n);
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(
"The vectorized size must be 1, 2 or 4.")); errors::InvalidArgument("The vectorized size must be 1, 2 or 4."));
} }
} }
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,24 +12,24 @@ ...@@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" #include "paddle/phi/kernels/distributed_fused_lamb_init_kernel.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/algorithm.h" #include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h" #include "paddle/phi/kernels/funcs/tensor_to_string.h"
#include "paddle/phi/kernels/fusion/gpu/cast_with_ptr.h"
namespace paddle { namespace phi {
namespace operators { namespace fusion {
using phi::funcs::FlattenToString; using phi::funcs::FlattenToString;
using phi::funcs::ToVector; using phi::funcs::ToVector;
struct ParamGradInfo { struct ParamGradInfo {
phi::DenseTensor *param_t{nullptr}; DenseTensor *param_t{nullptr};
phi::DenseTensor *grad_t{nullptr}; DenseTensor *grad_t{nullptr};
size_t idx{0}; size_t idx{0};
size_t numel{0}; size_t numel{0};
size_t numel_with_padding{0}; size_t numel_with_padding{0};
...@@ -82,20 +82,17 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos, ...@@ -82,20 +82,17 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
size_t *start_numel_offset, size_t *start_numel_offset,
size_t *end_numel_offset) { size_t *end_numel_offset) {
VLOG(10) << "NumelOffset: " VLOG(10) << "NumelOffset: "
<< string::join_strings(infos, ",", [](const ParamGradInfo &info) { << paddle::string::join_strings(
return info.numel_offset; infos, ",", [](const ParamGradInfo &info) {
}); return info.numel_offset;
});
VLOG(10) << "start_size = " << start_size << " , end_size = " << end_size; VLOG(10) << "start_size = " << start_size << " , end_size = " << end_size;
if (infos.empty()) { if (infos.empty()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
start_size, start_size, 0, errors::InvalidArgument("start_size should be 0."));
0,
platform::errors::InvalidArgument("start_size should be 0."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
end_size, end_size, 0, errors::InvalidArgument("end_size should be 0."));
0,
platform::errors::InvalidArgument("end_size should be 0."));
*start_idx = 0; *start_idx = 0;
*end_idx = 0; *end_idx = 0;
*start_numel_offset = 0; *start_numel_offset = 0;
...@@ -103,10 +100,10 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos, ...@@ -103,10 +100,10 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
return; return;
} }
PADDLE_ENFORCE_LT(start_size, PADDLE_ENFORCE_LT(
end_size, start_size,
platform::errors::InvalidArgument( end_size,
"start_size should be less than end_size.")); errors::InvalidArgument("start_size should be less than end_size."));
size_t n = infos.size(); size_t n = infos.size();
ParamGradInfoNumelOffsetCompFunctor comp; ParamGradInfoNumelOffsetCompFunctor comp;
auto i = static_cast<size_t>( auto i = static_cast<size_t>(
...@@ -116,7 +113,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos, ...@@ -116,7 +113,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
i, i,
0, 0,
platform::errors::InvalidArgument( errors::InvalidArgument(
"Cannot find suitable sharding which is between [%d, %d)", "Cannot find suitable sharding which is between [%d, %d)",
start_size, start_size,
end_size)); end_size));
...@@ -125,7 +122,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos, ...@@ -125,7 +122,7 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
i, i,
n, n,
platform::errors::InvalidArgument( errors::InvalidArgument(
"Cannot find suitable sharding which is between [%d, %d)", "Cannot find suitable sharding which is between [%d, %d)",
start_size, start_size,
end_size)); end_size));
...@@ -136,11 +133,11 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos, ...@@ -136,11 +133,11 @@ static void GetParamGradShardInfo(const std::vector<ParamGradInfo> &infos,
infos.begin()); infos.begin());
*end_idx = j - 1; *end_idx = j - 1;
*end_numel_offset = end_size - infos[j - 1].numel_offset; *end_numel_offset = end_size - infos[j - 1].numel_offset;
PADDLE_ENFORCE_GT(*end_numel_offset, PADDLE_ENFORCE_GT(
0, *end_numel_offset,
platform::errors::InvalidArgument( 0,
"Internal error when sharding, this may be a bug " errors::InvalidArgument("Internal error when sharding, this may be a bug "
"caused by empty parameter.")); "caused by empty parameter."));
VLOG(10) << "Sharding [start_size=" << start_size << ", end_size=" << end_size VLOG(10) << "Sharding [start_size=" << start_size << ", end_size=" << end_size
<< "): " << (*start_idx) << ":" << (*start_numel_offset) << " -> " << "): " << (*start_idx) << ":" << (*start_numel_offset) << " -> "
<< (*end_idx) << ":" << (*end_numel_offset); << (*end_idx) << ":" << (*end_numel_offset);
...@@ -154,7 +151,7 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos, ...@@ -154,7 +151,7 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
alignment % sizeof_dtype, alignment % sizeof_dtype,
0, 0,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The attr(alignment) should be exactly divided by sizeof(T) %d.", "The attr(alignment) should be exactly divided by sizeof(T) %d.",
sizeof_dtype)); sizeof_dtype));
alignment /= sizeof_dtype; alignment /= sizeof_dtype;
...@@ -182,41 +179,41 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos, ...@@ -182,41 +179,41 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,
template <typename T> template <typename T>
static T *TensorFillConstant(const phi::GPUContext &dev_ctx, static T *TensorFillConstant(const phi::GPUContext &dev_ctx,
phi::DenseTensor *tensor, DenseTensor *tensor,
const framework::DDim &dims, const DDim &dims,
T value) { T value) {
tensor->Resize(dims); tensor->Resize(dims);
auto *ptr = tensor->mutable_data<T>(dev_ctx.GetPlace()); auto *ptr = dev_ctx.template Alloc<T>(tensor);
phi::funcs::SetConstant<phi::GPUContext, T> set_constant; phi::funcs::SetConstant<phi::GPUContext, T> set_constant;
set_constant(dev_ctx, tensor, value); set_constant(dev_ctx, tensor, value);
return ptr; return ptr;
} }
static phi::DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx, static DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx,
phi::DenseTensor *origin, DenseTensor *origin,
phi::DenseTensor *fused_out, DenseTensor *fused_out,
size_t numel_offset) { size_t numel_offset) {
PADDLE_ENFORCE_EQ(origin->IsInitialized(), PADDLE_ENFORCE_EQ(
true, origin->IsInitialized(),
platform::errors::InvalidArgument( true,
"The tensor to be cast should be initialized.")); errors::InvalidArgument("The tensor to be cast should be initialized."));
PADDLE_ENFORCE_EQ(fused_out->dtype(), PADDLE_ENFORCE_EQ(fused_out->dtype(),
phi::DataType::FLOAT32, phi::DataType::FLOAT32,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The dst tensor to be cast should be FP32 tensor.")); "The dst tensor to be cast should be FP32 tensor."));
PADDLE_ENFORCE_EQ(origin->dtype(), PADDLE_ENFORCE_EQ(origin->dtype(),
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The src tensor to be cast should be FP16 tensor.")); "The src tensor to be cast should be FP16 tensor."));
auto *dst = fused_out->data<float>() + numel_offset; auto *dst = fused_out->data<float>() + numel_offset;
auto *src = origin->data<platform::float16>(); auto *src = origin->data<dtype::float16>();
auto numel = origin->numel(); auto numel = origin->numel();
LaunchCastKernel(dev_ctx, src, dst, numel); LaunchCastKernel(dev_ctx, src, dst, numel);
VLOG(10) << "Cast from FP32 -> FP16, range: [" << numel_offset << ", " VLOG(10) << "Cast from FP32 -> FP16, range: [" << numel_offset << ", "
<< numel_offset + numel << ")" << numel_offset + numel << ")"
<< " , total: [0, " << fused_out->numel() << ")"; << " , total: [0, " << fused_out->numel() << ")";
framework::DDim fused_out_dim = fused_out->dims(); DDim fused_out_dim = fused_out->dims();
auto fused_out_numel = fused_out->numel(); auto fused_out_numel = fused_out->numel();
fused_out->Resize({fused_out_numel}); fused_out->Resize({fused_out_numel});
auto sliced_tensor = fused_out->Slice(numel_offset, numel + numel_offset); auto sliced_tensor = fused_out->Slice(numel_offset, numel + numel_offset);
...@@ -224,45 +221,40 @@ static phi::DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx, ...@@ -224,45 +221,40 @@ static phi::DenseTensor CastDataForInitedTensor(const phi::GPUContext &dev_ctx,
return sliced_tensor; return sliced_tensor;
} }
static phi::DenseTensor CopyAndShareBufferForInitedTensor( static DenseTensor CopyAndShareBufferForInitedTensor(
phi::DenseTensor *origin, const phi::GPUContext &dev_ctx,
phi::DenseTensor *fused_out, DenseTensor *origin,
size_t numel_offset, DenseTensor *fused_out,
gpuStream_t stream) { size_t numel_offset) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
origin->IsInitialized(), origin->IsInitialized(),
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The tensor to be copied and shared data should be initialized.")); "The tensor to be copied and shared data should be initialized."));
auto dtype = fused_out->type(); auto dtype = fused_out->type();
PADDLE_ENFORCE_EQ(origin->type(), PADDLE_ENFORCE_EQ(origin->type(),
dtype, dtype,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The tensor to be copied and shared data should be " "The tensor to be copied and shared data should be "
"have the same data type.")); "have the same data type."));
auto place = fused_out->place(); auto place = fused_out->place();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
origin->place(), origin->place(),
place, place,
platform::errors::InvalidArgument("The tensor to be copied and shared " errors::InvalidArgument("The tensor to be copied and shared "
"data should be have the same place.")); "data should be have the same place."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU,
true, true,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The tensor to be copied and shared data should be on GPU place.")); "The tensor to be copied and shared data should be on GPU place."));
auto numel = origin->numel(); auto numel = origin->numel();
framework::DDim fused_out_dim = fused_out->dims(); DDim fused_out_dim = fused_out->dims();
auto fused_out_numel = fused_out->numel(); auto fused_out_numel = fused_out->numel();
auto sliced_tensor = fused_out->Resize({fused_out_numel}) auto sliced_tensor = fused_out->Resize({fused_out_numel})
.Slice(numel_offset, numel + numel_offset); .Slice(numel_offset, numel + numel_offset);
memory::Copy(place, phi::Copy(dev_ctx, *origin, dev_ctx.GetPlace(), false, &sliced_tensor);
sliced_tensor.data(),
place,
origin->data(),
numel * phi::SizeOf(dtype),
stream);
origin->ShareBufferWith(sliced_tensor); origin->ShareBufferWith(sliced_tensor);
fused_out->Resize(fused_out_dim); fused_out->Resize(fused_out_dim);
VLOG(10) << "Copy and share buffer, range: [" << numel_offset << ", " VLOG(10) << "Copy and share buffer, range: [" << numel_offset << ", "
...@@ -271,17 +263,17 @@ static phi::DenseTensor CopyAndShareBufferForInitedTensor( ...@@ -271,17 +263,17 @@ static phi::DenseTensor CopyAndShareBufferForInitedTensor(
return sliced_tensor; return sliced_tensor;
} }
static void ShareBufferForNonInitedTensor(phi::DenseTensor *origin, static void ShareBufferForNonInitedTensor(DenseTensor *origin,
phi::DenseTensor *fused_out, DenseTensor *fused_out,
size_t numel_offset, size_t numel_offset,
const framework::DDim &dims) { const DDim &dims) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
origin->IsInitialized(), origin->IsInitialized(),
false, false,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The tensor to be shared data should not be initialized.")); "The tensor to be shared data should not be initialized."));
framework::DDim fused_out_dim = fused_out->dims(); DDim fused_out_dim = fused_out->dims();
auto fused_out_numel = fused_out->numel(); auto fused_out_numel = fused_out->numel();
auto numel = phi::product(dims); auto numel = phi::product(dims);
*origin = fused_out->Resize({fused_out_numel}) *origin = fused_out->Resize({fused_out_numel})
...@@ -294,10 +286,11 @@ static void ShareBufferForNonInitedTensor(phi::DenseTensor *origin, ...@@ -294,10 +286,11 @@ static void ShareBufferForNonInitedTensor(phi::DenseTensor *origin,
} }
template <typename T> template <typename T>
static void CopyVectorToCPUTensor(const std::vector<T> &src, static void CopyVectorToCPUTensor(const phi::GPUContext &dev_ctx,
phi::DenseTensor *dst) { const std::vector<T> &src,
DenseTensor *dst) {
dst->Resize({static_cast<int64_t>(src.size())}); dst->Resize({static_cast<int64_t>(src.size())});
T *dst_ptr = dst->mutable_data<T>(platform::CPUPlace()); T *dst_ptr = dev_ctx.template HostAlloc<T>(dst);
const T *src_ptr = src.data(); const T *src_ptr = src.data();
auto nbytes = src.size() * sizeof(T); auto nbytes = src.size() * sizeof(T);
std::memcpy(dst_ptr, src_ptr, nbytes); std::memcpy(dst_ptr, src_ptr, nbytes);
...@@ -339,459 +332,473 @@ static T ClipByBound(T x, T low_value, T high_value) { ...@@ -339,459 +332,473 @@ static T ClipByBound(T x, T low_value, T high_value) {
return x; return x;
} }
template <typename T> template <typename T, typename Context>
class DistributedFusedLambInitOpKernel<T, phi::GPUContext> void DistributedFusedLambInitOpKernel(
: public framework::OpKernel<T> { const Context &dev_ctx,
public: const std::vector<const DenseTensor *> &param,
void Compute(const framework::ExecutionContext &ctx) const override { const std::vector<const DenseTensor *> &grad,
VLOG(10) << "starts to run DistributedFusedLambInitOp"; float beta1,
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); float beta2,
auto place = ctx.GetPlace(); const std::vector<int> &apply_weight_decay,
auto stream = dev_ctx.stream(); int alignment,
int rank,
// Step 1: Check Input(Param) and Output(ParamOut), Input(Grad) and int nranks,
// Output(GradOut) DenseTensor *fp32_fused_param,
auto params = ctx.MultiInput<phi::DenseTensor>("Param"); DenseTensor *fp32_fused_grad,
auto grads = ctx.MultiInput<phi::DenseTensor>("Grad"); DenseTensor *fp16_fused_param,
auto master_params = ctx.MultiOutput<phi::DenseTensor>("MasterParamOut"); DenseTensor *fp16_fused_grad,
std::vector<ParamGradInfo> fp32_infos, fp16_infos; DenseTensor *moment1,
{ DenseTensor *moment2,
PADDLE_ENFORCE_EQ(params.size(), DenseTensor *beta1_pow,
grads.size(), DenseTensor *beta2_pow,
platform::errors::InvalidArgument( DenseTensor *fused_param_offsets,
"The parameter number and parameter gradient " DenseTensor *fp32_shard_fused_param_offsets,
"number should be the same.")); DenseTensor *fp16_shard_fused_param_offsets,
DenseTensor *param_info,
auto params_out = ctx.MultiOutput<phi::DenseTensor>("ParamOut"); DenseTensor *param_order,
auto grads_out = ctx.MultiOutput<phi::DenseTensor>("GradOut"); std::vector<DenseTensor *> param_out,
std::vector<DenseTensor *> master_param_out,
std::vector<DenseTensor *> grad_out,
DenseTensor *global_scale,
DenseTensor *step) {
VLOG(10) << "starts to run DistributedFusedLambInitOp";
auto place = dev_ctx.GetPlace();
auto stream = dev_ctx.stream();
// Step 1: Check Input(Param) and Output(ParamOut), Input(Grad) and
// Output(GradOut)
std::vector<ParamGradInfo> fp32_infos, fp16_infos;
{
PADDLE_ENFORCE_EQ(
param.size(),
grad.size(),
errors::InvalidArgument("The parameter number and parameter gradient "
"number should be the same."));
PADDLE_ENFORCE_EQ(
param.size(),
param_out.size(),
errors::InvalidArgument("Input(Param) and Output(ParamOut) "
"should have the same number."));
PADDLE_ENFORCE_EQ(
grad.size(),
grad_out.size(),
errors::InvalidArgument(
"Input(Grad) and Output(GradOut) should have the same number."));
size_t n = param.size();
VLOG(10) << "parameter number: " << n;
for (size_t i = 0; i < n; ++i) {
auto *p = param[i];
auto *g = grad[i];
auto *p_out = param_out[i];
auto *g_out = grad_out[i];
PADDLE_ENFORCE_NOT_NULL(
p,
errors::InvalidArgument("The %d-th parameter should not be nullptr.",
i));
PADDLE_ENFORCE_EQ(p->IsInitialized(),
true,
errors::InvalidArgument(
"The %d-th parameter should be initialized.", i));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
params.size(), p->place(),
params_out.size(), place,
platform::errors::InvalidArgument("Input(Param) and Output(ParamOut) " errors::InvalidArgument(
"should have the same number.")); "The %d-th parameter is not initialized on the right place.", i));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grads.size(), p,
grads_out.size(), p_out,
platform::errors::InvalidArgument( errors::InvalidArgument("The %d-th Input(Param) and Output(ParamOut) "
"Input(Grad) and Output(GradOut) should have the same number.")); "should be the same tensor.",
size_t n = params.size(); i));
VLOG(10) << "parameter number: " << n;
for (size_t i = 0; i < n; ++i) { auto dtype = p->dtype();
auto *p = params[i]; PADDLE_ENFORCE_NOT_NULL(
auto *g = grads[i]; g,
auto *p_out = params_out[i]; errors::InvalidArgument("The %d-th gradient should not be nullptr.",
auto *g_out = grads_out[i]; i));
PADDLE_ENFORCE_EQ(g,
PADDLE_ENFORCE_NOT_NULL( g_out,
p, errors::InvalidArgument(
platform::errors::InvalidArgument( "The %d-th Input(Grad) and Output(Grad) should "
"The %d-th parameter should not be nullptr.", i)); "be the same tensor."));
PADDLE_ENFORCE_EQ(p->IsInitialized(), auto numel = p->numel();
true, PADDLE_ENFORCE_GT(
platform::errors::InvalidArgument( numel,
"The %d-th parameter should be initialized.", i)); 0,
PADDLE_ENFORCE_EQ( errors::InvalidArgument("The %d-th Input(Param) have no elements."));
p->place(),
place, void *g_data = nullptr;
platform::errors::InvalidArgument( if (g->IsInitialized()) {
"The %d-th parameter is not initialized on the right place.", PADDLE_ENFORCE_EQ(g->dtype(),
i)); dtype,
PADDLE_ENFORCE_EQ(p, errors::InvalidArgument(
p_out, "The %d-th Input(Param) and Input(Grad) should "
platform::errors::InvalidArgument( "have the same data type %s.",
"The %d-th Input(Param) and Output(ParamOut) " i,
"should be the same tensor.", dtype));
PADDLE_ENFORCE_EQ(g->dims(),
p->dims(),
errors::InvalidArgument(
"The %d-th Input(Param) and Input(Grad) should "
"have the same shape.",
i)); i));
g_data = g_out->data();
}
auto dtype = p->dtype(); ParamGradInfo *info;
PADDLE_ENFORCE_NOT_NULL( if (dtype == phi::DataType::FLOAT32) {
g, fp32_infos.emplace_back();
platform::errors::InvalidArgument( info = &fp32_infos.back();
"The %d-th gradient should not be nullptr.", i)); } else if (dtype == phi::DataType::FLOAT16) {
PADDLE_ENFORCE_EQ(g, fp16_infos.emplace_back();
g_out, info = &fp16_infos.back();
platform::errors::InvalidArgument( } else {
"The %d-th Input(Grad) and Output(Grad) should " PADDLE_THROW(
"be the same tensor.")); errors::InvalidArgument("Unsupported data type %s.", dtype));
auto numel = p->numel();
PADDLE_ENFORCE_GT(numel,
0,
platform::errors::InvalidArgument(
"The %d-th Input(Param) have no elements."));
void *g_data = nullptr;
if (g->IsInitialized()) {
PADDLE_ENFORCE_EQ(g->dtype(),
dtype,
platform::errors::InvalidArgument(
"The %d-th Input(Param) and Input(Grad) should "
"have the same data type %s.",
i,
dtype));
PADDLE_ENFORCE_EQ(g->dims(),
p->dims(),
platform::errors::InvalidArgument(
"The %d-th Input(Param) and Input(Grad) should "
"have the same shape.",
i));
g_data = g_out->data();
}
ParamGradInfo *info;
if (dtype == phi::DataType::FLOAT32) {
fp32_infos.emplace_back();
info = &fp32_infos.back();
} else if (dtype == phi::DataType::FLOAT16) {
fp16_infos.emplace_back();
info = &fp16_infos.back();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported data type %s.", dtype));
}
VLOG(10) << "Found " << dtype << " parameter " << i << " shape=["
<< p_out->dims() << "] numel=" << numel
<< " grad.IsInitialized()="
<< (g_out->IsInitialized() ? "true" : "false");
info->param_t = p_out;
info->grad_t = g_out;
info->idx = i;
info->numel = numel;
info->numel_with_padding = 0; // not determined yet
info->numel_offset = 0; // not determined yet
} }
VLOG(10) << "Found " << dtype << " parameter " << i << " shape=["
<< p_out->dims() << "] numel=" << numel
<< " grad.IsInitialized()="
<< (g_out->IsInitialized() ? "true" : "false");
info->param_t = p_out;
info->grad_t = g_out;
info->idx = i;
info->numel = numel;
info->numel_with_padding = 0; // not determined yet
info->numel_offset = 0; // not determined yet
} }
const auto &apply_weight_decay = }
ctx.Attr<std::vector<int>>("apply_weight_decay");
size_t fp32_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp32_infos);
size_t fp16_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp16_infos);
auto *param_order_t = ctx.Output<phi::DenseTensor>("ParamOrder");
auto param_num = fp32_infos.size() + fp16_infos.size();
param_order_t->Resize({static_cast<int16_t>(param_num)});
auto *param_order = param_order_t->mutable_data<int>(platform::CPUPlace());
for (size_t i = 0; i < fp32_infos.size(); ++i) {
param_order[i] = static_cast<int>(fp32_infos[i].idx);
}
for (size_t i = 0; i < fp16_infos.size(); ++i) {
param_order[i + fp32_infos.size()] = static_cast<int>(fp16_infos[i].idx);
}
VLOG(10) << "Fill ParamGradInfo ends"; size_t fp32_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp32_infos);
size_t fp16_wd_end_idx =
ReorderParamGradInfoList(apply_weight_decay, &fp16_infos);
// Step 2: determine the numel_with_padding and numel_offset auto param_num = fp32_infos.size() + fp16_infos.size();
auto rank = ctx.Attr<int>("rank"); param_order->Resize({static_cast<int16_t>(param_num)});
auto nranks = ctx.Attr<int>("nranks"); auto *param_order_t = dev_ctx.template HostAlloc<int>(param_order);
auto alignment = ctx.Attr<int>("alignment"); for (size_t i = 0; i < fp32_infos.size(); ++i) {
VLOG(10) << "rank = " << rank << ", nranks = " << nranks param_order_t[i] = static_cast<int>(fp32_infos[i].idx);
<< " , alignment = " << alignment; }
if (alignment <= 0) { for (size_t i = 0; i < fp16_infos.size(); ++i) {
alignment = platform::GpuMinChunkSize(); param_order_t[i + fp32_infos.size()] = static_cast<int>(fp16_infos[i].idx);
} }
PADDLE_ENFORCE_GE(alignment,
1,
platform::errors::InvalidArgument(
"The attr(alignment) should be larger than 0."));
PADDLE_ENFORCE_EQ(alignment & (alignment - 1),
0,
platform::errors::InvalidArgument(
"The attr(alignment) should be the power of 2."));
PADDLE_ENFORCE_GE(
rank,
0,
platform::errors::InvalidArgument(
"The attr(rank) should be equal to or larger than 0."));
PADDLE_ENFORCE_LT(
rank,
nranks,
platform::errors::InvalidArgument(
"The attr(rank) should be less than the attr(nranks)."));
// NOTE: We guarantee that both fp32_numel and fp16_numel can be exactly
// divided by alignment and nranks.
auto fp32_numel = FillAlignmentPaddingInfo(
&fp32_infos, alignment, nranks, phi::DataType::FLOAT32);
VLOG(10) << "FP32 ParamGradInfo: " << string::join_strings(fp32_infos, " ");
auto fp16_numel = FillAlignmentPaddingInfo(
&fp16_infos, alignment, nranks, phi::DataType::FLOAT16);
VLOG(10) << "FP16 ParamGradInfo: " << string::join_strings(fp16_infos, " ");
auto total_numel = fp32_numel + fp16_numel;
PADDLE_ENFORCE_LT(
total_numel,
std::numeric_limits<int>::max(),
platform::errors::InvalidArgument("Too many parameter number."));
auto fp32_numel_each_device = fp32_numel / nranks;
auto fp16_numel_each_device = fp16_numel / nranks;
auto numel_each_device = fp32_numel_each_device + fp16_numel_each_device;
VLOG(10) << "Fill padding ends. total_numel = " << total_numel
<< ", fp32_numel = " << fp32_numel
<< ", fp16_numel = " << fp16_numel
<< ", fp32_numel_each_device = " << fp32_numel_each_device
<< ", fp16_numel_each_device = " << fp16_numel_each_device;
// Step 3: allocate output tensor and do initialization
float *fused_fp32_param = nullptr, *fused_fp32_grad = nullptr;
platform::float16 *fused_fp16_param = nullptr, *fused_fp16_grad = nullptr;
phi::DenseTensor *fp32_p_t = nullptr, *fp16_p_t = nullptr,
*fp32_g_t = nullptr, *fp16_g_t = nullptr;
std::vector<phi::DenseTensor *> fp16_master_params;
if (total_numel > 0) {
fp32_p_t = ctx.Output<phi::DenseTensor>("FP32FusedParam");
fused_fp32_param = TensorFillConstant<float>(
dev_ctx, fp32_p_t, {static_cast<int64_t>(total_numel)}, 0.0f);
}
if (fp32_numel > 0) { VLOG(10) << "Fill ParamGradInfo ends";
fp32_g_t = ctx.Output<phi::DenseTensor>("FP32FusedGrad");
fused_fp32_grad = TensorFillConstant<float>(
dev_ctx, fp32_g_t, {static_cast<int64_t>(fp32_numel)}, 0.0f);
}
if (fp16_numel > 0) { // Step 2: determine the numel_with_padding and numel_offset
fp16_p_t = ctx.Output<phi::DenseTensor>("FP16FusedParam"); VLOG(10) << "rank = " << rank << ", nranks = " << nranks
fused_fp16_param = TensorFillConstant<platform::float16>( << " , alignment = " << alignment;
dev_ctx, if (alignment <= 0) {
fp16_p_t, alignment = phi::backends::gpu::GpuMinChunkSize();
{static_cast<int64_t>(fp16_numel)}, }
static_cast<platform::float16>(0)); PADDLE_ENFORCE_GE(
alignment,
fp16_g_t = ctx.Output<phi::DenseTensor>("FP16FusedGrad"); 1,
fused_fp16_grad = TensorFillConstant<platform::float16>( errors::InvalidArgument("The attr(alignment) should be larger than 0."));
dev_ctx, PADDLE_ENFORCE_EQ(
fp16_g_t, alignment & (alignment - 1),
{static_cast<int64_t>(fp16_numel)}, 0,
static_cast<platform::float16>(0)); errors::InvalidArgument("The attr(alignment) should be the power of 2."));
} PADDLE_ENFORCE_GE(rank,
VLOG(10) << "Allocate FP32FusedParam/Grad, FP16FusedParam/Grad ends"; 0,
errors::InvalidArgument(
// (1) For FP32FusedParam, memcpy for fp32 param and then share data, cast "The attr(rank) should be equal to or larger than 0."));
// for fp16 master weight PADDLE_ENFORCE_LT(
// (2) For FP16FusedParam, memcpy and then share data rank,
// (3) For FP32FusedGrad/FP16FusedGrad, memcpy if gradient has been inited nranks,
for (const auto &info : fp32_infos) { errors::InvalidArgument(
auto sliced_tensor = CopyAndShareBufferForInitedTensor( "The attr(rank) should be less than the attr(nranks)."));
info.param_t, fp32_p_t, info.numel_offset, stream); // NOTE: We guarantee that both fp32_numel and fp16_numel can be exactly
master_params[info.idx]->Resize(info.param_t->dims()); // divided by alignment and nranks.
master_params[info.idx]->ShareBufferWith(sliced_tensor); auto fp32_numel = FillAlignmentPaddingInfo(
PADDLE_ENFORCE_EQ(master_params[info.idx]->mutable_data<float>(place), &fp32_infos, alignment, nranks, phi::DataType::FLOAT32);
sliced_tensor.data<float>(), VLOG(10) << "FP32 ParamGradInfo: "
platform::errors::InvalidArgument( << paddle::string::join_strings(fp32_infos, " ");
"Invalid master weight tensor pointer.")); auto fp16_numel = FillAlignmentPaddingInfo(
if (info.grad_t->IsInitialized()) { &fp16_infos, alignment, nranks, phi::DataType::FLOAT16);
CopyAndShareBufferForInitedTensor( VLOG(10) << "FP16 ParamGradInfo: "
info.grad_t, fp32_g_t, info.numel_offset, stream); << paddle::string::join_strings(fp16_infos, " ");
} else { auto total_numel = fp32_numel + fp16_numel;
ShareBufferForNonInitedTensor( PADDLE_ENFORCE_LT(total_numel,
info.grad_t, fp32_g_t, info.numel_offset, info.param_t->dims()); std::numeric_limits<int>::max(),
} errors::InvalidArgument("Too many parameter number."));
}
auto fp32_numel_each_device = fp32_numel / nranks;
auto fp16_numel_each_device = fp16_numel / nranks;
auto numel_each_device = fp32_numel_each_device + fp16_numel_each_device;
VLOG(10) << "Fill padding ends. total_numel = " << total_numel
<< ", fp32_numel = " << fp32_numel << ", fp16_numel = " << fp16_numel
<< ", fp32_numel_each_device = " << fp32_numel_each_device
<< ", fp16_numel_each_device = " << fp16_numel_each_device;
// Step 3: allocate output tensor and do initialization
float *fused_fp32_param = nullptr, *fused_fp32_grad = nullptr;
dtype::float16 *fused_fp16_param = nullptr, *fused_fp16_grad = nullptr;
DenseTensor *fp32_p_t = nullptr, *fp16_p_t = nullptr, *fp32_g_t = nullptr,
*fp16_g_t = nullptr;
std::vector<DenseTensor *> fp16_master_params;
if (total_numel > 0) {
fp32_p_t = fp32_fused_param;
fused_fp32_param = TensorFillConstant<float>(
dev_ctx, fp32_p_t, {static_cast<int64_t>(total_numel)}, 0.0f);
}
if (fp32_numel > 0) {
fp32_g_t = fp32_fused_grad;
fused_fp32_grad = TensorFillConstant<float>(
dev_ctx, fp32_g_t, {static_cast<int64_t>(fp32_numel)}, 0.0f);
}
size_t fp16_numel_offset = 0; if (fp16_numel > 0) {
if (fp32_numel > 0) { fp16_p_t = fp16_fused_param;
auto last_fp32_info = fp32_infos.back(); fused_fp16_param =
fp16_numel_offset = TensorFillConstant<dtype::float16>(dev_ctx,
last_fp32_info.numel_offset + last_fp32_info.numel_with_padding; fp16_p_t,
{static_cast<int64_t>(fp16_numel)},
static_cast<dtype::float16>(0));
fp16_g_t = fp16_fused_grad;
fused_fp16_grad =
TensorFillConstant<dtype::float16>(dev_ctx,
fp16_g_t,
{static_cast<int64_t>(fp16_numel)},
static_cast<dtype::float16>(0));
}
VLOG(10) << "Allocate FP32FusedParam/Grad, FP16FusedParam/Grad ends";
// (1) For FP32FusedParam, memcpy for fp32 param and then share data, cast
// for fp16 master weight
// (2) For FP16FusedParam, memcpy and then share data
// (3) For FP32FusedGrad/FP16FusedGrad, memcpy if gradient has been inited
for (const auto &info : fp32_infos) {
auto sliced_tensor = CopyAndShareBufferForInitedTensor(
dev_ctx, info.param_t, fp32_p_t, info.numel_offset);
master_param_out[info.idx]->Resize(info.param_t->dims());
master_param_out[info.idx]->ShareBufferWith(sliced_tensor);
float *master_param_tmp =
dev_ctx.template Alloc<float>(master_param_out[info.idx]);
float *sliced_tensor_tmp = reinterpret_cast<float *>(sliced_tensor.data());
PADDLE_ENFORCE_EQ(
master_param_tmp,
sliced_tensor_tmp,
errors::InvalidArgument("Invalid master weight tensor pointer."));
if (info.grad_t->IsInitialized()) {
CopyAndShareBufferForInitedTensor(
dev_ctx, info.grad_t, fp32_g_t, info.numel_offset);
} else {
ShareBufferForNonInitedTensor(
info.grad_t, fp32_g_t, info.numel_offset, info.param_t->dims());
} }
}
size_t fp16_numel_offset = 0;
if (fp32_numel > 0) {
auto last_fp32_info = fp32_infos.back();
fp16_numel_offset =
last_fp32_info.numel_offset + last_fp32_info.numel_with_padding;
}
for (const auto &info : fp16_infos) { for (const auto &info : fp16_infos) {
auto master_weight_offset = info.numel_offset + fp16_numel_offset; auto master_weight_offset = info.numel_offset + fp16_numel_offset;
auto sliced_tensor = CastDataForInitedTensor( auto sliced_tensor = CastDataForInitedTensor(
dev_ctx, info.param_t, fp32_p_t, master_weight_offset); dev_ctx, info.param_t, fp32_p_t, master_weight_offset);
master_params[info.idx]->Resize(info.param_t->dims()); master_param_out[info.idx]->Resize(info.param_t->dims());
master_params[info.idx]->ShareBufferWith(sliced_tensor); master_param_out[info.idx]->ShareBufferWith(sliced_tensor);
CopyAndShareBufferForInitedTensor(
dev_ctx, info.param_t, fp16_p_t, info.numel_offset);
float *master_param_tmp =
dev_ctx.template Alloc<float>(master_param_out[info.idx]);
float *sliced_tensor_tmp = reinterpret_cast<float *>(sliced_tensor.data());
PADDLE_ENFORCE_EQ(
master_param_tmp,
sliced_tensor_tmp,
errors::InvalidArgument("Invalid master weight tensor pointer."));
if (info.grad_t->IsInitialized()) {
CopyAndShareBufferForInitedTensor( CopyAndShareBufferForInitedTensor(
info.param_t, fp16_p_t, info.numel_offset, stream); dev_ctx, info.grad_t, fp16_g_t, info.numel_offset);
PADDLE_ENFORCE_EQ(master_params[info.idx]->mutable_data<float>(place), } else {
sliced_tensor.data<float>(), ShareBufferForNonInitedTensor(
platform::errors::InvalidArgument( info.grad_t, fp16_g_t, info.numel_offset, info.param_t->dims());
"Invalid master weight tensor pointer."));
if (info.grad_t->IsInitialized()) {
CopyAndShareBufferForInitedTensor(
info.grad_t, fp16_g_t, info.numel_offset, stream);
} else {
ShareBufferForNonInitedTensor(
info.grad_t, fp16_g_t, info.numel_offset, info.param_t->dims());
}
} }
VLOG(10) << "Copy/share data for Param/Grad ends"; }
VLOG(10) << "Copy/share data for Param/Grad ends";
// Step 4: For Moment1, Moment2, Beta1Pow, Beta2Pow, just fill constant
TensorFillConstant<float>(dev_ctx, // Step 4: For Moment1, Moment2, Beta1Pow, Beta2Pow, just fill constant
ctx.Output<phi::DenseTensor>("Moment1"), TensorFillConstant<float>(
{static_cast<int64_t>(numel_each_device)}, dev_ctx, moment1, {static_cast<int64_t>(numel_each_device)}, 0.0f);
0.0f); TensorFillConstant<float>(
TensorFillConstant<float>(dev_ctx, dev_ctx, moment2, {static_cast<int64_t>(numel_each_device)}, 0.0f);
ctx.Output<phi::DenseTensor>("Moment2"), TensorFillConstant<float>(dev_ctx, beta1_pow, {1}, beta1);
{static_cast<int64_t>(numel_each_device)}, TensorFillConstant<float>(dev_ctx, beta2_pow, {1}, beta2);
0.0f); VLOG(10) << "Init Moment and BetaPow ends";
TensorFillConstant<float>(dev_ctx,
ctx.Output<phi::DenseTensor>("Beta1Pow"), // Step 5: Do sharding
{1}, size_t fp32_start_idx, fp32_end_idx, fp32_start_numel_offset,
ctx.Attr<float>("beta1")); fp32_end_numel_offset;
TensorFillConstant<float>(dev_ctx, GetParamGradShardInfo(fp32_infos,
ctx.Output<phi::DenseTensor>("Beta2Pow"), rank * fp32_numel_each_device,
{1}, (rank + 1) * fp32_numel_each_device,
ctx.Attr<float>("beta2")); &fp32_start_idx,
VLOG(10) << "Init Moment and BetaPow ends"; &fp32_end_idx,
&fp32_start_numel_offset,
// Step 5: Do sharding &fp32_end_numel_offset);
size_t fp32_start_idx, fp32_end_idx, fp32_start_numel_offset, size_t fp16_start_idx, fp16_end_idx, fp16_start_numel_offset,
fp32_end_numel_offset; fp16_end_numel_offset;
GetParamGradShardInfo(fp32_infos, GetParamGradShardInfo(fp16_infos,
rank * fp32_numel_each_device, rank * fp16_numel_each_device,
(rank + 1) * fp32_numel_each_device, (rank + 1) * fp16_numel_each_device,
&fp32_start_idx, &fp16_start_idx,
&fp32_end_idx, &fp16_end_idx,
&fp32_start_numel_offset, &fp16_start_numel_offset,
&fp32_end_numel_offset); &fp16_end_numel_offset);
size_t fp16_start_idx, fp16_end_idx, fp16_start_numel_offset, size_t fp32_local_param_num =
fp16_end_numel_offset; fp32_numel_each_device > 0 ? fp32_end_idx - fp32_start_idx + 1 : 0;
GetParamGradShardInfo(fp16_infos, size_t fp16_local_param_num =
rank * fp16_numel_each_device, fp16_numel_each_device > 0 ? fp16_end_idx - fp16_start_idx + 1 : 0;
(rank + 1) * fp16_numel_each_device, size_t total_local_param_num = fp32_local_param_num + fp16_local_param_num;
&fp16_start_idx, VLOG(10) << "Found the sharding arguments";
&fp16_end_idx,
&fp16_start_numel_offset, param_info->Resize({8});
&fp16_end_numel_offset); auto *param_info_t = dev_ctx.template HostAlloc<int>(param_info);
size_t fp32_local_param_num = param_info_t[0] = static_cast<int>(fp32_start_idx);
fp32_numel_each_device > 0 ? fp32_end_idx - fp32_start_idx + 1 : 0; param_info_t[1] = static_cast<int>(fp32_local_param_num);
size_t fp16_local_param_num = param_info_t[2] = static_cast<int>(fp32_infos.size());
fp16_numel_each_device > 0 ? fp16_end_idx - fp16_start_idx + 1 : 0; param_info_t[3] = ClipByBound<int>(fp32_wd_end_idx,
size_t total_local_param_num = fp32_local_param_num + fp16_local_param_num;
VLOG(10) << "Found the sharding arguments";
auto *param_info_t = ctx.Output<phi::DenseTensor>("ParamInfo");
param_info_t->Resize({8});
auto *param_info = param_info_t->mutable_data<int>(platform::CPUPlace());
param_info[0] = static_cast<int>(fp32_start_idx);
param_info[1] = static_cast<int>(fp32_local_param_num);
param_info[2] = static_cast<int>(fp32_infos.size());
param_info[3] = ClipByBound<int>(fp32_wd_end_idx,
fp32_start_idx, fp32_start_idx,
fp32_start_idx + fp32_local_param_num) - fp32_start_idx + fp32_local_param_num) -
static_cast<int>(fp32_start_idx); static_cast<int>(fp32_start_idx);
param_info[4] = static_cast<int>(fp16_start_idx + fp32_infos.size()); param_info_t[4] = static_cast<int>(fp16_start_idx + fp32_infos.size());
param_info[5] = static_cast<int>(fp16_local_param_num); param_info_t[5] = static_cast<int>(fp16_local_param_num);
param_info[6] = static_cast<int>(fp16_infos.size()); param_info_t[6] = static_cast<int>(fp16_infos.size());
param_info[7] = ClipByBound<int>(fp16_wd_end_idx, param_info_t[7] = ClipByBound<int>(fp16_wd_end_idx,
fp16_start_idx, fp16_start_idx,
fp16_start_idx + fp16_local_param_num) - fp16_start_idx + fp16_local_param_num) -
static_cast<int>(fp16_start_idx); static_cast<int>(fp16_start_idx);
VLOG(10) << "Start FP32 idx: " << param_info[0]; VLOG(10) << "Start FP32 idx: " << param_info_t[0];
VLOG(10) << "Local FP32 param num: " << param_info[1]; VLOG(10) << "Local FP32 param num: " << param_info_t[1];
VLOG(10) << "Global FP32 param num: " << param_info[2]; VLOG(10) << "Global FP32 param num: " << param_info_t[2];
VLOG(10) << "Start FP16 idx: " << param_info[4]; VLOG(10) << "Start FP16 idx: " << param_info_t[4];
VLOG(10) << "Local FP16 param num: " << param_info[5]; VLOG(10) << "Local FP16 param num: " << param_info_t[5];
VLOG(10) << "Global FP16 param num: " << param_info[6]; VLOG(10) << "Global FP16 param num: " << param_info_t[6];
std::vector<int> numel_offsets; std::vector<int> numel_offsets;
numel_offsets.reserve(params.size() + 1); numel_offsets.reserve(param.size() + 1);
for (const auto &info : fp32_infos) { for (const auto &info : fp32_infos) {
numel_offsets.push_back(info.numel_offset); numel_offsets.push_back(info.numel_offset);
} }
for (const auto &info : fp16_infos) { for (const auto &info : fp16_infos) {
numel_offsets.push_back(info.numel_offset + fp16_numel_offset); numel_offsets.push_back(info.numel_offset + fp16_numel_offset);
}
numel_offsets.push_back(fp32_numel + fp16_numel);
PADDLE_ENFORCE_EQ(numel_offsets.size(),
param.size() + 1,
errors::InvalidArgument(
"The numel_offsets number must be one larger than "
"the parameter number."));
VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets);
std::vector<int> fp32_partial_numel_offsets;
fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1);
fp32_partial_numel_offsets.push_back(0);
// Fill the partial_numel_offsets
for (size_t i = fp32_start_idx; i < fp32_start_idx + fp32_local_param_num;
++i) {
size_t valid_start_n = 0;
if (i == fp32_start_idx) {
valid_start_n = fp32_start_numel_offset;
} }
numel_offsets.push_back(fp32_numel + fp16_numel);
PADDLE_ENFORCE_EQ(numel_offsets.size(),
params.size() + 1,
platform::errors::InvalidArgument(
"The numel_offsets number must be one larger than "
"the parameter number."));
VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets);
std::vector<int> fp32_partial_numel_offsets;
fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1);
fp32_partial_numel_offsets.push_back(0);
// Fill the partial_numel_offsets
for (size_t i = fp32_start_idx; i < fp32_start_idx + fp32_local_param_num;
++i) {
size_t valid_start_n = 0;
if (i == fp32_start_idx) {
valid_start_n = fp32_start_numel_offset;
}
size_t end_n = fp32_infos[i].numel_with_padding;
if (i + 1 == fp32_start_idx + fp32_local_param_num) {
end_n = std::min(end_n, fp32_end_numel_offset);
}
PADDLE_ENFORCE_NE(valid_start_n, size_t end_n = fp32_infos[i].numel_with_padding;
end_n, if (i + 1 == fp32_start_idx + fp32_local_param_num) {
platform::errors::InvalidArgument( end_n = std::min(end_n, fp32_end_numel_offset);
"Indices sharding error. This may be a bug."));
VLOG(10) << "FP32 Partial numel = ["
<< valid_start_n + fp32_infos[i].numel << ","
<< end_n + fp32_infos[i].numel;
auto len = end_n - valid_start_n;
fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() +
len);
} }
std::vector<int> fp16_partial_numel_offsets; PADDLE_ENFORCE_NE(
fp16_partial_numel_offsets.reserve(fp16_local_param_num + 1); valid_start_n,
fp16_partial_numel_offsets.push_back(0); end_n,
for (size_t i = fp16_start_idx; i < fp16_start_idx + fp16_local_param_num; errors::InvalidArgument("Indices sharding error. This may be a bug."));
++i) { VLOG(10) << "FP32 Partial numel = [" << valid_start_n + fp32_infos[i].numel
size_t valid_start_n = 0; << "," << end_n + fp32_infos[i].numel;
if (i == fp16_start_idx) { auto len = end_n - valid_start_n;
valid_start_n = fp16_start_numel_offset; fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() +
} len);
}
size_t end_n = fp16_infos[i].numel_with_padding;
if (i + 1 == fp16_start_idx + fp16_local_param_num) {
end_n = std::min(end_n, fp16_end_numel_offset);
}
PADDLE_ENFORCE_NE(valid_start_n, std::vector<int> fp16_partial_numel_offsets;
end_n, fp16_partial_numel_offsets.reserve(fp16_local_param_num + 1);
platform::errors::InvalidArgument( fp16_partial_numel_offsets.push_back(0);
"Indices sharding error. This may be a bug.")); for (size_t i = fp16_start_idx; i < fp16_start_idx + fp16_local_param_num;
auto len = end_n - valid_start_n; ++i) {
fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() + size_t valid_start_n = 0;
len); if (i == fp16_start_idx) {
valid_start_n = fp16_start_numel_offset;
} }
CopyVectorToCPUTensor(numel_offsets, size_t end_n = fp16_infos[i].numel_with_padding;
ctx.Output<phi::DenseTensor>("FusedParamOffsets")); if (i + 1 == fp16_start_idx + fp16_local_param_num) {
CopyVectorToCPUTensor( end_n = std::min(end_n, fp16_end_numel_offset);
fp32_partial_numel_offsets,
ctx.Output<phi::DenseTensor>("FP32ShardFusedParamOffsets"));
CopyVectorToCPUTensor(
fp16_partial_numel_offsets,
ctx.Output<phi::DenseTensor>("FP16ShardFusedParamOffsets"));
auto *global_scale = ctx.Output<phi::DenseTensor>("GlobalScale");
if (!global_scale->IsInitialized()) {
TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f);
} }
VLOG(10) << "Init global scale ends";
TensorFillConstant<int64_t>(dev_ctx, PADDLE_ENFORCE_NE(
ctx.Output<phi::DenseTensor>("Step"), valid_start_n,
{1}, end_n,
static_cast<int64_t>(0)); errors::InvalidArgument("Indices sharding error. This may be a bug."));
auto len = end_n - valid_start_n;
fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() +
len);
}
CopyVectorToCPUTensor(dev_ctx, numel_offsets, fused_param_offsets);
CopyVectorToCPUTensor(
dev_ctx, fp32_partial_numel_offsets, fp32_shard_fused_param_offsets);
CopyVectorToCPUTensor(
dev_ctx, fp16_partial_numel_offsets, fp16_shard_fused_param_offsets);
dev_ctx.Wait(); if (!global_scale->IsInitialized()) {
VLOG(10) << "Wait for H2D copy"; TensorFillConstant<float>(dev_ctx, global_scale, {1}, 1.0f);
} }
}; VLOG(10) << "Init global scale ends";
} // namespace operators TensorFillConstant<int64_t>(dev_ctx, step, {1}, static_cast<int64_t>(0));
} // namespace paddle
namespace ops = paddle::operators; dev_ctx.Wait();
namespace plat = paddle::platform; VLOG(10) << "Wait for H2D copy";
}
PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb_init, } // namespace fusion
GPU, } // namespace phi
ALL_LAYOUT,
ops::DistributedFusedLambInitOpKernel, PD_REGISTER_KERNEL(distributed_fused_lamb_init,
float) {} GPU,
ALL_LAYOUT,
phi::fusion::DistributedFusedLambInitOpKernel,
float) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(8).SetDataType(phi::DataType::INT32);
kernel->OutputAt(9).SetDataType(phi::DataType::INT32);
kernel->OutputAt(10).SetDataType(phi::DataType::INT32);
kernel->OutputAt(11).SetDataType(phi::DataType::INT32);
kernel->OutputAt(12).SetDataType(phi::DataType::INT32);
kernel->OutputAt(13).SetDataType(kernel_key.dtype());
kernel->OutputAt(14).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(15).SetDataType(kernel_key.dtype());
kernel->OutputAt(16).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(17).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,22 +12,37 @@ ...@@ -12,22 +12,37 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/fluid/framework/op_registry.h" namespace phi {
#include "paddle/fluid/framework/operator.h"
namespace paddle { KernelSignature DistributedFusedLambInitOpArgumentMapping(
namespace operators { const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"distributed_fused_lamb_init",
{"Param", "Grad"},
{"beta1", "beta2", "apply_weight_decay", "alignment", "rank", "nranks"},
{"FP32FusedParam",
"FP32FusedGrad",
"FP16FusedParam",
"FP16FusedGrad",
"Moment1",
"Moment2",
"Beta1Pow",
"Beta2Pow",
"FusedParamOffsets",
"FP32ShardFusedParamOffsets",
"FP16ShardFusedParamOffsets",
"ParamInfo",
"ParamOrder",
"ParamOut",
"MasterParamOut",
"GradOut",
"GlobalScale",
"Step"});
}
template <typename T, typename DevCtx> } // namespace phi
class DistributedFusedLambInitOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"The distributed_fused_lamb_init operator does not support CPU yet."));
}
};
} // namespace operators PD_REGISTER_ARG_MAPPING_FN(distributed_fused_lamb_init,
} // namespace paddle phi::DistributedFusedLambInitOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册