conv_kernel.cpp 4.9 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14

L
liuruilong 已提交
15 16
#ifdef CONV_OP

朔-望's avatar
朔-望 已提交
17
#include "operators/kernel/conv_kernel.h"
18
#include "operators/kernel/central-arm-func/conv_arm_func.h"
朔-望's avatar
朔-望 已提交
19 20

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
21 22
namespace operators {

L
liuruilong 已提交
23
template <>
N
nhzlx 已提交
24
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
25 26
  bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
                 param->Filter()->dims()[2] == 3;
27 28
  bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
                 param->Filter()->dims()[2] == 5;
29 30
  bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
                  param->Input()->dims()[1] == param->Output()->dims()[1];
31 32
  bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
                  param->Input()->dims()[1] == param->Output()->dims()[1];
H
hjchen2 已提交
33
  if (param->Filter()->type() == typeid(int8_t)) {
34
    if (depth3x3 && param->Strides()[0] < 3 &&
35
        param->Strides()[0] == param->Strides()[1]) {
H
hjchen2 已提交
36 37 38 39
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
    } else {
      param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
    }
H
hjchen2 已提交
40
  } else {
41 42 43
    if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
        param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
        param->Paddings()[0] == param->Paddings()[1]) {
H
hjchen2 已提交
44
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
45 46 47 48 49 50 51 52
    } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
               param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
               param->Paddings()[0] == param->Paddings()[1]) {
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
    } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
               param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
               param->Paddings()[0] == param->Paddings()[1]) {
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
53 54 55
    } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
               param->Strides()[0] == 1) {
      param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5S1_FLOAT;
H
hjchen2 已提交
56
#ifndef __aarch64__
57
    } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
H
hjchen2 已提交
58
               param->Dilations()[0] == param->Dilations()[1] &&
59 60
               param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
               param->Output()->dims()[1] >= 16 &&
61 62
               param->Input()->dims()[1] >= 16 &&
               param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
H
hjchen2 已提交
63 64
      param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
      // transform weight
H
hjchen2 已提交
65 66 67
      param->transformed_filter_ = new framework::Tensor;
      operators::math::winograd_transform_weight<8, 3>(
          *param->Filter(), param->transformed_filter_);
H
hjchen2 已提交
68
#endif
H
hjchen2 已提交
69 70 71 72
    } else {
      param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
    }
  }
L
liuruilong 已提交
73 74 75
  return true;
}

朔-望's avatar
朔-望 已提交
76
template <>
L
liuruilong 已提交
77
void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
H
hjchen2 已提交
78 79 80
  switch (param.ExecMode()) {
    case ConvParam<CPU>::EXEC_GEMM_INT8:
      GemmConv<int8_t, int32_t>(param);
H
hjchen2 已提交
81 82 83
      break;
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
      DepthwiseConv3x3<int8_t, int32_t>(param);
H
hjchen2 已提交
84 85 86 87 88
      break;
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
      math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
                                 nullptr, false);
      break;
89 90 91 92 93 94 95
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
      math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
                                   param.Output(), nullptr, false);
      break;
    case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
      math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
                                 nullptr, false);
H
hjchen2 已提交
96
      break;
97 98 99 100
    case ConvParam<CPU>::EXEC_DEPTHWISE5x5S1_FLOAT:
      math::DepthwiseConv5x5S1<float, float>(*param.Input(), *param.Filter(),
                                             param.Paddings(), param.Output());
      break;
H
hjchen2 已提交
101 102 103 104 105 106 107 108 109 110
    case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
      WinogradConv3x3<8, 3>(param);
      break;
    case ConvParam<CPU>::EXEC_GEMM_FLOAT:
      GemmConv<float, float>(param);
      break;
    default:
      PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
                                    param.ExecMode());
  }
朔-望's avatar
朔-望 已提交
111 112
}

113
template class ConvKernel<CPU, float>;
朔-望's avatar
朔-望 已提交
114

朔-望's avatar
朔-望 已提交
115 116
}  // namespace operators
}  // namespace paddle_mobile
L
liuruilong 已提交
117 118

#endif