提交 c3e88184 编写于 作者: L liuqi

Add conv2d neon api.

上级 12c719a1
......@@ -108,6 +108,15 @@ class Conv2dFunctor {
const int* dilations_; // [dilation_h, dilation_w]
};
template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const index_t* filter_shape,
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape);
} // namespace kernels
} // namespace mace
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/neon/conv_2d_neon_base.h"
namespace mace {
namespace kernels {
static inline void ConstructInputWithPadding(const float* input, const index_t* input_shape,
const int* padding,
std::unique_ptr<float>& output,
index_t* output_shape) {
}
template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const index_t* filter_shape,
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape) {
static const bool selector[5][4] = {
{true, false, false, false},
{false, false, false, false},
{true, true, false, false},
{false, false, false, false},
{true, false, false, false},
};
// not implement yet
if (paddings_[0] != paddings_[1] || paddings_[0] > 5 ||
strides_[0] != strides_[1] || strides_[0] > 4 ||
dilations_[0] != 1 || dilations_[1] != 1 ||
!selector[paddings_[0]-1, strides_[0]-1]) {
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input,
input_shape,
filter,
filter_shape,
bias,
output,
output_shape
);
}
std::unique_ptr<float> padded_input;
index_t padded_input_shape[4];
ConstructInputWithPadding(input, input_shape, paddings_, padded_input, padded_input_shape);
Conv2dNeon<paddings_[0], paddings_[1], strides_[0], strides_[1]>(
padded_input.get(),
padded_input_shape,
filter,
bias,
output,
output_shape
);
}
} // namespace kernels
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include "mace/kernels/neon/conv_2d_neon_base.h"
namespace mace {
namespace kernels {
template<>
void Conv2dNeon<3, 3, 1, 1>(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape) {
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_
#define MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_
#include "mace/core/common.h"
namespace mace {
namespace kernels {
template <index_t kernel_h, index_t kernel_w, index_t stride_h, index_t stride_w>
inline void Conv2dNeon(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册