conv_kernel.cpp 5.3 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14

L
liuruilong 已提交
15 16
#ifdef CONV_OP

朔-望's avatar
朔-望 已提交
17
#include "operators/kernel/conv_kernel.h"
18
#include "operators/kernel/central-arm-func/conv_arm_func.h"
朔-望's avatar
朔-望 已提交
19 20

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
21 22
namespace operators {

L
liuruilong 已提交
23
template <>
N
nhzlx 已提交
24
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
25 26
  bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
                 param->Filter()->dims()[2] == 3;
27 28
  bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
                 param->Filter()->dims()[2] == 5;
29 30
  bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
                  param->Input()->dims()[1] == param->Output()->dims()[1];
31 32
  bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
                  param->Input()->dims()[1] == param->Output()->dims()[1];
H
hjchen2 已提交
33
  if (param->Filter()->type() == typeid(int8_t)) {
34
#ifndef __aarch64__
35
    if (depth3x3 && param->Strides()[0] < 3 &&
36
        param->Strides()[0] == param->Strides()[1]) {
H
hjchen2 已提交
37
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
38 39 40
    } else if (depth5x5 && param->Strides()[0] < 2 &&
               param->Strides()[0] == param->Strides()[1]) {
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8;
H
hjchen2 已提交
41
    } else {
42
#endif  // __aarch64__
H
hjchen2 已提交
43
      param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
44
#ifndef __aarch64__
H
hjchen2 已提交
45
    }
46
#endif  // __aarch64__
H
hjchen2 已提交
47
  } else {
48 49 50
    if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
        param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
        param->Paddings()[0] == param->Paddings()[1]) {
H
hjchen2 已提交
51
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
52 53 54 55 56 57 58 59
    } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
               param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
               param->Paddings()[0] == param->Paddings()[1]) {
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
    } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
               param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
               param->Paddings()[0] == param->Paddings()[1]) {
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
60
#ifndef __aarch64__
61 62
    } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
               param->Strides()[0] == 1) {
63
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
64
    } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
H
hjchen2 已提交
65
               param->Dilations()[0] == param->Dilations()[1] &&
66 67
               param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
               param->Output()->dims()[1] >= 16 &&
68 69
               param->Input()->dims()[1] >= 16 &&
               param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
H
hjchen2 已提交
70 71
      param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
      // transform weight
H
hjchen2 已提交
72 73 74
      param->transformed_filter_ = new framework::Tensor;
      operators::math::winograd_transform_weight<8, 3>(
          *param->Filter(), param->transformed_filter_);
H
hjchen2 已提交
75
#endif
H
hjchen2 已提交
76 77 78 79
    } else {
      param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
    }
  }
L
liuruilong 已提交
80 81 82
  return true;
}

朔-望's avatar
朔-望 已提交
83
template <>
L
liuruilong 已提交
84
void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
H
hjchen2 已提交
85 86 87
  switch (param.ExecMode()) {
    case ConvParam<CPU>::EXEC_GEMM_INT8:
      GemmConv<int8_t, int32_t>(param);
H
hjchen2 已提交
88
      break;
89
#ifndef __aarch64__
H
hjchen2 已提交
90 91
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
      DepthwiseConv3x3<int8_t, int32_t>(param);
H
hjchen2 已提交
92
      break;
93 94 95 96
    case ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8:
      DepthwiseConv5x5<int8_t, int32_t>(param);
      break;
#endif  // __aarch64__
H
hjchen2 已提交
97 98
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
      math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
99
                                 nullptr, false, false);
H
hjchen2 已提交
100
      break;
101 102
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
      math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
103
                                   param.Output(), nullptr, false, false);
104 105 106
      break;
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
      math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
107
                                 nullptr, false, false);
H
hjchen2 已提交
108
      break;
109 110 111
#ifndef __aarch64__
    case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
      DepthwiseConv5x5<float, float>(param);
112
      break;
H
hjchen2 已提交
113 114 115
    case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
      WinogradConv3x3<8, 3>(param);
      break;
116
#endif  // __aarch64__
H
hjchen2 已提交
117 118 119 120 121 122 123
    case ConvParam<CPU>::EXEC_GEMM_FLOAT:
      GemmConv<float, float>(param);
      break;
    default:
      PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
                                    param.ExecMode());
  }
朔-望's avatar
朔-望 已提交
124 125
}

126
template class ConvKernel<CPU, float>;
朔-望's avatar
朔-望 已提交
127

朔-望's avatar
朔-望 已提交
128 129
}  // namespace operators
}  // namespace paddle_mobile
L
liuruilong 已提交
130 131

#endif