未验证 提交 d32a0102 编写于 作者: S sneaxiy 提交者: GitHub

Add MultiTensorApply to calculate L2-Norm in DistributedFusedLamb optimizer (#39900)

* add multi tensor apply l2 norm

* add multi_tensor_apply code

* make sizeof(TensorMeta) smalller

* move code to distributed_fused_lamb_op.cu

* remove useless FLAGS
上级 639675de
...@@ -284,6 +284,16 @@ static void CopyVectorToTensor(const std::vector<T> &src, ...@@ -284,6 +284,16 @@ static void CopyVectorToTensor(const std::vector<T> &src,
memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream); memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream);
} }
template <typename T>
static void CopyVectorToCPUTensor(const std::vector<T> &src,
framework::Tensor *dst) {
dst->Resize({static_cast<int64_t>(src.size())});
T *dst_ptr = dst->mutable_data<T>(platform::CPUPlace());
const T *src_ptr = src.data();
auto nbytes = src.size() * sizeof(T);
std::memcpy(dst_ptr, src_ptr, nbytes);
}
template <typename T> template <typename T>
class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -677,14 +687,14 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T> ...@@ -677,14 +687,14 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
lengths.back()); lengths.back());
} }
CopyVectorToTensor( CopyVectorToCPUTensor(numel_offsets,
ctx.Output<framework::Tensor>("FusedParamOffsets"));
CopyVectorToCPUTensor(
fp32_partial_numel_offsets, fp32_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"), place, ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"));
stream); CopyVectorToCPUTensor(
CopyVectorToTensor(
fp16_partial_numel_offsets, fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"), place, ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));
stream);
// Fill the weight decay tensor // Fill the weight decay tensor
PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(), PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(),
......
...@@ -33,12 +33,7 @@ class DistributedFusedLambOp : public framework::OperatorWithKernel { ...@@ -33,12 +33,7 @@ class DistributedFusedLambOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ParamInfo") {
return expected_kernel_type; return expected_kernel_type;
} else {
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
} }
}; };
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#include <cmath> #include <cmath>
#include "paddle/fluid/memory/buffer.h" #include "paddle/fluid/memory/buffer.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h" #include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
#include "paddle/fluid/operators/tensor_to_string.h" #include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -40,6 +42,163 @@ namespace operators { ...@@ -40,6 +42,163 @@ namespace operators {
template <typename T> template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type; using MasterT = typename details::MPTypeTrait<T>::Type;
template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}
template <typename T, int BlockDim, int VecSize>
struct L2NormFunctor {
DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size,
const T *x, MasterT<T> *y, int max_chunk_num) const {
using MT = MasterT<T>;
const T *ptr = x + offset;
using BlockReduce = cub::BlockReduce<MT, BlockDim>;
__shared__ typename BlockReduce::TempStorage storage;
MT square_sum = static_cast<MT>(0);
int i;
for (i = threadIdx.x * VecSize; i + VecSize <= size;
i += (BlockDim * VecSize)) {
platform::AlignedVector<T, VecSize> tmp_vec;
platform::Load(ptr + i, &tmp_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
auto tmp = static_cast<MT>(tmp_vec[j]);
square_sum += (tmp * tmp);
}
}
for (; i < size; ++i) {
auto tmp = static_cast<MT>(ptr[i]);
square_sum += (tmp * tmp);
}
square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum());
if (threadIdx.x == 0) {
y[tensor_id * max_chunk_num + chunk_id] = square_sum;
}
}
};
template <typename InT, typename OutT, int BlockDim, bool NeedSqrt>
static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
const InT *x, OutT *y, int max_chunk_num) {
int tensor_id = blockIdx.x;
x += (tensor_id * max_chunk_num);
using BlockReduce = cub::BlockReduce<InT, BlockDim>;
__shared__ typename BlockReduce::TempStorage storage;
InT sum = static_cast<InT>(0);
for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) {
sum += x[i];
}
sum = BlockReduce(storage).Reduce(sum, cub::Sum());
if (threadIdx.x == 0) {
if (NeedSqrt) {
y[blockIdx.x] = static_cast<OutT>(sqrtf(sum));
} else {
y[blockIdx.x] = static_cast<OutT>(sum);
}
}
}
template <typename T>
static int GetChunkedVecSize(const T *ptr, int chunk_size) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
auto address = reinterpret_cast<uintptr_t>(ptr);
constexpr int vec8 = alignof(platform::AlignedVector<T, 8>);
constexpr int vec4 = alignof(platform::AlignedVector<T, 4>);
constexpr int vec2 = alignof(platform::AlignedVector<T, 2>);
if (address % vec8 == 0 && chunk_size % vec8 == 0) {
return std::min(8, valid_vec_size);
} else if (address % vec4 == 0 && chunk_size % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0 && chunk_size % vec2 == 0) {
return std::min(2, valid_vec_size);
} else {
return 1;
}
}
#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \
case __vec_size: { \
constexpr int kVecSize = __vec_size; \
__VA_ARGS__; \
break; \
}
#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \
do { \
switch (__vec_size) { \
PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \
} \
} while (0)
// TODO(zengjinle): which chunk_size is better?
template <typename InT, typename OutT, bool NeedSqrt = false,
int MaxTensorNumPerLaunch = 50, int MaxChunkNumPerLaunch = 680,
int BlockDim = 512>
static void MultiTensorL2Norm(const platform::CUDAPlace &place,
gpuStream_t stream, const InT *x,
const int *offsets, int n, OutT *y,
int chunk_size = 65536) {
if (n <= 0) return;
constexpr int kNumTensor = MaxTensorNumPerLaunch;
constexpr int kNumChunk = MaxChunkNumPerLaunch;
constexpr int kBlockDim = BlockDim;
int max_chunk_num = -1;
int vec_size = 8;
int total_chunk_num = 0;
for (int i = 0; i < n; ++i) {
vec_size = std::min(
vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size));
int length = offsets[i + 1] - offsets[i];
auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size;
max_chunk_num = std::max(max_chunk_num, tmp_chunk_num);
total_chunk_num += tmp_chunk_num;
}
VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num
<< " , total_chunk_num = " << total_chunk_num
<< " , tensor_num = " << n;
using MT = MasterT<InT>;
memory::Buffer tmp_out(place);
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);
#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \
do { \
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
VLOG(10) << __func__ << " " << typeid(InT).name() \
<< " VecSize = " << kVecSize; \
MultiTensorApply<FunctorT, kBlockDim, kNumTensor, kNumChunk>( \
FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \
max_chunk_num); \
} while (0)
PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL);
#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL
MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim,
NeedSqrt><<<n, kBlockDim, 0, stream>>>(
tmp_out_ptr, y, max_chunk_num);
}
template <int LogLevel> template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm( static void LogParamAndTrustRatioDivSquareNorm(
const framework::ExecutionContext &ctx, const float *param_square_norm, const framework::ExecutionContext &ctx, const float *param_square_norm,
...@@ -620,76 +779,6 @@ static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out, ...@@ -620,76 +779,6 @@ static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out,
num_items, reduction_op, init, stream)); num_items, reduction_op, init, stream));
} }
template <typename InputIteratorT, typename OutputIteratorT,
typename OffsetIteratorT, typename ReductionOp, typename T>
static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out,
int num_segments,
OffsetIteratorT d_begin_offsets,
OffsetIteratorT d_end_offsets,
ReductionOp reduction_op, T initial_value,
gpuStream_t stream,
memory::Buffer *buffer) {
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
}
template <typename T>
struct AddConstantFunctor {
explicit AddConstantFunctor(T bias) : bias_(bias) {}
T operator()(T x) const { return x + bias_; }
private:
T bias_;
};
template <typename T>
struct OffsetWithBiasFunctor {
OffsetWithBiasFunctor(const T *offset, T bias)
: offset_(offset), bias_(bias) {}
HOSTDEVICE T operator()(T idx) const { return offset_[idx] - bias_; }
HOSTDEVICE constexpr bool operator==(const OffsetWithBiasFunctor<T> &) const {
return true;
}
private:
const T *offset_;
const T bias_;
};
template <typename T, typename OffsetT>
static void CubDeviceSegmentedSquareNorm(const T *x, MasterT<T> *y, int n,
const OffsetT *offset,
OffsetT init_offset,
gpuStream_t stream,
memory::Buffer *buffer) {
if (n <= 0) return;
cub::TransformInputIterator<MasterT<T>, SquareFunctor<T>, const T *> iter(
x, SquareFunctor<T>());
if (init_offset == static_cast<OffsetT>(0)) {
CubDeviceSegmentedReduce(iter, y, n, offset, offset + 1, cub::Sum(),
static_cast<MasterT<T>>(0), stream, buffer);
} else {
cub::CountingInputIterator<OffsetT> cnt_iter(0);
OffsetWithBiasFunctor<OffsetT> functor(offset, init_offset);
cub::TransformInputIterator<OffsetT, OffsetWithBiasFunctor<OffsetT>,
cub::CountingInputIterator<OffsetT>>
offset_iter(cnt_iter, functor);
CubDeviceSegmentedReduce(iter, y, n, offset_iter, offset_iter + 1,
cub::Sum(), static_cast<MasterT<T>>(0), stream,
buffer);
}
}
template <typename T> template <typename T>
static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm, static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm,
gpuStream_t stream, gpuStream_t stream,
...@@ -862,16 +951,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel, ...@@ -862,16 +951,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel,
} }
} }
template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}
template <typename T> template <typename T>
class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -1191,13 +1270,16 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1191,13 +1270,16 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
fp16_partial_fused_offsets_t->data<int>(); fp16_partial_fused_offsets_t->data<int>();
VLOG(1) << "FusedParamOffsets: " VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets, fused_offsets_t->numel(), place); << FlattenToString(fused_offsets, fused_offsets_t->numel(),
fused_offsets_t->place());
VLOG(1) << "FP32ShardFusedParamOffsets: " VLOG(1) << "FP32ShardFusedParamOffsets: "
<< FlattenToString(fp32_partial_fused_offsets, << FlattenToString(fp32_partial_fused_offsets,
fp32_partial_fused_offsets_t->numel(), place); fp32_partial_fused_offsets_t->numel(),
fp32_partial_fused_offsets_t->place());
VLOG(1) << "FP16ShardFusedParamOffsets: " VLOG(1) << "FP16ShardFusedParamOffsets: "
<< FlattenToString(fp16_partial_fused_offsets, << FlattenToString(fp16_partial_fused_offsets,
fp16_partial_fused_offsets_t->numel(), place); fp16_partial_fused_offsets_t->numel(),
fp16_partial_fused_offsets_t->place());
if (num_devices > 1) { if (num_devices > 1) {
if (use_master_param_norm) { if (use_master_param_norm) {
...@@ -1207,32 +1289,26 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T> ...@@ -1207,32 +1289,26 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream); FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
} }
} }
CubDeviceSegmentedSquareNorm(fp32_param, param_square_norm, MultiTensorL2Norm(place, stream, fp32_param, fused_offsets,
fp32_global_param_num, fused_offsets, 0, fp32_global_param_num, param_square_norm);
stream, &cub_tmp_buffer);
if (use_master_param_norm) { if (use_master_param_norm) {
CubDeviceSegmentedSquareNorm( MultiTensorL2Norm(place, stream, master_param + fp16_offset,
master_param + fp16_offset, param_square_norm + fp16_local_start_idx, fp16_partial_fused_offsets, fp16_local_param_num,
fp16_local_param_num, fp16_partial_fused_offsets, 0, stream, param_square_norm + fp16_local_start_idx);
&cub_tmp_buffer);
} else { } else {
// NOTE: extra computation is performed. We can improve this performance // NOTE: extra computation is performed. We can improve this performance
// if needed in the future. // if needed in the future.
CubDeviceSegmentedSquareNorm( MultiTensorL2Norm(
fp16_param, param_square_norm + fp32_global_param_num, place, stream, fp16_param, fused_offsets + fp32_global_param_num,
fp16_global_param_num, fused_offsets + fp32_global_param_num, fp16_global_param_num, param_square_norm + fp32_global_param_num);
static_cast<int>(fp32_numel), stream, &cub_tmp_buffer);
} }
CubDeviceSegmentedSquareNorm( MultiTensorL2Norm(place, stream, trust_ratio_div,
trust_ratio_div, trust_ratio_div_square_norm + fp32_local_start_idx, fp32_partial_fused_offsets, fp32_local_param_num,
fp32_local_param_num, fp32_partial_fused_offsets, 0, stream, trust_ratio_div_square_norm + fp32_local_start_idx);
&cub_tmp_buffer); MultiTensorL2Norm(place, stream, trust_ratio_div + fp32_numel_each_device,
CubDeviceSegmentedSquareNorm( fp16_partial_fused_offsets, fp16_local_param_num,
trust_ratio_div + fp32_numel_each_device, trust_ratio_div_square_norm + fp16_local_start_idx);
trust_ratio_div_square_norm + fp16_local_start_idx,
fp16_local_param_num, fp16_partial_fused_offsets, 0, stream,
&cub_tmp_buffer);
VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: " VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
<< FlattenToString(trust_ratio_div_square_norm, param_num, place); << FlattenToString(trust_ratio_div_square_norm, param_num, place);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include "math.h" // NOLINT
namespace paddle {
namespace operators {
template <int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch>
struct TensorMetaList {
static constexpr int kTensorNum = MaxTensorNumPerLaunch;
static constexpr int kChunkNum = MaxChunkNumPerLaunch;
static_assert(kTensorNum > 0 && kTensorNum < 256,
"kTensorNum must be inside (0, 256).");
static_assert(kChunkNum > 0 && kChunkNum < 65536,
"kChunkNum must be inside (0, 65536).");
/**
* The tensor numel offset of each tensor.
* The offsets[0] would be always 0 in the first launch,
* and then offsets[0] >= 0 in the following other launches.
* The numel of the i-th tensor would be offsets[i + 1] - offsets[i].
*/
int offsets[kTensorNum + 1];
/**
* The tensor id of each chunk. The tensor_ids[0] is always 0.
* Note that tensor_ids would be always in the ascending order.
* The actual tensor id is start_tensor_id + tensor_ids[i].
*
* The reason why we assume that the actual tensor id is
* start_tensor_id + tensor_ids[i] is to make tensor_ids to be
* a uint8_t array instead of an int array, making sizeof(TensorMetaList)
* smaller, so that kChunkNum can be larger.
*/
uint8_t tensor_ids[kChunkNum];
/**
* The chunk id of the chunk inside each tensor. It would be
* something like chunk_ids = [0, 1, 2, 0, 0, 1, 2, 3], meaning
* that there are 3 tensors and each tensor contains 3, 1 and 4
* chunks. Note that chunk_ids[0] is always 0 and the actual
* chunk id of the first tensor is always start_chunk_id + chunk_ids[i].
*
* The reason why we assume that the actual chunk id of the first
* tensor is always start_chunk_id + chunk_ids[i] is to make
* chunk_ids to be a uint16_t array instead of an int array, making
* sizeof(TensorMetaList) smaller, so that kChunkNum can be larger.
*/
uint16_t chunk_ids[kChunkNum];
/**
* The tensor_ids offset.
*/
int start_tensor_id;
/**
* The chunk_ids offset.
*/
int start_chunk_id;
};
template <typename Functor, int MaxTensorNumPerLaunch, int MaxChunkNumPerLaunch,
typename... Args>
static __global__ void MultiTensorApplyCUDAKernel(
Functor functor,
TensorMetaList<MaxTensorNumPerLaunch, MaxChunkNumPerLaunch> meta,
int chunk_size, Args... args) {
const int block_id = blockIdx.x;
const int tensor_id = meta.tensor_ids[block_id];
const int chunk_id = static_cast<int>(meta.chunk_ids[block_id]) +
(tensor_id == 0) * meta.start_chunk_id;
const int prev_offset = meta.offsets[tensor_id];
const int next_offset = meta.offsets[tensor_id + 1];
const int ptr_offset = prev_offset + chunk_id * chunk_size;
const int size = min(next_offset - ptr_offset, chunk_size);
functor(tensor_id + meta.start_tensor_id, chunk_id, ptr_offset, size,
args...);
}
template <typename Functor, int BlockDim, int MaxTensorNumPerLaunch,
int MaxChunkNumPerLaunch, typename... Args>
static void MultiTensorApply(Functor functor, gpuStream_t stream,
const int *offsets, int n, int chunk_size,
Args... args) {
if (n == 0) return;
constexpr auto NumTensor = MaxTensorNumPerLaunch;
constexpr auto NumChunk = MaxChunkNumPerLaunch;
TensorMetaList<NumTensor, NumChunk> metas;
int tensor_id = 0;
int chunk_id = 0;
int numel_offset = 0;
metas.start_tensor_id = 0;
metas.start_chunk_id = 0;
for (int i = 0; i < n; ++i) {
auto length = offsets[i + 1] - offsets[i];
if (tensor_id == 0) {
metas.start_tensor_id = i;
metas.offsets[0] = numel_offset;
}
metas.offsets[tensor_id + 1] = metas.offsets[tensor_id] + length;
++tensor_id;
numel_offset += length;
auto chunk_num = (length + chunk_size - 1) / chunk_size;
int last_launch_chunk_id = 0;
for (int j = 0; j < chunk_num; ++j) {
metas.chunk_ids[chunk_id] = j - last_launch_chunk_id;
metas.tensor_ids[chunk_id] = tensor_id - 1;
++chunk_id;
bool tensor_full = (tensor_id == NumTensor && j + 1 == chunk_num);
bool block_full = (chunk_id == NumChunk);
bool last_chunk = (i + 1 == n && j + 1 == chunk_num);
if (tensor_full || block_full || last_chunk) {
MultiTensorApplyCUDAKernel<Functor, NumTensor,
NumChunk><<<chunk_id, BlockDim, 0, stream>>>(
functor, metas, chunk_size, args...);
chunk_id = 0;
if (j + 1 == chunk_num) { // chunk for the current tensor is full
metas.start_chunk_id = 0;
tensor_id = 0;
} else {
metas.offsets[0] = metas.offsets[tensor_id - 1];
metas.offsets[1] = metas.offsets[tensor_id];
metas.start_tensor_id = i;
metas.start_chunk_id = j + 1;
last_launch_chunk_id = j + 1;
tensor_id = 1;
}
}
}
}
}
} // namespace operators
} // namespace paddle
...@@ -178,11 +178,13 @@ class DistributedFusedLamb(Optimizer): ...@@ -178,11 +178,13 @@ class DistributedFusedLamb(Optimizer):
param_info = self._create_persistable_var('param_info', dtype='int32') param_info = self._create_persistable_var('param_info', dtype='int32')
param_info.is_distributed = True param_info.is_distributed = True
fused_offsets = self._create_persistable_var('fused_offsets') fused_offsets = self._create_persistable_var(
'fused_offsets', dtype='int32')
fp32_partial_fused_offsets = self._create_persistable_var( fp32_partial_fused_offsets = self._create_persistable_var(
'fp32_partial_fused_offsets', dtype='int32') 'fp32_partial_fused_offsets', dtype='int32')
fp32_partial_fused_offsets.is_distributed = True fp32_partial_fused_offsets.is_distributed = True
fp16_partial_fused_offsets = self._create_persistable_var( fp16_partial_fused_offsets = self._create_persistable_var(
'fp16_partial_fused_offsets', dtype='int32') 'fp16_partial_fused_offsets', dtype='int32')
fp16_partial_fused_offsets.is_distributed = True fp16_partial_fused_offsets.is_distributed = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册