未验证 提交 96dfa087 编写于 作者: R Ray Liu 提交者: GitHub

Merge pull request #1375 from hjchen2/ocr_ctc

Optimize int8/float 5x5 depthwise conv and 2x2 pooling, add aarch64 macros to make compilation no problem
......@@ -64,9 +64,9 @@ void OperatorBase<Dtype>::Run() {
for (const auto key : input_keys) {
auto var_vec_in = inputs_.at(key);
for (int i = 0; i < var_vec_in.size(); ++i) {
auto vari = scope_->FindVar(var_vec_in[i]);
auto vari = this->scope_->FindVar(var_vec_in[i]);
if (vari->IsInitialized()) {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
}
}
......@@ -76,7 +76,7 @@ void OperatorBase<Dtype>::Run() {
for (int i = 0; i < var_vec_out.size(); ++i) {
auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor;
}
}
......@@ -97,10 +97,10 @@ void OperatorBase<GPU_CL>::Run() {
auto vari = scope_->FindVar(var_vec_in[i]);
if (vari->IsInitialized()) {
if (type_ == "feed") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
const CLImage *cl_image = vari->template Get<framework::CLImage>();
if (cl_image) {
DLOG << type_ << " input- " << key << "=" << *cl_image;
}
......@@ -114,12 +114,12 @@ void OperatorBase<GPU_CL>::Run() {
auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) {
if (type_ == "fetch") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) {
DLOG << type_ << " output- " << key << "=" << *tensor;
}
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
const CLImage *cl_image = vari->template Get<framework::CLImage>();
if (cl_image) {
DLOG << type_ << " output- " << key << "=" << *cl_image;
}
......
......@@ -14,6 +14,7 @@
#include "io/api_paddle_mobile.h"
#include <vector>
#include "common/enforce.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......
......@@ -12,19 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains the implementation of inference API with Anakin engine
* embeded, this API can only support Anakin models.
*/
#pragma once
#include <vector>
#include "io/paddle_inference_api.h"
// from paddle_mobile
#include "common/enforce.h"
#include "common/types.h"
#include "io/paddle_inference_api.h"
#include "io/paddle_mobile.h"
namespace paddle_mobile {
......
......@@ -104,6 +104,8 @@ class PaddlePredictor {
// The common configs for all the predictors.
struct Config {
std::string model_dir; // path to the model directory.
std::string prog_file;
std::string param_file;
};
protected:
......@@ -128,9 +130,8 @@ struct PaddleMobileConfig : public PaddlePredictor::Config {
int batch_size = 1;
bool optimize = true;
bool quantification = false;
bool lod_mode = false;
int thread_num = 1;
std::string prog_file;
std::string param_file;
std::string cl_path;
struct PaddleModelMemoryPack memory_pack;
};
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "io/paddle_mobile.h"
#include <utility>
#include "common/common.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_MOBILE_CL
#include <CL/cl.h>
#include "framework/cl/cl_tensor.h"
......@@ -33,7 +36,7 @@ void PaddleMobile<Device, T>::SetThreadNum(int num) {
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -43,7 +46,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(dirname, optimize, quantification), batch_size, optimize,
loddable);
lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -55,7 +58,7 @@ template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -65,7 +68,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(model_path, para_path, optimize, quantification),
batch_size, optimize, loddable);
batch_size, optimize, lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -73,11 +76,26 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
return PMSuccess;
}
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const PaddleMobileConfig &config) {
if (!config.model_dir.empty()) {
return this->Load(config.model_dir, config.optimize, config.quantification,
config.batch_size, config.lod_mode);
} else if (!config.prog_file.empty() && !config.param_file.empty()) {
return this->Load(config.prog_file, config.param_file, config.optimize,
config.quantification, config.batch_size,
config.lod_mode);
} else {
LOG(kLOG_ERROR) << "Failed to load inference model";
return PMNotInitialized;
}
}
template <typename Device, typename T>
bool PaddleMobile<Device, T>::LoadCombinedMemory(
size_t model_len, const uint8_t *model_buf, size_t combined_params_len,
uint8_t *combined_params_buf, bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -88,7 +106,7 @@ bool PaddleMobile<Device, T>::LoadCombinedMemory(
loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len,
combined_params_buf, optimize,
quantification),
batch_size, optimize, loddable);
batch_size, optimize, lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......
......@@ -18,15 +18,12 @@ limitations under the License. */
#include <string>
#include <utility>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#include "common/types.h"
#include "framework/executor.h"
#include "framework/load_ops.h"
#include "framework/loader.h"
#include "framework/tensor.h"
#include "io/paddle_inference_api.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_engine.h"
#endif
......@@ -46,10 +43,12 @@ class PaddleMobile {
PMStatus Load(const std::string &dirname, const bool optimize = false,
const bool quantification = false, const int batch_size = 1,
const bool lod = false);
const bool lod_mode = false);
PMStatus Load(const std::string &model_path, const std::string &para_path,
const bool optimize = false, const bool quantification = false,
const int batch_size = 1, const bool lod = false);
const int batch_size = 1, const bool lod_mode = false);
PMStatus Load(const PaddleMobileConfig &config);
PMStatus Predict(const framework::Tensor &input);
PMStatus Predict(const framework::LoDTensor &input);
......@@ -75,7 +74,7 @@ class PaddleMobile {
size_t combined_params_len,
uint8_t *combined_params_buf, bool optimize = false,
bool quantification = false, int batch_size = 1,
bool loddable = false);
bool lod_mode = false);
void SetThreadNum(int count);
void Clear();
......
......@@ -24,15 +24,26 @@ template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 5;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) {
#ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else if (depth5x5 && param->Strides()[0] < 2 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8;
} else {
#endif // __aarch64__
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
#ifndef __aarch64__
}
#endif // __aarch64__
} else {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
......@@ -47,6 +58,9 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
#ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
......@@ -72,9 +86,14 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8:
DepthwiseConv5x5<int8_t, int32_t>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
......@@ -87,9 +106,14 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
......
......@@ -15,7 +15,8 @@ limitations under the License. */
#ifdef POOL_OP
#include "operators/kernel/pool_kernel.h"
#include "../central-arm-func/pool_arm_func.h"
#include "operators/kernel/central-arm-func/pool_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -28,7 +29,8 @@ template <>
void PoolKernel<CPU, float>::Compute(const PoolParam<CPU> &param) {
PoolCompute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // POOL_OP
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv5x5.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
......@@ -160,6 +161,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
}
}
#ifndef __aarch64__
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
......@@ -180,14 +182,34 @@ inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3S2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
PADDLE_MOBILE_THROW_EXCEPTION(
"Depthwise conv with generic strides has not been implemented.");
GemmConv<Itype, Otype>(param);
}
}
}
template <typename Itype, typename Otype>
inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output();
output->mutable_data<Otype>();
if (strides[0] == 1) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
}
} else {
GemmConv<Itype, Otype>(param);
}
}
#endif // __aarch64__
} // namespace operators
} // namespace paddle_mobile
......
......@@ -59,12 +59,11 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const float *input = input_data + offset;
const float bias = bias_data[j];
float *output = output_data + offset;
int remain = elementwise_num;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = elementwise_num >> 0x4;
remain = elementwise_num & 0xF;
int remain = elementwise_num & 0xF;
float32x4_t rb = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k) {
float32x4_t rb = vdupq_n_f32(bias);
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
......@@ -80,10 +79,46 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
input += 16;
output += 16;
}
#endif
for (int k = 0; k < remain; ++k) {
if (remain >= 8) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
input += 8;
output += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
vst1q_f32(output, r0);
input += 4;
output += 4;
remain -= 4;
}
if (remain > 0) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
switch (remain) {
case 1:
vst1q_lane_f32(output, r0, 0);
break;
case 2:
vst1_f32(output, vget_low_f32(r0));
break;
case 3:
vst1_f32(output, vget_low_f32(r0));
vst1q_lane_f32(output, r0, 2);
break;
}
}
#else
for (int k = 0; k < elementwise_num; ++k) {
output[k] = input[k] + bias;
}
#endif // __ARM_NEON__
}
}
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#pragma once
#include <string>
......@@ -54,8 +55,24 @@ void PoolCompute(const PoolParam<CPU> &param) {
} else {
math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
}
} else {
// Others
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1]) {
if (pooling_type == "max" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling2x2<MAX, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling2x2<MAX, 2>()(*input, paddings, output);
} else {
math::Pooling<MAX>()(*input, ksize, strides, paddings, output);
}
} else if (pooling_type == "avg" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling2x2<AVG, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling2x2<AVG, 2>()(*input, paddings, output);
} else {
math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
}
}
} else {
if (pooling_type == "max") {
......
......@@ -253,7 +253,6 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const float *bias_data = bias->data<float>();
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int h = static_cast<int>(input->dims()[2]);
......@@ -267,6 +266,11 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const int lb = (h - 1) * w;
const int rb = h * w - 1;
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; ++b) {
......@@ -1966,7 +1970,6 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
......@@ -1983,7 +1986,12 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data = bias->data<float>() + c;
const float *bias_data;
float32x4_t biasv;
if (if_bias) {
bias_data = bias->data<float>() + c;
biasv = vld1q_dup_f32(bias_data);
}
float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0];
float w01 = filter_data[1];
......@@ -1994,7 +2002,6 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
float32x4_t biasv = vld1q_dup_f32(bias_data);
for (int i = 0; i < output_height; i += 1) {
for (int m = 0; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m;
......
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv5x5(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// const std::vector<int> &paddings,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv5x5S1(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv5x5S2(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
此差异已折叠。
......@@ -3150,9 +3150,11 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
#ifndef __aarch64__
if (m == 1 && bias == nullptr) {
return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu);
}
#endif // __aarch64__
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
......
......@@ -53,7 +53,7 @@ struct PoolingVal<AVG> {
++count;
return *this;
}
inline float Value() { return (count > 0) ? val / count : 0.f; }
inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; }
};
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
......@@ -67,6 +67,16 @@ inline float32x4_t vPoolInitq_f32<AVG>() {
return vdupq_n_f32(0.f);
}
template <PoolingType P = MAX>
inline float32x2_t vPoolInit_f32() {
return vdup_n_f32(-std::numeric_limits<float>::max());
}
template <>
inline float32x2_t vPoolInit_f32<AVG>() {
return vdup_n_f32(0.f);
}
template <PoolingType P = MAX>
inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) {
return vmaxq_f32(x1, x2);
......@@ -78,6 +88,28 @@ inline float32x4_t vPoolPreq_f32<AVG>(const float32x4_t &x1,
return vaddq_f32(x1, x2);
}
template <PoolingType P = MAX>
inline float32x2_t vPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) {
return vmax_f32(x1, x2);
}
template <>
inline float32x2_t vPoolPre_f32<AVG>(const float32x2_t &x1,
const float32x2_t &x2) {
return vadd_f32(x1, x2);
}
template <PoolingType P = MAX>
inline float32x2_t vpPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) {
return vpmax_f32(x1, x2);
}
template <>
inline float32x2_t vpPoolPre_f32<AVG>(const float32x2_t &x1,
const float32x2_t &x2) {
return vpadd_f32(x1, x2);
}
template <PoolingType P = MAX>
inline float32x4_t vPoolPostq_f32(const float32x4_t &x,
const float32x4_t &post) {
......@@ -89,6 +121,18 @@ inline float32x4_t vPoolPostq_f32<AVG>(const float32x4_t &x,
const float32x4_t &post) {
return vmulq_f32(x, post);
}
template <PoolingType P = MAX>
inline float32x2_t vPoolPost_f32(const float32x2_t &x,
const float32x2_t &post) {
return x;
}
template <>
inline float32x2_t vPoolPost_f32<AVG>(const float32x2_t &x,
const float32x2_t &post) {
return vmul_f32(x, post);
}
#endif // __ARM_NEON__
template <PoolingType P = MAX>
......
此差异已折叠。
此差异已折叠。
......@@ -56,6 +56,9 @@ inline int32x4_t vRoundq_f32(const float32x4_t &x) {
template <>
inline int32x4_t vRoundq_f32<ROUND_NEAREST_AWAY_ZERO>(const float32x4_t &x) {
#if __aarch64__
return vcvtaq_s32_f32(x);
#else
float32x4_t plus = vdupq_n_f32(0.5);
float32x4_t minus = vdupq_n_f32(-0.5);
float32x4_t zero = vdupq_n_f32(0);
......@@ -64,10 +67,14 @@ inline int32x4_t vRoundq_f32<ROUND_NEAREST_AWAY_ZERO>(const float32x4_t &x) {
temp = vaddq_f32(x, temp);
int32x4_t ret = vcvtq_s32_f32(temp);
return ret;
#endif
}
template <>
inline int32x4_t vRoundq_f32<ROUND_NEAREST_TO_EVEN>(const float32x4_t &x) {
#if __aarch64__
return vcvtnq_s32_f32(x);
#else
float32x4_t point5 = vdupq_n_f32(0.5);
int32x4_t one = vdupq_n_s32(1);
int32x4_t zero = vdupq_n_s32(0);
......@@ -90,6 +97,7 @@ inline int32x4_t vRoundq_f32<ROUND_NEAREST_TO_EVEN>(const float32x4_t &x) {
smask = vsubq_s32(smask, one);
rnd = vaddq_s32(rnd, smask);
return rnd;
#endif
}
#endif // __ARM_NEON__
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册