提交 7e9d1442 编写于 作者: L liuqi

Change conv2d neon kernel logic.

上级 23f6c70c
......@@ -4,16 +4,42 @@
#include <arm_neon.h>
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/neon/conv_2d_neon_base.h"
#include "mace/kernels/neon/conv_2d_neon_3x3.h"
namespace mace {
namespace kernels {
static inline void ConstructInputWithPadding(const float* input, const index_t* input_shape,
const int* padding,
std::unique_ptr<float>& output,
index_t* output_shape) {
const int* paddings,
Tensor& output_tensor,
std::vector<index_t>& output_shape) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
output_shape[0] = batch;
output_shape[1] = channels;
output_shape[2] = paddings[0] + height;
output_shape[3] = paddings[1] + width;
index_t output_width = output_shape[3];
int padded_left = paddings[1] / 2;
output_tensor.Resize(output_shape);
float* output_ptr = output_tensor.mutable_data<float>();
memset(output_ptr, 0, output_tensor.size() * sizeof(float));
output_ptr += paddings[0] / 2 * output_width;
for (; batch > 0; --batch) {
for (; channels > 0; --channels) {
for(; height > 0; --height) {
memcpy(output_ptr+padded_left, input, width*sizeof(float));
input += width;
output_ptr += output_width;
}
output_ptr += paddings[0] * output_width;
}
}
}
template<>
......@@ -25,18 +51,39 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N
float* output, // NCHW
const index_t* output_shape) {
static const bool selector[5][4] = {
{true, false, false, false},
{false, false, false, false},
{true, true, false, false},
{false, false, false, false},
{true, false, false, false},
typedef void (*Conv2dNeonFunction)(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape);
static const Conv2dNeonFunction selector[5][2] = {
{
nullptr,
nullptr
},
{
nullptr,
nullptr
},
{
Conv2dNeonK3x3S1,
nullptr
},
{
nullptr,
nullptr
},
{
nullptr,
nullptr
}
};
// not implement yet
if (paddings_[0] != paddings_[1] || paddings_[0] > 5 ||
strides_[0] != strides_[1] || strides_[0] > 4 ||
dilations_[0] != 1 || dilations_[1] != 1 ||
!selector[paddings_[0]-1, strides_[0]-1]) {
selector[paddings_[0]-1][strides_[0]-1] == nullptr) {
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input,
input_shape,
......@@ -47,12 +94,13 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N
output_shape
);
}
std::unique_ptr<float> padded_input;
index_t padded_input_shape[4];
Tensor padded_input;
std::vector<index_t> padded_input_shape(4);
ConstructInputWithPadding(input, input_shape, paddings_, padded_input, padded_input_shape);
Conv2dNeon<paddings_[0], paddings_[1], strides_[0], strides_[1]>(
padded_input.get(),
padded_input_shape,
auto conv2d_neon_func = selector[paddings_[0] - 1][strides_[0] - 1];
conv2d_neon_func(
padded_input.data<float>(),
padded_input_shape.data(),
filter,
bias,
output,
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_
#define MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_
#include <arm_neon.h>
#include "mace/kernels/neon/conv_2d_neon_base.h"
#include "mace/core/common.h"
namespace mace {
namespace kernels {
template<>
void Conv2dNeon<3, 3, 1, 1>(const float* input, // NCHW
void Conv2dNeonK3x3S1(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
......@@ -20,3 +21,5 @@ void Conv2dNeon<3, 3, 1, 1>(const float* input, // NCHW
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_
#define MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_
#include "mace/core/common.h"
namespace mace {
namespace kernels {
template <index_t kernel_h, index_t kernel_w, index_t stride_h, index_t stride_w>
inline void Conv2dNeon(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册