未验证 提交 7e7b0a8d 编写于 作者: S Samyam Rajbhandari 提交者: GitHub

Add files via upload

Lamb CUDA Kernels
上级 c04ae78a
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void fused_lamb_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g,
float lr, float beta1, float beta2, float max_coeff, float min_coeff, float eps, float grad_scale, int step, int mode, int bias_correction, float decay,
at::Tensor & w_l2_i, at::Tensor & u_l2_i, at::Tensor & lamb_coeff_val );
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
at::Tensor lamb(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float max_coeff, float min_coeff, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
//intermediate for weight L2 reduction
//make sure that the threads per block is at least 512 during the kernel launch otherwise the behavious is unexpected
at::Tensor w_l2_i = at::empty({512}, p.options().dtype(p.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : p.type().scalarType()));
//intermediate for update L2 reduction
//make sure that the threads per block is at least 512 during the kernel launch otherwise the behavious is unexpected
at::Tensor u_l2_i = at::empty({512}, p.options().dtype(p.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : p.type().scalarType()));
at::Tensor lamb_coeff_val = at::empty({1}, p.options().dtype(p.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : p.type().scalarType()));
fused_lamb_cuda(p, p_copy, m, v, g, lr, beta1, beta2, max_coeff, min_coeff, eps, grad_scale, step, mode, bias_correction, decay, w_l2_i, u_l2_i, lamb_coeff_val);
return lamb_coeff_val;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("lamb", &lamb, "Adam optimized CUDA implementation with LAMB.");
}
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include <iostream>
//#include <helper_functions.h>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory
{
// Ensure that we won't compile any un-specialized types
__device__ inline operator T *()
{
extern __device__ void error(void);
error();
return NULL;
}
};
template <>
struct SharedMemory <float>
{
__device__ inline operator float *()
{
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory <double>
{
__device__ inline operator double *()
{
extern __shared__ double s_double[];
return s_double;
}
};
}
#include "type_shim.h"
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
//s_a and s_b are in shared memory
//g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void
reduce_block_in_shared_memory(T *s_a, T *s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300 )
if ( tid < 32 )
{
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64)
{
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize/2; offset > 0; offset /= 2)
{
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1))
{
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0){
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b){
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T *s_a = SharedMemory<T>();
T *s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T,blockSize>(s_a, s_b ,g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
T pj = p[j];
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j]/denom) + (decay*p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T,blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(
const size_t tsize,
T* __restrict__ g_a,
T* __restrict__ g_b)
{
T *s_a = SharedMemory<T>() ;
T *s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize){
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T,blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w !=0 and reg_u !=0){
lamb_coeff = reg_w/reg_u;
if (lamb_coeff > max_coeff){
lamb_coeff = max_coeff;
}
if (lamb_coeff < min_coeff){
lamb_coeff = min_coeff;
}
}
if(blockId == 0 and threadIdInBlock == 0)
{
lamb_coeff_val[0] = lamb_coeff;
//printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j+=totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj/denom) + (decay*pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T) pj;
}
}
void fused_lamb_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor & w_l2_i,
at::Tensor & u_l2_i,
at::Tensor & lamb_coeff)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize+threadsPerBlock-1)/threadsPerBlock;
if (num_blocks > 512) num_blocks=512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
num_blocks,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
lamb_cuda_kernel_part3<accscalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "lamb_cuda_kernel", ([&] {
lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
num_blocks,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
lamb_cuda_kernel_part3<scalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
THCudaCheck(cudaGetLastError());
}
//template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a, float* g_b, cg::grid_group &cgg);
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册