conv_compute.h 5.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <Eigen/Core>
#include <string>
#include <vector>
19 20 21
#include "lite/backends/x86/math/blas.h"
#include "lite/backends/x86/math/im2col.h"
#include "lite/backends/x86/math/vol2col.h"
Y
Yan Chunwei 已提交
22 23 24
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
25
#include "lite/fluid/eigen.h"
Y
Yan Chunwei 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
#include "lite/operators/conv_op.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

inline bool IsExpand(const std::vector<int64_t>& filter_dim,
                     const std::vector<int>& strides,
                     const std::vector<int>& paddings,
                     const std::vector<int>& dilations) {
  bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
  for (size_t j = 0; j < strides.size(); ++j) {
    filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
    strides_1 = strides_1 && (strides[j] == 1);
    padding_0 = padding_0 && (paddings[j] == 0);
    dilation_1 = dilation_1 && (dilations[j] == 1);
  }
  return !(filter_1 && strides_1 && padding_0 && dilation_1);
}

template <typename T>
class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::ConvParam;
  void Run() override {
52
    auto& context = ctx_->As<X86Context>();
Y
Yan Chunwei 已提交
53 54
    auto& param = *param_.get_mutable<operators::ConvParam>();
    lite::Tensor filter = *param.filter;
H
huzhiqiang 已提交
55
    param.output->template mutable_data<T>();
Y
Yan Chunwei 已提交
56 57 58 59 60 61 62 63 64 65 66 67
    const int batch_size = static_cast<int>(param.x->dims()[0]);

    std::vector<int64_t> filter_shape_vec(filter.dims().Vectorize());
    std::vector<int64_t> output_shape_vec(param.output->dims().Vectorize());
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
    col_shape_vec[0] = param.x->dims()[1] / param.groups;
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
    }
    lite::DDim col_shape(col_shape_vec);
L
liu zhengxi 已提交
68
    lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim + 1);
Y
Yan Chunwei 已提交
69
    bool is_expand = IsExpand(
H
HappyAngel 已提交
70
        filter_shape_vec, param.strides, *param.paddings, *param.dilations);
Y
Yan Chunwei 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    lite::Tensor col;
    lite::Tensor col_matrix;
    if (is_expand) {
      col.Resize(col_shape);
      col.mutable_data<T>();
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }
    lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size());
    lite::DDim filter_matrix_shape(std::vector<int64_t>{
        filter.dims()[0], filter.dims().production() / filter.dims()[0]});
    filter.Resize(filter_matrix_shape);
    lite::DDim output_matrix_shape(std::vector<int64_t>{
        param.output->dims()[1],
        param.output->dims().production() /
            (param.output->dims()[0] * param.output->dims()[1])});
    int in_step = static_cast<int>(param.x->dims()[1]) / param.groups;
    int out_step = static_cast<int>(param.output->dims()[1]) / param.groups;
89 90 91 92
    paddle::lite::x86::math::Vol2ColFunctor<lite::TargetType::kX86, T> vol2col;
    paddle::lite::x86::math::Im2ColFunctor<
        paddle::lite::x86::math::ColFormat::kCFO,
        lite::TargetType::kX86,
Y
Yan Chunwei 已提交
93 94
        T>
        im2col;
95 96
    auto blas =
        paddle::lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
Y
Yan Chunwei 已提交
97
    for (int i = 0; i < batch_size; i++) {
H
huzhiqiang 已提交
98
      lite::Tensor in_batch = param.x->template Slice<T>(i, i + 1);
99
      in_batch.Resize(input_shape);
H
huzhiqiang 已提交
100
      lite::Tensor out_batch = param.output->template Slice<T>(i, i + 1);
101
      out_batch.Resize(output_matrix_shape);
Y
Yan Chunwei 已提交
102
      for (int g = 0; g < param.groups; g++) {
103
        lite::Tensor in_slice =
104
            in_batch.Slice<T>(static_cast<int64_t>(g * in_step),
105
                              static_cast<int64_t>((g + 1) * in_step));
H
HappyAngel 已提交
106
        auto paddings = *param.paddings;
Y
Yan Chunwei 已提交
107 108 109 110 111 112
        if (!is_expand) {
          col.ShareDataWith(in_slice);
          col_matrix.ShareDataWith(col);
          col_matrix.Resize(col_matrix_shape);
        } else if (data_dim == 2U) {
          // im2col
113 114
          im2col(context,
                 in_slice,
H
HappyAngel 已提交
115
                 *param.dilations,
Y
Yan Chunwei 已提交
116
                 param.strides,
H
HappyAngel 已提交
117 118
                 std::vector<int>{
                     paddings[0], paddings[2], paddings[0], paddings[2]},
119
                 &(col));
Y
Yan Chunwei 已提交
120 121
        } else if (data_dim == 3U) {
          // vol2col
122 123
          vol2col(context,
                  in_slice,
H
HappyAngel 已提交
124
                  *param.dilations,
Y
Yan Chunwei 已提交
125
                  param.strides,
H
HappyAngel 已提交
126
                  *param.paddings,
127
                  &(col));
Y
Yan Chunwei 已提交
128 129 130 131
        }

        // gemm
        lite::Tensor out_slice;
132
        out_slice =
133
            out_batch.Slice<T>(static_cast<int64_t>(g * out_step),
134
                               static_cast<int64_t>((g + 1) * out_step));
Y
Yan Chunwei 已提交
135
        lite::Tensor filter_slice;
136
        filter_slice =
137
            filter.Slice<T>(static_cast<int64_t>(g * out_step),
138
                            static_cast<int64_t>((g + 1) * out_step));
139
        blas.MatMul(filter_slice,
Y
Yan Chunwei 已提交
140
                    false,
141
                    col_matrix,
Y
Yan Chunwei 已提交
142 143
                    false,
                    T(1.0),
144
                    &(out_slice),
Y
Yan Chunwei 已提交
145 146 147 148 149 150 151 152 153 154 155 156
                    T(0.0));
      }
    }
  }

  virtual ~Conv2dCompute() = default;
};

}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle