未验证 提交 f0f2a702 编写于 作者: R RezaYazdaniAminabadi 提交者: GitHub

support dynamic sequence length in transformer kernels (#424)

Co-authored-by: NConglong Li <conglong.li@gmail.com>
上级 71f7df39
......@@ -29,7 +29,7 @@
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 4096
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
......
......@@ -29,7 +29,6 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
template <typename T>
......@@ -37,7 +36,6 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
template <typename T>
......@@ -46,7 +44,6 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
// Custom fused bias add with layer normalization
......@@ -57,14 +54,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* means = nullptr,
T* vals_hat = nullptr);
bool training,
T* vars,
T* means);
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
......@@ -73,14 +68,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* vals_hat = nullptr,
bool save_vals = false);
bool training,
T* vars);
template <typename T>
void launch_layerNorm_backward_fused_add(const T* out_grad1,
......@@ -93,7 +85,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
template <typename T>
......@@ -106,7 +97,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
......@@ -122,7 +112,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
......@@ -135,7 +124,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
......@@ -153,7 +141,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
......
......@@ -9,15 +9,13 @@ class Dropout {
public:
struct Config {
float ratio;
uint32_t batch, dim;
uint32_t dim;
bool training;
Config(float r, uint32_t batch, uint32_t dim)
: ratio(r), batch(batch), dim(dim), training(true)
{
}
Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
inline void SetDim(uint32_t d) { dim = d; }
};
Dropout(const Config& config) : _config(config), _mask(nullptr) {}
......@@ -70,6 +68,8 @@ public:
Config GetConfig() const { return _config; }
inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }
private:
uint8_t* _mask;
Config _config;
......
......@@ -121,11 +121,17 @@ public:
void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr);
uint8_t* layer_output_dropout_mask_ptr,
T* layer_norm_var,
T* layer_norm_mean,
T* attn_layer_norm_var,
T* attn_layer_norm_mean);
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
void SetSeqLength(int seq_len, int bsz);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
......@@ -150,8 +156,8 @@ private:
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _norm_layer2;
Normalize_Layer<T> _norm_layer3;
Normalize_Layer<T> _attn_layer_norm;
Normalize_Layer<T> _layer_norm;
Normalize_Layer<T>* _last_normalize;
FeedForward<T> _ff1, _ff2;
Softmax<T> _softmax;
......
......@@ -9,13 +9,8 @@ template <typename T>
class Gelu {
public:
struct Config {
uint32_t batch_size;
uint32_t seq_length;
uint32_t intermediate_size;
Config(uint32_t batch, uint32_t seq, uint32_t inter_size)
: batch_size(batch), seq_length(seq), intermediate_size(inter_size)
{
}
Config(uint32_t inter_size) : intermediate_size(inter_size) {}
};
Gelu(const Config& config) : _config(config) {}
......@@ -28,14 +23,12 @@ public:
T* output,
cudaStream_t stream)
{
launch_bias_gelu<T>(
input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream);
}
void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream)
{
launch_d_gelu<T>(
d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
}
private:
......
......@@ -16,57 +16,27 @@ public:
uint32_t seqLength;
uint32_t hiddenDim;
float epsilon;
bool training, save_vals;
bool allocateGrad;
bool training;
bool useMean;
Config(uint32_t batch,
uint32_t seq,
uint32_t h,
bool training,
bool save_vals = true,
bool allocateGrad = true,
bool useMean = true)
Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true)
: batchSize(batch),
seqLength(seq),
hiddenDim(h),
epsilon(1e-12),
training(training),
save_vals(save_vals),
allocateGrad(allocateGrad),
useMean(useMean)
{
}
};
Normalize_Layer(Config config) : config_(config), vars(nullptr), vals_hat(nullptr)
Normalize_Layer(Config config)
: config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{
if (config_.training) {
cudaMalloc((void**)&vars, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.useMean)
cudaMalloc((void**)&means, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.save_vals)
cudaMalloc((void**)&vals_hat,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
if (config_.allocateGrad)
cudaMalloc((void**)&inp_grad,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
}
}
~Normalize_Layer()
{
if (config_.training) {
cudaFree(vars);
if (config_.useMean) cudaFree(means);
if (config_.save_vals) cudaFree(vals_hat);
if (config_.allocateGrad) cudaFree(inp_grad);
}
}
~Normalize_Layer() {}
void ForwardCheckpoint(int bsz,
void ForwardCheckpoint(int bsz, // batch * seq
T* vals,
const T* residual,
const T* gamma,
......@@ -80,14 +50,12 @@ public:
betta,
config_.epsilon,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
means,
vals_hat);
means);
}
void Forward(int bsz,
......@@ -104,14 +72,11 @@ public:
betta,
config_.epsilon,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
vals_hat,
config_.save_vals);
vars);
}
void Backward(int bsz,
......@@ -120,7 +85,7 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward(out_grad,
......@@ -130,9 +95,8 @@ public:
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream);
}
......@@ -144,21 +108,20 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
const T* norm_out = nullptr)
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward(out_grad,
(config_.save_vals ? vals_hat : norm_out),
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
config_.save_vals,
!config_.useMean,
betta);
}
......@@ -169,7 +132,7 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward_fused_add(out_grad1,
......@@ -180,9 +143,8 @@ public:
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream);
}
......@@ -195,33 +157,41 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
const T* norm_out = nullptr)
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
(config_.save_vals ? vals_hat : norm_out),
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
config_.save_vals,
!config_.useMean,
betta);
}
inline T* GetInputGrad() const { return inp_grad; }
inline bool UseMean() const { return config_.useMean; }
inline void SetVar(T* variance)
{
if (!variance) { throw std::runtime_error("Normalize variance is null."); }
vars = variance;
}
inline void SetMean(T* mean)
{
if (!mean) { throw std::runtime_error("Normalize mean is null."); }
means = mean;
}
private:
Config config_;
T* vars;
T* means;
T* vals_hat;
T* inp_grad;
};
......@@ -45,13 +45,15 @@ public:
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
}
inline int GetProbDepth() const { return config_.prob_depth; }
inline size_t GetProbDepth() const { return config_.prob_depth; }
inline int GetBatchSize() const { return config_.batchSize; }
inline size_t GetBatchSize() const { return config_.batchSize; }
inline int GetNumHeads() const { return config_.heads; }
inline size_t GetNumHeads() const { return config_.heads; }
inline int GetSeqLength() const { return config_.seq_length; }
inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private:
Config config_;
......
......@@ -3,6 +3,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "context.h"
template <typename T>
class StridedBatchGemm {
......@@ -38,6 +39,12 @@ public:
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
......@@ -163,6 +170,8 @@ public:
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
......
......@@ -34,7 +34,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -74,7 +79,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -122,7 +132,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -170,7 +185,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
......
此差异已折叠。
......@@ -279,13 +279,12 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(input, bias, output, intermediate_size);
}
......@@ -295,24 +294,26 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(input, output, intermediate_size);
}
template void
launch_bias_gelu<float>(const float*, const float*, float*, int, int, int, cudaStream_t);
template void
launch_bias_gelu<__half>(const __half*, const __half*, __half*, int, int, int, cudaStream_t);
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, int, cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
......@@ -320,17 +321,15 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(d_output, input, bias, intermediate_size);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, int, cudaStream_t);
template void
launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, int, cudaStream_t);
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, cudaStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);
......@@ -14,15 +14,18 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
......@@ -40,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
out[pos] = sum;
if (pos < (rows * width)) out[pos] = sum;
}
}
......@@ -58,10 +61,10 @@ void launch_fuse_transpose_bias_kernel<float>(const float* inp,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
......@@ -74,10 +77,10 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
......
此差异已折叠。
#include <math.h>
#include "custom_cuda_layers.h"
#include "general_kernels.h"
......@@ -282,7 +283,7 @@ __global__ void attn_softmax(__half* vals,
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t, bool);
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
......@@ -294,11 +295,10 @@ void launch_attn_softmax<float>(float* vals,
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -330,8 +330,9 @@ void launch_attn_softmax<float>(float* vals,
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -362,11 +363,10 @@ void launch_attn_softmax<__half>(__half* vals,
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -399,8 +399,9 @@ void launch_attn_softmax<__half>(__half* vals,
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -531,55 +532,41 @@ void launch_attn_softmax_backward_v2(T* out_grad,
int seq_length,
cudaStream_t stream)
{
if ((seq_length % WARP_SIZE) != 0 || seq_length > 2048)
throw std::runtime_error("Invalid sequence length found in softmax backward.");
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
switch (seq_length) {
case 32:
softmax_backward_kernel_v2<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 64:
softmax_backward_kernel_v2<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 128:
softmax_backward_kernel_v2<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 256:
softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 384:
softmax_backward_kernel_v2<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 512:
softmax_backward_kernel_v2<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 768:
softmax_backward_kernel_v2<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 1024:
softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 2048:
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
default:
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
if (seq_length <= 32)
softmax_backward_kernel_v2<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
softmax_backward_kernel_v2<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
softmax_backward_kernel_v2<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
softmax_backward_kernel_v2<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
softmax_backward_kernel_v2<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
softmax_backward_kernel_v2<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
......
......@@ -187,26 +187,30 @@ class DeepSpeedTransformerFunction(Function):
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask) = forward_func(config.layer_id,
input,
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b,
config.training,
config.pre_layer_norm,
config.attn_dropout_checkpoint,
config.normalize_invertible,
config.gelu_checkpoint)
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean) = forward_func(config.layer_id,
input,
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b,
config.training,
config.pre_layer_norm,
config.attn_dropout_checkpoint,
config.normalize_invertible,
config.gelu_checkpoint)
# For testing only.
if grads is not None:
......@@ -283,6 +287,9 @@ class DeepSpeedTransformerFunction(Function):
if not config.normalize_invertible:
ctx.add_res = add_res
ctx.attn_layer_norm_mean = attn_layer_norm_mean
ctx.layer_norm_mean = layer_norm_mean
ctx.ff1_inp = ff1_inp
if not config.gelu_checkpoint:
ctx.gelu_inp = gelu_inp
......@@ -291,6 +298,8 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask = attn_prob_dropout_mask
ctx.attn_output_dropout_mask = attn_output_dropout_mask
ctx.layer_output_dropout_mask = layer_output_dropout_mask
ctx.attn_layer_norm_var = attn_layer_norm_var
ctx.layer_norm_var = layer_norm_var
return output
......@@ -367,6 +376,10 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask,
ctx.attn_output_dropout_mask,
ctx.layer_output_dropout_mask,
ctx.attn_layer_norm_var,
ctx.attn_layer_norm_mean,
ctx.layer_norm_var,
ctx.layer_norm_mean,
(ctx.inp_norm if (ctx.config.pre_layer_norm
and ctx.config.normalize_invertible) else input),
input_mask,
......
......@@ -256,10 +256,10 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[
(3,1024,128,16,24,True,False, 0.05),
(3,1024,128,16,24,True,True, 0.05),
(3,1024,128,16,24,False,False, 0.1),
(3,1024,128,16,24,False,True, 0.2),
(3,1024,120,16,24,True,False, 0.05),
(3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2),
]) # yapf: disable
def test_backward(batch_size,
hidden_size,
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册