提交 5969769b 编写于 作者: S sunsuodong

topk_int8

上级 c17ed236
......@@ -35,13 +35,15 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
auto topk_prim = this->primitive->value_as_TopK();
MS_ASSERT(topk_prim != nullptr);
output0->set_shape(input->shape());
auto out_shape = input->shape();
out_shape[out_shape.size() - 1] = topk_prim->k();
output0->set_shape(out_shape);
output0->set_data_type(input->data_type());
// output0->shape().back() = topk_prim->k();
output0->SetFormat(input->GetFormat());
output1->set_shape(input->shape());
output1->set_data_type(input->data_type());
// output1->shape().back() = topk_prim->k();
output1->set_shape(out_shape);
output1->set_data_type(kNumberTypeInt32);
output1->SetFormat(input->GetFormat());
return RET_OK;
......
......@@ -34,7 +34,7 @@
#include "src/runtime/kernel/arm/opclib/matmul.h"
#include "src/runtime/kernel/arm/opclib/fp32/softmax.h"
#include "src/runtime/kernel/arm/opclib/tile.h"
#include "src/runtime/kernel/arm/opclib/topk.h"
#include "src/runtime/kernel/arm/opclib/fp32/topk.h"
#include "src/runtime/kernel/arm/opclib/fp32/reduce.h"
#include "src/runtime/kernel/arm/opclib/fp32/activation.h"
#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h"
......
......@@ -25,11 +25,18 @@ using mindspore::schema::PrimitiveType_TopK;
namespace mindspore::kernel {
int TopKCPUKernel::Init() {
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(opParameter);
lite::tensor::Tensor *input = inputs_.at(0);
topk_parameter_->last_dim_size_ = input->shape()[input->shape().size() - 1];
topk_parameter_->loop_num_ = 1;
parameter->last_dim_size_ = input->shape()[input->shape().size() - 1];
parameter->loop_num_ = 1;
for (int i = 0; i < input->shape().size() - 1; ++i) {
topk_parameter_->loop_num_ *= input->shape()[i];
parameter->loop_num_ *= input->shape()[i];
}
parameter->topk_node_list_ = malloc(sizeof(TopkNode) * parameter->last_dim_size_);
if (parameter->topk_node_list_ == nullptr) {
MS_LOG(ERROR) << "malloc fail.";
return RET_ERROR;
}
return RET_OK;
}
......@@ -39,14 +46,9 @@ int TopKCPUKernel::ReSize() { return RET_OK; }
int TopKCPUKernel::Run() {
auto input_data = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto output_data = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto output_index = reinterpret_cast<float *>(outputs_.at(1)->Data());
auto output_index = reinterpret_cast<int32_t *>(outputs_.at(1)->Data());
Node *top_map = reinterpret_cast<Node *>(malloc(sizeof(Node) * topk_parameter_->last_dim_size_));
MS_EXCEPTION_IF_NULL(top_map);
topk_parameter_->topk_node_list_ = top_map;
Topk(input_data, output_data, output_index, topk_parameter_);
free(top_map);
topk_parameter_->topk_node_list_ = nullptr;
Topk(input_data, output_data, output_index, reinterpret_cast<TopkParameter *>(opParameter));
return RET_OK;
}
......@@ -54,7 +56,6 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector<lite::tensor::Ten
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter,
const lite::Context *ctx, const KernelKey &desc) {
MS_ASSERT(parameter != nullptr);
MS_ASSERT(desc.type == PrimitiveType_Tile);
auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new TopKCPUKernel fail!";
......@@ -73,4 +74,3 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector<lite::tensor::Ten
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TopK, CpuTopKFp32KernelCreator)
} // namespace mindspore::kernel
......@@ -18,26 +18,25 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/topk.h"
#include "src/runtime/kernel/arm/opclib/fp32/topk.h"
namespace mindspore::kernel {
class TopKCPUKernel : public LiteKernel {
public:
explicit TopKCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs) {
topk_parameter_ = reinterpret_cast<TopkParameter *>(parameter);
: LiteKernel(parameter, inputs, outputs) {}
~TopKCPUKernel() override {
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(opParameter);
free(parameter->topk_node_list_);
}
~TopKCPUKernel() override {}
int Init() override;
int ReSize() override;
int Run() override;
private:
TopkParameter *topk_parameter_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/topk_int8.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_TopK;
namespace mindspore::kernel {
int TopKInt8CPUKernel::Init() {
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(opParameter);
lite::tensor::Tensor *input = inputs_.at(0);
parameter->last_dim_size_ = input->shape()[input->shape().size() - 1];
parameter->loop_num_ = 1;
for (int i = 0; i < input->shape().size() - 1; ++i) {
parameter->loop_num_ *= input->shape()[i];
}
parameter->topk_node_list_ = malloc(sizeof(TopkNodeInt8) * parameter->last_dim_size_);
if (parameter->topk_node_list_ == nullptr) {
MS_LOG(ERROR) << "malloc fail.";
return RET_ERROR;
}
return RET_OK;
}
int TopKInt8CPUKernel::ReSize() { return RET_OK; }
int TopKInt8CPUKernel::Run() {
int8_t *input_data = reinterpret_cast<int8_t *>(inputs_.at(0)->Data());
int8_t *output_data = reinterpret_cast<int8_t *>(outputs_.at(0)->Data());
int32_t *output_index = reinterpret_cast<int32_t *>(outputs_.at(1)->Data());
TopkInt8(input_data, output_data, output_index, reinterpret_cast<TopkParameter *>(opParameter));
return RET_OK;
}
kernel::LiteKernel *CpuTopKInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *parameter,
const lite::Context *ctx, const KernelKey &desc) {
MS_ASSERT(parameter != nullptr);
auto *kernel = new (std::nothrow) TopKInt8CPUKernel(parameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new TopKInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_TopK, CpuTopKInt8KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/int8/topk_int8.h"
namespace mindspore::kernel {
class TopKInt8CPUKernel : public LiteKernel {
public:
explicit TopKInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs) {}
~TopKInt8CPUKernel() override {
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(opParameter);
free(parameter->topk_node_list_);
}
int Init() override;
int ReSize() override;
int Run() override;
private:
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_
......@@ -14,25 +14,25 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/topk.h"
#include "src/runtime/kernel/arm/opclib/fp32/topk.h"
int DescendCmp(const void *a, const void *b) {
return ((const Node *)b)->element - ((const Node *)a)->element;
return ((const TopkNode *)b)->element - ((const TopkNode *)a)->element;
}
int AscendCmp(const void *a, const void *b) {
return ((const Node *)a)->element - ((const Node *)b)->element;
return ((const TopkNode *)a)->element - ((const TopkNode *)b)->element;
}
void Topk(float *input_data, float *output_data, float *output_index, TopkParameter *parameter) {
void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter) {
int last_dim_size = parameter->last_dim_size_;
int loop_num = parameter->loop_num_;
int k = parameter->k_;
Node *top_map = parameter->topk_node_list_;
TopkNode *top_map = (TopkNode *)parameter->topk_node_list_;
float *cur_input_data = input_data;
float *cur_output_data = output_data;
float *cur_output_index = output_index;
int32_t *cur_output_index = output_index;
for (int i = 0; i < loop_num; i++) {
for (int j = 0; j < last_dim_size; j++) {
top_map[j].element = *(cur_input_data + j);
......
......@@ -19,9 +19,9 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
struct Node {
struct TopkNode {
float element;
float index;
int32_t index;
};
struct TopkParameter {
......@@ -30,10 +30,10 @@ struct TopkParameter {
int loop_num_;
int k_;
bool sorted_;
Node *topk_node_list_;
void *topk_node_list_;
};
void Topk(float *input_data, float *output_data, float *output_index, TopkParameter *parameter);
void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TOPK_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/int8/topk_int8.h"
int DescendCmpInt8(const void *a, const void *b) {
return ((const TopkNodeInt8 *)b)->element - ((const TopkNodeInt8 *)a)->element;
}
int AscendCmpInt8(const void *a, const void *b) {
return ((const TopkNodeInt8 *)a)->element - ((const TopkNodeInt8 *)b)->element;
}
void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter) {
int last_dim_size = parameter->last_dim_size_;
int loop_num = parameter->loop_num_;
int k = parameter->k_;
TopkNodeInt8 *top_map = (TopkNodeInt8 *)parameter->topk_node_list_;
int8_t *cur_input_data = input_data;
int8_t *cur_output_data = output_data;
int32_t *cur_output_index = output_index;
for (int i = 0; i < loop_num; i++) {
for (int j = 0; j < last_dim_size; j++) {
top_map[j].element = *(cur_input_data + j);
top_map[j].index = j;
}
if (parameter->sorted_) {
qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmpInt8);
} else {
qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmpInt8);
}
for (int m = 0; m < k; m++) {
cur_output_data[m] = top_map[m].element;
cur_output_index[m] = top_map[m].index;
}
cur_input_data += last_dim_size;
cur_output_data += k;
cur_output_index += k;
}
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/fp32/topk.h"
struct TopkNodeInt8 {
int8_t element;
int32_t index;
};
void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h"
#include "mindspore/lite/src/kernel_registry.h"
namespace mindspore {
class TestTopKFp32 : public mindspore::Common {
public:
TestTopKFp32() {}
};
TEST_F(TestTopKFp32, TopK) {
lite::tensor::Tensor in_tensor(kNumberTypeFloat32, {2, 2, 3});
lite::tensor::Tensor out_tensor0(kNumberTypeFloat32, {2, 2, 2});
lite::tensor::Tensor out_tensor1(kNumberTypeInt32, {2, 2, 2});
float input_data[] = {1, 2, 3, 6, 5, 4, 9, 8, 7, 10, 12, 11};
float output_data0[8] = {0};
int32_t output_data1[8] = {0};
in_tensor.SetData(input_data);
out_tensor0.SetData(output_data0);
out_tensor1.SetData(output_data1);
std::vector<lite::tensor::Tensor *> inputs = {&in_tensor};
std::vector<lite::tensor::Tensor *> outputs = {&out_tensor0, &out_tensor1};
TopkParameter parameter = {{}, 3, 4, 2, true};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_TopK};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), nullptr, desc);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
EXPECT_EQ(0, ret);
float expect0[] = {3, 2, 6, 5, 9, 8, 12, 11};
int32_t expect1[] = {2, 1, 0, 1, 0, 1, 1, 2};
for (int i = 0; i < 8; ++i) {
EXPECT_EQ(output_data0[i], expect0[i]);
EXPECT_EQ(output_data1[i], expect1[i]);
}
in_tensor.SetData(nullptr);
out_tensor0.SetData(nullptr);
out_tensor1.SetData(nullptr);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h"
#include "mindspore/lite/src/kernel_registry.h"
namespace mindspore {
class TestTopKInt8 : public mindspore::Common {
public:
TestTopKInt8() {}
};
TEST_F(TestTopKInt8, TopK) {
lite::tensor::Tensor in_tensor(kNumberTypeInt8, {2, 2, 3});
lite::tensor::Tensor out_tensor0(kNumberTypeInt8, {2, 2, 2});
lite::tensor::Tensor out_tensor1(kNumberTypeInt32, {2, 2, 2});
int8_t input_data[] = {1, 2, 3, 6, 5, 4, 9, 8, 7, 10, 12, 11};
int8_t output_data0[8] = {0};
int32_t output_data1[8] = {0};
in_tensor.SetData(input_data);
out_tensor0.SetData(output_data0);
out_tensor1.SetData(output_data1);
std::vector<lite::tensor::Tensor *> inputs = {&in_tensor};
std::vector<lite::tensor::Tensor *> outputs = {&out_tensor0, &out_tensor1};
TopkParameter parameter = {{}, 3, 4, 2, true};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_TopK};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), nullptr, desc);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
EXPECT_EQ(0, ret);
int8_t expect0[] = {3, 2, 6, 5, 9, 8, 12, 11};
int32_t expect1[] = {2, 1, 0, 1, 0, 1, 1, 2};
for (int i = 0; i < 8; ++i) {
EXPECT_EQ(output_data0[i], expect0[i]);
EXPECT_EQ(output_data1[i], expect1[i]);
}
in_tensor.SetData(nullptr);
out_tensor0.SetData(nullptr);
out_tensor1.SetData(nullptr);
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册