diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 1b2a1f14d911dc89aff98dc20827f76a64b8d67b..d79c2455a21a2bf7a4f0d1279c3f25f9b9ce350c 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1272,8 +1272,6 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), (pdef('GeneralNorm') .add_fields('bool', 'affine', 'true') .add_fields('float32', 'eps', '1e-5f') - .add_fields('uint64', 'normalized_dim', '1') - .add_fields('uint64', 'normalized_size', '1') .add_fields('uint64', 'normalized_axis', '0') ) diff --git a/dnn/src/cuda/general_norm/general_norm_cuda.cu b/dnn/src/cuda/general_norm/general_norm_cuda.cu index 8195632ca1732b9d1a8c9e6d9ad3b1005bc9026c..d6be625947d807cafcd2d7b64342b9bed5fe3997 100644 --- a/dnn/src/cuda/general_norm/general_norm_cuda.cu +++ b/dnn/src/cuda/general_norm/general_norm_cuda.cu @@ -11,635 +11,122 @@ namespace megdnn { namespace cuda { namespace general_norm { -constexpr int kCUDANumThreads = 256; -constexpr int vec_size = 4; - -// warp size may be used as array length, or used in host function, -// so we define WARP_SIZE rather than using warpSize -#define WARP_SIZE 32 - -#if defined(__clang__) -#define __ubsan_ignore_float_divide_by_zero__ \ - __attribute__((no_sanitize("float-divide-by-zero"))) -#else -#define __ubsan_ignore_float_divide_by_zero__ -#endif - -struct WelfordStat { - float mean; - float sigma2; - float count; - MEGDNN_HOST MEGDNN_DEVICE WelfordStat() : mean(0.f), sigma2(0.f), count(0.f) {} - MEGDNN_HOST MEGDNN_DEVICE WelfordStat(float mean, float sigma2, float count) - : mean(mean), sigma2(sigma2), count(count) {} -}; - -template -struct WelfordData { - T mean; - T sigma2; - combine_t count; - - MEGDNN_HOST MEGDNN_DEVICE WelfordData() : mean(0), sigma2(0), count(0) {} - - MEGDNN_HOST MEGDNN_DEVICE WelfordData(T mean, T sigma2, combine_t count) - : mean(mean), sigma2(sigma2), count(count) {} -}; - -template -struct WelfordOps { -public: - using WelfordData_T = WelfordData; - inline MEGDNN_DEVICE WelfordData_T reduce(WelfordData_T acc, T data) const { - T delta = data - acc.mean; - T new_mean = static_cast(acc.mean + delta / (acc.count + 1)); - T new_delta = static_cast(data - new_mean); - return { - new_mean, - acc.sigma2 + delta * new_delta, - combine_t(acc.count + 1), - }; - } - inline MEGDNN_DEVICE WelfordData_T - combine(WelfordData_T lhs, WelfordData_T rhs) const { - if (lhs.count != 0 && rhs.count != 0) { - T delta = rhs.mean - lhs.mean; - combine_t new_count = lhs.count + rhs.count; - T nb_over_n = rhs.count / new_count; - return {lhs.mean + delta * nb_over_n, - lhs.sigma2 + rhs.sigma2 + delta * delta * lhs.count * nb_over_n, - new_count}; - } else { - return (lhs.count != 0) ? lhs : rhs; - } - } - inline MEGDNN_DEVICE res_t - project(WelfordData_T acc) const __ubsan_ignore_float_divide_by_zero__ { - const auto mean = static_cast(acc.mean); - const combine_t divisor = static_cast(acc.count); - const auto var = acc.sigma2 / divisor; - res_t results(var, mean); - return results; - } - -#if defined(__CUDACC__) || defined(__HIPCC__) - inline MEGDNN_DEVICE WelfordData_T - warp_shfl_down(WelfordData_T acc, int offset) const { - return {__shfl_down(acc.mean, offset, warpSize), - __shfl_down(acc.sigma2, offset, warpSize), - __shfl_down(acc.count, offset, warpSize)}; - } -#endif - MEGDNN_HOST MEGDNN_DEVICE WelfordOps() {} -}; - -template -struct alignas(sizeof(T) * vec_size) aligned_vector { - T val[vec_size]; -}; - -template -using acc_type = T; - -template -MEGDNN_DEVICE WelfordStat -update_welford_stat_online(const U val, const WelfordStat& curr_sum) { - U delta = static_cast(val - curr_sum.mean); - U new_count = static_cast(curr_sum.count + 1.f); - U new_mean = static_cast(curr_sum.mean + delta * (1.f / new_count)); - return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; -} - -MEGDNN_DEVICE WelfordStat -combine_welford_stat(const WelfordStat lhs, const WelfordStat rhs) { - using U = decltype(lhs.count); - U delta = lhs.mean - rhs.mean; - U count = rhs.count + lhs.count; - U mean, sigma2; - if (count > decltype(lhs.count){0}) { - auto coef = 1.f / count; - auto nA = rhs.count * coef; - auto nB = lhs.count * coef; - mean = nA * rhs.mean + nB * lhs.mean; - sigma2 = rhs.sigma2 + lhs.sigma2 + delta * delta * rhs.count * nB; - } else { - mean = U(0); - sigma2 = U(0); - } - return {mean, sigma2, count}; -} - -template -MEGDNN_DEVICE WelfordStat -compute_stats(const T* __restrict__ X, const int slice_len, float* buf) { - using vec_t = aligned_vector; - using acc_t = acc_type; - const vec_t* X_vec = reinterpret_cast(X); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const int n_vec_to_read = slice_len / vec_size; - WelfordStat w_stat(0.f, 0.f, 0.f); - for (int i = thrx; i < n_vec_to_read; i += numx) { - vec_t data = X_vec[i]; -#pragma unroll - for (int ii = 0; ii < vec_size; ii++) { - w_stat = update_welford_stat_online( - static_cast(data.val[ii]), w_stat); - } - } - // intra-warp reduction -#pragma unroll - for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { - WelfordStat w_tmp{ - __shfl_down(w_stat.mean, offset, warpSize), - __shfl_down(w_stat.sigma2, offset, warpSize), - __shfl_down(w_stat.count, offset, warpSize)}; - w_stat = combine_welford_stat(w_stat, w_tmp); - } - - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float* mean_sigma_buf = buf; - float* count_buf = buf + blockDim.y; - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - mean_sigma_buf[2 * wrt_y] = w_stat.mean; - mean_sigma_buf[2 * wrt_y + 1] = w_stat.sigma2; - count_buf[wrt_y] = w_stat.count; +template +__global__ void forward_kernel( + T* X_data, T* weight_data, T* bias_data, T* Y_data, T_ACC* mean_data, + T_ACC* rstd_data, T_ACC eps, int64_t A, int64_t B, int64_t C, cudaStream_t stream) { + for (int64_t a = 0; a < A; ++a) + for (int64_t c = 0; c < C; ++c) { + T_ACC slice_sum = static_cast(0.0f); + for (int64_t b = 0; b < B; b++) { + auto value = X_data[a * B * C + b * C + c]; + slice_sum += value; } - __syncthreads(); + T_ACC slice_mean = static_cast(slice_sum / B); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - WelfordStat w_tmp{ - mean_sigma_buf[2 * threadIdx.y], - mean_sigma_buf[2 * threadIdx.y + 1], count_buf[threadIdx.y]}; - w_stat = combine_welford_stat(w_stat, w_tmp); + T_ACC slice_var = static_cast(0.0f); + for (int64_t b = 0; b < B; b++) { + slice_var += (X_data[a * B * C + b * C + c] - slice_mean) * + (X_data[a * B * C + b * C + c] - slice_mean); } - __syncthreads(); - } - if (threadIdx.x == 0 && threadIdx.y == 0) { - mean_sigma_buf[0] = w_stat.mean; - mean_sigma_buf[1] = w_stat.sigma2 / float(slice_len); - } - __syncthreads(); - return WelfordStat{mean_sigma_buf[0], mean_sigma_buf[1], 0.f}; - - } else { - return WelfordStat{ - __shfl(w_stat.mean, 0, warpSize), - __shfl(w_stat.sigma2, 0, warpSize) / float(slice_len), 0.f}; - } -} - -template -__global__ void vectorized_general_norm_forward_affine_kernel( - const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, - const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { - // if we made smem WelfordStat type, there would be bank conflicts, - // as one thread would have to write 3 consecutive floats - extern __shared__ float s_data[]; - - auto slice_id = blockIdx.x; - const T* slice = X + slice_id * slice_len; - WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); - using vec_t = aligned_vector; - const vec_t* X_vec = reinterpret_cast(slice); - vec_t* Y_vec = reinterpret_cast(Y + slice_id * slice_len); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const int n_vec_to_read = slice_len / vec_size; - T_ACC rstd_val = static_cast(rsqrt(slice_w_stat.sigma2 + eps)); - - for (int i = thrx; i < n_vec_to_read; i += numx) { - vec_t data = X_vec[i]; - vec_t out; - // computation is performed in T_ACC, X is cast to T_ACC and result is - // implicitly cast to T - -#pragma unroll - for (int ii = 0; ii < vec_size; ii++) { - out.val[ii] = static_cast(weight[i * vec_size + ii]) * - (rstd_val * (static_cast(data.val[ii]) - - slice_w_stat.mean)) + - static_cast(bias[i * vec_size + ii]); - } - Y_vec[i] = out; - } - if (thrx == 0) { - mean[slice_id] = slice_w_stat.mean; - rstd[slice_id] = rstd_val; - } -} - -template -__global__ void vectorized_general_norm_forward_kernel( - const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, - const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { - extern __shared__ float s_data[]; - - auto slice_id = blockIdx.x; - const T* slice = X + slice_id * slice_len; - WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); - using vec_t = aligned_vector; - const vec_t* X_vec = reinterpret_cast(slice); - vec_t* Y_vec = reinterpret_cast(Y + slice_id * slice_len); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const int n_vec_to_read = slice_len / vec_size; - T_ACC rstd_val = static_cast(rsqrt(slice_w_stat.sigma2 + eps)); - - for (int i = thrx; i < n_vec_to_read; i += numx) { - vec_t data = X_vec[i]; - vec_t out; - -#pragma unroll - for (int ii = 0; ii < vec_size; ii++) { - out.val[ii] = - rstd_val * (static_cast(data.val[ii]) - slice_w_stat.mean); + slice_var = slice_var / B; + + T_ACC slice_std = static_cast(sqrt(slice_var + eps)); + for (int64_t b = 0; b < B; b++) { + Y_data[a * B * C + b * C + c] = + (X_data[a * B * C + b * C + c] - slice_mean) / slice_std; + if (weight_data || bias_data) { + Y_data[a * B * C + b * C + c] = + Y_data[a * B * C + b * C + c] * weight_data[b] + + bias_data[b]; + } + } + mean_data[a * C + c] = static_cast(slice_mean); + rstd_data[a * C + c] = static_cast(1.0 / slice_std); } - Y_vec[i] = out; - } - if (thrx == 0) { - mean[slice_id] = slice_w_stat.mean; - rstd[slice_id] = rstd_val; - } -} - -template -void launch_vectorized_general_norm_forward_kernel( - int64_t slice_len, int64_t slice_num, T_ACC eps, const T* X_data, - const T* weight_data, const T* bias_data, T* Y_data, T_ACC* mean_data, - T_ACC* rstd_data, cudaStream_t stream) { - const int num_threads = 128; - const dim3 threads(WARP_SIZE, num_threads / WARP_SIZE, 1); - const dim3 blocks(slice_num); - int nshared = threads.y > 1 ? threads.y * 3 / 2 * sizeof(T_ACC) : 0; - - if (weight_data == nullptr && bias_data == nullptr) { - vectorized_general_norm_forward_kernel<<>>( - slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, - Y_data); - } else { - vectorized_general_norm_forward_affine_kernel<<< - blocks, threads, nshared, stream>>>( - slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, - Y_data); - } - after_kernel_launch(); -} - -template -__inline__ MEGDNN_DEVICE T welford_warp_reduce(T val, const ReduceOp& op) { -#pragma unroll - for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { - val = op.combine(val, op.warp_shfl_down(val, offset)); - } - return val; -} - -template -__inline__ MEGDNN_DEVICE T -welford_block_reduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { - const int lid = threadIdx.x % warpSize; - const int wid = threadIdx.x / warpSize; - val = welford_warp_reduce(val, op); - __syncthreads(); - if (lid == 0) { - shared[wid] = val; - } - __syncthreads(); - val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : identity_element; - if (wid == 0) { - val = welford_warp_reduce(val, op); - } - return val; -} - -template -__global__ void get_input_mean_and_rstd_kernel( - int64_t slice_len, T_ACC eps, const T* X, T_ACC* mean, T_ACC* rstd) { - using WelfordType = WelfordData; - using WelfordOp = WelfordOps>; - - __shared__ typename std::aligned_storage< - sizeof(WelfordType), alignof(WelfordType)>::type val_shared[WARP_SIZE]; - WelfordType* val_shared_ptr = reinterpret_cast(val_shared); - - const int64_t i = blockIdx.x; - WelfordOp welford_op; - WelfordType val( - static_cast(0), static_cast(0), static_cast(0)); - - for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { - const int64_t index = i * slice_len + j; - val = welford_op.reduce(val, static_cast(X[index])); - } - val = welford_block_reduce( - val, welford_op, - WelfordType( - static_cast(0), static_cast(0), - static_cast(0)), - val_shared_ptr); - - if (threadIdx.x == 0) { - T_ACC slice_mean; - T_ACC slice_sigma2; - thrust::tie(slice_sigma2, slice_mean) = welford_op.project(val); - mean[i] = slice_mean; - rstd[i] = rsqrt(slice_sigma2 + eps); - } } -template -__global__ void general_norm_forward_kernel( - int64_t slice_len, const T* X, const T_ACC* mean, const T_ACC* rstd, - const T* weight, const T* bias, T* Y) { - const int64_t i = blockIdx.x; - for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { - const int64_t index = i * slice_len + j; - const T_ACC weight_v = - weight == nullptr ? T_ACC(1) : static_cast(weight[j]); - const T_ACC bias_v = bias == nullptr ? T_ACC(0) : static_cast(bias[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * weight_v + - bias_v; - } -} - -template +template void forward( - T* X, T* weight, T* bias, int64_t slice_num, int64_t slice_len, T_ACC eps, T* Y, - T_ACC* mean, T_ACC* rstd, cudaStream_t stream) { - auto can_vectorize = [&](const T* ptr, int alignment) { - uint64_t addr = reinterpret_cast(ptr); - return addr % alignment == 0; - }; - constexpr int num_vec_elems = vec_size; - constexpr int alignment = num_vec_elems * sizeof(T); - if ((std::is_same::value || std::is_same::value || - std::is_same::value) && - slice_len <= static_cast(1ULL << std::numeric_limits::digits) && - slice_len % num_vec_elems == 0 && can_vectorize(X, alignment) && - can_vectorize(Y, alignment)) { - launch_vectorized_general_norm_forward_kernel( - slice_len, slice_num, static_cast(eps), X, weight, bias, Y, mean, - rstd, stream); - after_kernel_launch(); - } else { - get_input_mean_and_rstd_kernel - <<>>(slice_len, eps, X, mean, rstd); - after_kernel_launch(); - general_norm_forward_kernel<<>>( - slice_len, X, mean, rstd, weight, bias, Y); - after_kernel_launch(); - } -} - -template -__inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) { -#pragma unroll - for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { - val += __shfl_down(val, offset, warpSize); - } - return val; -} - -template -__inline__ MEGDNN_DEVICE T block_reduce_sum(T val, T* shared) { - const int lid = threadIdx.x % warpSize; - const int wid = threadIdx.x / warpSize; - val = warp_reduce_sum(val); - __syncthreads(); - if (lid == 0) { - shared[wid] = val; - } - __syncthreads(); - val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0); - if (wid == 0) { - val = warp_reduce_sum(val); - } - return val; -} - -template -__inline__ MEGDNN_DEVICE void general_norm_grad_input_kernel_impl( - const T* __restrict__ dY, const T* __restrict__ X, - const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, - const T* __restrict__ weight, T* dX, const int slice_len, T_ACC* buf) { - const auto slice_id = blockIdx.x; - const T_ACC mean_val = mean[slice_id]; - const T_ACC rstd_val = rstd[slice_id]; - T_ACC stats_x1{0}, stats_x2{0}; - constexpr int unroll = 4; - auto l = unroll * threadIdx.x; - const T* X_i = X + slice_id * slice_len; - const T* dY_i = dY + slice_id * slice_len; - T* dX_i = dX + slice_id * slice_len; - // vectorized reads don't improve perf, so use regular unrolling - - for (; l + unroll - 1 < slice_len; l += blockDim.x * unroll) { -#pragma unroll - for (int k = 0; k < unroll; k++) { - T_ACC weight_val = - (weight != nullptr) ? static_cast(weight[l + k]) : T_ACC(1); - const T_ACC c_h = static_cast(X_i[l + k]); - const T_ACC c_loss = static_cast(dY_i[l + k]); - stats_x1 += c_loss * weight_val; - stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; + T* X_data, T* weight_data, T* bias_data, T* Y_data, T_ACC* mean_data, + T_ACC* rstd_data, T_ACC eps, int64_t A, int64_t B, int64_t C, cudaStream_t stream) { + printf("Gpu general forward\n"); + forward_kernel + <<<1, 1, 0, stream>>>(X_data, weight_data, bias_data, Y_data, mean_data, + rstd_data, eps, A, B, C, stream); +} + +template +__global__ void backward_kernel( + const T* dY_data, const T* X_data, const T* weight_data, const T_ACC* mean_data, + const T_ACC* rstd_data, T* dX_data, T* dweight_data, T* dbias_data, int64_t A, + int64_t B, int64_t C, cudaStream_t stream) { + if (dweight_data || dbias_data) { + for (int64_t b = 0; b < B; ++b) { + dweight_data[b] = 0; + dbias_data[b] = 0; } - } - for (; l < slice_len; l++) { - T_ACC weight_val = - (weight != nullptr) ? static_cast(weight[l]) : T_ACC(1); - const T_ACC c_h = static_cast(X_i[l]); - const T_ACC c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * weight_val; - stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; - } - stats_x1 = block_reduce_sum(stats_x1, buf); - stats_x2 = block_reduce_sum(stats_x2, buf); - if (threadIdx.x == 0) { - buf[0] = stats_x1; - buf[1] = stats_x2; - } - __syncthreads(); - stats_x1 = buf[0]; - stats_x2 = buf[1]; - T_ACC fH = slice_len; - T_ACC term1 = (T_ACC(1) / fH) * rstd_val; + for (int64_t a = 0; a < A; ++a) + for (int64_t c = 0; c < C; ++c) { + for (int64_t b = 0; b < B; ++b) { + dweight_data[b] += (X_data[a * B * C + b * C + c] - + mean_data[a * C + c]) * + rstd_data[a * C + c] * + dY_data[a * B * C + b * C + c]; - for (int l = threadIdx.x; l < slice_len; l += blockDim.x) { - const T_ACC x = X_i[l]; - const T_ACC dy = dY_i[l]; - T_ACC weight_val = - (weight != nullptr) ? static_cast(weight[l]) : T_ACC(1); - T_ACC f_grad_input = fH * weight_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; - f_grad_input *= term1; - dX_i[l] = f_grad_input; + dbias_data[b] += dY_data[a * B * C + b * C + c]; + } + } } -} - -template -__global__ void general_norm_grad_input_kernel( - const T* __restrict__ dY, const T* __restrict__ X, - const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, - const T* __restrict__ weight, T* dX, const int slice_len) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* buf = reinterpret_cast(&s_data1); - general_norm_grad_input_kernel_impl(dY, X, mean, rstd, weight, dX, slice_len, buf); -} + for (int64_t a = 0; a < A; ++a) + for (int64_t c = 0; c < C; ++c) { + T_ACC ds = static_cast(0.0f); + T_ACC db = static_cast(0.0f); + T_ACC atmp = static_cast(0.0f); + T_ACC btmp = static_cast(0.0f); + T_ACC ctmp = static_cast(0.0f); + + for (int64_t b = 0; b < B; ++b) { + auto value = X_data[a * B * C + b * C + c]; + auto dY_v = dY_data[a * B * C + b * C + c]; + auto weight_v = weight_data ? weight_data[b] : static_cast(1.0f); + db += dY_v * weight_v; + ds += dY_v * value * weight_v; + } -template -__global__ void general_norm_grad_weight_bias_simple_kernel( - int64_t slice_num, int64_t slice_len, const T* dY, const T* X, - const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; - if (j < slice_len) { - T_ACC sum1 = 0; - T_ACC sum2 = 0; - for (int64_t i = 0; i < slice_num; ++i) { - const int64_t index = i * slice_len + j; - sum1 += dweight == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += dbias == nullptr ? T_ACC(0) : static_cast(dY[index]); - } - if (dweight != nullptr) { - dweight[j] = sum1; - } - if (dbias != nullptr) { - dbias[j] = sum2; - } - } -} + atmp = rstd_data[a * C + c]; + btmp = (db * mean_data[a * C + c] - ds) * atmp * atmp * atmp / B; + ctmp = -btmp * mean_data[a * C + c] - db * atmp / B; -template -__global__ void general_norm_grad_weight_bias_kernel( - int64_t slice_num, int64_t slice_len, const T* dY, const T* X, - const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; - constexpr int unroll = 8; - T dYs[unroll]; - T Xs[unroll]; - T_ACC* means = s_data_typed; - T_ACC* rstds = s_data_typed + unroll * blockDim.y; - T_ACC dg_sum = 0; - T_ACC db_sum = 0; - if (j < slice_len) { - int bcounter; - for (bcounter = 0; bcounter < slice_num / (blockDim.y * unroll); bcounter++) { - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; -#pragma unroll - for (int ii = 0; ii < unroll; ii++) { - if (threadIdx.x == 0) { - means[ii * blockDim.y + threadIdx.y] = mean[offset + ii]; - rstds[ii * blockDim.y + threadIdx.y] = rstd[offset + ii]; - } - dYs[ii] = dY[(offset + ii) * slice_len + j]; - Xs[ii] = X[(offset + ii) * slice_len + j]; - } - __syncthreads(); -#pragma unroll - for (int ii = 0; ii < unroll; ii++) { - dg_sum += dYs[ii] * (Xs[ii] - means[ii * blockDim.y + threadIdx.y]) * - rstds[ii * blockDim.y + threadIdx.y]; - db_sum += dYs[ii]; + for (int64_t b = 0; b < B; b++) { + auto weight_v = weight_data ? weight_data[b] : static_cast(1.0f); + dX_data[a * B * C + b * C + c] = + dY_data[a * B * C + b * C + c] * atmp * weight_v + + X_data[a * B * C + b * C + c] * btmp + ctmp; } - __syncthreads(); } - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; - for (int ii = 0; ii < 8; ii++) { - T_ACC mean_val, rstd_val; // we don't use smem in the tail to avoid awkward - // synchronizations, perf penalty is negligible - if ((offset + ii) < slice_num) { - mean_val = mean[offset + ii]; - rstd_val = rstd[offset + ii]; - dYs[0] = dY[(offset + ii) * slice_len + j]; - Xs[0] = X[(offset + ii) * slice_len + j]; - dg_sum += dYs[0] * (Xs[0] - mean_val) * rstd_val; - db_sum += dYs[0]; - } - } - s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; - s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] = - db_sum; - __syncthreads(); - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - if (threadIdx.y < offset) { - s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] += - s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - s_data_typed - [blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x] += s_data_typed - [blockDim.x * blockDim.y + - (threadIdx.y + offset) * blockDim.x + threadIdx.x]; - } - __syncthreads(); - } - if (threadIdx.y == 0) { - if (dweight) { - dweight[j] = s_data_typed[threadIdx.x]; - } - if (dbias) { - dbias[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y]; - } - } - } + + } -template +template void backward( - const T* dY_data, const T* X_data, const T_ACC* mean_data, - const T_ACC* rstd_data, const T* weight_data, int64_t slice_num, - int64_t slice_len, T* dX_data, T* dweight_data, T* dbias_data, - cudaStream_t stream) { - if (dX_data != nullptr) { - const int num_threads = 128; - const dim3 blocks(slice_num); - int nshared = (num_threads / WARP_SIZE) * sizeof(T_ACC); - general_norm_grad_input_kernel<<>>( - dY_data, X_data, mean_data, rstd_data, weight_data, dX_data, slice_len); - after_kernel_launch(); - } - if (dweight_data || dbias_data) { - if (slice_num < 512) { - const int64_t B = (slice_len + kCUDANumThreads - 1) / kCUDANumThreads; - general_norm_grad_weight_bias_simple_kernel - <<>>( - slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, - dweight_data, dbias_data); - after_kernel_launch(); - } else { - dim3 threads{16, 32}; - int blocks = (slice_len + threads.x - 1) / threads.x; - general_norm_grad_weight_bias_kernel - <<>>( - slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, - dweight_data, dbias_data); - after_kernel_launch(); - } - } + const T* dY_data, const T* X_data, const T* weight_data, const T_ACC* mean_data, + const T_ACC* rstd_data, T* dX_data, T* dweight_data, T* dbias_data, int64_t A, + int64_t B, int64_t C, cudaStream_t stream) { + backward_kernel + <<<1, 1, 0, stream>>>(dY_data, X_data, weight_data, mean_data, rstd_data, dX_data, + dweight_data, dbias_data, A, B, C, stream); } #define INST(T, T_ACC) \ template void forward( \ - T*, T*, T*, int64_t, int64_t, T_ACC, T*, T_ACC*, T_ACC*, cudaStream_t); \ + T*, T*, T*, T*,T_ACC*, T_ACC*, T_ACC, int64_t,int64_t,int64_t, cudaStream_t); \ template void backward( \ - const T*, const T*, const T_ACC*, const T_ACC*, const T*, int64_t, \ - int64_t, T*, T*, T*, cudaStream_t); + const T*, const T*, const T*, const T_ACC*, const T_ACC*, \ + T*, T*, T*, int64_t, int64_t, int64_t, cudaStream_t); INST(dt_float32, dt_float32) INST(dt_float16, dt_float32) diff --git a/dnn/src/cuda/general_norm/general_norm_cuda.cuh b/dnn/src/cuda/general_norm/general_norm_cuda.cuh index 1ee34d60cb0ae68440c2a3ee8b5777c5cbc9b4dd..d24d9d629cff80d0e5b83ad4a4432ba0ddc3dc85 100644 --- a/dnn/src/cuda/general_norm/general_norm_cuda.cuh +++ b/dnn/src/cuda/general_norm/general_norm_cuda.cuh @@ -7,14 +7,15 @@ namespace general_norm { template void forward( - T* X, T* gamma, T* beta, int64_t M, int64_t N, T_ACC eps, T* Y, T_ACC* mean, - T_ACC* rstd, cudaStream_t stream); + T* X_data, T* weight_data, T* bias_data, T* Y_data, T_ACC* mean_data, + T_ACC* rstd_data, T_ACC eps, int64_t A, int64_t B, int64_t C, + cudaStream_t stream); template void backward( - const T* dY_data, const T* X_data, const T_ACC* mean_data, - const T_ACC* rstd_data, const T* gamma_data, int64_t M, int64_t N, T* dX_data, - T* dgamma_data, T* dbeta_data, cudaStream_t stream); + const T* dY_data, const T* X_data, const T* gamma_data, const T_ACC* mean_data, + const T_ACC* rstd_data, T* dX_data, T* dgamma_data, + T* dbeta_data, int64_t A, int64_t B, int64_t C, cudaStream_t stream); } // namespace general_norm } // namespace cuda diff --git a/dnn/src/cuda/general_norm/opr_impl.cpp b/dnn/src/cuda/general_norm/opr_impl.cpp index 585cd7fc2b24bcc04df5f3849b4d2038c699d1af..13ab71357a232c86630a214b484de9f87b0ec793 100644 --- a/dnn/src/cuda/general_norm/opr_impl.cpp +++ b/dnn/src/cuda/general_norm/opr_impl.cpp @@ -16,26 +16,23 @@ void GeneralNormForwardImpl::exec( auto p = param(); float eps = p.eps; bool affine = p.affine; - uint64_t slice_length = p.normalized_size; - uint64_t slice_dim = p.normalized_dim; - uint64_t n_slices = 1; - for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { - n_slices = n_slices * data.layout.shape[i]; - } + uint64_t axis = p.normalized_axis; + uint64_t A, B, C; + megdnn::reduce::get_ABC(data.layout, A, B, C, axis); auto stream = cuda_stream(handle()); using namespace ::megdnn::cuda::general_norm; -#define cb(DType) \ - if (data.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - using T_ACC = float; \ - forward( \ - data.ptr(), affine ? weight.ptr() : nullptr, \ - affine ? bias.ptr() : nullptr, static_cast(n_slices), \ - static_cast(slice_length), static_cast(eps), \ - dst.ptr(), mean.ptr(), rstd.ptr(), stream); \ - return; \ +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + using T_ACC = float; \ + forward( \ + data.ptr(), affine ? weight.ptr() : nullptr, \ + affine ? bias.ptr() : nullptr, dst.ptr(), mean.ptr(), \ + rstd.ptr(), static_cast(eps), A, B, \ + C, stream); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb @@ -52,25 +49,24 @@ void GeneralNormBackwardImpl::exec( ddata.layout, dweight.layout, dbias.layout, workspace.size); auto p = param(); bool affine = p.affine; - uint64_t slice_length = p.normalized_size; - uint64_t slice_dim = p.normalized_dim; - uint64_t n_slices = 1; - for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { - n_slices = n_slices * data.layout.shape[i]; - } + uint64_t axis = p.normalized_axis; + uint64_t A, B, C; + megdnn::reduce::get_ABC(data.layout, A, B, C, axis); auto stream = cuda_stream(handle()); using namespace ::megdnn::cuda::general_norm; -#define cb(DType) \ - if (data.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - using T_ACC = float; \ - backward( \ - diff.ptr(), data.ptr(), mean.ptr(), rstd.ptr(), \ - affine ? weight.ptr() : nullptr, n_slices, slice_length, \ - ddata.ptr(), affine ? dweight.ptr() : nullptr, \ - affine ? dbias.ptr() : nullptr, stream); \ - return; \ +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + using T_ACC = float; \ + backward( \ + diff.ptr(), data.ptr(), affine ? weight.ptr() : nullptr, \ + mean.ptr(), rstd.ptr(), \ + ddata.ptr(), \ + affine ? dweight.ptr() : nullptr, \ + affine ? dbias.ptr() : nullptr, A, B, C, \ + stream); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb diff --git a/dnn/src/cuda/general_norm/opr_impl.h b/dnn/src/cuda/general_norm/opr_impl.h index a9bc58201b36c0cac40acf1aa977216ea5100d37..91e571ea9231d59e13f04ae5f61648882e65ff76 100644 --- a/dnn/src/cuda/general_norm/opr_impl.h +++ b/dnn/src/cuda/general_norm/opr_impl.h @@ -1,5 +1,6 @@ #pragma once #include "megdnn/oprs.h" +#include "src/common/reduce_helper.h" #include "src/cuda/cudnn_wrapper.h" diff --git a/dnn/src/naive/general_norm/opr_impl.cpp b/dnn/src/naive/general_norm/opr_impl.cpp index 419e9801fd36ec9163f47e7838c814bedcbe6329..7fac59d93fc31eddf319ed5b4b4d2b3412fecf90 100644 --- a/dnn/src/naive/general_norm/opr_impl.cpp +++ b/dnn/src/naive/general_norm/opr_impl.cpp @@ -16,7 +16,7 @@ void forward( _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, const Param& param) { - printf("general forward\n"); + printf("Cpu general forward\n"); float eps = param.eps; bool affine = param.affine; uint64_t axis = param.normalized_axis; @@ -105,7 +105,7 @@ void backward( btmp = (db * mean.ptr()[a * C + c] - ds) * atmp * atmp * atmp / B; ctmp = -btmp * mean.ptr()[a * C + c] - db * atmp / B; - for (uint64_t b = 0; b < B; b++) { + for (size_t b = 0; b < B; b++) { auto weight_v = affine ? weight.ptr()[b] : static_cast(1.0f); ddata.ptr()[a * B * C + b * C + c] = diff.ptr()[a * B * C + b * C + c] * atmp * weight_v + diff --git a/dnn/test/cuda/general_norm.cpp b/dnn/test/cuda/general_norm.cpp index 346aa0dfe0208b4e025f65ce1588911e76e97d5f..1e83338168025ea4cdf25e1299b889bfc87c1516 100644 --- a/dnn/test/cuda/general_norm.cpp +++ b/dnn/test/cuda/general_norm.cpp @@ -1,23 +1,37 @@ #include "test/cuda/fixture.h" #include "test/common/checker.h" +#include "test/cuda/benchmark.h" namespace megdnn { namespace test { -TEST_F(CUDA, GeneralNorm_FORWARD) { +TEST_F(CUDA, GENERALNORM_FORWARD) { using Param = GeneralNormForward::Param; Param param; param.affine = true; param.eps = 1e-6; - param.normalized_dim = 1; Checker checker(handle_cuda()); checker.set_epsilon(1e-2); auto run = [&](DType d) { for (size_t n_slices : {10, 30}) for (size_t slice_len : {10, 30}) { - param.normalized_size = slice_len; + param.normalized_axis = 0; + checker.set_param(param) + .set_dtype(0, d) + .set_dtype(1, d) + .set_dtype(2, d) + .set_dtype(3, d) + .set_dtype(4, dtype::Float32()) + .set_dtype(5, dtype::Float32()) + .execs({{n_slices, slice_len}, + {n_slices}, + {n_slices}, + {n_slices, slice_len}, + {slice_len}, + {slice_len}}); + param.normalized_axis = 1; checker.set_param(param) .set_dtype(0, d) .set_dtype(1, d) @@ -39,19 +53,76 @@ TEST_F(CUDA, GeneralNorm_FORWARD) { run(dtype::BFloat16()); } -TEST_F(CUDA, GeneralNorm_BACKWARD) { +TEST_F(CUDA, GENERALNORM_SPEED_FP32) { + using Param = GeneralNormForward::Param; + auto benchmarker = Benchmarker(handle_cuda()); + benchmarker.set_dtype(0, dtype::Float32()); + benchmarker.set_dtype(1, dtype::Float32()); + Param param; + param.affine = true; + float eachTime; + float totalTime = 0.f; + +#define ITER 10 + param.normalized_axis = 0; + for (auto i = 0; i < ITER; i++) { + eachTime = benchmarker.set_param(param).exec({{100, 2000}, + {100}, + {100}, + {}, + {}, + {}}); + totalTime += eachTime; + } + totalTime /= ITER; + printf("PGENERALNORM_SPEED_FP32 AVG TIME: %.6fms\n", totalTime); + + totalTime = 0.f; + param.normalized_axis = 1; + for (auto i = 0; i < ITER; i++) { + eachTime = benchmarker.set_param(param).exec({{2000, 100}, + {100}, + {100}, + {}, + {}, + {}}); + totalTime += eachTime; + } + totalTime /= ITER; + printf("PGENERALNORM_SPEED_FP32 AVG TIME: %.6fms\n", totalTime); +#undef ITER +} + +TEST_F(CUDA, GENERALNORM_BACKWARD) { using Param = GeneralNormBackward::Param; Param param; param.affine = true; param.eps = 1e-6; - param.normalized_dim = 1; Checker checker(handle_cuda()); checker.set_epsilon(1e-1); auto run = [&](DType d) { for (size_t n_slices : {10, 30}) for (size_t slice_len : {10, 30}) { - param.normalized_size = slice_len; + param.normalized_axis = 0; + checker.set_param(param) + .set_dtype(0, d) + .set_dtype(1, d) + .set_dtype(2, d) + .set_dtype(3, dtype::Float32()) + .set_dtype(4, dtype::Float32()) + .set_dtype(5, d) + .set_dtype(6, d) + .set_dtype(7, d) + .execs({{n_slices, slice_len}, + {n_slices, slice_len}, + {n_slices}, + {slice_len}, + {slice_len}, + {n_slices, slice_len}, + {n_slices}, + {n_slices}}); + param.normalized_axis = 1; checker.set_param(param) .set_dtype(0, d) .set_dtype(1, d) @@ -68,7 +139,7 @@ TEST_F(CUDA, GeneralNorm_BACKWARD) { {n_slices}, {n_slices, slice_len}, {slice_len}, - {slice_len}}); + {slice_len}}); } }; diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index eceaf87a757834118d2dfffec9d4aaa27a6f8264..aac6d616458fa3ad7ff2a678797fad1f74370435 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1136,7 +1136,6 @@ def layer_norm( def general_norm( inp: Tensor, - normalized_shape: tuple, normalized_axis: int, affine: bool, weight: Optional[Tensor] = None, @@ -1158,21 +1157,11 @@ def general_norm( See :math:`\beta` in :class:`~.GeneralNorm`. eps: a value added to the denominator for numerical stability. Default: 1e-5 """ - if isinstance(normalized_shape, int): - normalized_shape = [normalized_shape] - - normalized_dim = len(normalized_shape) - assert normalized_dim > 0 - - normalized_size = 1 - for i in range(normalized_dim): - normalized_size = normalized_size * normalized_shape[i] + assert normalized_axis >= 0 and normalized_axis < inp.ndim op = builtin.GeneralNorm( affine=affine, eps=eps, - normalized_dim=normalized_dim, - normalized_size=normalized_size, normalized_axis = normalized_axis, ) if affine: diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index efcfb37978e5c32f511a883eb534b33ecfd8a29b..e40c2c83ded90f884d8f6902520117aaa5051c46 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -231,7 +231,7 @@ class GeneralNorm(Module): (2, 3, 4, 4) """ - def __init__(self, normalized_shape, normalized_axis, eps=1e-05, affine=True, **kwargs): + def __init__(self, inp_shape, normalized_axis, eps=1e-05, affine=True, **kwargs): super().__init__(**kwargs) if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) @@ -241,9 +241,9 @@ class GeneralNorm(Module): self.affine = affine if self.affine: self.weight = Parameter( - np.ones(self.normalized_shape, dtype="float32")) + np.ones(inp_shape[normalized_axis], dtype="float32")) self.bias = Parameter( - np.zeros(self.normalized_shape, dtype="float32")) + np.zeros(inp_shape[normalized_axis], dtype="float32")) else: self.weight = None self.bias = None @@ -257,10 +257,10 @@ class GeneralNorm(Module): def forward(self, x): x = F.nn.general_norm( - x, self.normalized_shape, self.normalized_axis, self.affine, self.weight, self.bias, self.eps + x, self.normalized_axis, self.affine, self.weight, self.bias, self.eps ) return x def _module_info_string(self) -> str: - s = "normalized_shape={normalized_shape}, normalized_axis={normalized_axis}, eps={eps}, affine={affine}" + s = "normalized_axis={normalized_axis}, eps={eps}, affine={affine}" return s.format(**self.__dict__) diff --git a/src/opr/impl/dnn/general_norm.cpp b/src/opr/impl/dnn/general_norm.cpp index cfe42b3dc3cfce48a893b2ec5163f1ae8c6705eb..00cfbc9dc4a5907b5efce66738d1a39663820dfc 100644 --- a/src/opr/impl/dnn/general_norm.cpp +++ b/src/opr/impl/dnn/general_norm.cpp @@ -66,7 +66,6 @@ SymbolVarArray GeneralNormForward::make( void GeneralNormForward::get_output_var_shape( const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { - uint64_t normalized_dim = param().normalized_dim; out_shape[0] = inp_shape[0]; TensorShape unnormalized_shape = inp_shape[0]; unnormalized_shape.ndim -= 1; diff --git a/src/opr/test/dnn/general_norm.cpp b/src/opr/test/dnn/general_norm.cpp index bbaf1f3244bb97f8e83f6b0ac7fd0f51329f35eb..0bc78e459981dbc47d0ef8439d7be02676df4fde 100644 --- a/src/opr/test/dnn/general_norm.cpp +++ b/src/opr/test/dnn/general_norm.cpp @@ -23,8 +23,6 @@ void run_forward(bool is_affine, size_t normalized_size, size_t normalized_axis) Param param; param.eps = 1e-5; param.affine = is_affine; - param.normalized_dim = 1; - param.normalized_size = normalized_size; param.normalized_axis = normalized_axis; auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {