From d32a01028b83a91ac0be421b0f0bad06dc993798 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 25 Feb 2022 18:57:54 +0800 Subject: [PATCH] Add MultiTensorApply to calculate L2-Norm in DistributedFusedLamb optimizer (#39900) * add multi tensor apply l2 norm * add multi_tensor_apply code * make sizeof(TensorMeta) smalller * move code to distributed_fused_lamb_op.cu * remove useless FLAGS --- .../distributed_fused_lamb_init_op.cu | 22 +- .../optimizers/distributed_fused_lamb_op.cc | 7 +- .../optimizers/distributed_fused_lamb_op.cu | 282 +++++++++++------- .../operators/optimizers/multi_tensor_apply.h | 156 ++++++++++ .../optimizer/distributed_fused_lamb.py | 4 +- 5 files changed, 355 insertions(+), 116 deletions(-) create mode 100644 paddle/fluid/operators/optimizers/multi_tensor_apply.h diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu index 3bb605d7f5..3445e9b658 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu @@ -284,6 +284,16 @@ static void CopyVectorToTensor(const std::vector &src, memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream); } +template +static void CopyVectorToCPUTensor(const std::vector &src, + framework::Tensor *dst) { + dst->Resize({static_cast(src.size())}); + T *dst_ptr = dst->mutable_data(platform::CPUPlace()); + const T *src_ptr = src.data(); + auto nbytes = src.size() * sizeof(T); + std::memcpy(dst_ptr, src_ptr, nbytes); +} + template class DistributedFusedLambInitOpKernel : public framework::OpKernel { @@ -677,14 +687,14 @@ class DistributedFusedLambInitOpKernel lengths.back()); } - CopyVectorToTensor( + CopyVectorToCPUTensor(numel_offsets, + ctx.Output("FusedParamOffsets")); + CopyVectorToCPUTensor( fp32_partial_numel_offsets, - ctx.Output("FP32ShardFusedParamOffsets"), place, - stream); - CopyVectorToTensor( + ctx.Output("FP32ShardFusedParamOffsets")); + CopyVectorToCPUTensor( fp16_partial_numel_offsets, - ctx.Output("FP16ShardFusedParamOffsets"), place, - stream); + ctx.Output("FP16ShardFusedParamOffsets")); // Fill the weight decay tensor PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(), diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index 748f8206ad..e5b27446eb 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -33,12 +33,7 @@ class DistributedFusedLambOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - if (var_name == "ParamInfo") { - return expected_kernel_type; - } else { - return framework::OperatorWithKernel::GetKernelTypeForVar( - var_name, tensor, expected_kernel_type); - } + return expected_kernel_type; } }; diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index aeecea8a8e..3f90140f77 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -14,8 +14,10 @@ #include #include "paddle/fluid/memory/buffer.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/optimizers/cast_with_ptr.h" #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" +#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h" #include "paddle/fluid/operators/tensor_to_string.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/collective_helper.h" @@ -40,6 +42,163 @@ namespace operators { template using MasterT = typename details::MPTypeTrait::Type; +template +static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { + static_assert(!std::is_same::value, "T cannot be void."); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream)); +#endif +} + +template +struct L2NormFunctor { + DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size, + const T *x, MasterT *y, int max_chunk_num) const { + using MT = MasterT; + const T *ptr = x + offset; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage storage; + + MT square_sum = static_cast(0); + int i; + for (i = threadIdx.x * VecSize; i + VecSize <= size; + i += (BlockDim * VecSize)) { + platform::AlignedVector tmp_vec; + platform::Load(ptr + i, &tmp_vec); +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + auto tmp = static_cast(tmp_vec[j]); + square_sum += (tmp * tmp); + } + } + + for (; i < size; ++i) { + auto tmp = static_cast(ptr[i]); + square_sum += (tmp * tmp); + } + + square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum()); + if (threadIdx.x == 0) { + y[tensor_id * max_chunk_num + chunk_id] = square_sum; + } + } +}; + +template +static __global__ void MultiTensorL2NormReduceAgainCUDAKernel( + const InT *x, OutT *y, int max_chunk_num) { + int tensor_id = blockIdx.x; + x += (tensor_id * max_chunk_num); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage storage; + InT sum = static_cast(0); + for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) { + sum += x[i]; + } + sum = BlockReduce(storage).Reduce(sum, cub::Sum()); + if (threadIdx.x == 0) { + if (NeedSqrt) { + y[blockIdx.x] = static_cast(sqrtf(sum)); + } else { + y[blockIdx.x] = static_cast(sum); + } + } +} + +template +static int GetChunkedVecSize(const T *ptr, int chunk_size) { + static_assert(!std::is_same::value, "T cannot be void."); + + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); + auto address = reinterpret_cast(ptr); + constexpr int vec8 = alignof(platform::AlignedVector); + constexpr int vec4 = alignof(platform::AlignedVector); + constexpr int vec2 = alignof(platform::AlignedVector); + if (address % vec8 == 0 && chunk_size % vec8 == 0) { + return std::min(8, valid_vec_size); + } else if (address % vec4 == 0 && chunk_size % vec4 == 0) { + return std::min(4, valid_vec_size); + } else if (address % vec2 == 0 && chunk_size % vec2 == 0) { + return std::min(2, valid_vec_size); + } else { + return 1; + } +} + +#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \ + case __vec_size: { \ + constexpr int kVecSize = __vec_size; \ + __VA_ARGS__; \ + break; \ + } + +#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \ + do { \ + switch (__vec_size) { \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \ + } \ + } while (0) + +// TODO(zengjinle): which chunk_size is better? +template +static void MultiTensorL2Norm(const platform::CUDAPlace &place, + gpuStream_t stream, const InT *x, + const int *offsets, int n, OutT *y, + int chunk_size = 65536) { + if (n <= 0) return; + + constexpr int kNumTensor = MaxTensorNumPerLaunch; + constexpr int kNumChunk = MaxChunkNumPerLaunch; + constexpr int kBlockDim = BlockDim; + + int max_chunk_num = -1; + int vec_size = 8; + int total_chunk_num = 0; + for (int i = 0; i < n; ++i) { + vec_size = std::min( + vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size)); + int length = offsets[i + 1] - offsets[i]; + auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size; + max_chunk_num = std::max(max_chunk_num, tmp_chunk_num); + total_chunk_num += tmp_chunk_num; + } + + VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num + << " , total_chunk_num = " << total_chunk_num + << " , tensor_num = " << n; + + using MT = MasterT; + memory::Buffer tmp_out(place); + auto *tmp_out_ptr = tmp_out.Alloc(n * max_chunk_num); + FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream); + +#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \ + do { \ + using FunctorT = L2NormFunctor; \ + VLOG(10) << __func__ << " " << typeid(InT).name() \ + << " VecSize = " << kVecSize; \ + MultiTensorApply( \ + FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \ + max_chunk_num); \ + } while (0) + + PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL); +#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL + + MultiTensorL2NormReduceAgainCUDAKernel<<>>( + tmp_out_ptr, y, max_chunk_num); +} + template static void LogParamAndTrustRatioDivSquareNorm( const framework::ExecutionContext &ctx, const float *param_square_norm, @@ -620,76 +779,6 @@ static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out, num_items, reduction_op, init, stream)); } -template -static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out, - int num_segments, - OffsetIteratorT d_begin_offsets, - OffsetIteratorT d_end_offsets, - ReductionOp reduction_op, T initial_value, - gpuStream_t stream, - memory::Buffer *buffer) { - void *d_temp_storage = nullptr; - size_t temp_storage_bytes = 0; - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce( - d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, - d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream)); - d_temp_storage = buffer->Alloc(temp_storage_bytes); - PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce( - d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, - d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream)); -} - -template -struct AddConstantFunctor { - explicit AddConstantFunctor(T bias) : bias_(bias) {} - - T operator()(T x) const { return x + bias_; } - - private: - T bias_; -}; - -template -struct OffsetWithBiasFunctor { - OffsetWithBiasFunctor(const T *offset, T bias) - : offset_(offset), bias_(bias) {} - - HOSTDEVICE T operator()(T idx) const { return offset_[idx] - bias_; } - - HOSTDEVICE constexpr bool operator==(const OffsetWithBiasFunctor &) const { - return true; - } - - private: - const T *offset_; - const T bias_; -}; - -template -static void CubDeviceSegmentedSquareNorm(const T *x, MasterT *y, int n, - const OffsetT *offset, - OffsetT init_offset, - gpuStream_t stream, - memory::Buffer *buffer) { - if (n <= 0) return; - cub::TransformInputIterator, SquareFunctor, const T *> iter( - x, SquareFunctor()); - if (init_offset == static_cast(0)) { - CubDeviceSegmentedReduce(iter, y, n, offset, offset + 1, cub::Sum(), - static_cast>(0), stream, buffer); - } else { - cub::CountingInputIterator cnt_iter(0); - OffsetWithBiasFunctor functor(offset, init_offset); - cub::TransformInputIterator, - cub::CountingInputIterator> - offset_iter(cnt_iter, functor); - CubDeviceSegmentedReduce(iter, y, n, offset_iter, offset_iter + 1, - cub::Sum(), static_cast>(0), stream, - buffer); - } -} - template static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm, gpuStream_t stream, @@ -862,16 +951,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel, } } -template -static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { - static_assert(!std::is_same::value, "T cannot be void."); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream)); -#endif -} - template class DistributedFusedLambOpKernel : public framework::OpKernel { @@ -1191,13 +1270,16 @@ class DistributedFusedLambOpKernel fp16_partial_fused_offsets_t->data(); VLOG(1) << "FusedParamOffsets: " - << FlattenToString(fused_offsets, fused_offsets_t->numel(), place); + << FlattenToString(fused_offsets, fused_offsets_t->numel(), + fused_offsets_t->place()); VLOG(1) << "FP32ShardFusedParamOffsets: " << FlattenToString(fp32_partial_fused_offsets, - fp32_partial_fused_offsets_t->numel(), place); + fp32_partial_fused_offsets_t->numel(), + fp32_partial_fused_offsets_t->place()); VLOG(1) << "FP16ShardFusedParamOffsets: " << FlattenToString(fp16_partial_fused_offsets, - fp16_partial_fused_offsets_t->numel(), place); + fp16_partial_fused_offsets_t->numel(), + fp16_partial_fused_offsets_t->place()); if (num_devices > 1) { if (use_master_param_norm) { @@ -1207,32 +1289,26 @@ class DistributedFusedLambOpKernel FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream); } } - CubDeviceSegmentedSquareNorm(fp32_param, param_square_norm, - fp32_global_param_num, fused_offsets, 0, - stream, &cub_tmp_buffer); + MultiTensorL2Norm(place, stream, fp32_param, fused_offsets, + fp32_global_param_num, param_square_norm); if (use_master_param_norm) { - CubDeviceSegmentedSquareNorm( - master_param + fp16_offset, param_square_norm + fp16_local_start_idx, - fp16_local_param_num, fp16_partial_fused_offsets, 0, stream, - &cub_tmp_buffer); + MultiTensorL2Norm(place, stream, master_param + fp16_offset, + fp16_partial_fused_offsets, fp16_local_param_num, + param_square_norm + fp16_local_start_idx); } else { // NOTE: extra computation is performed. We can improve this performance // if needed in the future. - CubDeviceSegmentedSquareNorm( - fp16_param, param_square_norm + fp32_global_param_num, - fp16_global_param_num, fused_offsets + fp32_global_param_num, - static_cast(fp32_numel), stream, &cub_tmp_buffer); + MultiTensorL2Norm( + place, stream, fp16_param, fused_offsets + fp32_global_param_num, + fp16_global_param_num, param_square_norm + fp32_global_param_num); } - CubDeviceSegmentedSquareNorm( - trust_ratio_div, trust_ratio_div_square_norm + fp32_local_start_idx, - fp32_local_param_num, fp32_partial_fused_offsets, 0, stream, - &cub_tmp_buffer); - CubDeviceSegmentedSquareNorm( - trust_ratio_div + fp32_numel_each_device, - trust_ratio_div_square_norm + fp16_local_start_idx, - fp16_local_param_num, fp16_partial_fused_offsets, 0, stream, - &cub_tmp_buffer); + MultiTensorL2Norm(place, stream, trust_ratio_div, + fp32_partial_fused_offsets, fp32_local_param_num, + trust_ratio_div_square_norm + fp32_local_start_idx); + MultiTensorL2Norm(place, stream, trust_ratio_div + fp32_numel_each_device, + fp16_partial_fused_offsets, fp16_local_param_num, + trust_ratio_div_square_norm + fp16_local_start_idx); VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: " << FlattenToString(trust_ratio_div_square_norm, param_num, place); diff --git a/paddle/fluid/operators/optimizers/multi_tensor_apply.h b/paddle/fluid/operators/optimizers/multi_tensor_apply.h new file mode 100644 index 0000000000..5d8d03c733 --- /dev/null +++ b/paddle/fluid/operators/optimizers/multi_tensor_apply.h @@ -0,0 +1,156 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "math.h" // NOLINT + +namespace paddle { +namespace operators { + +template +struct TensorMetaList { + static constexpr int kTensorNum = MaxTensorNumPerLaunch; + static constexpr int kChunkNum = MaxChunkNumPerLaunch; + + static_assert(kTensorNum > 0 && kTensorNum < 256, + "kTensorNum must be inside (0, 256)."); + static_assert(kChunkNum > 0 && kChunkNum < 65536, + "kChunkNum must be inside (0, 65536)."); + + /** + * The tensor numel offset of each tensor. + * The offsets[0] would be always 0 in the first launch, + * and then offsets[0] >= 0 in the following other launches. + * The numel of the i-th tensor would be offsets[i + 1] - offsets[i]. + */ + int offsets[kTensorNum + 1]; + + /** + * The tensor id of each chunk. The tensor_ids[0] is always 0. + * Note that tensor_ids would be always in the ascending order. + * The actual tensor id is start_tensor_id + tensor_ids[i]. + * + * The reason why we assume that the actual tensor id is + * start_tensor_id + tensor_ids[i] is to make tensor_ids to be + * a uint8_t array instead of an int array, making sizeof(TensorMetaList) + * smaller, so that kChunkNum can be larger. + */ + uint8_t tensor_ids[kChunkNum]; + + /** + * The chunk id of the chunk inside each tensor. It would be + * something like chunk_ids = [0, 1, 2, 0, 0, 1, 2, 3], meaning + * that there are 3 tensors and each tensor contains 3, 1 and 4 + * chunks. Note that chunk_ids[0] is always 0 and the actual + * chunk id of the first tensor is always start_chunk_id + chunk_ids[i]. + * + * The reason why we assume that the actual chunk id of the first + * tensor is always start_chunk_id + chunk_ids[i] is to make + * chunk_ids to be a uint16_t array instead of an int array, making + * sizeof(TensorMetaList) smaller, so that kChunkNum can be larger. + */ + uint16_t chunk_ids[kChunkNum]; + + /** + * The tensor_ids offset. + */ + int start_tensor_id; + + /** + * The chunk_ids offset. + */ + int start_chunk_id; +}; + +template +static __global__ void MultiTensorApplyCUDAKernel( + Functor functor, + TensorMetaList meta, + int chunk_size, Args... args) { + const int block_id = blockIdx.x; + const int tensor_id = meta.tensor_ids[block_id]; + const int chunk_id = static_cast(meta.chunk_ids[block_id]) + + (tensor_id == 0) * meta.start_chunk_id; + const int prev_offset = meta.offsets[tensor_id]; + const int next_offset = meta.offsets[tensor_id + 1]; + const int ptr_offset = prev_offset + chunk_id * chunk_size; + const int size = min(next_offset - ptr_offset, chunk_size); + + functor(tensor_id + meta.start_tensor_id, chunk_id, ptr_offset, size, + args...); +} + +template +static void MultiTensorApply(Functor functor, gpuStream_t stream, + const int *offsets, int n, int chunk_size, + Args... args) { + if (n == 0) return; + + constexpr auto NumTensor = MaxTensorNumPerLaunch; + constexpr auto NumChunk = MaxChunkNumPerLaunch; + TensorMetaList metas; + + int tensor_id = 0; + int chunk_id = 0; + int numel_offset = 0; + metas.start_tensor_id = 0; + metas.start_chunk_id = 0; + for (int i = 0; i < n; ++i) { + auto length = offsets[i + 1] - offsets[i]; + if (tensor_id == 0) { + metas.start_tensor_id = i; + metas.offsets[0] = numel_offset; + } + metas.offsets[tensor_id + 1] = metas.offsets[tensor_id] + length; + ++tensor_id; + numel_offset += length; + + auto chunk_num = (length + chunk_size - 1) / chunk_size; + int last_launch_chunk_id = 0; + for (int j = 0; j < chunk_num; ++j) { + metas.chunk_ids[chunk_id] = j - last_launch_chunk_id; + metas.tensor_ids[chunk_id] = tensor_id - 1; + ++chunk_id; + + bool tensor_full = (tensor_id == NumTensor && j + 1 == chunk_num); + bool block_full = (chunk_id == NumChunk); + bool last_chunk = (i + 1 == n && j + 1 == chunk_num); + + if (tensor_full || block_full || last_chunk) { + MultiTensorApplyCUDAKernel<<>>( + functor, metas, chunk_size, args...); + chunk_id = 0; + if (j + 1 == chunk_num) { // chunk for the current tensor is full + metas.start_chunk_id = 0; + tensor_id = 0; + } else { + metas.offsets[0] = metas.offsets[tensor_id - 1]; + metas.offsets[1] = metas.offsets[tensor_id]; + metas.start_tensor_id = i; + metas.start_chunk_id = j + 1; + last_launch_chunk_id = j + 1; + tensor_id = 1; + } + } + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index 74c481fb64..e7c3cfbb7b 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -178,11 +178,13 @@ class DistributedFusedLamb(Optimizer): param_info = self._create_persistable_var('param_info', dtype='int32') param_info.is_distributed = True - fused_offsets = self._create_persistable_var('fused_offsets') + fused_offsets = self._create_persistable_var( + 'fused_offsets', dtype='int32') fp32_partial_fused_offsets = self._create_persistable_var( 'fp32_partial_fused_offsets', dtype='int32') fp32_partial_fused_offsets.is_distributed = True + fp16_partial_fused_offsets = self._create_persistable_var( 'fp16_partial_fused_offsets', dtype='int32') fp16_partial_fused_offsets.is_distributed = True -- GitLab