未验证 提交 8f2656ef 编写于 作者: W wawltor 提交者: GitHub

fix the gradient bug for the topk v2

fix the gradient bug for the topk v2 
上级 a972c33f
......@@ -335,6 +335,7 @@ __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
for (size_t j = 0; j < cols; ++j) {
x_grad[i * cols + j] = 0;
}
__syncthreads();
for (size_t j = 0; j < k; ++j) {
size_t idx = indices[i * k + j];
x_grad[i * cols + idx] = out_grad[i * k + j];
......@@ -349,15 +350,16 @@ __global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices,
int raw_height, int k) {
// raw_height is the length of topk axis
for (int i = blockIdx.x; i < pre; i += gridDim.x) {
const int& base_index = i * post * k;
const int& base_grad = i * post * raw_height;
int base_index = i * post * k;
int base_grad = i * post * raw_height;
for (int j = threadIdx.x; j < raw_height * post; j += blockDim.x) {
grad_in[base_grad + j] = static_cast<T>(0);
}
__syncthreads();
for (int j = threadIdx.x; j < k * post; j += blockDim.x) {
const int64_t idx_ij = indices[base_index + j];
const int64_t in_ij = base_grad + (idx_ij * post) + (j % post);
grad_in[in_ij] = grad_out[idx_ij];
int64_t idx_ij = indices[base_index + j];
int64_t in_ij = base_grad + (idx_ij * post) + (j % post);
grad_in[in_ij] = grad_out[base_index + j];
}
}
}
......
......@@ -64,34 +64,38 @@ class TestTopkOp(OpTest):
class TestTopkOp1(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 0
self.largest = True
class TestTopkOp2(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 0
self.largest = False
class TestTopkOp3(TestTopkOp):
class TestTopkOp2(TestTopkOp):
def init_args(self):
self.k = 4
self.axis = 0
self.largest = False
class TestTopkOp4(TestTopkOp):
class TestTopkOp3(OpTest):
def init_args(self):
self.k = 4
self.axis = 0
self.largest = False
self.k = 6
self.axis = 1
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.dtype = np.float64
self.input_data = np.random.rand(16, 100)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp5(TestTopkOp):
class TestTopkOp4(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
......@@ -109,7 +113,7 @@ class TestTopkOp5(TestTopkOp):
self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp6(TestTopkOp):
class TestTopkOp5(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册