未验证 提交 8f2656ef 编写于 作者: W wawltor 提交者: GitHub

fix the gradient bug for the topk v2

fix the gradient bug for the topk v2 
上级 a972c33f
...@@ -335,6 +335,7 @@ __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad, ...@@ -335,6 +335,7 @@ __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
for (size_t j = 0; j < cols; ++j) { for (size_t j = 0; j < cols; ++j) {
x_grad[i * cols + j] = 0; x_grad[i * cols + j] = 0;
} }
__syncthreads();
for (size_t j = 0; j < k; ++j) { for (size_t j = 0; j < k; ++j) {
size_t idx = indices[i * k + j]; size_t idx = indices[i * k + j];
x_grad[i * cols + idx] = out_grad[i * k + j]; x_grad[i * cols + idx] = out_grad[i * k + j];
...@@ -349,15 +350,16 @@ __global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices, ...@@ -349,15 +350,16 @@ __global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices,
int raw_height, int k) { int raw_height, int k) {
// raw_height is the length of topk axis // raw_height is the length of topk axis
for (int i = blockIdx.x; i < pre; i += gridDim.x) { for (int i = blockIdx.x; i < pre; i += gridDim.x) {
const int& base_index = i * post * k; int base_index = i * post * k;
const int& base_grad = i * post * raw_height; int base_grad = i * post * raw_height;
for (int j = threadIdx.x; j < raw_height * post; j += blockDim.x) { for (int j = threadIdx.x; j < raw_height * post; j += blockDim.x) {
grad_in[base_grad + j] = static_cast<T>(0); grad_in[base_grad + j] = static_cast<T>(0);
} }
__syncthreads();
for (int j = threadIdx.x; j < k * post; j += blockDim.x) { for (int j = threadIdx.x; j < k * post; j += blockDim.x) {
const int64_t idx_ij = indices[base_index + j]; int64_t idx_ij = indices[base_index + j];
const int64_t in_ij = base_grad + (idx_ij * post) + (j % post); int64_t in_ij = base_grad + (idx_ij * post) + (j % post);
grad_in[in_ij] = grad_out[idx_ij]; grad_in[in_ij] = grad_out[base_index + j];
} }
} }
} }
......
...@@ -64,34 +64,38 @@ class TestTopkOp(OpTest): ...@@ -64,34 +64,38 @@ class TestTopkOp(OpTest):
class TestTopkOp1(TestTopkOp): class TestTopkOp1(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 0
self.largest = True
class TestTopkOp2(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 3 self.k = 3
self.axis = 0 self.axis = 0
self.largest = False self.largest = False
class TestTopkOp3(TestTopkOp): class TestTopkOp2(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 4 self.k = 4
self.axis = 0 self.axis = 0
self.largest = False self.largest = False
class TestTopkOp4(TestTopkOp): class TestTopkOp3(OpTest):
def init_args(self): def init_args(self):
self.k = 4 self.k = 6
self.axis = 0 self.axis = 1
self.largest = False self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.dtype = np.float64
self.input_data = np.random.rand(16, 100)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp5(TestTopkOp):
class TestTopkOp4(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 3 self.k = 3
self.axis = 1 self.axis = 1
...@@ -109,7 +113,7 @@ class TestTopkOp5(TestTopkOp): ...@@ -109,7 +113,7 @@ class TestTopkOp5(TestTopkOp):
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp6(TestTopkOp): class TestTopkOp5(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 3 self.k = 3
self.axis = 1 self.axis = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册