提交 6ce741c0 编写于 作者: S ScXfjiang

update check logic


Former-commit-id: b780c94966aec7048bac921ce1844a624dc704fe
上级 b9a92606
......@@ -16,12 +16,12 @@ void TopKOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlob
const BlobDesc* in = GetBlobDesc4BnInOp("in");
CHECK_LE(in->shape().elem_cnt(), GetMaxVal<int32_t>());
const TopKOpConf& conf = op_conf().top_k_conf();
CHECK_GE(conf.k(), 1);
CHECK_LE(conf.k(), in->shape().dim_vec().back());
if (device_type() == DeviceType::kGPU) {
// GPU version top_k op only support "k == 1" for now
CHECK_EQ(conf.k(), 1);
} else if (device_type() == DeviceType::kCPU) {
CHECK_GE(conf.k(), 1);
CHECK_LE(conf.k(), in->shape().dim_vec().back());
if (conf.k() > 1) {
// fw_buf
BlobDesc* fw_buf = GetBlobDesc4BnInOp("fw_buf");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册