提交 5bfbd60f 编写于 作者: S ScXfjiang

correctness


Former-commit-id: d1eba08d1903fceb50bf9e0f77bf4548b2ddedb4
上级 31c8f14d
......@@ -7,8 +7,27 @@ namespace oneflow {
namespace {
template<typename T>
void ForwardPartDataContent(const T* in, const Range range, const int32_t instance_size,
const int32_t k, const bool sorted, int32_t* fw_buf, int32_t* out) {
void ForwardPartDataContentTopOne(const T* in, const Range range, const int32_t instance_size,
int32_t* out) {
FOR_RANGE(int32_t, i, range.begin(), range.end()) {
const int32_t offset = i * instance_size;
const T* values = in + offset;
T max_val = GetMinVal<T>();
int32_t max_idx = -1;
FOR_RANGE(int32_t, j, 0, instance_size) {
if (values[j] > max_val) {
max_val = values[j];
max_idx = j;
}
}
out[i] = max_idx;
}
}
template<typename T>
void ForwardPartDataContentTopK(const T* in, const Range range, const int32_t instance_size,
const int32_t k, const bool sorted, int32_t* fw_buf, int32_t* out) {
CHECK_NOTNULL(fw_buf);
FOR_RANGE(int32_t, i, range.begin(), range.end()) {
const int32_t offset = i * instance_size;
int32_t* indices = fw_buf + offset;
......@@ -24,7 +43,7 @@ void ForwardPartDataContent(const T* in, const Range range, const int32_t instan
}
};
std::nth_element(indices, indices + k, indices + instance_size, comp);
if (k > 1 && sorted) { std::sort(indices, indices + k, comp); }
if (sorted) { std::sort(indices, indices + k, comp); }
std::copy(indices, indices + k, out + i * k);
}
}
......@@ -43,7 +62,11 @@ struct TopKKernelUtil<DeviceType::kCPU, T> {
FOR_RANGE(int32_t, part_id, 0, part_num) {
Range range = bs.At(part_id);
Global<ThreadMgr>::Get()->compute_thread_pool()->AddWork([=, &bc]() {
ForwardPartDataContent(in, range, instance_size, k, sorted, fw_buf, out);
if (k == 1) {
ForwardPartDataContentTopOne(in, range, instance_size, out);
} else {
ForwardPartDataContentTopK(in, range, instance_size, k, sorted, fw_buf, out);
}
bc.Decrease();
});
}
......@@ -62,7 +85,7 @@ void TopKKernel<device_type, T>::ForwardDataContent(
const int32_t instance_size = static_cast<int32_t>(in_blob->shape().dim_vec().back());
const int32_t instance_num = static_cast<int32_t>(in_blob->shape().elem_cnt() / instance_size);
const T* in = in_blob->dptr<T>();
int32_t* fw_buf = fw_buf_blob->mut_dptr<int32_t>();
int32_t* fw_buf = fw_buf_blob ? fw_buf_blob->mut_dptr<int32_t>() : nullptr;
int32_t* out = out_blob->mut_dptr<int32_t>();
const auto& conf = this->op_conf().top_k_conf();
TopKKernelUtil<device_type, T>::Forward(ctx.device_ctx, in, instance_num, instance_size, conf.k(),
......
......@@ -10,7 +10,7 @@ namespace oneflow {
template<typename T>
__global__ void ForwardGpu(const T* in, const int32_t instance_num, const int32_t instance_size,
const int32_t k, const bool sorted, int32_t* fw_buf, int32_t* out) {
int32_t* out) {
CUDA_1D_KERNEL_LOOP(i, instance_num) {
T max_val = in[i * instance_size];
int32_t max_idx = 0;
......@@ -33,7 +33,7 @@ struct TopKKernelUtil<DeviceType::kGPU, T> {
// GPU version top_k op only support "k == 1" for now
CHECK_EQ(k, 1);
ForwardGpu<<<BlocksNum4ThreadsNum(instance_num), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(in, instance_num, instance_size, k, sorted, fw_buf, out);
ctx->cuda_stream()>>>(in, instance_num, instance_size, out);
}
};
......
......@@ -18,9 +18,11 @@ void TopKOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlob
CHECK_GE(conf.k(), 1);
CHECK_LE(conf.k(), in->shape().dim_vec().back());
// fw_buf
BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf");
fw_buf->mut_shape() = Shape({in->shape()});
fw_buf->set_data_type(DataType::kInt32);
if (conf.k() > 1) {
BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf");
fw_buf->mut_shape() = Shape({in->shape()});
fw_buf->set_data_type(DataType::kInt32);
}
// out
BlobDesc* out = GetBlobDesc4BnInOp("out");
*out = *in;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册