提交 10015ad9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4371 Add fp16 pooling

Merge pull request !4371 from fuzhiye/tmp
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/cast_fp16.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/nnacl/op_base.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Cast;
namespace mindspore::kernel {
namespace {
int CastRun(int thread_id, LiteParallelGroupEnv *penv, void *cdata) {
if (cdata == nullptr) {
MS_LOG(ERROR) << "input cdata is nullptr!";
return RET_ERROR;
}
return reinterpret_cast<CastFp16CPUKernel *>(cdata)->DoCast(thread_id);
}
} // namespace
int CastFp16CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int CastFp16CPUKernel::ReSize() {
data_num_ = in_tensors_[0]->ElementsNum();
if (data_num_ == 0) {
return RET_OK;
}
op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, data_num_);
stride_ = UP_DIV(data_num_, op_parameter_->thread_num_);
return RET_OK;
}
int CastFp16CPUKernel::DoCast(int thread_id) {
auto input = in_tensors_.at(0);
int data_num = MSMIN(stride_, data_num_ - thread_id * stride_);
if (data_num <= 0) {
return RET_OK;
}
auto offset = thread_id * stride_;
auto output_data = out_tensors_.at(0)->Data();
switch (input->data_type()) {
case kNumberTypeFloat32:
Float32ToFloat16(reinterpret_cast<float *>(input->Data()) + offset,
reinterpret_cast<float16_t *>(output_data) + offset, data_num);
break;
case kNumberTypeFloat16:
Float16ToFloat32(reinterpret_cast<float16_t *>(input->Data()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupport input data type " << input->data_type();
return RET_ERROR;
}
return RET_OK;
}
int CastFp16CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
if (data_num_ == 0) {
return RET_OK;
}
return LiteBackendParallelLaunch(CastRun, this, op_parameter_->thread_num_);
}
kernel::LiteKernel *CpuCastFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "Input context is nullptr!";
return nullptr;
}
if (ctx->thread_num_ == 0) {
MS_LOG(ERROR) << "context thread num is 0!";
return nullptr;
}
auto *kernel = new (std::nothrow) CastFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new CastFp16CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, CpuCastFp16KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CAST_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CAST_FP16_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class CastFp16CPUKernel : public LiteKernel {
public:
CastFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~CastFp16CPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
int DoCast(int thread_id);
private:
uint32_t stride_;
uint32_t data_num_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CAST_FP16_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/pooling_fp16.h"
#include <vector>
#include "src/runtime/kernel/arm/nnacl/fp16/pooling_fp16.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/op_base.h"
#include "nnacl/fp16/cast_fp16.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Pooling;
namespace mindspore::kernel {
int PoolingFp16CPUKernel::InitBuffer() {
int in_batch = pooling_param_->input_batch_;
int in_h = pooling_param_->input_h_;
int in_w = pooling_param_->input_w_;
int in_channel = pooling_param_->input_channel_;
fp16_input_ = reinterpret_cast<float16_t *>(malloc(in_batch * in_h * in_w * in_channel * sizeof(float16_t)));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;
}
int out_batch = pooling_param_->output_batch_;
int out_h = pooling_param_->output_h_;
int out_w = pooling_param_->output_w_;
int out_channel = pooling_param_->output_channel_;
fp16_output_ = reinterpret_cast<float16_t *>(malloc(out_batch * out_h * out_w * out_channel * sizeof(float16_t)));
if (fp16_output_ == nullptr) {
MS_LOG(ERROR) << "fp16_out malloc failed.";
return RET_ERROR;
}
return RET_OK;
}
int PoolingFp16CPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
auto ret = PoolingBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "PoolingBase Init failed.";
return ret;
}
ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Buffer failed.";
return ret;
}
return RET_OK;
}
int PoolingFp16CPUKernel::ReSize() {
auto ret = Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Pooling resize init failed.";
return RET_ERROR;
}
return RET_OK;
}
int PoolingFp16CPUKernel::RunImpl(int task_id) {
if (pooling_param_->max_pooling_) {
MaxPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id);
} else {
AvgPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id);
}
return RET_OK;
}
int PoolingFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto pooling = reinterpret_cast<PoolingFp16CPUKernel *>(cdata);
auto error_code = pooling->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Pooling Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int PoolingFp16CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto ele_num = in_tensors_.front()->ElementsNum();
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->Data());
Float32ToFloat16(input_ptr, fp16_input_, ele_num);
int error_code = LiteBackendParallelLaunch(PoolingFp16Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]";
return RET_ERROR;
}
auto out_ele_num = out_tensors_.front()->ElementsNum();
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
Float16ToFloat32(fp16_output_, output_ptr, out_ele_num);
return RET_OK;
}
kernel::LiteKernel *CpuPoolingFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Pooling);
auto *kernel = new (std::nothrow) PoolingFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PoolingCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Pooling, CpuPoolingFp16KernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_POOLING_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_POOLING_FP16_H_
#include <arm_neon.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/pooling_base.h"
namespace mindspore::kernel {
class PoolingFp16CPUKernel : public PoolingBaseCPUKernel {
public:
PoolingFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const lite::Primitive *primitive)
: PoolingBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~PoolingFp16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_output_ != nullptr) {
free(fp16_output_);
}
};
int Init() override;
int InitBuffer();
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
private:
float16_t *fp16_input_ = nullptr;
float16_t *fp16_output_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_POOLING_FP16_H_
......@@ -21,15 +21,8 @@
#include "src/runtime/kernel/arm/base/pooling_base.h"
#include "src/lite_kernel.h"
#include "ir/anf.h"
#include "include/context.h"
namespace mindspore::kernel {
using mindspore::lite::Context;
using mindspore::schema::PadMode;
using mindspore::schema::PoolMode;
using mindspore::schema::QuantType;
using mindspore::schema::RoundMode;
class PoolingCPUKernel : public PoolingBaseCPUKernel {
public:
PoolingCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/cast_fp16.h"
void Float32ToFloat16(const float *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float16_t)input[i];
}
}
void Float16ToFloat32(const float16_t *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float)input[i];
}
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_
#include <arm_neon.h>
#include "nnacl/op_base.h"
#include "nnacl/fp32/cast.h"
void Float32ToFloat16(const float *input, float16_t *output, int number);
void Float16ToFloat32(const float16_t *input, float *output, int number);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_FP16_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/pooling_fp16.h"
#include <float.h>
void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int c8 = channel / C8NUM;
int c8_res = channel % C8NUM;
int c4 = c8_res / C4NUM;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
// input channel is equal to output channel
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c8; j++) {
int in_channel_offset = in_batch_offset + j * C8NUM;
int out_channel_offset = out_plane_offset + j * C8NUM;
#ifdef ENABLE_NEON
float16x8_t tmp_avg = vdupq_n_f16(0);
#else
float16_t tmp_avg[8]{0};
#endif
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_avg = vaddq_f16(tmp_avg, vld1q_f16(input_ptr + in_offset));
#else
for (int t = 0; t < 8; t++) {
tmp_avg[t] += *(input_ptr + in_offset + t);
}
#endif
++real_count;
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
vst1q_f16(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f16(real_count));
#else
for (int t = 0; t < C8NUM; ++t) {
*(output_ptr + out_channel_offset + t) = tmp_avg[t] / (float16_t)real_count;
}
#endif
} // c8 loop
int c4_offset = c8 * C8NUM;
for (int l = 0; l < c4; ++l) {
int in_channel_offset = in_batch_offset + c4_offset + l * C4NUM;
int out_channel_offset = out_plane_offset + c4_offset + l * C4NUM;
#ifdef ENABLE_NEON
float16x4_t tmp_avg = vdup_n_f16(0);
#else
float16_t tmp_avg[4]{0};
#endif
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_avg = vadd_f16(tmp_avg, vld1_f16(input_ptr + in_offset));
#else
for (int j = 0; j < C4NUM; ++j) {
tmp_avg[j] += *(input_ptr + in_offset);
}
#endif
++real_count;
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
vst1_f16(output_ptr + out_channel_offset, tmp_avg / vdup_n_f16(real_count));
#else
for (int t = 0; t < C4NUM; ++t) {
*(output_ptr + out_channel_offset + t) = tmp_avg[t] / (float16_t)real_count;
}
#endif
} // c4 loop
int channel_s = c8 * C8NUM + c4 * C4NUM;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
float16_t tmp_avg = 0;
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg += *(input_ptr + in_offset);
++real_count;
}
} // win_w loop
} // win_h loop
*(output_ptr + out_channel_offset) = tmp_avg / (float16_t)real_count;
} // channel_res loop
} // real_cal_num loop
} // out_plane loop
} // out_batch loop
}
void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int c8 = channel / C8NUM;
int c8_res = channel % C8NUM;
int c4 = c8_res / C4NUM;
// input channel is equal to output channel
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c8; j++) {
int in_channel_offset = in_batch_offset + j * C8NUM;
int out_channel_offset = out_plane_offset + j * C8NUM;
#ifdef ENABLE_NEON
float16x8_t tmp_max = vdupq_n_f16(-FLT_MAX);
#else
float16_t tmp_max[8]{-FLT_MAX};
#endif
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_f16(tmp_max, vld1q_f16(input_ptr + in_offset));
#else
for (int k = 0; k < C8NUM; k++) {
tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k));
}
#endif
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
vst1q_f16(output_ptr + out_channel_offset, tmp_max);
#else
for (int l = 0; l < C8NUM; ++l) {
*(output_ptr + out_channel_offset + l) = tmp_max[l];
}
#endif
} // c8 loop
int c4_offset = c8 * C8NUM;
for (int j = 0; j < c4; j++) {
int in_channel_offset = in_batch_offset + c4_offset + j * C4NUM;
int out_channel_offset = out_plane_offset + c4_offset + j * C4NUM;
#ifdef ENABLE_NEON
float16x4_t tmp_max = vdup_n_f16(-FLT_MAX);
#else
float16_t tmp_max[4]{-FLT_MAX};
#endif
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_max = vmax_f16(tmp_max, vld1_f16(input_ptr + in_offset));
#else
for (int k = 0; k < C4NUM; k++) {
tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k));
}
#endif
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
vst1_f16(output_ptr + out_channel_offset, tmp_max);
#else
for (int l = 0; l < C4NUM; ++l) {
*(output_ptr + out_channel_offset + l) = tmp_max[l];
}
#endif
} // c4 loop
int channel_s = c8 * C8NUM + c4 * C4NUM;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
float16_t tmp_max = -FLT_MAX;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_max = fmax(tmp_max, *(input_ptr + in_offset));
}
} // win_w loop
} // win_h loop
*(output_ptr + out_channel_offset) = tmp_max;
} // channel_res loop
} // real_cal_num loop
} // out_plane loop
} // out_batch loop
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_POOLING_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_POOLING_FP16_H_
#include <arm_neon.h>
#include "nnacl/pooling_parameter.h"
void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_POOLING_FP16_H_
......@@ -45,17 +45,3 @@ void Float32ToInt32(const float *input, int32_t *output, int number) {
output[i] = (int32_t)input[i];
}
}
#ifdef ENABLE_FP16
void Float32ToFloat16(const float *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float16_t)input[i];
}
}
void Float16ToFloat32(const float16_t *input, float *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = (float)input[i];
}
}
#endif
......@@ -33,9 +33,5 @@ void Uint8ToInt8(const uint8_t *input, int8_t *output, int number);
void Int8ToUint8(const int8_t *input, uint8_t *output, int number);
void Int32ToFloat32(const int32_t *input, float *output, int number);
void Float32ToInt32(const float *input, int32_t *output, int number);
#ifdef ENABLE_FP16
void Float32ToFloat16(const float *input, float16_t *output, int number);
void Float16ToFloat32(const float16_t *input, float *output, int number);
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CAST_H_
......@@ -21,35 +21,9 @@
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/pooling_parameter.h"
#include "nnacl/quantization/quantize.h"
typedef struct PoolingParameter {
OpParameter op_parameter_;
QuantArg **quant_args_;
bool global_;
bool max_pooling_;
bool avg_pooling_;
bool round_ceil_;
bool round_floor_;
int window_w_;
int window_h_;
int input_w_;
int input_h_;
int input_batch_;
int input_channel_;
int output_w_;
int output_h_;
int output_batch_;
int output_channel_;
int pad_u_;
int pad_d_;
int pad_l_;
int pad_r_;
int stride_w_;
int stride_h_;
int thread_num_;
} PoolingParameter;
void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POOLING_PARAMETER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POOLING_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef struct PoolingParameter {
OpParameter op_parameter_;
QuantArg **quant_args_;
bool global_;
bool max_pooling_;
bool avg_pooling_;
bool round_ceil_;
bool round_floor_;
int window_w_;
int window_h_;
int input_w_;
int input_h_;
int input_batch_;
int input_channel_;
int output_w_;
int output_h_;
int output_batch_;
int output_channel_;
int pad_u_;
int pad_d_;
int pad_l_;
int pad_r_;
int stride_w_;
int stride_h_;
int thread_num_;
} PoolingParameter;
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_POOLING_PARAMETER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册