提交 a5e07475 编写于 作者: S ScXfjiang

change top_k from cpu only to normal version (gpu and cpu)


Former-commit-id: 2f625c9bb2e0199b734b09dc2330f10ce85bd80a
上级 ddfe95a9
......@@ -32,8 +32,27 @@ void ForwardPartDataContent(const T* in, const Range range, const int32_t instan
} // namespace
template<typename T>
void TopKKernel<T>::ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
struct TopKKernelUtil<DeviceType::kCPU, T> {
static void Forward(const T* in, const int32_t instance_num, const int32_t instance_size,
const int32_t k, const bool sorted, int32_t* fw_buf, int32_t* out) {
const int32_t part_num =
std::min(instance_num, Global<ThreadMgr>::Get()->compute_thread_pool()->thread_num());
const BalancedSplitter bs(instance_num, part_num);
BlockingCounter bc(part_num);
FOR_RANGE(int32_t, part_id, 0, part_num) {
Range range = bs.At(part_id);
Global<ThreadMgr>::Get()->compute_thread_pool()->AddWork([=, &bc]() {
ForwardPartDataContent(in, range, instance_size, k, sorted, fw_buf, out);
bc.Decrease();
});
}
bc.WaitUntilCntEqualZero();
}
};
template<DeviceType device_type, typename T>
void TopKKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
Blob* fw_buf_blob = BnInOp2Blob("fw_buf");
Blob* out_blob = BnInOp2Blob("out");
......@@ -45,21 +64,15 @@ void TopKKernel<T>::ForwardDataContent(const KernelCtx& ctx,
int32_t* fw_buf = fw_buf_blob->mut_dptr<int32_t>();
int32_t* out = out_blob->mut_dptr<int32_t>();
const auto& conf = this->op_conf().top_k_conf();
const int32_t k = conf.k();
const int32_t part_num =
std::min(instance_num, Global<ThreadMgr>::Get()->compute_thread_pool()->thread_num());
const BalancedSplitter bs(instance_num, part_num);
BlockingCounter bc(part_num);
FOR_RANGE(int32_t, part_id, 0, part_num) {
Range range = bs.At(part_id);
Global<ThreadMgr>::Get()->compute_thread_pool()->AddWork([=, &bc]() {
ForwardPartDataContent(in, range, instance_size, k, conf.sorted(), fw_buf, out);
bc.Decrease();
});
}
bc.WaitUntilCntEqualZero();
TopKKernelUtil<device_type, T>::Forward(in, instance_num, instance_size, conf.k(), conf.sorted(),
fw_buf, out);
}
ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kTopKConf, TopKKernel, FLOATING_DATA_TYPE_SEQ);
#define INSTANTIATE_TOP_K_KERNEL_UTIL(type_cpp, type_proto) \
template struct TopKKernelUtil<DeviceType::kCPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_TOP_K_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
#undef INSTANTIATE_TOP_K_KERNEL_UTIL
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kTopKConf, TopKKernel, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#include "oneflow/core/kernel/top_k_kernel.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/kernel/kernel_util.cuh"
namespace oneflow {
template<typename T>
struct TopKKernelUtil<DeviceType::kGPU, T> {
static void Forward(const T* in, const int32_t instance_num, const int32_t instance_size,
const int32_t k, const bool sorted, int32_t* fw_buf, int32_t* out) {
UNIMPLEMENTED();
}
};
#define INSTANTIATE_TOP_K_KERNEL_UTIL(type_cpp, type_proto) \
template struct TopKKernelUtil<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_TOP_K_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
#undef INSTANTIATE_TOP_K_KERNEL_UTIL
} // namespace oneflow
\ No newline at end of file
......@@ -6,8 +6,8 @@
namespace oneflow {
template<typename T>
class TopKKernel final : public KernelIf<DeviceType::kCPU> {
template<DeviceType device_type, typename T>
class TopKKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(TopKKernel);
TopKKernel() = default;
......@@ -18,6 +18,12 @@ class TopKKernel final : public KernelIf<DeviceType::kCPU> {
std::function<Blob*(const std::string&)>) const override;
};
template<DeviceType device_type, typename T>
struct TopKKernelUtil {
static void Forward(const T* in, const int32_t instance_num, const int32_t instance_size,
const int32_t k, const bool sorted, int32_t* fw_buf, int32_t* out);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_TOP_K_KERNEL_H_
......@@ -33,6 +33,6 @@ void TopKOp::VirtualGenKernelConf(
kernel_conf->set_data_type(GetBlobDesc4BnInOp("in")->data_type());
}
REGISTER_CPU_OP(OperatorConf::kTopKConf, TopKOp);
REGISTER_OP(OperatorConf::kTopKConf, TopKOp);
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册