未验证 提交 a040c055 编写于 作者: L Leo Chen 提交者: GitHub

fix layer_norm accuracy (#29434)

上级 24ba9ed4
...@@ -135,7 +135,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -135,7 +135,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
} }
__syncthreads(); __syncthreads();
mean_val = mean[blockIdx.x]; mean_val = mean[blockIdx.x];
var_val = static_cast<U>(real_sqrt(var[blockIdx.x]) + epsilon); var_val = static_cast<U>(real_sqrt(var[blockIdx.x] + epsilon));
// Step 2: Calculate y // Step 2: Calculate y
if (scale != nullptr) { if (scale != nullptr) {
......
...@@ -211,7 +211,7 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -211,7 +211,7 @@ class TestLayerNormOp(unittest.TestCase):
for name in ['x', 'scale', 'bias', 'y@GRAD'] for name in ['x', 'scale', 'bias', 'y@GRAD']
}, },
fetch_list=fetch_list) fetch_list=fetch_list)
self.__assert_close(y, out[0], "y", 1e-3) self.__assert_close(y, out[0], "y")
self.__assert_close(mean, out[1], "mean") self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(variance, out[2], "variance", 1e-3)
self.__assert_close(x_grad, out[3], "x_grad") self.__assert_close(x_grad, out[3], "x_grad")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册