提交 b1aa065e 编写于 作者: S ScXfjiang

more refine


Former-commit-id: 9b8869f3075cbf0b4d77113fe26a18fccbfe4644
上级 35a06bf6
......@@ -11,15 +11,7 @@ void ForwardPartDataContentTopOne(const T* in, const Range& range, const int32_t
int32_t* out) {
FOR_RANGE(int32_t, i, range.begin(), range.end()) {
const T* values = in + i * instance_size;
T max_val = values[0];
int32_t max_idx = 0;
FOR_RANGE(int32_t, j, 0, instance_size) {
if (values[j] > max_val) {
max_val = values[j];
max_idx = j;
}
}
out[i] = max_idx;
out[i] = std::distance(values, std::max_element(values, values + instance_size));
}
}
......
......@@ -14,7 +14,7 @@ __global__ void ForwardGpuTopOne(const T* in, const int32_t instance_num,
const T* values = in + i * instance_size;
T max_val = values[0];
int32_t max_idx = 0;
FOR_RANGE(int32_t, j, 0, instance_size) {
FOR_RANGE(int32_t, j, 1, instance_size) {
if (values[j] > max_val) {
max_val = values[j];
max_idx = j;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册