未验证 提交 dca56f47 编写于 作者: Z Zhong Hui 提交者: GitHub

fix invalid read of pnorm gradient function

fix invalid read of pnorm gradient function and delete the unused code
上级 2c9d0f3c
......@@ -99,39 +99,25 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
const float porder, const int pre,
const int axis_n, const int post, const T eps,
T* x_grad) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
int num = pre * post;
auto porder_grad = static_cast<T>(porder - 1.0f);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
T sum = 0.0;
__shared__ T row_sum;
__shared__ T row_sqrt_norm;
__shared__ T row_norm;
__shared__ T pnorm_i;
__shared__ T yout_i;
auto base = (i / post) * post * axis_n + (i % post);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
sum += x[index] * y_grad[index];
}
T reduce_result = BlockReduce(temp_storage_sum).Sum(sum);
if (threadIdx.x == 0) {
row_sum = reduce_result;
row_sqrt_norm = x_norm[i];
row_norm = row_sqrt_norm * row_sqrt_norm;
pnorm_i = x_norm[i];
yout_i = y_grad[i];
}
__syncthreads();
const T pnorm_i = x_norm[i];
const T yout_i = y_grad[i];
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const T x_ij = inline_abs(x[index]);
const T dy_ij = y_grad[index];
x_grad[index] = inline_pow(x_ij, porder_grad) /
(inline_pow(pnorm_i, porder_grad) + eps) * yout_i *
inline_sign(x[index]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册