提交 f57d706a 编写于 作者: Y Yu Yang

Use double to reduce

上级 f94fdeaa
......@@ -67,27 +67,27 @@ template <typename T, int BlockDim>
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
T *y, T *mean, T *var, float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean and var
T mean_val = static_cast<T>(0);
T var_val = static_cast<T>(0);
double mean_val = 0;
double var_val = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
T tmp = x[i];
mean_val += tmp;
var_val += (tmp * tmp);
}
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(mean_val, var_val),
PairForLayerNormAddFunctor<T>());
.Reduce(PairForLayerNorm<double>(mean_val, var_val),
PairForLayerNormAddFunctor<double>());
if (threadIdx.x == 0) {
auto tmp = pair.first_ / feature_size;
mean[blockIdx.x] = tmp;
var[blockIdx.x] = pair.second_ / feature_size - tmp * tmp;
mean[blockIdx.x] = static_cast<T>(tmp);
var[blockIdx.x] = static_cast<T>(pair.second_ / feature_size - tmp * tmp);
}
__syncthreads();
mean_val = mean[blockIdx.x];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册