From b6b18e477acea4a4f55c80492383c009d50565e0 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Mon, 24 Aug 2020 11:25:12 +0800 Subject: [PATCH] softmax activation fp16 --- mindspore/lite/nnacl/fp16/activation_fp16.c | 98 +++++++++++ mindspore/lite/nnacl/fp16/activation_fp16.h | 44 +++++ mindspore/lite/nnacl/fp16/common_func.c | 61 ------- mindspore/lite/nnacl/fp16/common_func.h | 2 - mindspore/lite/nnacl/fp16/softmax_fp16.c | 67 ++++++++ mindspore/lite/nnacl/fp16/softmax_fp16.h | 33 ++++ .../kernel/arm/fp16/activation_fp16.cc | 156 ++++++++++++++++++ .../runtime/kernel/arm/fp16/activation_fp16.h | 52 ++++++ .../runtime/kernel/arm/fp16/softmax_fp16.cc | 156 ++++++++++++++++++ .../runtime/kernel/arm/fp16/softmax_fp16.h | 47 ++++++ 10 files changed, 653 insertions(+), 63 deletions(-) create mode 100644 mindspore/lite/nnacl/fp16/activation_fp16.c create mode 100644 mindspore/lite/nnacl/fp16/activation_fp16.h delete mode 100644 mindspore/lite/nnacl/fp16/common_func.c create mode 100644 mindspore/lite/nnacl/fp16/softmax_fp16.c create mode 100644 mindspore/lite/nnacl/fp16/softmax_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.h diff --git a/mindspore/lite/nnacl/fp16/activation_fp16.c b/mindspore/lite/nnacl/fp16/activation_fp16.c new file mode 100644 index 000000000..ff2b34767 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/activation_fp16.c @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/activation_fp16.h" +#include "nnacl/errorcode.h" + +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { + int eight_block = UP_DIV(ele_num, C8NUM); + int i; + for (i = 0; i < eight_block - 1; i++) { + int index = i * C8NUM; +#ifdef ENABLE_NEON + float16x8_t relu_src = vld1q_f16(src + index); + float16x8_t zero_src = vdupq_n_f16(0); + relu_src = vmaxq_f16(relu_src, zero_src); + vst1q_f16(dst + index, relu_src); +#else + int j; + for (j = 0; j < C8NUM; j++) { + dst[index + j] = src[index + j] < 0 ? 0 : src[index + j]; + } +#endif + } + for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { + dst[j] = src[j] < 0 ? 0 : src[j]; + } + return NNACL_OK; +} + +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { + int eight_block = UP_DIV(ele_num, C8NUM); + int i; + for (i = 0; i < eight_block - 1; i++) { + int index = i * C8NUM; +#ifdef ENABLE_NEON + float16x8_t relu6_data = vld1q_f16(data + index); + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t six_data = vdupq_n_f16(6); + relu6_data = vmaxq_f16(relu6_data, zero_data); + relu6_data = vminq_f16(relu6_data, six_data); + vst1q_f16(dst + index, relu6_data); +#else + int j; + for (j = 0; j < C8NUM; ++j) { + dst[index + j] = data[index + j] < 0 ? 0 : data[index + j]; + dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j]; + } +#endif + } + for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { + dst[j] = data[j] < 0 ? 0 : data[j]; + dst[j] = dst[j] > 6 ? 6 : dst[j]; + } + return NNACL_OK; +} + +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { + for (int i = 0; i < ele_num; ++i) { + dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha); + } + return NNACL_OK; +} + +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { + for (int i = 0; i < ele_num; ++i) { + dst[i] = (float16_t)1.0f / (float16_t)(1.0f + exp(-src[i])); + } + return NNACL_OK; +} + +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { + for (int i = 0; i < ele_num; ++i) { + dst[i] = (float16_t)1.0f - (float16_t)2.0f / (float16_t)(exp(2 * src[i]) + 1); + } + return NNACL_OK; +} + +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { + for (int i = 0; i < ele_num; ++i) { + float16_t in = src[i]; + float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6); + dst[i] = in * relu6 / (float16_t)6.0f; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp16/activation_fp16.h b/mindspore/lite/nnacl/fp16/activation_fp16.h new file mode 100644 index 000000000..eea4b489f --- /dev/null +++ b/mindspore/lite/nnacl/fp16/activation_fp16.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_ +#define MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl/op_base.h" +#include "nnacl/quantization/fixed_point.h" + +typedef struct ActivationParameter { + OpParameter op_parameter_; + int type_; + float alpha_; +} ActivationParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num); +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num); +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha); +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num); +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_ diff --git a/mindspore/lite/nnacl/fp16/common_func.c b/mindspore/lite/nnacl/fp16/common_func.c deleted file mode 100644 index 84ddcd8e4..000000000 --- a/mindspore/lite/nnacl/fp16/common_func.c +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "nnacl/fp16/common_func.h" - -void ReluFp16(float16_t *data, float16_t *dst, int ele_num) { - int eight_block = UP_DIV(ele_num, C8NUM); - for (int i = 0; i < eight_block - 1; i++) { - int index = i * C8NUM; -#ifdef ENABLE_NEON - float16x8_t relu_data = vld1q_f16(data + index); - float16x8_t zero_data = vdupq_n_f16(0); - relu_data = vmaxq_f16(relu_data, zero_data); - vst1q_f16(dst + index, relu_data); -#else - data[index] = data[index] < 0 ? 0 : data[index]; - data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; - data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; - data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; -#endif - } - for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { - data[j] = data[j] < 0 ? 0 : data[j]; - } -} - -void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) { - int eight_block = UP_DIV(ele_num, C8NUM); - for (int i = 0; i < eight_block - 1; i++) { - int index = i * C8NUM; -#ifdef ENABLE_NEON - float16x8_t relu6_data = vld1q_f16(data + index); - float16x8_t zero_data = vdupq_n_f16(0); - float16x8_t six_data = vdupq_n_f16(6); - relu6_data = vmaxq_f16(relu6_data, zero_data); - relu6_data = vminq_f16(relu6_data, six_data); - vst1q_f16(dst + index, relu6_data); -#else - for (int j = 0; j < C8NUM; ++j) { - data[index + j] = data[index + j] < 0 ? 0 : data[index + j]; - data[index + j] = data[index + j] > 6 ? 6 : data[index + j]; - } -#endif - } - for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) { - data[j] = data[j] < 0 ? 0 : data[j]; - data[j] = data[j] > 6 ? 6 : data[j]; - } -} diff --git a/mindspore/lite/nnacl/fp16/common_func.h b/mindspore/lite/nnacl/fp16/common_func.h index 2faaec4bb..54e356c7b 100644 --- a/mindspore/lite/nnacl/fp16/common_func.h +++ b/mindspore/lite/nnacl/fp16/common_func.h @@ -41,8 +41,6 @@ void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *w size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); #endif -void ReluFp16(float16_t *data, float16_t *dst, int ele_num); -void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/fp16/softmax_fp16.c b/mindspore/lite/nnacl/fp16/softmax_fp16.c new file mode 100644 index 000000000..b0df45db6 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/softmax_fp16.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/softmax_fp16.h" +#include +#include + +// output = exp(input) / reduce_sum(exp(input), axis) +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter) { + int32_t axis = parameter->axis_; + int n_dim = parameter->n_dim_; + int ele_size = parameter->element_size_; + int *input_shape = parameter->input_shape_; + + float16_t max_data = input_ptr[0]; + for (int i = 0; i < ele_size; i++) { + max_data = max_data > input_ptr[i] ? max_data : input_ptr[i]; + } + + for (int i = 0; i < ele_size; i++) { + output_ptr[i] = exp(input_ptr[i] - max_data); + } + int inner_size = 1, outter_size = 1; + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; + } + } + } + + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore/lite/nnacl/fp16/softmax_fp16.h b/mindspore/lite/nnacl/fp16/softmax_fp16.h new file mode 100644 index 000000000..7e4127fe0 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/softmax_fp16.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_ +#define MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/softmax_parameter.h" +#ifdef ENABLE_NEON +#include +#endif +#ifdef __cplusplus +extern "C" { +#endif +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc new file mode 100644 index 000000000..b2fbf6e81 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/activation_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" +#include "nnacl/fp16/cast_fp16.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_HSWISH; +using mindspore::schema::ActivationType_LEAKY_RELU; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::PrimitiveType_Activation; + +namespace mindspore::kernel { +int ActivationFp16CPUKernel::Init() { return RET_OK; } + +int ActivationFp16CPUKernel::ReSize() { return RET_OK; } + +int ActivationFp16CPUKernel::MallocTmpBuffer() { + fp16_input_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); + if (fp16_input_ == nullptr) { + MS_LOG(ERROR) << "malloc data failed"; + return RET_ERROR; + } + fp16_output_ = MallocOutputFp16(out_tensors_.at(0), context_); + if (fp16_output_ == nullptr) { + MS_LOG(ERROR) << "malloc data failed"; + return RET_ERROR; + } + return RET_OK; +} + +void ActivationFp16CPUKernel::FreeTmpBuffer() { + if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + if (fp16_input_ != nullptr) { + context_->allocator->Free(fp16_input_); + fp16_input_ = nullptr; + } + } + if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + if (fp16_output_ != nullptr) { + context_->allocator->Free(fp16_output_); + fp16_output_ = nullptr; + } + } +} + +int ActivationFp16CPUKernel::DoActivation(int task_id) { + auto length = in_tensors_.at(0)->ElementsNum(); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + int error_code; + if (type_ == schema::ActivationType_RELU) { + error_code = ReluFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); + } else if (type_ == schema::ActivationType_RELU6) { + error_code = Relu6Fp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); + } else if (type_ == schema::ActivationType_LEAKY_RELU) { + error_code = LReluFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count, alpha_); + } else if (type_ == schema::ActivationType_SIGMOID) { + error_code = SigmoidFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); + } else if (type_ == schema::ActivationType_TANH) { + error_code = TanhFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); + } else if (type_ == schema::ActivationType_HSWISH) { + error_code = HSwishFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); + } else { + MS_LOG(ERROR) << "Activation fp16 not support type: " << type_; + return RET_ERROR; + } + return error_code; +} + +int ActivationRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto activation_kernel = reinterpret_cast(cdata); + auto error_code = activation_kernel->DoActivation(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ActivationRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ActivationFp16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } + + ret = MallocTmpBuffer(); + if (ret != RET_OK) { + FreeTmpBuffer(); + return ret; + } + + int error_code = LiteBackendParallelLaunch(ActivationRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; + FreeTmpBuffer(); + return RET_ERROR; + } + + auto out_tensor = out_tensors_.at(0); + if (out_tensor->data_type() == kNumberTypeFloat32) { + Float16ToFloat32(fp16_output_, reinterpret_cast(out_tensor->Data()), out_tensor->ElementsNum()); + } + FreeTmpBuffer(); + return RET_OK; +} + +kernel::LiteKernel *CpuActivationFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Activation); + auto *kernel = new (std::nothrow) ActivationFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Activation, CpuActivationFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.h new file mode 100644 index 000000000..8cdfe18ef --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/fp16/activation_fp16.h" + +namespace mindspore::kernel { +class ActivationFp16CPUKernel : public LiteKernel { + public: + ActivationFp16CPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { + type_ = (reinterpret_cast(param))->type_; + alpha_ = (float16_t)((reinterpret_cast(param))->alpha_); + } + ~ActivationFp16CPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoActivation(int task_id); + int MallocTmpBuffer(); + void FreeTmpBuffer(); + + private: + int thread_count_; + int type_; + float16_t alpha_; + float16_t *fp16_input_ = nullptr; + float16_t *fp16_output_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.cc new file mode 100644 index 000000000..edfe40b32 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/runtime/kernel/arm/fp16/softmax_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" +#include "nnacl/fp16/softmax_fp16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SoftMax; + +namespace mindspore::kernel { +int SoftmaxFp16CPUKernel::Init() { + auto ret = SoftmaxBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int SoftmaxFp16CPUKernel::ReSize() { + return SoftmaxBaseCPUKernel::ReSize(); +} + +int SoftmaxFp16CPUKernel::MallocTmpBuffer() { + auto n_dim = softmax_param_->n_dim_; + auto axis = softmax_param_->axis_; + if (axis == -1) { + softmax_param_->axis_ += n_dim; + axis = softmax_param_->axis_; + } + auto in_shape = in_tensors_.front()->shape(); + int out_plane_size = 1; + for (int i = 0; i < axis; ++i) { + out_plane_size *= in_shape[i]; + } + int in_plane_size = 1; + for (int i = axis + 1; i < n_dim; i++) { + in_plane_size *= in_shape[i]; + } + + sum_data_ = + reinterpret_cast(context_->allocator->Malloc(out_plane_size * in_plane_size * sizeof(float16_t))); + if (sum_data_ == nullptr) { + MS_LOG(ERROR) << "malloc data for softmax fail!"; + return RET_ERROR; + } + memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float16_t)); + + input_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(kInputIndex), context_); + if (input_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data failed"; + return RET_ERROR; + } + output_fp16_ = MallocOutputFp16(out_tensors_.at(kOutputIndex), context_); + if (output_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data failed"; + return RET_ERROR; + } + return RET_OK; +} + +void SoftmaxFp16CPUKernel::FreeTmpBuffer() { + if (sum_data_ != nullptr) { + context_->allocator->Free(sum_data_); + sum_data_ = nullptr; + } + if (in_tensors_.at(kInputIndex)->data_type() == kNumberTypeFloat32) { + if (input_fp16_ != nullptr) { + context_->allocator->Free(input_fp16_); + input_fp16_ = nullptr; + } + } + + if (out_tensors_.at(kOutputIndex)->data_type() == kNumberTypeFloat32) { + if (output_fp16_ != nullptr) { + context_->allocator->Free(output_fp16_); + output_fp16_ = nullptr; + } + } +} + +int SoftmaxFp16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << ret; + return RET_ERROR; + } + ret = MallocTmpBuffer(); + if (ret != RET_OK) { + FreeTmpBuffer(); + MS_LOG(ERROR) << "MallocTmpBuffer failed"; + return RET_ERROR; + } + SoftmaxFp16(input_fp16_, output_fp16_, sum_data_, softmax_param_); + auto out_tensor = out_tensors_.at(kOutputIndex); + if (out_tensor->data_type() == kNumberTypeFloat32) { + Float16ToFloat32(output_fp16_, reinterpret_cast(out_tensor->Data()), out_tensor->ElementsNum()); + } + FreeTmpBuffer(); + return RET_OK; +} + +kernel::LiteKernel *CpuSoftmaxFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); + auto *kernel = new (std::nothrow) SoftmaxFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SoftmaxFp16CPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SoftMax, CpuSoftmaxFp16KernelCreator) + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.h new file mode 100644 index 000000000..669a595c2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/softmax_fp16.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/softmax_base.h" + +namespace mindspore::kernel { +class SoftmaxFp16CPUKernel : public SoftmaxBaseCPUKernel { + public: + SoftmaxFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const mindspore::lite::PrimitiveC *primitive) + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr) {} + ~SoftmaxFp16CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int MallocTmpBuffer(); + void FreeTmpBuffer(); + + private: + float16_t *sum_data_ = nullptr; + float16_t *input_fp16_ = nullptr; + float16_t *output_fp16_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_ -- GitLab