提交 ddfe95a9 编写于 作者: S ScXfjiang

thread pool version

Former-commit-id: ce92c2d972aa74760351052e9d89cf6e3c40282c
上级 69c9ee85
#include "oneflow/core/kernel/top_k_kernel.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/thread/thread_manager.h"
namespace oneflow {
namespace {
template<typename T>
void ForwardPartDataContent(const T* in, const Range range, const int32_t instance_size,
const int32_t k, const bool sorted, int32_t* fw_buf, int32_t* out) {
FOR_RANGE(int32_t, i, range.begin(), range.end()) {
const int32_t offset = i * instance_size;
int32_t* indices = fw_buf + offset;
const T* values = in + offset;
std::iota(indices, indices + instance_size, 0);
auto comp = [&](const int32_t lhs, const int32_t rhs) {
const T l = values[lhs];
const T r = values[rhs];
if (l == r) {
return lhs < rhs;
} else {
return l > r;
std::nth_element(indices, indices + k, indices + instance_size, comp);
if (k > 1 && sorted) { std::sort(indices, indices + k, comp); }
std::copy(indices, indices + k, out + i * k);
} // namespace
template<typename T>
void TopKKernel<T>::ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
......@@ -17,22 +46,18 @@ void TopKKernel<T>::ForwardDataContent(const KernelCtx& ctx,
int32_t* out = out_blob->mut_dptr<int32_t>();
const auto& conf = this->op_conf().top_k_conf();
const int32_t k = conf.k();
FOR_RANGE(int32_t, i, 0, instance_num) {
std::iota(fw_buf, fw_buf + instance_size, 0);
const int32_t offset = i * instance_size;
auto comp = [&](const int32_t lhs, const int32_t rhs) {
const T l = in[offset + lhs];
const T r = in[offset + rhs];
if (l == r) {
return lhs < rhs;
} else {
return l > r;
std::nth_element(fw_buf, fw_buf + k, fw_buf + instance_size, comp);
if (k > 1 && conf.sorted()) { std::sort(fw_buf, fw_buf + k, comp); }
std::copy(fw_buf, fw_buf + k, out + i * k);
const int32_t part_num =
std::min(instance_num, Global<ThreadMgr>::Get()->compute_thread_pool()->thread_num());
const BalancedSplitter bs(instance_num, part_num);
BlockingCounter bc(part_num);
FOR_RANGE(int32_t, part_id, 0, part_num) {
Range range = bs.At(part_id);
Global<ThreadMgr>::Get()->compute_thread_pool()->AddWork([=, &bc]() {
ForwardPartDataContent(in, range, instance_size, k, conf.sorted(), fw_buf, out);
......@@ -18,7 +18,7 @@ void TopKOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlob
CHECK_LE(conf.k(), in->shape().dim_vec().back());
// fw_buf
BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf");
fw_buf->mut_shape() = Shape({in->shape().dim_vec().back()});
fw_buf->mut_shape() = Shape({in->shape()});
// out
BlobDesc* out = GetBlobDesc4BnInOp("out");
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册