提交 89b4b039 编写于 作者: L liuqi

Fix pooling test bug and some typo in conv2d test.

上级 328839bd
...@@ -22,6 +22,13 @@ extern void Conv2dNeonK3x3S1(const float *input, ...@@ -22,6 +22,13 @@ extern void Conv2dNeonK3x3S1(const float *input,
float *output, float *output,
const index_t *output_shape); const index_t *output_shape);
extern void Conv2dNeonK3x3S2(const float *input,
const index_t *input_shape,
const float *filter,
const float *bias,
float *output,
const index_t *output_shape);
extern void Conv2dNeonK5x5S1(const float *input, extern void Conv2dNeonK5x5S1(const float *input,
const index_t *input_shape, const index_t *input_shape,
const float *filter, const float *filter,
...@@ -30,27 +37,25 @@ extern void Conv2dNeonK5x5S1(const float *input, ...@@ -30,27 +37,25 @@ extern void Conv2dNeonK5x5S1(const float *input,
const index_t *output_shape); const index_t *output_shape);
template <> template <>
void Conv2dFunctor<DeviceType::NEON, void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
float>:: const index_t *input_shape,
operator()(const float *input, // NCHW const float *filter,
const index_t *input_shape, const index_t *filter_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w const float *bias,
const index_t *filter_shape, float *output,
const float *bias, // c_out const index_t *output_shape) {
float *output, // NCHW
const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)( typedef void (*Conv2dNeonFunction)(
const float *input, // NCHW const float *input,
const index_t *input_shape, const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w const float *filter,
const float *bias, // c_out const float *bias,
float *output, // NCHW float *output,
const index_t *output_shape); const index_t *output_shape);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = { static const Conv2dNeonFunction selector[5][2] = {
{Conv2dNeonK1x1S1, nullptr}, {Conv2dNeonK1x1S1, nullptr},
{nullptr, nullptr}, {nullptr, nullptr},
{Conv2dNeonK3x3S1, nullptr}, {Conv2dNeonK3x3S1, Conv2dNeonK3x3S2},
{nullptr, nullptr}, {nullptr, nullptr},
{Conv2dNeonK5x5S1, nullptr}}; {Conv2dNeonK5x5S1, nullptr}};
// not implement yet // not implement yet
...@@ -59,7 +64,10 @@ operator()(const float *input, // NCHW ...@@ -59,7 +64,10 @@ operator()(const float *input, // NCHW
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) { selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion"; LOG(WARNING) << "NEON conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)( Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, input_shape, filter, filter_shape, bias, output, output_shape); input, input_shape, filter, filter_shape, bias, output, output_shape);
return; return;
......
...@@ -155,9 +155,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -155,9 +155,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3}, {6, 8, 9, 16, 18, 19}); auto expected = CreateTensor<float>({1, 1, 2, 3}, {6, 8, 9, 16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
TEST_F(PoolingOpTest, MAX_k3x3s2x2) { TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
...@@ -183,7 +183,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { ...@@ -183,7 +183,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3}, {11, 13, 14, 16, 18, 19}); auto expected = CreateTensor<float>({1, 1, 2, 3}, {11, 13, 14, 16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册