未验证 提交 e7e75955 编写于 作者: C Connor Holmes 提交者: GitHub

Stable Diffusion Enhancements (#2491)

Co-authored-by: Ncmikeh2 <connorholmes@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
Co-authored-by: NReza Yazdani <reyazda@microsoft.com>
上级 6f77da1b
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <cassert>
#include "memory_access_utils.h"
#include "spatial_cuda_layers.h"
/*
Fused bias add variants
*/
namespace badd_opt {
constexpr int threads = 256;
constexpr int steps = 2;
constexpr int granularity = 16;
constexpr int vals_per_h = granularity / sizeof(__half);
constexpr int vals_per_h2 = granularity / sizeof(__half2);
constexpr int vals_per_block = threads * steps * vals_per_h;
constexpr int stride = vals_per_h * threads;
} // namespace badd_opt
__global__ void opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;
for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
for (int j = 0; j < badd_opt::vals_per_h2; j++) { act_buffer[j] += bias_buffer[j]; }
mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}
__global__ void opt_bias_add_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;
for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
__half2 other_buffer[badd_opt::vals_per_h2];
mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
mem_access::load_global<badd_opt::granularity>(other_buffer, other + id + i * stride);
for (int j = 0; j < badd_opt::vals_per_h2; j++) {
act_buffer[j] += bias_buffer[j] + other_buffer[j];
}
mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}
__global__ void opt_bias_add_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;
for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
__half2 other_buffer[badd_opt::vals_per_h2];
__half2 other_bias_buffer[badd_opt::vals_per_h2];
mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
mem_access::load_global<badd_opt::granularity>(other_buffer, other + id + i * stride);
mem_access::load_global<badd_opt::granularity>(
other_bias_buffer, other_bias + ((id + i * stride) % channels));
for (int j = 0; j < badd_opt::vals_per_h2; j++) {
act_buffer[j] =
(act_buffer[j] + bias_buffer[j]) + (other_buffer[j] + other_bias_buffer[j]);
}
mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}
void launch_opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int batch_size,
int seq_len,
int channels,
cudaStream_t stream)
{
// Should evaluate `true` for reasonable hidden sizes
assert(channels % badd_opt::vals_per_h == 0);
const int effective_seq_len = batch_size * seq_len;
const int vals = effective_seq_len * channels;
dim3 block(badd_opt::threads);
dim3 grid((vals + badd_opt::vals_per_block - 1) / badd_opt::vals_per_block);
if (!other) {
// We shouldn't have a bias if there's no activation
assert(!other_bias);
opt_bias_add<<<grid, block, 0, stream>>>(
result, activation, bias, effective_seq_len, channels);
} else if (!other_bias) {
opt_bias_add_add<<<grid, block, 0, stream>>>(
result, activation, bias, other, effective_seq_len, channels);
} else {
opt_bias_add_bias_add<<<grid, block, 0, stream>>>(
result, activation, bias, other, other_bias, effective_seq_len, channels);
}
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdio>
#include <vector>
#include "spatial_cuda_layers.h"
ChannelsLastProblem dimension_problem(at::Tensor& input)
{
ChannelsLastProblem dims;
if (input.dim() == 4) {
// In some sense this is unsafe (and a reflection of the assumptions made inside
// the C10 options checker). Basically, there's no great way to be sure that
// a tensor is in channels last because a 1x1 image will appear to be in channels
// last even when it isn't.
assert(input.is_contiguous(at::MemoryFormat::ChannelsLast));
dims.batch_size = input.size(0);
dims.seq_len = input.size(2) * input.size(3);
dims.channels = input.size(1);
} else {
assert(input.is_contiguous());
dims.batch_size = input.size(0);
dims.seq_len = input.size(1);
dims.channels = input.size(2);
}
return dims;
}
at::Tensor seq_unroll_bias_add(at::Tensor& input, at::Tensor& bias)
{
assert(input.dtype() == at::kHalf);
// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);
auto output = at::empty_like(input);
launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
nullptr,
nullptr,
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor seq_bias_add_add(at::Tensor& input, at::Tensor& bias, at::Tensor& other)
{
assert(input.dtype() == at::kHalf);
// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);
auto output = at::empty_like(input);
launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)other.data_ptr(),
nullptr,
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor seq_bias_add_bias_add(at::Tensor& input,
at::Tensor& bias,
at::Tensor& other,
at::Tensor& other_bias)
{
assert(input.dtype() == at::kHalf);
// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);
auto output = at::empty_like(input);
launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)other.data_ptr(),
(const __half*)other_bias.data_ptr(),
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("nhwc_bias_add", &seq_unroll_bias_add);
m.def("nhwc_bias_add_add", &seq_bias_add_add);
m.def("nhwc_bias_add_bias_add", &seq_bias_add_bias_add);
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#if __CUDA_ARCH__ >= 530
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_fp16.h>
/*********** Group Norm Kernels, Structs, and Helpers ************/
struct {
int64_t batch_size;
int64_t seq_len;
int64_t channels;
} typedef ChannelsLastProblem;
void launch_opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int batch_size,
int seq_len,
int channels,
cudaStream_t stream);
......@@ -593,3 +593,106 @@ template void pad_head_seq(float* padded_output,
int head_size,
int padded_head_size,
cudaStream_t stream);
// TODO(cmikeh2): evaluate different GeLU performance
__device__ __forceinline__ float old_gelu(float val)
{
// 1 / sqrt(2)
constexpr float rsqrt_2 = 0.707106769084930419922;
return val * 0.5f * (1.0f + erff(val * rsqrt_2));
}
namespace fused_geglu {
constexpr int threads = 256;
constexpr int steps = 2;
constexpr int granularity = 16;
} // namespace fused_geglu
template <typename T>
__global__ void fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int base_channels,
int total_elems)
{
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;
const int id = blockIdx.x * T_per_block + threadIdx.x * T_per_access;
#pragma unroll
for (int i = 0; i < fused_geglu::steps; i++) {
T activation_buffer_1[T_per_access];
T activation_buffer_2[T_per_access];
T bias_buffer_1[T_per_access];
T bias_buffer_2[T_per_access];
const int iter_id = id + T_per_step * i;
if (iter_id < total_elems) {
const int channel_id = iter_id % base_channels;
const int seq_id = iter_id / base_channels;
const int seq_offset = seq_id * base_channels * 2;
mem_access::load_global<fused_geglu::granularity>(activation_buffer_1,
activation + seq_offset + channel_id);
mem_access::load_global<fused_geglu::granularity>(
activation_buffer_2, activation + seq_offset + channel_id + base_channels);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_1, bias + channel_id);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_2,
bias + channel_id + base_channels);
// Since the GeLU is going to happen at float, might as well
// convert
#pragma unroll
for (int v = 0; v < T_per_access; v++) {
T hidden_state = activation_buffer_1[v] + bias_buffer_1[v];
T pre_gate = activation_buffer_2[v] + bias_buffer_2[v];
float gate_f = old_gelu(conversion::to<float>(pre_gate));
T gate = conversion::to<T>(gate_f);
activation_buffer_1[v] = hidden_state * gate;
}
mem_access::store_global<fused_geglu::granularity>(output + iter_id,
activation_buffer_1);
}
}
}
template <typename T>
void launch_fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int rows,
int elems_per_row,
cudaStream_t stream)
{
/*
Fused bias GEGLU is a variant of the gated activation functions.
The input here is a matrix of [batch, seq_len, 2 * intermediate_dim]
where the second half of the channels act as GeLU gates for the first
half.
*/
// Re-derive the above figures
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;
const int base_channels = elems_per_row / 2;
const int total_elems = base_channels * rows;
dim3 block(fused_geglu::threads);
dim3 grid((total_elems + T_per_block - 1) / T_per_block);
fused_bias_geglu<<<grid, block, 0, stream>>>(
output, activation, bias, base_channels, total_elems);
}
template void launch_fused_bias_geglu(__half*,
const __half*,
const __half*,
int,
int,
cudaStream_t);
template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, cudaStream_t);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
#include "reduction_utils.h"
namespace cg = cooperative_groups;
using rop = reduce::ROpType;
namespace ln {
constexpr int granularity = 16;
constexpr int max_threads = 512;
constexpr int max_warps = max_threads / hw_warp_size;
constexpr int internal_unroll = 4;
} // namespace ln
/*
Primary layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
Args:
output: buffer for output data
vals: buffer for input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
*/
template <typename T, int UNROLL>
__global__ void fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = tb.group_index().x * elems_per_row;
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = tb.size() * T_per_load;
// TODO(cmikeh2): refactor to reduction utility library
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
T local_buffer[UNROLL * ln::internal_unroll * T_per_load];
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
T* iteration_buffer = local_buffer + i * ln::internal_unroll * T_per_load;
#pragma unroll
for (int j = 0; j < ln::internal_unroll; j++) {
const int iteration = i * ln::internal_unroll + j;
mem_access::load_global<ln::granularity>(
iteration_buffer + j * T_per_load,
input_base + iteration * stride,
thread_offset + iteration * stride < elems_per_row);
}
#pragma unroll
for (int j = 0; j < ln::internal_unroll * T_per_load; j++) {
float up_cast = conversion::to<float>(iteration_buffer[j]);
sum = reduce::element<rop::Add>(sum, up_cast);
}
}
reduce::block<rop::Add, ln::max_warps>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unroll
for (int i = 0; i < UNROLL * ln::internal_unroll; i++) {
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::block<rop::Add, ln::max_warps>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
const T mean_compute = conversion::to<T>(mean);
const T denom_compute = conversion::to<T>(denom);
T* block_output = output + block_offset;
#pragma unroll
for (int i = 0; i < UNROLL * ln::internal_unroll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute;
iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j];
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
#define LAUNCH_FUSED_LN(unroll_factor) \
fused_ln<T, unroll_factor> \
<<<grid, block, 0, stream>>>(output, vals, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
// 32 for __half, 16 for float
constexpr int T_per_thread_unroll = T_per_load * ln::internal_unroll;
// 1024 for __half, 512 for float
constexpr int T_per_warp_unroll = T_per_thread_unroll * hw_warp_size;
int32_t unroll = 1;
while (T_per_warp_unroll * ln::max_warps * unroll < elems_per_row) { unroll++; }
const int sched_warps =
(elems_per_row + unroll * T_per_warp_unroll - 1) / (unroll * T_per_warp_unroll);
const int warps = (unroll > 1) ? ln::max_warps : sched_warps;
dim3 grid(rows);
dim3 block(warps * hw_warp_size);
// This should match the max_unroll constexpr
if (unroll == 1) {
LAUNCH_FUSED_LN(1);
} else if (unroll == 2) {
LAUNCH_FUSED_LN(2);
} else if (unroll == 3) {
LAUNCH_FUSED_LN(3);
} else if (unroll == 4) {
LAUNCH_FUSED_LN(4);
}
}
template void launch_fused_ln(__half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
cudaStream_t);
template void
launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, cudaStream_t);
/*
Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual
need to be fused into compute-bound producer operations.
Args:
output: buffer for output data
res_output: output of residual addition
vals: buffer for input data
residual: residual data
bias: bias of of input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
Template arg:
StoreResidual: controls whether the residual calculation is stored
or not. When set to false, the input `res_output` is unused.
*/
template <typename T, int UNROLL, bool StoreResidual>
__global__ void fused_residual_ln(T* output,
T* res_output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = tb.group_index().x * elems_per_row;
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = tb.size() * T_per_load;
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
const T* residual_base = residual + base_offset;
const T* bias_base = bias + thread_offset;
T local_buffer[UNROLL * ln::internal_unroll * T_per_load];
// Unlike a vanilla layernorm, since we're fusing the two adds as well
// an inner unroll seems to be less valuable. If anything, a double unroll
// makes the most sense if we find we are having performance issues.
#pragma unroll
for (int i = 0; i < UNROLL * ln::internal_unroll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
T residual_buffer[T_per_load];
T bias_buffer[T_per_load];
mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
mem_access::load_global<ln::granularity>(residual_buffer,
residual_base + i * stride,
thread_offset + i * stride < elems_per_row);
mem_access::load_global<ln::granularity>(
bias_buffer, bias_base + i * stride, thread_offset + i * stride < elems_per_row);
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
float res_up_cast = conversion::to<float>(residual_buffer[j]);
float bias_up_cast = conversion::to<float>(bias_buffer[j]);
vals_up_cast += res_up_cast + bias_up_cast;
sum = reduce::element<rop::Add>(sum, vals_up_cast);
iteration_buffer[j] = conversion::to<T>(vals_up_cast);
}
if (StoreResidual && (thread_offset + i * stride < elems_per_row)) {
mem_access::store_global<ln::granularity>(res_output + base_offset + i * stride,
iteration_buffer);
}
}
reduce::block<rop::Add, ln::max_warps>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unroll
for (int i = 0; i < UNROLL * ln::internal_unroll; i++) {
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::block<rop::Add, ln::max_warps>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
const T mean_compute = conversion::to<T>(mean);
const T denom_compute = conversion::to<T>(denom);
T* block_output = output + block_offset;
#pragma unroll
for (int i = 0; i < UNROLL * ln::internal_unroll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unroll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute;
iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j];
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified.
#define LAUNCH_FUSED_RES_LN(unroll_factor) \
fused_residual_ln<T, unroll_factor, false><<<grid, block, 0, stream>>>( \
output, nullptr, vals, residual, bias, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_residual_ln(T* output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
// 32 for __half, 16 for float
constexpr int T_per_thread_unroll = T_per_load * ln::internal_unroll;
// 1024 for __half, 512 for float
constexpr int T_per_warp_unroll = T_per_thread_unroll * hw_warp_size;
int32_t unroll = 1;
while (T_per_warp_unroll * ln::max_warps * unroll < elems_per_row) { unroll++; }
const int warps =
(elems_per_row + unroll * T_per_warp_unroll - 1) / (unroll * T_per_warp_unroll);
dim3 grid(rows);
dim3 block(warps * hw_warp_size);
// This should match the max_unroll constexpr
if (unroll == 1) {
LAUNCH_FUSED_RES_LN(1);
} else if (unroll == 2) {
LAUNCH_FUSED_RES_LN(2);
} else if (unroll == 3) {
LAUNCH_FUSED_RES_LN(3);
} else if (unroll == 4) {
LAUNCH_FUSED_RES_LN(4);
}
}
#define LAUNCH_FUSED_RES_LN_STORE(unroll_factor) \
fused_residual_ln<T, unroll_factor, true><<<grid, block, 0, stream>>>( \
norm_output, res_output, vals, residual, bias, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_residual_ln_store(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
// 32 for __half, 16 for float
constexpr int T_per_thread_unroll = T_per_load * ln::internal_unroll;
// 1024 for __half, 512 for float
constexpr int T_per_warp_unroll = T_per_thread_unroll * hw_warp_size;
int32_t unroll = 1;
while (T_per_warp_unroll * ln::max_warps * unroll < elems_per_row) { unroll++; }
const int warps =
(elems_per_row + unroll * T_per_warp_unroll - 1) / (unroll * T_per_warp_unroll);
dim3 grid(rows);
dim3 block(warps * hw_warp_size);
// This should match the max_unroll constexpr
if (unroll == 1) {
LAUNCH_FUSED_RES_LN_STORE(1);
} else if (unroll == 2) {
LAUNCH_FUSED_RES_LN_STORE(2);
} else if (unroll == 3) {
LAUNCH_FUSED_RES_LN_STORE(3);
} else if (unroll == 4) {
LAUNCH_FUSED_RES_LN_STORE(4);
}
}
// No-store specializations
template void launch_fused_residual_ln(__half*,
const __half*,
const __half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
cudaStream_t);
template void launch_fused_residual_ln(float*,
const float*,
const float*,
const float*,
const float*,
const float*,
float,
int,
int,
cudaStream_t);
// Store specializations
template void launch_fused_residual_ln_store(__half*,
__half*,
const __half*,
const __half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
cudaStream_t);
template void launch_fused_residual_ln_store(float*,
float*,
const float*,
const float*,
const float*,
const float*,
const float*,
float,
int,
int,
cudaStream_t);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <limits>
#include "inference_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
......@@ -101,11 +101,11 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
}
template <typename T>
void allocate_workspace(size_t hidden_dim,
size_t batch_size,
size_t prompt_length,
unsigned num_layers,
void allocate_workspace(unsigned hidden_dim,
unsigned num_heads,
unsigned prompt_length,
unsigned batch_size,
unsigned num_layers,
unsigned mp_size = 1,
bool external_cache = false,
unsigned rank = 0,
......@@ -545,6 +545,41 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
return input_cont;
}
at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias)
{
/*
Used in FF of Stable diffusion
*/
const int batch_size = activation.size(0);
const int seq_len = activation.size(1);
const int channels = activation.size(2);
const int rows = batch_size * seq_len;
// Dimensionality is cut in half
const int out_channels = channels / 2;
auto output = at::empty({batch_size, seq_len, out_channels}, activation.options());
if (activation.options().dtype() == torch::kFloat32) {
launch_fused_bias_geglu((float*)output.data_ptr(),
(const float*)activation.data_ptr(),
(const float*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
} else {
launch_fused_bias_geglu((__half*)output.data_ptr(),
(const __half*)activation.data_ptr(),
(const __half*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
}
return output;
}
template <typename T>
at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
{
......@@ -594,38 +629,132 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor&
return input_cont;
}
template <typename T>
at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon)
at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, float epsilon)
{
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
launch_layer_norm((T*)inp_norm.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)gamma.data_ptr(),
(T*)betta.data_ptr(),
epsilon,
bsz,
input_cont.size(2),
Context::Instance().GetCurrentStream());
return inp_norm;
const int rows = input.size(0) * input.size(1);
const int elems_per_row = input.size(2);
auto output = at::empty_like(input);
if (input.options().dtype() == torch::kFloat16) {
launch_fused_ln((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)gamma.data_ptr(),
(const __half*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
} else {
launch_fused_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
(const float*)gamma.data_ptr(),
(const float*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
}
return output;
}
template <typename T>
void ds_layernorm_internal(T* workspace,
at::Tensor& input,
at::Tensor& gamma,
at::Tensor& betta,
float epsilon)
void ds_layer_norm_internal(T* workspace,
at::Tensor& input,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
int bsz = input.size(0) * input.size(1);
launch_layer_norm(workspace,
(T*)input.data_ptr(),
(T*)gamma.data_ptr(),
(T*)betta.data_ptr(),
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
launch_fused_ln(workspace,
(const T*)input.data_ptr(),
(const T*)gamma.data_ptr(),
(const T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
}
/* Currently only used in unit testing */
at::Tensor ds_layer_norm_residual(at::Tensor& input,
at::Tensor& bias,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
const int rows = input.size(0) * input.size(1);
const int elems_per_row = input.size(2);
auto output = at::empty_like(input);
if (input.options().dtype() == torch::kFloat16) {
launch_fused_residual_ln((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)residual.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)gamma.data_ptr(),
(const __half*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
(const float*)residual.data_ptr(),
(const float*)bias.data_ptr(),
(const float*)gamma.data_ptr(),
(const float*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
}
return output;
}
/* Currently only used in unit testing */
std::vector<at::Tensor> ds_layer_norm_residual_store(at::Tensor& input,
at::Tensor& bias,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
const int rows = input.size(0) * input.size(1);
const int elems_per_row = input.size(2);
auto norm_output = at::empty_like(input);
auto res_output = at::empty_like(input);
if (input.options().dtype() == torch::kFloat16) {
launch_fused_residual_ln_store((__half*)norm_output.data_ptr(),
(__half*)res_output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)residual.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)gamma.data_ptr(),
(const __half*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln_store((float*)norm_output.data_ptr(),
(float*)res_output.data_ptr(),
(const float*)input.data_ptr(),
(const float*)residual.data_ptr(),
(const float*)bias.data_ptr(),
(const float*)gamma.data_ptr(),
(const float*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
}
return {norm_output, res_output};
}
template <typename T>
......@@ -682,7 +811,7 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
workspace += (3 * bsz * input.size(2));
ds_layernorm_internal<T>(workspace, input, gamma, beta, epsilon);
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz);
......@@ -816,7 +945,7 @@ at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto inp_norm = ds_layernorm<T>(input_cont, gamma, beta, epsilon);
auto inp_norm = ds_layer_norm(input_cont, gamma, beta, epsilon);
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
if (add_bias)
......@@ -834,10 +963,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
bool add_bias,
bool external_cache,
bool do_flash_attn,
int num_heads,
unsigned num_layers)
int num_heads)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
......@@ -1151,19 +1278,21 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
T* inp_norm =
(T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output);
T* intermediate = inp_norm + torch::numel(input);
launch_residual_layer_norm((T*)inp_norm,
(T*)nullptr,
(T*)input.data_ptr(),
(T*)residual.data_ptr(),
(T*)input_bias.data_ptr(),
(T*)gamma.data_ptr(),
(T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
preLayerNorm,
mlp_after_attn,
Context::Instance().GetCurrentStream());
if (mlp_after_attn) {
launch_fused_residual_ln((T*)inp_norm,
(const T*)input.data_ptr(),
(const T*)residual.data_ptr(),
(const T*)input_bias.data_ptr(),
(const T*)gamma.data_ptr(),
(const T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
} else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
}
if (q_int8) {
quantized_gemm<T>(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz);
......@@ -1532,6 +1661,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"DeepSpeed attention with int8 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_geglu", &ds_bias_geglu, "DeepSpeed Bias GEGLU (CUDA)");
m.def("bias_add_fp32", &ds_bias_add<float>, "DeepSpeed Bias Add with fp32 (CUDA)");
m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_relu_fp32", &ds_bias_relu<float>, "DeepSpeed ReLU with fp32 (CUDA)");
......@@ -1542,8 +1672,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp16 (CUDA)");
m.def("layer_norm_fp32", &ds_layernorm<float>, "DeepSpeed layer-norm with fp32 (CUDA)");
m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)");
m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)");
m.def(
"_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)");
m.def("layer_norm_residual_store",
&ds_layer_norm_residual_store,
"DeepSpeed layer norm + store residual (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
......
......@@ -103,7 +103,7 @@ public:
// Flash attention requires padded heads and we'll conservatively allocate
// for that here. Flash attention is only enabled for head size <= 128 right now
const int head_size = hidden_dim / num_heads;
const int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128);
const int effective_head_size = (head_size > 128) ? head_size : padded_head_size;
size_t activation_size = 16 * (num_heads * effective_head_size) * batch_size;
......@@ -130,10 +130,15 @@ public:
temp_size *= _max_seq_len * elem_size;
if (rank == 0 && !_workspace)
printf(
"Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total "
"tokens (input + output) to %lu \n",
_free_memory_size,
total_size,
"------------------------------------------------------\n"
"Free memory : %f (GigaBytes) \n"
"Total memory: %f (GigaBytes) \n"
"Requested memory: %f (GigaBytes) \n"
"Setting maximum total tokens (input + output) to %lu \n"
"------------------------------------------------------\n",
(float)_free_memory_size / GIGABYTE,
(float)total_size / GIGABYTE,
(float)workSpaceSize / GIGABYTE,
_max_seq_len);
if (!_workspace) {
assert(_workspace == nullptr);
......
......@@ -46,6 +46,14 @@ void launch_bias_gelu(T* input,
int batch_size,
cudaStream_t stream);
template <typename T>
void launch_fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int rows,
int elems_per_row,
cudaStream_t stream);
// Fused bias add with relu activation
template <typename T>
void launch_bias_relu(T* input,
......@@ -70,29 +78,40 @@ void launch_bias_residual(T* input,
cudaStream_t stream);
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream);
void launch_fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream);
template <typename T>
void launch_fused_residual_ln(T* output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream);
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream);
void launch_fused_residual_ln_store(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
cudaStream_t stream);
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
......
......@@ -379,7 +379,8 @@ class InferenceEngine(Module):
self.checkpoint_engine) if checkpoint_dir is not None else None
generic_injection(self.module,
fp16=(self.dtype == torch.half) or (self.dtype == torch.int8))
fp16=(self.dtype == torch.half) or (self.dtype == torch.int8),
enable_cuda_graph=self.enable_cuda_graph)
if isinstance(self.module, torch.nn.Module):
replace_transformer_layer(
......
......@@ -2,23 +2,21 @@
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
import diffusers
class DSUNet(torch.nn.Module):
def __init__(self, unet):
def __init__(self, unet, enable_cuda_graph=True):
super().__init__()
self.unet = unet
# SD pipeline accesses this attribute
self.in_channels = unet.in_channels
self._traced_unet = None
self._trace_enabled = False
self.device = self.unet.device
self.dtype = self.unet.dtype
self.fwd_count = 0
self.unet.requires_grad_(requires_grad=False)
self.unet.to(memory_format=torch.channels_last)
self.cuda_graph_created = False
self.enable_cuda_graph = enable_cuda_graph
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
......@@ -31,12 +29,15 @@ class DSUNet(torch.nn.Module):
return self.static_output
def forward(self, *inputs, **kwargs):
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
return self._forward(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
......@@ -58,25 +59,4 @@ class DSUNet(torch.nn.Module):
self.cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
if self._trace_enabled:
if self._traced_unet is None:
print("Unet: start tracing with Nvfuser")
# force return tuple instead of dict
self._traced_unet = torch.jit.trace(
lambda _sample,
_timestamp,
_encoder_hidden_states: self.unet(_sample,
_timestamp,
_encoder_hidden_states,
return_dict=False),
(sample,
timestamp,
encoder_hidden_states))
return self.unet(sample, timestamp, encoder_hidden_states)
else:
# convert return type to UNet2DConditionOutput
out_sample, *_ = self._traced_unet(sample, timestamp, encoder_hidden_states)
return diffusers.models.unet_2d_condition.UNet2DConditionOutput(
out_sample)
else:
return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
class DSVAE(torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True):
super().__init__()
self.vae = vae
self.device = self.vae.device
self.dtype = self.vae.dtype
self.vae.requires_grad_(requires_grad=False)
self.decoder_cuda_graph_created = False
self.encoder_cuda_graph_created = False
self.all_cuda_graph_created = False
self.enable_cuda_graph = enable_cuda_graph
def _graph_replay_decoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_decoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_decoder_kwargs[k].copy_(kwargs[k])
self._decoder_cuda_graph.replay()
return self.static_decoder_output
def _decode(self, x, return_dict=True):
return self.vae.decode(x, return_dict=return_dict)
def _create_cuda_graph_decoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._decode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._decoder_cuda_graph = torch.cuda.CUDAGraph()
self.static_decoder_inputs = inputs
self.static_decoder_kwargs = kwargs
with torch.cuda.graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs,
**self.static_decoder_kwargs)
self.decoder_cuda_graph_created = True
def decode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.decoder_cuda_graph_created:
outputs = self._graph_replay_decoder(*inputs, **kwargs)
else:
self._create_cuda_graph_decoder(*inputs, **kwargs)
outputs = self._graph_replay_decoder(*inputs, **kwargs)
return outputs
else:
return self._decode(*inputs, **kwargs)
def _graph_replay_encoder(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_encoder_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_encoder_kwargs[k].copy_(kwargs[k])
self._encoder_cuda_graph.replay()
return self.static_encoder_output
def _encode(self, x, return_dict=True):
return self.vae.encode(x, return_dict=return_dict)
def _create_cuda_graph_encoder(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._encode(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._encoder_cuda_graph = torch.cuda.CUDAGraph()
self.static_encoder_inputs = inputs
self.static_encoder_kwargs = kwargs
with torch.cuda.graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs,
**self.static_encoder_kwargs)
self.encoder_cuda_graph_created = True
def encode(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.encoder_cuda_graph_created:
outputs = self._graph_replay_encoder(*inputs, **kwargs)
else:
self._create_cuda_graph_encoder(*inputs, **kwargs)
outputs = self._graph_replay_encoder(*inputs, **kwargs)
return outputs
else:
return self._encode(*inputs, **kwargs)
def _graph_replay_all(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._all_cuda_graph.replay()
return self.static_output
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay_all(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay_all(*inputs, **kwargs)
return outputs
else:
return self._forward(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._all_cuda_graph = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with torch.cuda.graph(self._all_cuda_graph):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
self.all_cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
return self.vae(sample, timestamp, encoder_hidden_states, return_dict)
......@@ -5,13 +5,19 @@ import torch
class DSClipEncoder(torch.nn.Module):
def __init__(self, enc):
def __init__(self, enc, enable_cuda_graph=False):
super().__init__()
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
self.enc = enc
self.device = self.enc.device
self.dtype = self.enc.dtype
self.cuda_graph_created = False
self.cuda_graph_created = [False, False]
self.static_inputs = [None, None]
self.static_kwargs = [None, None]
self.static_output = [None, None]
self._cuda_graphs = [None, None]
self.iter = 0
self.enable_cuda_graph = enable_cuda_graph
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
mask = torch.empty(bsz,
......@@ -27,20 +33,24 @@ class DSClipEncoder(torch.nn.Module):
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
self.static_inputs[self.iter][i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
return self.static_output
self.static_kwargs[self.iter][k].copy_(kwargs[k])
self._cuda_graphs[self.iter].replay()
return self.static_output[self.iter]
def forward(self, *inputs, **kwargs):
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
if self.enable_cuda_graph:
if self.cuda_graph_created[self.iter]:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
self.iter = (self.iter + 1) % 2
return outputs
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
return outputs
return self.enc(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
......@@ -52,15 +62,16 @@ class DSClipEncoder(torch.nn.Module):
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
self.static_inputs[self.iter] = inputs
self.static_kwargs[self.iter] = kwargs
with torch.cuda.graph(self._cuda_graphs):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
with torch.cuda.graph(self._cuda_graphs[self.iter]):
self.static_output[self.iter] = self._forward(
*self.static_inputs[self.iter],
**self.static_kwargs[self.iter])
self.cuda_graph_created = True
self.cuda_graph_created[self.iter] = True
def _forward(self, *inputs, **kwargs):
return self.enc(*inputs, **kwargs)
......@@ -104,10 +104,10 @@ class DeepSpeedTransformerInference(nn.Module):
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size,
input.size()[0],
self.config.heads,
input.size()[1],
input.size()[0],
DeepSpeedTransformerInference.layer_id,
self.config.heads,
self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0,
......@@ -151,12 +151,10 @@ class DeepSpeedTransformerInference(nn.Module):
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
if not self.config.pre_layer_norm:
ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \
inference_cuda_module.layer_norm_fp32
output = ds_layernorm(output,
self.norm_w,
self.norm_b,
self.config.epsilon)
output = inference_cuda_module.layer_norm(output,
self.norm_w,
self.norm_b,
self.config.epsilon)
output = output.to(input_type)
if get_present:
......
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer
from .module_quantize import quantize_transformer_layer
from .replace_policy import DSPolicy, HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
......@@ -195,8 +195,8 @@ def _module_match(module):
return None
def generic_injection(module, fp16=False):
def replace_attn(child, policy, layer_id):
def generic_injection(module, fp16=False, enable_cuda_graph=True):
def replace_attn(child, policy):
policy_attn = policy.attention(child)
if policy_attn is None:
return child
......@@ -212,7 +212,7 @@ def generic_injection(module, fp16=False):
triangular_masking=False,
max_out_tokens=4096,
)
attn_module = transformer_inference.DeepSpeedAttention(config)
attn_module = transformer_inference.DeepSpeedDiffusersAttention(config)
def transpose(data):
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
......@@ -233,13 +233,24 @@ def generic_injection(module, fp16=False):
attn_module.attn_ob.data.copy_(attn_ob.data.to(torch.cuda.current_device()))
return attn_module
def replace_attn_block(child, policy):
config = transformer_inference.Diffusers2DTransformerConfig()
return transformer_inference.DeepSpeedDiffusersTransformerBlock(child, config)
if isinstance(module, torch.nn.Module):
pass
else:
if fp16 is False:
raise ValueError("Generic injection only supported with FP16")
try:
import diffusers
cross_attention = diffusers.models.attention.CrossAttention
new_policies = {cross_attention: replace_attn}
attention_block = diffusers.models.attention.BasicTransformerBlock
new_policies = {
cross_attention: replace_attn,
attention_block: replace_attn_block,
}
except ImportError:
new_policies = {}
......@@ -247,30 +258,29 @@ def generic_injection(module, fp16=False):
# module.text_encoder,
# training=False,
# replace_with_kernel_inject=True,
# triangular_masking=True)
#from .encoder import DSClipEncoder
#cg_encoder = DSClipEncoder(module.text_encoder)
#setattr(module, 'text_encoder', cg_encoder)
# triangular_masking=True,
# max_out_tokens=8192)
from ..model_implementations.transformers.clip_encoder import DSClipEncoder
cg_encoder = DSClipEncoder(module.text_encoder,
enable_cuda_graph=enable_cuda_graph)
setattr(module, 'text_encoder', cg_encoder)
for name in module.__dict__.keys():
sub_module = getattr(module, name)
policy = _module_match(sub_module)
if policy is not None:
def _replace_module(module, policy, layer_id=0):
def _replace_module(module, policy):
for name, child in module.named_children():
_replace_module(child, policy)
if child.__class__ in new_policies:
replaced_module = new_policies[child.__class__](child,
policy,
layer_id)
policy)
setattr(module, name, replaced_module)
layer_id += 1
else:
layer_id = _replace_module(child, policy, layer_id=layer_id)
return layer_id
_replace_module(sub_module, policy)
new_module = policy.apply(sub_module)
new_module = policy.apply(sub_module,
enable_cuda_graph=enable_cuda_graph)
print(f"**** found and replaced {name} w. {type(new_module)}")
setattr(module, name, new_module)
......
......@@ -39,9 +39,10 @@ class UNetPolicy(DSPolicy):
def match(self, module):
return isinstance(module, self._orig_layer_class)
def apply(self, module):
from .unet import DSUNet
return DSUNet(module)
def apply(self, module, enable_cuda_graph=True):
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
from ..model_implementations.diffusers.unet import DSUNet
return DSUNet(module, enable_cuda_graph=enable_cuda_graph)
def attention(self, client_module):
qw = client_module.to_q.weight
......@@ -67,6 +68,24 @@ class UNetPolicy(DSPolicy):
client_module.heads
class VAEPolicy(DSPolicy):
def __init__(self):
super().__init__()
try:
import diffusers
self._orig_layer_class = diffusers.models.vae.AutoencoderKL
except ImportError:
self._orig_layer_class = None
def match(self, module):
return isinstance(module, self._orig_layer_class)
def apply(self, module, enable_cuda_graph=True):
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
from ..model_implementations.diffusers.vae import DSVAE
return DSVAE(module, enable_cuda_graph=enable_cuda_graph)
class TransformerPolicy(DSPolicy):
# a static class variable containing the HuggingFace model configuration.
# see e.g., transformers.models.opt.configuration_opt.OPTConfig
......@@ -185,7 +204,7 @@ class HFBertLayerPolicy(TransformerPolicy):
class HFCLIPLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=False):
super().__init__(inference, pre_attn_norm=True, scale_attention=False)
super().__init__(inference, pre_attn_norm=True, scale_attention=True)
self.client_module = client_module
self.cuda_graph_supported = True
......@@ -608,4 +627,4 @@ replace_policies = [
]
# non-transformer-based policies
generic_policies = [UNetPolicy]
generic_policies = [UNetPolicy, VAEPolicy]
......@@ -2,4 +2,6 @@ from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .inference.config import DeepSpeedInferenceConfig
from ...model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
from .inference.moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference
from .inference.attention import DeepSpeedAttention
from .inference.diffusers_attention import DeepSpeedDiffusersAttention
from .inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from .inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
from .config import DeepSpeedInferenceConfig
from ....model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
from .moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference
from .attention import DeepSpeedAttention
from .diffusers_attention import DeepSpeedDiffusersAttention
from .diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from .diffusers_2d_transformer import Diffusers2DTransformerConfig
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from typing import Optional
import torch
from ... import op_builder
spatial_cuda_module = None
def nhwc_bias_add(activation: torch.Tensor,
bias: torch.Tensor,
other: Optional[torch.Tensor] = None,
other_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
global spatial_cuda_module
if spatial_cuda_module is None:
spatial_cuda_module = op_builder.SpatialInferenceBuilder().load()
if other is None:
return spatial_cuda_module.nhwc_bias_add(activation, bias)
elif other_bias is None:
return spatial_cuda_module.nhwc_bias_add_add(activation, bias, other)
else:
return spatial_cuda_module.nhwc_bias_add_bias_add(activation,
bias,
other,
other_bias)
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
class Diffusers2DTransformerConfig():
def __init__(self, int8_quantization=False):
self.int8_quantization = int8_quantization
'''
Copyright 2020 The Microsoft DeepSpeed Team
Copyright 2022 The Microsoft DeepSpeed Team
'''
import math
import torch
......@@ -27,7 +27,7 @@ def load_triton_flash_attn():
from .triton_ops import triton_flash_attn
class DeepSpeedAttentionFunction(Function):
class DeepSpeedDiffusersAttentionFunction(Function):
@staticmethod
def forward(ctx,
input,
......@@ -44,6 +44,7 @@ class DeepSpeedAttentionFunction(Function):
hidden_size_per_partition,
attn_ow,
attn_ob,
do_out_bias,
score_context_func,
linear_func,
triton_flash_attn_kernel):
......@@ -82,7 +83,7 @@ class DeepSpeedAttentionFunction(Function):
config.window_size,
no_masking,
config.layer_id,
DeepSpeedAttention.layer_id,
DeepSpeedDiffusersAttention.layer_id,
torch.empty(1))
return context_layer
......@@ -97,10 +98,8 @@ class DeepSpeedAttentionFunction(Function):
attn_qkvw,
attn_qkvb if attn_qkvb is not None else attn_qkvw,
attn_qkvb is not None,
True,
do_flash_attn,
config.heads,
DeepSpeedAttention.layer_id)
config.heads)
if do_flash_attn:
context_layer = triton_flash_attn_kernel(qkv_out[0],
qkv_out[1],
......@@ -134,11 +133,9 @@ class DeepSpeedAttentionFunction(Function):
output = linear_func(context_layer,
attn_ow,
attn_ob,
attn_ob is not None,
True,
do_out_bias,
False,
config.heads,
DeepSpeedAttention.layer_id)
config.heads)
return output
output = selfAttention_fp(input, context, input_mask)
......@@ -151,7 +148,7 @@ class DeepSpeedAttentionFunction(Function):
Please switch to Training mode for running backward!')
class DeepSpeedAttention(nn.Module):
class DeepSpeedDiffusersAttention(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Arguments:
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
......@@ -164,11 +161,11 @@ class DeepSpeedAttention(nn.Module):
self,
config,
):
super(DeepSpeedAttention, self).__init__()
super(DeepSpeedDiffusersAttention, self).__init__()
self.config = config
self.config.layer_id = DeepSpeedAttention.layer_id
DeepSpeedAttention.layer_id += 1
self.config.layer_id = DeepSpeedDiffusersAttention.layer_id
DeepSpeedDiffusersAttention.layer_id += 1
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
......@@ -179,7 +176,7 @@ class DeepSpeedAttention(nn.Module):
builder = op_builder.InferenceBuilder()
inference_cuda_module = builder.load()
if DeepSpeedAttention.layer_id == 1:
if DeepSpeedDiffusersAttention.layer_id == 1:
log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0])
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
......@@ -217,6 +214,8 @@ class DeepSpeedAttention(nn.Module):
dtype=data_type_fp,
device=device),
requires_grad=False)
self.do_out_bias = True
if triton_flash_attn is None:
load_triton_flash_attn()
self.triton_flash_attn_kernel = triton_flash_attn()
......@@ -235,81 +234,38 @@ class DeepSpeedAttention(nn.Module):
inference_cuda_module.softmax_context_fp16
self.linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \
inference_cuda_module.linear_layer_fp32
self.cuda_graph_created = False
self.enable_cuda_graph = False
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
inference_cuda_module.allocate_workspace_fp16
self.iter = 0
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
return self.static_output
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with torch.cuda.graph(self._cuda_graphs):
self.static_output = self._forward(*self.static_inputs, **self.static_kwargs)
self.cuda_graph_created = True
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if not (config.fp16) else \
inference_cuda_module.allocate_workspace_fp16
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self._forward(*inputs, **kwargs)
return outputs
def _forward(self, input, context=None, input_mask=None):
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self.iter == 0:
self.iter += 1
def forward(self, input, context=None, input_mask=None):
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size,
input.size()[0],
input.size()[1],
DeepSpeedAttention.layer_id,
self.config.heads,
input.size()[1],
input.size()[0],
DeepSpeedDiffusersAttention.layer_id,
self.config.mp_size,
self.config.bigscience_bloom,
False,
0,
self.config.max_out_tokens)
output = DeepSpeedAttentionFunction.apply(input,
context,
input_mask,
self.config,
self.attn_qkvw,
self.attn_qw,
self.attn_kw,
self.attn_vw,
self.attn_qkvb,
self.num_attention_heads_per_partition,
self.norm_factor,
self.hidden_size_per_partition,
self.attn_ow,
self.attn_ob,
self.score_context_func,
self.linear_func,
self.triton_flash_attn_kernel)
output = DeepSpeedDiffusersAttentionFunction.apply(
input,
context,
input_mask,
self.config,
self.attn_qkvw,
self.attn_qw,
self.attn_kw,
self.attn_vw,
self.attn_qkvb,
self.num_attention_heads_per_partition,
self.norm_factor,
self.hidden_size_per_partition,
self.attn_ow,
self.attn_ob,
self.do_out_bias,
self.score_context_func,
self.linear_func,
self.triton_flash_attn_kernel)
return output
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
import torch.nn as nn
from ... import op_builder
from ....module_inject import GroupQuantizer
from .diffusers_attention import DeepSpeedDiffusersAttention
from .bias_add import nhwc_bias_add
from .diffusers_2d_transformer import Diffusers2DTransformerConfig
# Ops will be loaded on demand
transformer_cuda_module = None
spatial_cuda_module = None
def load_transformer_module():
global transformer_cuda_module
if transformer_cuda_module is None:
transformer_cuda_module = op_builder.InferenceBuilder().load()
return transformer_cuda_module
def load_spatial_module():
global spatial_cuda_module
if spatial_cuda_module is None:
spatial_cuda_module = op_builder.SpatialInferenceBuilder().load()
return spatial_cuda_module
class DeepSpeedDiffusersTransformerBlock(nn.Module):
def __init__(self,
equivalent_module: nn.Module,
config: Diffusers2DTransformerConfig):
super(DeepSpeedDiffusersTransformerBlock, self).__init__()
self.quantizer = GroupQuantizer(q_int8=config.int8_quantization)
# Ensure ops are built by the time we start running
self.config = config
self.ff1_w = self.quantizer.quantize(
nn.Parameter(equivalent_module.ff.net[0].proj.weight.data,
requires_grad=False))
self.ff1_b = nn.Parameter(equivalent_module.ff.net[0].proj.bias.data,
requires_grad=False)
self.ff2_w = self.quantizer.quantize(
nn.Parameter(equivalent_module.ff.net[2].weight.data,
requires_grad=False))
self.ff2_b = nn.Parameter(equivalent_module.ff.net[2].bias.data,
requires_grad=False)
self.norm1_g = nn.Parameter(equivalent_module.norm1.weight.data,
requires_grad=False)
self.norm1_b = nn.Parameter(equivalent_module.norm1.bias.data,
requires_grad=False)
self.norm1_eps = equivalent_module.norm1.eps
self.norm2_g = nn.Parameter(equivalent_module.norm2.weight.data,
requires_grad=False)
self.norm2_b = nn.Parameter(equivalent_module.norm2.bias.data,
requires_grad=False)
self.norm2_eps = equivalent_module.norm2.eps
self.norm3_g = nn.Parameter(equivalent_module.norm3.weight.data,
requires_grad=False)
self.norm3_b = nn.Parameter(equivalent_module.norm3.bias.data,
requires_grad=False)
self.norm3_eps = equivalent_module.norm3.eps
self.attn_1 = equivalent_module.attn1
self.attn_2 = equivalent_module.attn2
# Pull the bias in if we can
if isinstance(self.attn_1, DeepSpeedDiffusersAttention):
self.attn_1.do_out_bias = False
self.attn_1_bias = self.attn_1.attn_ob
else:
self.attn_1_bias = nn.Paramaeter(torch.zeros_like(self.norm2_g),
requires_grad=False)
# Pull the bias in if we can
if isinstance(self.attn_2, DeepSpeedDiffusersAttention):
self.attn_2.do_out_bias = False
self.attn_2_bias = self.attn_2.attn_ob
else:
self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g),
requires_grad=False)
self.transformer_cuda_module = load_transformer_module()
load_spatial_module()
def forward(self, hidden_states, context=None, timestep=None):
out_norm_1 = self.transformer_cuda_module.layer_norm(hidden_states,
self.norm1_g,
self.norm1_b,
self.norm1_eps)
out_attn_1 = self.attn_1(out_norm_1)
out_norm_2, out_attn_1 = self.transformer_cuda_module.layer_norm_residual_store(out_attn_1,
self.attn_1_bias,
hidden_states,
self.norm2_g,
self.norm2_b,
self.norm2_eps)
out_attn_2 = self.attn_2(out_norm_2, context=context)
out_norm_3, out_attn_2 = self.transformer_cuda_module.layer_norm_residual_store(out_attn_2,
self.attn_2_bias,
out_attn_1,
self.norm3_g,
self.norm3_b,
self.norm3_eps)
out_ff1 = nn.functional.linear(out_norm_3, self.ff1_w)
out_geglu = self.transformer_cuda_module.bias_geglu(out_ff1, self.ff1_b)
out_ff2 = nn.functional.linear(out_geglu, self.ff2_w)
return nhwc_bias_add(out_ff2, self.ff2_b, other=out_attn_2)
......@@ -308,9 +308,7 @@ class DeepSpeedSelfAttentionFunction(Function):
attn_qkvb,
attn_qkvb is not None,
False,
False,
num_attention_heads_per_partition,
DeepSpeedSelfAttention.num_layers)
num_attention_heads_per_partition)
else:
qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \
inference_cuda_module.qkv_gemm_fp32
......
......@@ -12,6 +12,7 @@ from .utils import UtilsBuilder
from .async_io import AsyncIOBuilder
from .transformer_inference import InferenceBuilder
from .quantizer import QuantizerBuilder
from .spatial_inference import SpatialInferenceBuilder
from .builder import get_default_compute_capabilities, OpBuilder
# TODO: infer this list instead of hard coded
......@@ -27,6 +28,7 @@ __op_builders__ = [
AsyncIOBuilder(),
UtilsBuilder(),
QuantizerBuilder(),
InferenceBuilder()
InferenceBuilder(),
SpatialInferenceBuilder(),
]
ALL_OPS = {op.name: op for op in __op_builders__}
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from .builder import CUDAOpBuilder, installed_cuda_version
class SpatialInferenceBuilder(CUDAOpBuilder):
BUILD_VAR = "DS_BUILD_SPATIAL_INFERENCE"
NAME = "spatial_inference"
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
def absolute_name(self):
return f'deepspeed.ops.spatial.{self.NAME}_op'
def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
self.warning(
"Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
if not self.is_rocm_pytorch() and torch.cuda.is_available():
sys_cuda_major, _ = installed_cuda_version()
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
self.warning(
"On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
def sources(self):
return [
'csrc/spatial/csrc/opt_bias_add.cu',
'csrc/spatial/csrc/pt_binding.cpp',
]
def include_paths(self):
return ['csrc/spatial/includes', 'csrc/includes']
......@@ -37,7 +37,7 @@ class InferenceBuilder(CUDAOpBuilder):
'csrc/transformer/inference/csrc/pt_binding.cpp',
'csrc/transformer/inference/csrc/gelu.cu',
'csrc/transformer/inference/csrc/relu.cu',
'csrc/transformer/inference/csrc/normalize.cu',
'csrc/transformer/inference/csrc/layer_norm.cu',
'csrc/transformer/inference/csrc/softmax.cu',
'csrc/transformer/inference/csrc/dequantize.cu',
'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu',
......
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import pytest
import torch
from deepspeed.ops.transformer.inference.bias_add import nhwc_bias_add
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-3, 5e-4), torch.float16: (3e-2, 2e-3), torch.int8: (1, 1)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)
def ref_bias_add(activations, bias):
return activations + bias.reshape(1, -1, 1, 1)
channels_list = [
192,
384,
320,
576,
640,
768,
960,
1152,
1280,
1536,
1600,
1920,
2240,
2560
]
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2, 10])
@pytest.mark.parametrize("image_size", [16, 32, 64])
@pytest.mark.parametrize("channels", channels_list)
def test_bias_add(batch, image_size, channels):
activations = torch.randn((batch,
channels,
image_size,
image_size),
dtype=torch.float16,
device="cuda").to(memory_format=torch.channels_last)
bias = torch.randn((channels), dtype=torch.float16, device="cuda")
ref_vals = ref_bias_add(activations.clone().detach(), bias)
ds_vals = nhwc_bias_add(activations, bias)
assert allclose(ds_vals, ref_vals)
def ref_bias_add_add(activations, bias, other):
return (activations + bias.reshape(1, -1, 1, 1)) + other
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2, 10])
@pytest.mark.parametrize("image_size", [16, 32, 64])
@pytest.mark.parametrize("channels", channels_list)
def test_bias_add_add(batch, image_size, channels):
activations = torch.randn((batch,
channels,
image_size,
image_size),
dtype=torch.float16,
device="cuda").to(memory_format=torch.channels_last)
other = torch.randn((batch,
channels,
image_size,
image_size),
dtype=torch.float16,
device="cuda").to(memory_format=torch.channels_last)
bias = torch.randn((channels), dtype=torch.float16, device="cuda")
ref_vals = ref_bias_add_add(activations.clone().detach(), bias, other)
ds_vals = nhwc_bias_add(activations, bias, other=other)
assert allclose(ds_vals, ref_vals)
def ref_bias_add_bias_add(activations, bias, other, other_bias):
return (activations + bias.reshape(1,
-1,
1,
1)) + (other + other_bias.reshape(1,
-1,
1,
1))
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2, 10])
@pytest.mark.parametrize("image_size", [16, 32, 64])
@pytest.mark.parametrize("channels", channels_list)
def test_bias_add_bias_add(batch, image_size, channels):
activations = torch.randn((batch,
channels,
image_size,
image_size),
dtype=torch.float16,
device="cuda").to(memory_format=torch.channels_last)
other = torch.randn((batch,
channels,
image_size,
image_size),
dtype=torch.float16,
device="cuda").to(memory_format=torch.channels_last)
bias = torch.randn((channels), dtype=torch.float16, device="cuda")
other_bias = torch.randn((channels), dtype=torch.float16, device="cuda")
ref_vals = ref_bias_add_bias_add(activations.clone().detach(),
bias,
other,
other_bias)
ds_vals = nhwc_bias_add(activations, bias, other=other, other_bias=other_bias)
assert allclose(ds_vals, ref_vals)
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system",
allow_module_level=True)
inference_module = None
torch_minor_version = None
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-3, 5e-4), torch.float16: (3e-2, 2e-3), torch.int8: (0, 0)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)
def run_bias_geglu_reference(activations, bias):
# Expected behavior is that of casting to float32 internally
# Explicitly using the default GeLU
activations = activations + bias.reshape(1, 1, -1)
hidden_states, gate = activations.chunk(2, dim=-1)
return hidden_states * torch.nn.functional.gelu(gate.to(torch.float32)).to(
activations.dtype)
def run_bias_geglu_ds(activation, bias):
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
return inference_module.bias_geglu(activation, bias)
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_bias_geglu(batch, sequence, channels, dtype):
activation = torch.randn((batch, sequence, channels * 2), dtype=dtype, device='cuda')
bias = torch.randn((channels * 2), dtype=dtype, device='cuda')
ds_out = run_bias_geglu_ds(activation, bias)
ref_out = run_bias_geglu_reference(activation, bias)
assert (allclose(ds_out, ref_out))
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import deepspeed
import torch
import pytest
from deepspeed.ops.op_builder import InferenceBuilder
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system",
allow_module_level=True)
inference_module = None
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)
def ref_implementation(vals, gamma, beta, espilon, channels, dtype):
vals_f = vals.to(torch.float32)
gamma_f = gamma.to(torch.float32)
beta_f = beta.to(torch.float32)
return torch.nn.functional.layer_norm(vals_f,
(channels,
),
weight=gamma_f,
bias=beta_f).to(dtype)
def ds_implementation(vals, gamma, beta, epsilon):
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
return inference_module.layer_norm(vals, gamma, beta, epsilon)
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 32])
@pytest.mark.parametrize("seq_len", [1, 128])
@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_layer_norm(batch, seq_len, channels, dtype):
vals = torch.randn((batch,
seq_len,
channels),
dtype=dtype,
device=torch.cuda.current_device())
gamma = torch.randn((channels), dtype=dtype, device=torch.cuda.current_device())
beta = torch.rand((channels), dtype=dtype, device=torch.cuda.current_device())
epsilon = 1e-5
ref_output = ref_implementation(vals, gamma, beta, epsilon, channels, dtype)
new_output = ds_implementation(vals, gamma, beta, epsilon)
assert allclose(new_output, ref_output)
def residual_ref_implementation(vals, bias, res, gamma, beta, espilon, channels, dtype):
vals_f = vals.to(torch.float32)
bias_f = bias.to(torch.float32).reshape(1, 1, -1)
res_f = res.to(torch.float32)
gamma_f = gamma.to(torch.float32)
beta_f = beta.to(torch.float32)
return torch.nn.functional.layer_norm(vals_f + bias_f + res_f,
(channels,
),
weight=gamma_f,
bias=beta_f).to(dtype)
def residual_ds_implementation(vals, bias, res, gamma, beta, epsilon):
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
return inference_module._layer_norm_residual(vals, bias, res, gamma, beta, epsilon)
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 32])
@pytest.mark.parametrize("seq_len", [1, 128])
@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_layer_norm_residual(batch, seq_len, channels, dtype):
vals = torch.randn((batch,
seq_len,
channels),
dtype=dtype,
device=torch.cuda.current_device())
residual = torch.randn((batch,
seq_len,
channels),
dtype=dtype,
device=torch.cuda.current_device())
bias = torch.randn((channels), dtype=dtype, device=torch.cuda.current_device())
gamma = torch.randn((channels), dtype=dtype, device=torch.cuda.current_device())
beta = torch.rand((channels), dtype=dtype, device=torch.cuda.current_device())
epsilon = 1e-5
new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon)
ref_output = residual_ref_implementation(vals,
bias,
residual,
gamma,
beta,
epsilon,
channels,
dtype)
assert allclose(new_output, ref_output)
def residual_store_ref_implementation(vals,
bias,
res,
gamma,
beta,
espilon,
channels,
dtype):
vals_f = vals.to(torch.float32)
bias_f = bias.to(torch.float32).reshape(1, 1, -1)
res_f = res.to(torch.float32)
gamma_f = gamma.to(torch.float32)
beta_f = beta.to(torch.float32)
res_output = vals_f + bias_f + res_f
norm_output = torch.nn.functional.layer_norm(res_output,
(channels,
),
weight=gamma_f,
bias=beta_f).to(dtype)
return norm_output, res_output.to(dtype)
def residual_store_ds_implementation(vals, bias, res, gamma, beta, epsilon):
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
return inference_module.layer_norm_residual_store(vals,
bias,
res,
gamma,
beta,
epsilon)
@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 32])
@pytest.mark.parametrize("seq_len", [1, 128])
@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_layer_norm_residual_store(batch, seq_len, channels, dtype):
vals = torch.randn((batch,
seq_len,
channels),
dtype=dtype,
device=torch.cuda.current_device())
residual = torch.randn((batch,
seq_len,
channels),
dtype=dtype,
device=torch.cuda.current_device())
bias = torch.randn((channels), dtype=dtype, device=torch.cuda.current_device())
gamma = torch.randn((channels), dtype=dtype, device=torch.cuda.current_device())
beta = torch.rand((channels), dtype=dtype, device=torch.cuda.current_device())
epsilon = 1e-5
# Need to run the reference first since there's an in-place component to ours
ref_norm_output, norm_res_output = residual_store_ref_implementation(vals,
bias,
residual,
gamma,
beta,
epsilon,
channels,
dtype)
ds_norm_output, ds_res_output = residual_store_ds_implementation(vals, bias, residual, gamma, beta, epsilon)
assert allclose(ds_norm_output, ref_norm_output)
assert allclose(ds_res_output, norm_res_output)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册