提交 44349a17 编写于 作者: S ScXfjiang

refine


Former-commit-id: 5b1d4a918d46205f0465352cf2d17f7cf4e3bc4d
上级 b1aa065e
......@@ -5,7 +5,9 @@ namespace oneflow {
void TopKOp::InitFromOpConf() {
CHECK(op_conf().has_top_k_conf());
EnrollInputBn("in", false);
EnrollFwBufBn("fw_buf");
if (device_type() == DeviceType::kCPU && op_conf().top_k_conf().k() > 1) {
EnrollFwBufBn("fw_buf");
}
EnrollOutputBn("out", false);
}
......@@ -14,14 +16,18 @@ void TopKOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlob
const BlobDesc* in = GetBlobDesc4BnInOp("in");
CHECK_LE(in->shape().elem_cnt(), GetMaxVal<int32_t>());
const TopKOpConf& conf = op_conf().top_k_conf();
// GPU version top_k op only support "k == 1" for now
CHECK_GE(conf.k(), 1);
CHECK_LE(conf.k(), in->shape().dim_vec().back());
// fw_buf
if (conf.k() > 1) {
BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf");
fw_buf->mut_shape() = Shape({in->shape()});
fw_buf->set_data_type(DataType::kInt32);
if (device_type() == DeviceType::kGPU) {
// GPU version top_k op only support "k == 1" for now
CHECK_EQ(conf.k(), 1);
} else if (device_type() == DeviceType::kCPU) {
if (conf.k() > 1) {
// fw_buf
BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf");
fw_buf->mut_shape() = Shape({in->shape()});
fw_buf->set_data_type(DataType::kInt32);
}
}
// out
BlobDesc* out = GetBlobDesc4BnInOp("out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册