提交 f6dc9287 编写于 作者: Z zhongligeng

fix quantdtypecast

上级 fa96dfd1
...@@ -137,10 +137,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { ...@@ -137,10 +137,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) {
return new lite::SpaceToDepth(const_cast<schema::Primitive *>(primitive)); return new lite::SpaceToDepth(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_SpaceToBatch: case schema::PrimitiveType_SpaceToBatch:
return new lite::SpaceToBatch(const_cast<schema::Primitive *>(primitive)); return new lite::SpaceToBatch(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_OnnxInt8Dequantize: case schema::PrimitiveType_QuantDTypeCast:
return new lite::Dequantize(const_cast<schema::Primitive *>(primitive)); return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_OnnxInt8Quantize:
return new lite::Quantize(const_cast<schema::Primitive *>(primitive));
default: default:
break; break;
} }
......
...@@ -691,17 +691,10 @@ class SpaceToDepth : public Primitive { ...@@ -691,17 +691,10 @@ class SpaceToDepth : public Primitive {
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
}; };
class Dequantize : public Primitive { class QuantDTypeCast : public Primitive {
public: public:
explicit Dequantize(schema::Primitive *primitive) : Primitive(primitive) {} explicit QuantDTypeCast(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::OnnxInt8Dequantize *GetAttribute() const { return this->primitive->value_as_OnnxInt8Dequantize(); } const schema::QuantDTypeCast *GetAttribute() const { return this->primitive->value_as_QuantDTypeCast(); }
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
};
class Quantize : public Primitive {
public:
explicit Quantize(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::OnnxInt8Quantize *GetAttribute() const { return this->primitive->value_as_OnnxInt8Quantize(); }
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
}; };
} // namespace lite } // namespace lite
......
...@@ -20,15 +20,16 @@ ...@@ -20,15 +20,16 @@
#include "src/ir/tensor.h" #include "src/ir/tensor.h"
namespace mindspore::lite { namespace mindspore::lite {
int Dequantize::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr); MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front(); auto input = inputs_.front();
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
output->set_shape(input->shape()); output->set_shape(input->shape());
output->set_data_type(kNumberTypeFloat32); auto param = primitive->value_as_QuantDTypeCast();
MS_ASSERT(input->data_type() == param->srcT);
output->set_data_type(static_cast<TypeId>(param->dstT()));
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore::lite {
int Quantize::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(kNumberTypeInt8);
return RET_OK;
}
} // namespace mindspore::lite
...@@ -65,8 +65,7 @@ ...@@ -65,8 +65,7 @@
#include "src/runtime/kernel/arm/base/prior_box.h" #include "src/runtime/kernel/arm/base/prior_box.h"
#include "src/runtime/kernel/arm/opclib/fp32/space_to_depth.h" #include "src/runtime/kernel/arm/opclib/fp32/space_to_depth.h"
#include "src/runtime/kernel/arm/opclib/fp32/space_to_batch.h" #include "src/runtime/kernel/arm/opclib/fp32/space_to_batch.h"
#include "src/runtime/kernel/arm/opclib/int8/dequantize.h" #include "src/runtime/kernel/arm/opclib/int8/quant_dtype_cast.h"
#include "src/runtime/kernel/arm/opclib/fp32/quantize.h"
namespace mindspore::kernel { namespace mindspore::kernel {
OpParameter *PopulateFillParameter(const lite::Primitive *primitive) { OpParameter *PopulateFillParameter(const lite::Primitive *primitive) {
...@@ -1032,24 +1031,17 @@ OpParameter *PopulateFlattenParameter(const lite::Primitive *primitive) { ...@@ -1032,24 +1031,17 @@ OpParameter *PopulateFlattenParameter(const lite::Primitive *primitive) {
return reinterpret_cast<OpParameter *>(flatten_param); return reinterpret_cast<OpParameter *>(flatten_param);
} }
OpParameter *PopulateDequantizeParameter(const lite::Primitive *primitive) { OpParameter *PopulateQuantDTypeCastParameter(const lite::Primitive *primitive) {
DequantizeParameter *dequantize_parameter = new (std::nothrow) DequantizeParameter(); QuantDTypeCastParameter *parameter = new (std::nothrow) QuantDTypeCastParameter();
if (dequantize_parameter == nullptr) { if (parameter == nullptr) {
MS_LOG(ERROR) << "new DequantizeParameter fail!"; MS_LOG(ERROR) << "new QuantDTypeCastParameter fail!";
return nullptr;
}
dequantize_parameter->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(dequantize_parameter);
}
OpParameter *PopulateQuantizeParameter(const lite::Primitive *primitive) {
QuantizeParameter *quantize_parameter = new (std::nothrow) QuantizeParameter();
if (quantize_parameter == nullptr) {
MS_LOG(ERROR) << "new QuantizeParameter fail!";
return nullptr; return nullptr;
} }
quantize_parameter->op_parameter_.type_ = primitive->Type(); parameter->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(quantize_parameter); auto quant_dtype_cast_param = primitive->Value()->value_as_QuantDTypeCast();
parameter->srcT = quant_dtype_cast_param->srcT();
parameter->dstT = quant_dtype_cast_param->dstT();
return reinterpret_cast<OpParameter *>(parameter);
} }
OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) { OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) {
...@@ -1209,8 +1201,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { ...@@ -1209,8 +1201,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_Square] = PopulateSqueezeParameter; populate_parameter_funcs_[schema::PrimitiveType_Square] = PopulateSqueezeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter; populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter;
populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter;
populate_parameter_funcs_[schema::PrimitiveType_OnnxInt8Dequantize] = PopulateDequantizeParameter; populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter;
populate_parameter_funcs_[schema::PrimitiveType_OnnxInt8Quantize] = PopulateQuantizeParameter;
} }
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/fp32/quantize.h" #include "src/runtime/kernel/arm/base/quant_dtype_cast.h"
#include <vector> #include <vector>
#include "src/runtime/kernel/arm/opclib/fp32/quantize.h" #include "src/runtime/kernel/arm/opclib/int8/quant_dtype_cast.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "include/errorcode.h" #include "include/errorcode.h"
...@@ -25,15 +25,15 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; ...@@ -25,15 +25,15 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_OnnxInt8Quantize; using mindspore::schema::PrimitiveType_QuantDTypeCast;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace { namespace {
constexpr int kQuantizeInputNum = 1; constexpr int kQuantDTypeCastInputNum = 1;
constexpr int kQuantizeOutputNum = 1; constexpr int kQuantDTypeCastOutputNum = 1;
} // namespace } // namespace
int QuantizeCPUKernel::Init() { int QuantDTypeCastCPUKernel::Init() {
if (inputs_.size() != 1) { if (inputs_.size() != 1) {
MS_LOG(ERROR) << "inputs number should be 1, but " << inputs_.size() << " is given."; MS_LOG(ERROR) << "inputs number should be 1, but " << inputs_.size() << " is given.";
return RET_ERROR; return RET_ERROR;
...@@ -43,6 +43,25 @@ int QuantizeCPUKernel::Init() { ...@@ -43,6 +43,25 @@ int QuantizeCPUKernel::Init() {
return RET_ERROR; return RET_ERROR;
} }
auto in_tensor = inputs_.front(); auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front();
auto param = reinterpret_cast<QuantDTypeCastParameter *>(opParameter);
if (param->srcT == kNumberTypeFloat32 && param->dstT == kNumberTypeInt8) {
if (in_tensor->data_type() != kNumberTypeFloat32 || out_tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = false;
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeFloat32) {
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = true;
} else {
MS_LOG(ERROR) << "param data type not supported.";
return RET_ERROR;
}
num_unit_ = static_cast<int>(in_tensor->DataSize()); num_unit_ = static_cast<int>(in_tensor->DataSize());
thread_n_num_ = MSMIN(thread_num_, num_unit_); thread_n_num_ = MSMIN(thread_num_, num_unit_);
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
...@@ -50,39 +69,50 @@ int QuantizeCPUKernel::Init() { ...@@ -50,39 +69,50 @@ int QuantizeCPUKernel::Init() {
return RET_OK; return RET_OK;
} }
int QuantizeCPUKernel::ReSize() { return RET_OK; } int QuantDTypeCastCPUKernel::ReSize() { return RET_OK; }
int QuantizeCPUKernel::Quantize(int task_id) { int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) { if (num_unit_thread <= 0) {
return RET_OK; return RET_OK;
} }
int thread_offset = task_id * thread_n_stride_; int thread_offset = task_id * thread_n_stride_;
auto quant_arg = inputs_.front()->GetQuantParams().front(); auto quant_arg = inputs_.front()->GetQuantParams().front();
int ret = QuantizeToInt8(input_ptr_ + thread_offset, output_ptr_ + thread_offset, quant_arg.scale, int ret;
quant_arg.zeroPoint, num_unit_thread); if (inverse_) {
ret = DequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint,
num_unit_thread);
} else {
ret = QuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
}
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Quantize error task_id[" << task_id << "] error_code[" << ret << "]"; MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;
} }
return RET_OK; return RET_OK;
} }
int QuantizeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { int QuantDTypeCastRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto g_kernel = reinterpret_cast<QuantizeCPUKernel *>(cdata); auto g_kernel = reinterpret_cast<QuantDTypeCastCPUKernel *>(cdata);
auto ret = g_kernel->Quantize(task_id); auto ret = g_kernel->QuantDTypeCast(task_id);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantizeRun error task_id[" << task_id << "] error_code[" << ret << "]"; MS_LOG(ERROR) << "QuantDTypeCastRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;
} }
return RET_OK; return RET_OK;
} }
int QuantizeCPUKernel::Run() { int QuantDTypeCastCPUKernel::Run() {
input_ptr_ = reinterpret_cast<float *>(inputs_[0]->Data()); if (inverse_) {
output_ptr_ = reinterpret_cast<int8_t *>(outputs_[0]->Data()); int8_ptr_ = reinterpret_cast<int8_t *>(inputs_[0]->Data());
int ret = LiteBackendParallelLaunch(QuantizeRun, this, thread_n_num_); float32_ptr_ = reinterpret_cast<float *>(outputs_[0]->Data());
} else {
float32_ptr_ = reinterpret_cast<float *>(inputs_[0]->Data());
int8_ptr_ = reinterpret_cast<int8_t *>(outputs_[0]->Data());
}
int ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;
...@@ -91,17 +121,17 @@ int QuantizeCPUKernel::Run() { ...@@ -91,17 +121,17 @@ int QuantizeCPUKernel::Run() {
return RET_OK; return RET_OK;
} }
kernel::LiteKernel *CpuQuantizeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx, OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) { const kernel::KernelKey &desc) {
if (opParameter == nullptr) { if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!"; MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr; return nullptr;
} }
auto *kernel = new (std::nothrow) QuantizeCPUKernel(opParameter, inputs, outputs, ctx); auto *kernel = new (std::nothrow) QuantDTypeCastCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new QuantizeCPUKernel fail!"; MS_LOG(ERROR) << "new QuantDTypeCastCPUKernel fail!";
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
...@@ -114,5 +144,5 @@ kernel::LiteKernel *CpuQuantizeFp32KernelCreator(const std::vector<lite::tensor: ...@@ -114,5 +144,5 @@ kernel::LiteKernel *CpuQuantizeFp32KernelCreator(const std::vector<lite::tensor:
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OnnxInt8Quantize, CpuQuantizeFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -14,33 +14,34 @@ ...@@ -14,33 +14,34 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEQUANTIZE_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_QUANTDTYPECAST_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEQUANTIZE_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_QUANTDTYPECAST_H_
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
namespace mindspore::kernel { namespace mindspore::kernel {
class DequantizeCPUKernel : public LiteKernel { class QuantDTypeCastCPUKernel : public LiteKernel {
public: public:
DequantizeCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, QuantDTypeCastCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {} : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {}
~DequantizeCPUKernel() = default; ~QuantDTypeCastCPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int Dequantize(int task_id); int QuantDTypeCast(int task_id);
private: private:
int thread_num_; int thread_num_;
int thread_n_num_; int thread_n_num_;
int thread_n_stride_; int thread_n_stride_;
int num_unit_; int num_unit_;
int8_t *input_ptr_; int8_t *int8_ptr_;
float *output_ptr_; float *float32_ptr_;
bool inverse_;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DEQUANTIZE_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_QUANTDTYPECAST_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_QUANTIZE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_QUANTIZE_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class QuantizeCPUKernel : public LiteKernel {
public:
QuantizeCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {}
~QuantizeCPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
int Quantize(int task_id);
private:
int thread_num_;
int thread_n_num_;
int thread_n_stride_;
int num_unit_;
float *input_ptr_;
int8_t *output_ptr_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_QUANTIZE_H_
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#include "src/runtime/kernel/arm/fp32/space_to_depth.h" #include "src/runtime/kernel/arm/fp32/space_to_depth.h"
#include <vector> #include <vector>
#include "schema/ops_generated.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/opclib/fp32/space_to_depth.h" #include "src/runtime/kernel/arm/opclib/fp32/space_to_depth.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
...@@ -41,21 +41,48 @@ int SpaceToDepthCPUKernel::Init() { ...@@ -41,21 +41,48 @@ int SpaceToDepthCPUKernel::Init() {
MS_LOG(ERROR) << "Input block_size should > 0!"; MS_LOG(ERROR) << "Input block_size should > 0!";
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
num_unit_ = static_cast<int>(inputs_[0]->shape().at(kNHWC_H));
thread_h_num_ = MSMIN(thread_num_, num_unit_);
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);
return RET_OK; return RET_OK;
} }
int SpaceToDepthCPUKernel::Run() { int SpaceToDepthCPUKernel::SpaceToDepth(int task_id) {
auto input = inputs_[0]; int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_);
auto output = outputs_[0]; if (num_unit_thread <= 0) {
const float *input_data = static_cast<const float *>(input->Data()); return RET_OK;
float *output_data = static_cast<float *>(output->Data()); }
auto in_shape = input->shape(); int thread_offset = task_id * thread_h_stride_;
auto out_shape = output->shape(); auto in_shape = inputs_[0]->shape();
auto out_shape = outputs_[0]->shape();
SpaceToDepthParameter *param = reinterpret_cast<SpaceToDepthParameter *>(opParameter); SpaceToDepthParameter *param = reinterpret_cast<SpaceToDepthParameter *>(opParameter);
if (input->GetFormat() == schema::Format_NHWC) { auto ret = SpaceToDepthForNHWC(input_ptr_, output_ptr_, in_shape.data(), out_shape.data(), in_shape.size(),
auto ret = SpaceToDepthForNHWC(input_data, output_data, in_shape.data(), out_shape.data(), in_shape.size(), param->block_size_, thread_offset, thread_offset + num_unit_thread);
param->block_size_); if (ret != RET_OK) {
return ret; MS_LOG(ERROR) << "SpaceToDepth error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int SpaceToDepthRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto g_kernel = reinterpret_cast<SpaceToDepthCPUKernel *>(cdata);
auto ret = g_kernel->SpaceToDepth(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SpaceToDepthRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int SpaceToDepthCPUKernel::Run() {
if (inputs_[0]->GetFormat() == schema::Format_NHWC) {
int ret = LiteBackendParallelLaunch(SpaceToDepthRun, this, thread_h_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SpaceToDepth error error_code[" << ret << "]";
return ret;
}
} else { } else {
MS_LOG(ERROR) << "Only support NHWC now!"; MS_LOG(ERROR) << "Only support NHWC now!";
return RET_ERROR; return RET_ERROR;
...@@ -69,7 +96,7 @@ kernel::LiteKernel *CpuSpaceToDepthFp32KernelCreator(const std::vector<lite::ten ...@@ -69,7 +96,7 @@ kernel::LiteKernel *CpuSpaceToDepthFp32KernelCreator(const std::vector<lite::ten
MS_LOG(ERROR) << "Input opParameter is nullptr!"; MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr; return nullptr;
} }
auto *kernel = new (std::nothrow) SpaceToDepthCPUKernel(opParameter, inputs, outputs); auto *kernel = new (std::nothrow) SpaceToDepthCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new SpaceToDepthCPUKernel fail!"; MS_LOG(ERROR) << "new SpaceToDepthCPUKernel fail!";
return nullptr; return nullptr;
......
...@@ -24,13 +24,22 @@ namespace mindspore::kernel { ...@@ -24,13 +24,22 @@ namespace mindspore::kernel {
class SpaceToDepthCPUKernel : public LiteKernel { class SpaceToDepthCPUKernel : public LiteKernel {
public: public:
SpaceToDepthCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, SpaceToDepthCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs) const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs) {} : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {}
~SpaceToDepthCPUKernel() = default; ~SpaceToDepthCPUKernel() = default;
int SpaceToDepth(int task_id);
int Init() override; int Init() override;
int ReSize() override { return 0; }; int ReSize() override { return 0; };
int Run() override; int Run() override;
private:
int thread_num_;
int thread_h_stride_;
int thread_h_num_;
int num_unit_;
float *input_ptr_;
float *output_ptr_;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/dequantize.h"
#include <vector>
#include "src/runtime/kernel/arm/opclib/int8/dequantize.h"
#include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_OnnxInt8Dequantize;
namespace mindspore::kernel {
namespace {
constexpr int kDequantizeInputNum = 1;
constexpr int kDequantizeOutputNum = 1;
} // namespace
int DequantizeCPUKernel::Init() {
if (inputs_.size() != 1) {
MS_LOG(ERROR) << "inputs number should be 1, but " << inputs_.size() << " is given.";
return RET_ERROR;
}
if (outputs_.size() != 1) {
MS_LOG(ERROR) << "outputs number should be 1, but " << inputs_.size() << " is given.";
return RET_ERROR;
}
auto in_tensor = inputs_.front();
num_unit_ = static_cast<int>(in_tensor->DataSize());
thread_n_num_ = MSMIN(thread_num_, num_unit_);
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
return RET_OK;
}
int DequantizeCPUKernel::ReSize() { return RET_OK; }
int DequantizeCPUKernel::Dequantize(int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int thread_offset = task_id * thread_n_stride_;
auto quant_arg = inputs_.front()->GetQuantParams().front();
int ret = DequantizeInt8(input_ptr_ + thread_offset, output_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Dequantize error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DequantizeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto g_kernel = reinterpret_cast<DequantizeCPUKernel *>(cdata);
auto ret = g_kernel->Dequantize(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DequantizeRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DequantizeCPUKernel::Run() {
input_ptr_ = reinterpret_cast<int8_t *>(inputs_[0]->Data());
output_ptr_ = reinterpret_cast<float *>(outputs_[0]->Data());
int ret = LiteBackendParallelLaunch(DequantizeRun, this, thread_n_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
kernel::LiteKernel *CpuDequantizeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) DequantizeCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new DequantizeCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed! name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_OnnxInt8Dequantize, CpuDequantizeFp32KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_QUANTIZE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_QUANTIZE_H_
#include "src/runtime/kernel/arm/opclib/op_base.h"
struct QuantizeParameter {
OpParameter op_parameter_;
};
int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_QUANTIZE_H_
...@@ -19,13 +19,16 @@ ...@@ -19,13 +19,16 @@
#include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/op_base.h"
int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size, int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size,
int block_size) { int block_size, int h_start, int h_end) {
if (input == nullptr || output == nullptr) { if (input == nullptr || output == nullptr) {
return OPCLIB_NULL_PTR; return OPCLIB_NULL_PTR;
} }
if (shape_size != C4NUM) { if (shape_size != C4NUM) {
return OPCLIB_PARAM_INVALID; return OPCLIB_PARAM_INVALID;
} }
if (h_start < 0 || h_start >= h_end || h_end > out_shape[1]) {
return OPCLIB_PARAM_INVALID;
}
int in_strides[C4NUM]; int in_strides[C4NUM];
ComputeStrides(in_shape, in_strides, shape_size); ComputeStrides(in_shape, in_strides, shape_size);
int out_strides[C4NUM]; int out_strides[C4NUM];
...@@ -33,7 +36,7 @@ int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *o ...@@ -33,7 +36,7 @@ int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *o
for (int i = 0; i < out_shape[0]; ++i) { for (int i = 0; i < out_shape[0]; ++i) {
size_t in_offset_n = i * in_strides[0]; size_t in_offset_n = i * in_strides[0];
size_t out_offset_n = i * out_strides[0]; size_t out_offset_n = i * out_strides[0];
for (int j = 0; j < out_shape[1]; ++j) { for (int j = h_start; j < h_end; ++j) {
size_t in_offset_h = in_offset_n + j * block_size * in_strides[1]; size_t in_offset_h = in_offset_n + j * block_size * in_strides[1];
size_t out_offset_h = out_offset_n + j * out_strides[1]; size_t out_offset_h = out_offset_n + j * out_strides[1];
for (int k = 0; k < out_shape[2]; ++k) { for (int k = 0; k < out_shape[2]; ++k) {
......
...@@ -23,5 +23,5 @@ struct SpaceToDepthParameter { ...@@ -23,5 +23,5 @@ struct SpaceToDepthParameter {
}; };
int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size, int SpaceToDepthForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size,
int block_size); int block_size, int h_start, int h_end);
#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_FP32_SPACE_TO_DEPTH_H_ #endif // MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_FP32_SPACE_TO_DEPTH_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/int8/dequantize.h"
#include "src/runtime/kernel/arm/opclib/errorcode.h"
int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) {
if (quant_values == nullptr || real_values == nullptr) {
return OPCLIB_PARAM_INVALID;
}
for (int i = 0; i < size; ++i) {
real_values[i] = (quant_values[i] + zp) * scale;
}
return OPCLIB_OK;
}
...@@ -14,9 +14,20 @@ ...@@ -14,9 +14,20 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/arm/opclib/fp32/quantize.h" #include "src/runtime/kernel/arm/opclib/int8/quant_dtype_cast.h"
#include "src/runtime/kernel/arm/opclib/errorcode.h" #include "src/runtime/kernel/arm/opclib/errorcode.h"
int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) {
if (quant_values == nullptr || real_values == nullptr) {
return OPCLIB_PARAM_INVALID;
}
for (int i = 0; i < size; ++i) {
real_values[i] = (quant_values[i] + zp) * scale;
}
return OPCLIB_OK;
}
int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
if (quant_values == nullptr || real_values == nullptr) { if (quant_values == nullptr || real_values == nullptr) {
return OPCLIB_PARAM_INVALID; return OPCLIB_PARAM_INVALID;
......
...@@ -14,15 +14,18 @@ ...@@ -14,15 +14,18 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_DEQUANTIZE_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_QUANTDTYPECAST_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_DEQUANTIZE_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_QUANTDTYPECAST_H_
#include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/op_base.h"
struct DequantizeParameter { struct QuantDTypeCastParameter {
OpParameter op_parameter_; OpParameter op_parameter_;
int32_t srcT;
int32_t dstT;
}; };
int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_DEQUANTIZE_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_QUANTDTYPECAST_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/quantize.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/quantize.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
namespace mindspore {
class QuantizeTestFp32 : public mindspore::Common {
public:
QuantizeTestFp32() {}
};
TEST_F(QuantizeTestFp32, QuantizeTest1) {
const lite::tensor::QuantArg quant_arg = {0.3515625, -57};
QuantizeParameter param;
param.op_parameter_.type_ = schema::PrimitiveType_OnnxInt8Quantize;
std::vector<float> input = {1, 2, 5, 6, 10, -20, 3, 8, 18, 10, 3, 4, 11, 16, 15, 25};
std::vector<int> in_shape = {1, 4, 4, 1};
lite::tensor::Tensor input_tensor;
input_tensor.SetData(input.data());
input_tensor.set_shape(in_shape);
input_tensor.SetFormat(schema::Format_NHWC);
input_tensor.set_data_type(kNumberTypeFloat32);
input_tensor.AddQuantParam(quant_arg);
std::vector<lite::tensor::Tensor *> inputs_tensor;
inputs_tensor.emplace_back(&input_tensor);
const int out_size = 16;
int8_t expect_out[16] = {-54, -51, -43, -40, -29, -114, -48, -34, -6, -29, -48, -46, -26, -11, -14, 14};
std::vector<int8_t> output(16);
std::vector<int> out_shape = {1, 4, 4, 1};
lite::tensor::Tensor output_tensor;
output_tensor.SetData(output.data());
output_tensor.set_shape(out_shape);
output_tensor.SetFormat(schema::Format_NHWC);
output_tensor.set_data_type(kNumberTypeInt8);
std::vector<lite::tensor::Tensor *> outputs_tensor;
outputs_tensor.emplace_back(&output_tensor);
lite::Context ctx;
ctx.threadNum = 3;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_OnnxInt8Quantize};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&param), &ctx, desc);
ASSERT_NE(kernel, nullptr);
kernel->Run();
for (int i = 0; i < out_size; ++i) {
std::cout << output[i] << " ";
}
std::cout << "\n";
CompareOutputData(output.data(), expect_out, out_size, 0.000001);
}
} // namespace mindspore
...@@ -37,7 +37,9 @@ TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest1) { ...@@ -37,7 +37,9 @@ TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest1) {
float output[16]; float output[16];
int in_shape[4] = {1, 4, 4, 1}; int in_shape[4] = {1, 4, 4, 1};
int out_shape[4] = {1, 2, 2, 4}; int out_shape[4] = {1, 2, 2, 4};
SpaceToDepthForNHWC((const float *)input, output, in_shape, out_shape, 4, 2); int h_start = 0;
int h_end = 2;
SpaceToDepthForNHWC((const float *)input, output, in_shape, out_shape, 4, 2, h_start, h_end);
for (int i = 0; i < out_size; ++i) { for (int i = 0; i < out_size; ++i) {
std::cout << output[i] << " "; std::cout << output[i] << " ";
} }
...@@ -69,10 +71,11 @@ TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest2) { ...@@ -69,10 +71,11 @@ TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest2) {
outputs_tensor.emplace_back(&output_tensor); outputs_tensor.emplace_back(&output_tensor);
SpaceToDepthParameter op_param; SpaceToDepthParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToBatch; op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth;
op_param.block_size_ = 2; op_param.block_size_ = 2;
lite::Context ctx; lite::Context ctx;
ctx.threadNum = 3;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SpaceToDepth}; kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SpaceToDepth};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr); ASSERT_NE(creator, nullptr);
......
...@@ -17,28 +17,26 @@ ...@@ -17,28 +17,26 @@
#include <memory> #include <memory>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "common/common_test.h" #include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/int8/dequantize.h" #include "mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/int8/dequantize.h" #include "mindspore/lite/src/runtime/kernel/arm/opclib/int8/quant_dtype_cast.h"
#include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h" #include "mindspore/lite/src/lite_kernel.h"
namespace mindspore { namespace mindspore {
class DequantizeTestFp32 : public mindspore::Common { class QuantDTypeCastTestFp32 : public mindspore::Common {
public: public:
DequantizeTestFp32() {} QuantDTypeCastTestFp32() {}
}; };
TEST_F(DequantizeTestFp32, DequantizeTest1) { TEST_F(QuantDTypeCastTestFp32, QuantDTypeCastTest1) {
const lite::tensor::QuantArg quant_arg{0.21176, 5}; const lite::tensor::QuantArg quant_arg{0.21176, 5};
// quant_arg.scale = 100.0; QuantDTypeCastParameter param;
// quant_arg.zeroPoint = 20; param.srcT = kNumberTypeInt8;
DequantizeParameter param; param.dstT = kNumberTypeFloat32;
param.op_parameter_.type_ = schema::PrimitiveType_OnnxInt8Dequantize; param.op_parameter_.type_ = schema::PrimitiveType_QuantDTypeCast;
std::vector<int8_t> input = {10, 14, 29, 33, 52, 99, 19, 43, 90, 52, 19, 24, 57, 127, 76, 123}; std::vector<int8_t> input = {10, 14, 29, 33, 52, 99, 19, 43, 90, 52, 19, 24, 57, 127, 76, 123};
// int8_t input0[] = {1, 2, 10};
// int32_t a = input0[0] + 2;
std::vector<int> in_shape = {1, 4, 4, 1}; std::vector<int> in_shape = {1, 4, 4, 1};
lite::tensor::Tensor input_tensor; lite::tensor::Tensor input_tensor;
input_tensor.SetData(input.data()); input_tensor.SetData(input.data());
...@@ -59,13 +57,13 @@ TEST_F(DequantizeTestFp32, DequantizeTest1) { ...@@ -59,13 +57,13 @@ TEST_F(DequantizeTestFp32, DequantizeTest1) {
output_tensor.SetData(output.data()); output_tensor.SetData(output.data());
output_tensor.set_shape(out_shape); output_tensor.set_shape(out_shape);
output_tensor.set_data_type(kNumberTypeFloat32); output_tensor.set_data_type(kNumberTypeFloat32);
output_tensor.SetFormat(schema::Format_NHWC); // output_tensor.SetFormat(schema::Format_NHWC);
std::vector<lite::tensor::Tensor *> outputs_tensor; std::vector<lite::tensor::Tensor *> outputs_tensor;
outputs_tensor.emplace_back(&output_tensor); outputs_tensor.emplace_back(&output_tensor);
lite::Context ctx; lite::Context ctx;
ctx.threadNum = 3; ctx.threadNum = 3;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_OnnxInt8Dequantize}; kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_QuantDTypeCast};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr); ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel = kernel::LiteKernel *kernel =
...@@ -80,4 +78,49 @@ TEST_F(DequantizeTestFp32, DequantizeTest1) { ...@@ -80,4 +78,49 @@ TEST_F(DequantizeTestFp32, DequantizeTest1) {
CompareOutputData(output.data(), expect_out, out_size, 0.000001); CompareOutputData(output.data(), expect_out, out_size, 0.000001);
} }
TEST_F(QuantDTypeCastTestFp32, QuantDTypeCastTest2) {
const lite::tensor::QuantArg quant_arg = {0.3515625, -57};
QuantDTypeCastParameter param;
param.op_parameter_.type_ = schema::PrimitiveType_QuantDTypeCast;
param.dstT = kNumberTypeInt8;
param.srcT = kNumberTypeFloat32;
std::vector<float> input = {1, 2, 5, 6, 10, -20, 3, 8, 18, 10, 3, 4, 11, 16, 15, 25};
std::vector<int> in_shape = {1, 4, 4, 1};
lite::tensor::Tensor input_tensor;
input_tensor.SetData(input.data());
input_tensor.set_shape(in_shape);
// input_tensor.SetFormat(schema::Format_NHWC);
input_tensor.set_data_type(kNumberTypeFloat32);
input_tensor.AddQuantParam(quant_arg);
std::vector<lite::tensor::Tensor *> inputs_tensor;
inputs_tensor.emplace_back(&input_tensor);
const int out_size = 16;
int8_t expect_out[16] = {-54, -51, -43, -40, -29, -114, -48, -34, -6, -29, -48, -46, -26, -11, -14, 14};
std::vector<int8_t> output(16);
std::vector<int> out_shape = {1, 4, 4, 1};
lite::tensor::Tensor output_tensor;
output_tensor.SetData(output.data());
output_tensor.set_shape(out_shape);
output_tensor.SetFormat(schema::Format_NHWC);
output_tensor.set_data_type(kNumberTypeInt8);
std::vector<lite::tensor::Tensor *> outputs_tensor;
outputs_tensor.emplace_back(&output_tensor);
lite::Context ctx;
ctx.threadNum = 3;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_QuantDTypeCast};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&param), &ctx, desc);
ASSERT_NE(kernel, nullptr);
kernel->Run();
for (int i = 0; i < out_size; ++i) {
std::cout << output[i] << " ";
}
std::cout << "\n";
CompareOutputData(output.data(), expect_out, out_size, 0.000001);
}
} // namespace mindspore } // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册