提交 42dd5da0 编写于 作者: M Markus Kliegl

conv shift: fix return before syncthreads

上级 3dc88342
......@@ -62,19 +62,19 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
if (tx < num_x) {
int load_i = (i - y_half_width + x_width) % x_width;
sx[tx] = x[k * x_width + load_i];
} else {
return;
}
__syncthreads();
// Compute dot product of sx[tx:tx + y_width] and sy.
T sum = 0;
for (int j = 0; j < y_width; ++j) {
sum += sx[tx + j] * sy[j];
}
if (tx < num_x) {
// Compute dot product of sx[tx:tx + y_width] and sy.
T sum = 0;
for (int j = 0; j < y_width; ++j) {
sum += sx[tx + j] * sy[j];
}
// Save to out[k, i].
out[k * x_width + i] = sum;
// Save to out[k, i].
out[k * x_width + i] = sum;
}
}
// Compute x gradient - initial naive implementation with atomic add.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册