diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/div_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/div_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5db6555d128edf2b0d0c4065f76520fb8f3f10b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/div_int8.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/div_int8.h" +#include +#include +#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Div; + +namespace mindspore::kernel { + +int DivInt8CPUKernel::Init() { + lite::tensor::Tensor *input0 = in_tensors_.at(0); + lite::tensor::Tensor *input1 = in_tensors_.at(1); + lite::tensor::Tensor *output = out_tensors_.at(0); + MS_ASSERT(input0); + MS_ASSERT(input1); + MS_ASSERT(output); + + broadcast_ = input0->ElementsNum() != input1->ElementsNum(); + + param_.in0_args_.scale_ = input0->GetQuantParams().front().scale; + param_.in0_args_.zp_ = -input0->GetQuantParams().front().zeroPoint; + param_.in1_args_.scale_ = input1->GetQuantParams().front().scale; + param_.in1_args_.zp_ = -input1->GetQuantParams().front().zeroPoint; + param_.out_args_.scale_ = output->GetQuantParams().front().scale; + param_.out_args_.zp_ = output->GetQuantParams().front().zeroPoint; + + const double real_multiplier = param_.in0_args_.scale_ / (param_.in1_args_.scale_ * param_.out_args_.scale_); + + QuantizeMultiplier(real_multiplier, ¶m_.output_multiplier_, ¶m_.output_shift_); + + param_.output_activation_min_ = std::numeric_limits::min(); + param_.output_activation_max_ = std::numeric_limits::max(); + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int DivInt8CPUKernel::ReSize() { + if (broadcast_) { + if (tile0_data_ != nullptr) { + if (context_ != nullptr && context_->allocator != nullptr) { + context_->allocator->Free(tile0_data_); + } else { + free(tile0_data_); + } + } + if (tile1_data_ != nullptr) { + if (context_ != nullptr && context_->allocator != nullptr) { + context_->allocator->Free(tile1_data_); + } else { + free(tile1_data_); + } + } + + if (context_ != nullptr && context_->allocator != nullptr) { + tile0_data_ = static_cast(context_->allocator->Malloc(out_tensors_.at(0)->Size())); + tile1_data_ = static_cast(context_->allocator->Malloc(out_tensors_.at(0)->Size())); + } else { + tile0_data_ = static_cast(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size())); + tile1_data_ = static_cast(malloc(sizeof(int8_t) * out_tensors_.at(0)->Size())); + } + + if (tile0_data_ == nullptr || tile1_data_ == nullptr) { + if (tile0_data_ != nullptr) { + if (context_ != nullptr && context_->allocator != nullptr) { + context_->allocator->Free(tile0_data_); + } else { + free(tile0_data_); + } + } + if (tile1_data_ != nullptr) { + if (context_ != nullptr && context_->allocator != nullptr) { + context_->allocator->Free(tile1_data_); + } else { + free(tile1_data_); + } + } + MS_LOG(ERROR) << "malloc memroy fail!"; + return RET_ERROR; + } + } + return RET_OK; +} + +int DivInt8CPUKernel::DoExecute(int task_id) { + auto input0_data_ = static_cast(in_tensors_.at(0)->Data()); + auto input1_data_ = static_cast(in_tensors_.at(1)->Data()); + auto output_data_ = static_cast(out_tensors_.at(0)->Data()); + auto element_num = out_tensors_[0]->ElementsNum(); + + MS_ASSERT(op_parameter_->thread_num_ != 0); + int stride = UP_DIV(element_num, op_parameter_->thread_num_); + int count = MSMIN(stride, element_num - stride * task_id); + + auto ret = RET_OK; + if (broadcast_) { + ret = DivInt8(tile0_data_ + task_id * count, tile1_data_ + task_id * count, output_data_ + task_id * count, count, + ¶m_); + } else { + ret = DivInt8(input0_data_ + task_id * count, input1_data_ + task_id * count, output_data_ + task_id * count, count, + ¶m_); + } + + if (ret != RET_OK) { + MS_LOG(ERROR) << "Divint8 function error error_code[" << ret << "]"; + } + return ret; +} + +int DivInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto div_kernel = reinterpret_cast(cdata); + auto ret = div_kernel->DoExecute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DivInt8 DoExecute error task_id[" << task_id << "] error_code[" << ret << "]"; + } + return ret; +} + +int DivInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + + if (broadcast_) { + ArithmeticParameter tile_para = {0}; + tile_para.ndim_ = out_tensors_.at(0)->shape().size(); + for (size_t i = 0; i < tile_para.ndim_; i++) { + tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); + tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i); + tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); + } + TileDimensionsUint8(static_cast(in_tensors_.at(0)->Data()), + static_cast(in_tensors_.at(1)->Data()), reinterpret_cast(tile0_data_), + reinterpret_cast(tile1_data_), &tile_para); + } + ret = LiteBackendParallelLaunch(DivInt8Run, this, op_parameter_->thread_num_); + + if (ret != RET_OK) { + MS_LOG(ERROR) << "DivInt8Run function error error_code[" << ret << "]"; + } + return ret; +} + +kernel::LiteKernel *CpuDivInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Div); + auto *kernel = new (std::nothrow) DivInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Div, CpuDivInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..2385b87d2109ad1324537b1841517fdf600ee4dd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/nnacl/int8/div_int8.h" +#include "src/runtime/runtime_api.h" + +namespace mindspore::kernel { +class DivInt8CPUKernel : public LiteKernel { + public: + explicit DivInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~DivInt8CPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int task_id); + + private: + DivQuantArg param_; + int8_t *tile0_data_ = nullptr; + int8_t *tile1_data_ = nullptr; + bool broadcast_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DIV_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/div_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/div_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..72e3705960a4475d14fa74ffb369de83ff37e803 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/div_int8.c @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/div_int8.h" +#include "nnacl/quantization/fixed_point.h" +#include "nnacl/errorcode.h" + +int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para) { + int index = 0; + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + if (input1_val == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + + int recip_shift; + const int32_t input1_inv = (input1_val > 0) ? ComputerReciproal(input1_val, 31, &recip_shift) + : -ComputerReciproal(-input1_val, 31, &recip_shift); + const int leading_bits = CountLeadingSignBits(input0_val); + const int32_t raw_data = + SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); + const int total_shift = para->output_shift_ - recip_shift - leading_bits; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + + para->out_args_.zp_; + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return NNACL_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/div_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/div_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..5ebbaa3e43ead7b77b870b9916b48b8c9fc11490 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/div_int8.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/quantization/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, DivQuantArg *para); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_DIV_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h index 0c444b85a3bec48a897bdd99a512aed913bb66ed..67b453fd3c79d778ebb3f45d95b2f803d38c321c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/fixed_point.h @@ -64,6 +64,114 @@ inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int3 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); } +inline int FractionsBits(int kIntegerBits) { + int totalBits = 8 * sizeof(int32_t) - 1; + return totalBits - kIntegerBits; +} + +inline int FixedPoint_One(int kIntegerBits, int kFractionsBits) { + return (kIntegerBits == 0 ? INT32_MAX : ((1) << (uint32_t)(kIntegerBits == 0 ? 0 : kFractionsBits))); +} + +inline int RoundingHalfSum(int a, int b) { + int64_t a64 = a; + int64_t b64 = b; + int64_t sum = a64 + b64; + int64_t sign = sum > 0 ? 1 : -1; + return (int32_t)((sum + sign) / 2); +} + +inline int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } + +inline int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } + +inline int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } + +inline int32_t BitNot(int32_t a) { return ~(uint32_t)a; } + +inline int SelectUsingMask(int mask, int bound, int val) { + return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); +} + +inline int32_t MaskNonZero(int32_t a) { + int32_t zreo = 0; + return a ? BitNot(zreo) : zreo; +} + +inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) { + int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0); + if (ExponentSign == 0) { + return x; + } else if (ExponentSign == 1) { + const int min = INT32_MIN; + const int max = INT32_MAX; + const int thresold = ((1 << (uint32_t)(31 - Exponent)) - 1); + const int postive_mask = MaskNonZero(x > thresold); + const int negative_mask = MaskNonZero(x < -thresold); + int result = x << Exponent; + result = SelectUsingMask(postive_mask, max, result); + result = SelectUsingMask(negative_mask, min, result); + return result; + } else if (ExponentSign == -1) { + return RoundingDivideByPOT(x, -Exponent); + } else { + return 0; + } +} + +inline int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) { + int kExponent = kIntegerBitsSrc - kIntegerBitsDst; + int result = SaturatingRoundingMultiplyByPOT(x, kExponent); + return result; +} + +static inline int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) { + int one = FixedPoint_One(0, FractionsBits(0)); + int half_denominator = RoundingHalfSum(a, one); + const int constant_48_over_17 = 1515870810; + const int constant_neg_32_over_17 = -1010580540; + int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_denominator, constant_neg_32_over_17); + for (int i = 0; i < 3; i++) { + int half_denominator_times_x = SaturatingRoundingDoublingHighMul(half_denominator, x); + int one_minus_half_denominator_times_x = FixedPoint_One(2, FractionsBits(2)) - half_denominator_times_x; + x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_denominator_times_x), 2 + 2, 2); + } + return Rescale(x, 2 - 1, 0); +} + +inline int CountLeadingZeroBits(uint32_t x) { +#if defined(__GUNC__) + return x ? __builtin_clz(x) : 8 * sizeof(uint32_t); +#else + if (x == 0) { + return 8 * sizeof(uint32_t); + } + const int32_t leading_positive = (int32_t)(1) << (8 * sizeof(uint32_t) - 1); + int leading_zeros = 0; + while (x < leading_positive) { + x <<= 1; + leading_zeros++; + } + return leading_zeros; +#endif +} + +inline int CountLeadingSignBits(int32_t x) { +#if defined(__GUNC__) && !defined(__clang__) + return x ? __builtin_clrsb(x) : 8 * sizeof(int32_t); +#else + return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0; +#endif +} + +static inline int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) { + int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x); + *recip_shift = x_digits - leading_zreos_plus_one; + const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31)); + const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one); + return shifted_scaled; +} + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h index 7c0d3f4bf96b7f2eaf033c2287b7502a4a0bd32b..f6a561a8db2bfcc00c409337f6e2ff4055af1fd5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h @@ -197,6 +197,16 @@ typedef struct ArithmeticQuantArg { QuantArg in1_args_; QuantArg out_args_; } ArithmeticQuantArg; + +typedef struct DivQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int output_shift_; +} DivQuantArg; #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..48206673fc5b68adbe85df5295c33b697acb47c4 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/include/context.h" + +namespace mindspore { +class TestDivInt8 : public mindspore::CommonTest { + public: + TestDivInt8() {} +}; + +TEST_F(TestDivInt8, DivInt8) { + lite::tensor::Tensor in_tensor0(kNumberTypeInt8, {1, 1, 2, 5}); + lite::tensor::Tensor in_tensor1(kNumberTypeInt8, {1, 1, 2, 5}); + lite::tensor::Tensor out_tensor(kNumberTypeInt8, {1, 1, 2, 5}); + + int8_t input_data0[] = {105, 35, -27, 0, -63, 99, 16, 45, 67, -49}; + int8_t input_data1[] = {126, -38, -115, 106, -98, 119, 103, 81, -114, 68}; + int8_t output_data[10] = {0}; + in_tensor0.SetData(input_data0); + in_tensor1.SetData(input_data1); + out_tensor.SetData(output_data); + + const lite::tensor::QuantArg quant_in0 = {0.00784314f, 0}; // -1.0--1.0 -> 0--255 + const lite::tensor::QuantArg quant_in1 = {0.00784314f, 0}; + const lite::tensor::QuantArg quant_out = {0.00784314f, 0}; + in_tensor0.AddQuantParam(quant_in0); + in_tensor1.AddQuantParam(quant_in1); + out_tensor.AddQuantParam(quant_out); + + std::vector inputs = {&in_tensor0, &in_tensor1}; + std::vector outputs = {&out_tensor}; + + OpParameter parameter = {}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Div}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto ctx = std::make_shared(); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect0[10] = {106, -117, 30, 0, 82, 106, 20, 71, -75, -92}; + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(output_data[i], expect0[i]); + } + + in_tensor0.SetData(nullptr); + in_tensor1.SetData(nullptr); + out_tensor.SetData(nullptr); +} +} // namespace mindspore