提交 cf56085c 编写于 作者: S songhonglei413

add op_gather_int8 and testcase

上级 75af5464
......@@ -19,19 +19,13 @@
#include "nnacl/op_base.h"
typedef struct GatherParameter {
OpParameter op_parameter_;
int axis_;
int batchDims_;
} GatherParameter;
#ifdef __cplusplus
extern "C" {
#endif
int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
float *output);
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, int32_t *output);
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
int32_t *output);
#ifdef __cplusplus
}
#endif
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct GatherParameter {
OpParameter op_parameter_;
int axis_;
int batchDims_;
} GatherParameter;
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/gatherNd_int8.h"
#include <string.h>
#include "nnacl/errorcode.h"
int GatherNdInt8(int8_t *input, int8_t *output, int *in_offset, int area, int count, GatherQuantArg param) {
double alpha = param.alpha_;
int z1 = param.zp_in_;
int z2 = param.zp_out_;
for (int i = 0; i < count; ++i) {
for (int j = 0; j < area; ++j) {
int32_t tmp = round(alpha * (input[in_offset[i] + j] - z1)) + z2;
tmp = tmp > 127 ? 127 : tmp;
tmp = tmp < -128 ? -128 : tmp;
output[area * i + j] = (int8_t)tmp;
}
}
return NNACL_OK;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#ifdef __cplusplus
extern "C" {
#endif
int GatherNdInt8(int8_t *in_data, int8_t *out_data, int *in_offset, int area, int count, GatherQuantArg param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#include "nnacl/int8/gather_int8.h"
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/errorcode.h"
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, GatherQuantArg para) {
double alpha = para.alpha_;
int z1 = para.zp_in_;
int z2 = para.zp_out_;
int i, m, j;
for (m = 0; m < outer_size; ++m) {
const int8_t *inputm = in_data + inner_size * m * limit;
int8_t *outputm = out_data + inner_size * m * indices_element_size;
for (i = 0; i < indices_element_size; ++i) {
if (indices[i] < 0 || indices[i] > limit) {
return NNACL_ERR;
}
for (j = 0; j < inner_size; ++j) {
int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2;
tmp = tmp > 127 ? 127 : tmp;
tmp = tmp < -128 ? -128 : tmp;
outputm[i * inner_size + j] = (int8_t)tmp;
}
}
}
return NNACL_OK;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#ifdef __cplusplus
extern "C" {
#endif
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, GatherQuantArg para);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
......@@ -159,6 +159,12 @@ typedef struct ArithSelfQuantArg {
int shift_right_;
} ArithSelfQuantArg;
typedef struct GatherQuantArg {
double alpha_;
int zp_in_;
int zp_out_;
} GatherQuantArg;
typedef struct SplitQuantArg {
QuantArg in_args_;
QuantArg out_args_[20];
......
......@@ -144,7 +144,7 @@
#include "nnacl/transpose.h"
#include "nnacl/split_parameter.h"
#include "nnacl/squeeze.h"
#include "nnacl/fp32/gather.h"
#include "nnacl/gather_parameter.h"
#include "nnacl/fp32/reverse.h"
#include "nnacl/reverse_sequence.h"
#include "nnacl/fp32/unique.h"
......
......@@ -13,9 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/runtime/kernel/arm/fp32/gather.h"
#include <vector>
#include "nnacl/gather_parameter.h"
#include "nnacl/fp32/gather.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
......
......@@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_
#include <vector>
#include "nnacl/fp32/gather.h"
#include "nnacl/gather_parameter.h"
#include "src/lite_kernel.h"
namespace mindspore::kernel {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/gatherNd_int8.h"
#include <string.h>
#include <vector>
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "nnacl/int8/gatherNd_int8.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_GatherNd;
namespace mindspore::kernel {
GatherNdInt8CPUKernel::~GatherNdInt8CPUKernel() {
if (in_offset_ != nullptr) {
free(in_offset_);
in_offset_ = nullptr;
}
}
int GatherNdInt8CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherNdInt8CPUKernel::ReSize() {
if (in_offset_ != nullptr) {
free(in_offset_);
in_offset_ = nullptr;
}
auto in_quant_args = in_tensors_.at(0)->GetQuantParams();
auto ind_quant_args = in_tensors_.at(1)->GetQuantParams();
auto out_quant_args = out_tensors_.at(0)->GetQuantParams();
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
param_.zp_in_ = in_quant_args.front().zeroPoint;
param_.zp_out_ = out_quant_args.front().zeroPoint;
auto indices_tensor = in_tensors_.at(1);
auto indices_shape = indices_tensor->shape();
int indices_rank = indices_shape.size();
count_ = 1;
for (int i = 0; i < indices_rank - 1; ++i) {
count_ *= indices_shape[i];
}
in_offset_ = reinterpret_cast<int *>(malloc(count_ * sizeof(int)));
if (in_offset_ == nullptr) {
MS_LOG(ERROR) << "GatherNdInt8 Malloc in_offset_ error!";
return RET_ERROR;
}
(void)memset(in_offset_, 0, count_ * sizeof(int));
thread_sz_count_ = MSMIN(thread_count_, count_);
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
auto in_shape = in_tensors_.front()->shape();
int in_rank = in_shape.size();
int idx_lastshape = indices_shape[indices_rank - 1];
auto indices_ptr = reinterpret_cast<int8_t *>(indices_tensor->Data());
area_ = 1;
for (int i = idx_lastshape; i < in_rank; ++i) {
area_ *= in_shape[i];
}
std::vector<int> in_stride(in_rank);
in_stride[in_rank - 1] = 1;
for (int i = in_rank - 2; i >= 0; --i) {
in_stride[i] = in_shape[i + 1] * in_stride[i + 1];
}
int idx_stride = idx_lastshape;
for (int j = 0; j < count_; ++j) {
for (int k = 0; k < idx_lastshape; ++k) {
int tmp = static_cast<int>(
round((indices_ptr[j * idx_stride + k] - ind_quant_args.front().zeroPoint) * ind_quant_args.front().scale));
in_offset_[j] += tmp * in_stride[k];
}
}
return RET_OK;
}
int GatherNdInt8CPUKernel::DoGatherNd(int task_id) {
int count = MSMIN(thread_sz_stride_, count_ - task_id * thread_sz_stride_);
if (count <= 0) {
return RET_OK;
}
int offset = task_id * thread_sz_stride_;
auto ret = GatherNdInt8(in_ptr_, out_ptr_ + offset * area_, in_offset_ + offset, area_, count, param_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}
int GatherNdInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto g_kernel = reinterpret_cast<GatherNdInt8CPUKernel *>(cdata);
auto ret = g_kernel->DoGatherNd(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}
int GatherNdInt8CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->Data());
out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->Data());
auto ret = LiteBackendParallelLaunch(GatherNdInt8Run, this, thread_sz_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";
return ret;
}
return RET_OK;
}
kernel::LiteKernel *CpuGatherNdInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_GatherNd);
auto *kernel = new (std::nothrow) GatherNdInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_GatherNd, CpuGatherNdInt8KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHERND_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHERND_INT8_H_
#include <vector>
#include "nnacl/quantization/quantize.h"
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class GatherNdInt8CPUKernel : public LiteKernel {
public:
GatherNdInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~GatherNdInt8CPUKernel() override;
int Init() override;
int ReSize() override;
int Run() override;
int DoGatherNd(int task_id);
private:
int thread_count_;
int thread_sz_count_;
int thread_sz_stride_;
int count_;
int area_;
int *in_offset_ = nullptr;
int8_t *in_ptr_;
int8_t *out_ptr_;
GatherQuantArg param_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHERND_INT8_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/gather_int8.h"
#include <vector>
#include "nnacl/gather_parameter.h"
#include "nnacl/int8/gather_int8.h"
#include "nnacl/quantization/quantize.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
int GatherInt8CPUKernel::Init() {
axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_;
auto in_quant_args = in_tensors_.at(0)->GetQuantParams();
auto ind_quant_args = in_tensors_.at(1)->GetQuantParams();
auto out_quant_args = out_tensors_.at(0)->GetQuantParams();
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
param_.zp_in_ = in_quant_args.front().zeroPoint;
param_.zp_out_ = out_quant_args.front().zeroPoint;
auto indices_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->Data());
if (indices_ != nullptr) {
free(indices_);
indices_ = nullptr;
}
int count = in_tensors_.at(1)->ElementsNum();
indices_ = reinterpret_cast<int *>(malloc(count * sizeof(int)));
if (indices_ == nullptr) {
MS_LOG(ERROR) << "Gather Malloc indices_ error!";
return RET_ERROR;
}
(void)memset(indices_, 0, count * sizeof(int));
for (int i = 0; i < count; ++i) {
indices_[i] =
static_cast<int>(round((indices_ptr[i] - ind_quant_args.front().zeroPoint) * ind_quant_args.front().scale));
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherInt8CPUKernel::ReSize() { return RET_OK; }
int GatherInt8CPUKernel::DoGather(int task_id) {
auto input_tensor = in_tensors_.at(0);
auto indices_tensor = in_tensors_.at(1);
auto out_tensor = out_tensors_.at(0);
auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->Data());
auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->Data());
auto in_shape = input_tensor->shape();
int in_rank = in_shape.size();
int indices_element_size = indices_tensor->ElementsNum();
const int limit = in_shape[axis_];
for (int i = 0; i < indices_element_size; ++i) {
if (indices_[i] >= limit) {
MS_LOG(ERROR) << " indice data: " << indices_[i] << " is not in [ 0, " << limit - 1 << " ]";
return RET_ERROR;
}
}
int outer_size = 1;
for (int i = 0; i < axis_; ++i) {
outer_size *= in_shape[i];
}
int inner_size = 1;
for (int i = axis_ + 1; i < in_rank; ++i) {
inner_size *= in_shape[i];
}
int stride = UP_DIV(outer_size, thread_count_);
int count = MSMIN(stride, outer_size - stride * task_id);
auto thread_stride = stride * task_id;
int error_code;
input_ptr += thread_stride * limit;
output_ptr += thread_stride * indices_element_size;
error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_, indices_element_size, param_);
if (error_code != RET_OK) {
return RET_ERROR;
}
return RET_OK;
}
int GatherInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto gather_kernel = reinterpret_cast<GatherInt8CPUKernel *>(cdata);
auto error_code = gather_kernel->DoGather(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int GatherInt8CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = LiteBackendParallelLaunch(GatherInt8Run, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
kernel::LiteKernel *CpuGatherInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_Gather);
if (opParameter == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) GatherInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Gather, CpuGatherInt8KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHER_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHER_INT8_H_
#include <vector>
#include "nnacl/gather_parameter.h"
#include "nnacl/quantization/quantize.h"
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class GatherInt8CPUKernel : public LiteKernel {
public:
GatherInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~GatherInt8CPUKernel() {
free(indices_);
indices_ = nullptr;
}
int Init() override;
int ReSize() override;
int Run() override;
int DoGather(int task_id);
private:
int *indices_ = nullptr;
int thread_count_;
int batchDims_;
int axis_;
GatherQuantArg param_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GATHER_INT8_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/nnacl/fp32/gatherNd.h"
#include "mindspore/lite/nnacl/int8/gatherNd_int8.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
namespace mindspore {
class TestGatherNdInt8 : public mindspore::CommonTest {
public:
TestGatherNdInt8() {}
};
TEST_F(TestGatherNdInt8, GatherNdTest) {
std::vector<int8_t> in_data = {3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1};
std::vector<int8_t> in_data1 = {2, 4, 4, 2, 2, 4, 2, 4, 2};
// std::vector<int8_t> in_data1 = {2, 2, 2, 4};
std::vector<lite::tensor::Tensor *> inputs_tensor;
std::vector<lite::tensor::Tensor *> outputs_tensor;
GatherNdParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_GatherNd;
op_param.batchDims_ = 1;
std::vector<int> shape = {1, 2, 2, 5};
std::vector<int> out_shape = {1, 3, 5};
lite::tensor::QuantArg input_quant_arg;
input_quant_arg.scale = 0.5;
input_quant_arg.zeroPoint = 1;
lite::tensor::QuantArg input_quant_arg_1;
input_quant_arg_1.scale = 0.5;
input_quant_arg_1.zeroPoint = 2;
lite::tensor::QuantArg output_quant_arg;
output_quant_arg.scale = 1;
output_quant_arg.zeroPoint = 0;
lite::tensor::Tensor input0_tensor;
lite::tensor::Tensor input1_tensor;
inputs_tensor.push_back(&input0_tensor);
inputs_tensor.push_back(&input1_tensor);
input0_tensor.SetData(in_data.data());
input1_tensor.SetData(in_data1.data());
input0_tensor.set_shape(shape);
input1_tensor.set_shape({3, 3});
input0_tensor.AddQuantParam(input_quant_arg);
input1_tensor.AddQuantParam(input_quant_arg_1);
std::vector<int8_t> output(15);
// std::vector<int8_t> corr_out = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0};
std::vector<int8_t> corr_out = {6, 7, 8, 9, 0, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5};
lite::tensor::Tensor output0_tensor;
outputs_tensor.push_back(&output0_tensor);
output0_tensor.SetData(output.data());
output0_tensor.set_shape(out_shape);
output0_tensor.AddQuantParam(output_quant_arg);
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_GatherNd};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
lite::Context ctx;
ctx.thread_num_ = 3;
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), &ctx, desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto output_tensor_shape = output0_tensor.shape();
kernel->Run();
printf("==================output data=================\n");
for (int i = 0; i < output0_tensor.ElementsNum(); i++) {
printf("%d, ", output[i]);
}
std::cout << std::endl;
CompareOutputData(output.data(), corr_out.data(), output0_tensor.ElementsNum(), 0.001);
input0_tensor.SetData(nullptr);
input1_tensor.SetData(nullptr);
output0_tensor.SetData(nullptr);
MS_LOG(INFO) << "TestGatherNd accuracy passed";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/nnacl/gather_parameter.h"
#include "mindspore/lite/nnacl/int8/gather_int8.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
namespace mindspore {
class TestGatherInt8 : public mindspore::CommonTest {
public:
TestGatherInt8() {}
};
TEST_F(TestGatherInt8, GatherTest) {
std::vector<int8_t> in_data = {11, 41, 21, 51, 31, 61, -11, -41, -21, -51, -31, -61};
std::vector<int8_t> in_data1 = {4, 2};
std::vector<lite::tensor::Tensor *> inputs_tensor;
std::vector<lite::tensor::Tensor *> outputs_tensor;
GatherParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_Gather;
op_param.axis_ = 0;
op_param.batchDims_ = 1;
std::vector<int> shape = {2, 1, 3, 2};
lite::tensor::QuantArg input_quant_arg;
input_quant_arg.scale = 0.1;
input_quant_arg.zeroPoint = 1;
lite::tensor::QuantArg input_quant_arg_1;
input_quant_arg_1.scale = 0.5;
input_quant_arg_1.zeroPoint = 2;
lite::tensor::QuantArg output_quant_arg;
output_quant_arg.scale = 0.1;
output_quant_arg.zeroPoint = 1;
lite::tensor::Tensor input0_tensor;
lite::tensor::Tensor input1_tensor;
inputs_tensor.push_back(&input0_tensor);
inputs_tensor.push_back(&input1_tensor);
input0_tensor.SetData(in_data.data());
input1_tensor.SetData(in_data1.data());
input0_tensor.set_shape(shape);
input1_tensor.set_shape({2});
input0_tensor.AddQuantParam(input_quant_arg);
input1_tensor.AddQuantParam(input_quant_arg_1);
std::vector<int8_t> output(12);
// std::vector<int8_t> corr_out = {-18, -22, -16, -21, -14, -19, -22, -34, -24, -35, -26, -36 };
std::vector<int8_t> corr_out = {-11, -41, -21, -51, -31, -61, 11, 41, 21, 51, 31, 61};
lite::tensor::Tensor output0_tensor;
outputs_tensor.push_back(&output0_tensor);
output0_tensor.SetData(output.data());
output0_tensor.set_shape(shape);
output0_tensor.AddQuantParam(output_quant_arg);
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Gather};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
lite::Context ctx;
ctx.thread_num_ = 3;
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), &ctx, desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto output_tensor_shape = output0_tensor.shape();
kernel->Run();
printf("==================output data=================\n");
for (int i = 0; i < output0_tensor.ElementsNum(); i++) {
printf("%d, ", output[i]);
}
std::cout << std::endl;
CompareOutputData(output.data(), corr_out.data(), output0_tensor.ElementsNum(), 0.001);
input0_tensor.SetData(nullptr);
input1_tensor.SetData(nullptr);
output0_tensor.SetData(nullptr);
MS_LOG(INFO) << "TestGather_int8 accuracy passed";
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册