fc_compute.h 5.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14

Y
Yan Chunwei 已提交
15 16
#pragma once

17 18 19 20 21
#include <vector>
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/backends/x86/math/blas.h"
22
#include "lite/backends/x86/parallel.h"
Y
Yan Chunwei 已提交
23 24 25 26 27 28 29 30 31 32 33
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/fc_op.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

34 35 36 37 38 39
inline void FCOutputSize(const lite::DDim& in_dims,
                         const lite::DDim& w_dims,
                         std::vector<int64_t>& out_dims,  // NOLINT
                         int in_num_col_dims,
                         bool padding_weights) {
  auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
Y
Yan Chunwei 已提交
40

41 42 43
  out_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
  for (int i = 0; i < in_num_col_dims; ++i) {
    out_dims.push_back(in_dims[i]);
Y
Yan Chunwei 已提交
44
  }
45
  out_dims.push_back(w_dims1);
Y
Yan Chunwei 已提交
46 47
}

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
template <lite::TargetType Target, typename T>
class FCFunctor {
 public:
  void operator()(const lite::X86Context& context,
                  const int M,
                  const int N,
                  const int K,
                  const T* X,
                  const T* W,
                  T* Y,
                  const T* B = nullptr,
                  bool relu = false,
                  bool padding_weights = false) {
    auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
    T* Y1_data = nullptr;
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

    auto compute =
        relu
            ? jit::KernelFuncs<jit::VAddReluTuple<T>, fluid::CPUPlace>::Cache()
                  .At(N)
            : jit::KernelFuncs<jit::VAddTuple<T>, fluid::CPUPlace>::Cache().At(
                  N);
    auto parallel_compute = [&](int64_t begin, int64_t end) {
      for (int64_t i = begin; i < end; i++) {
        T* dst = Y + i * N;
        T* src = Y1_data ? Y1_data + i * (N + 4) : dst;
        compute(B, src, dst, N);
      }
    };

    // Because of the overhead of memcpy, we only do padding for GEMM
    //  when weights is already padded in fc_fuse_pass.
    if (padding_weights) {
81 82
      const int NN = N + 4;
      const int KK = K + 4;
83 84 85

      // NOTE: here need to mutable_data for temporary Tensor X1 and Y1,
      //  the overhead is unmeasured.
86 87
      Tensor X1;
      X1.Resize(std::vector<int64_t>({M * KK}));
88
      T* X1_data = X1.mutable_data<T>();
89

90 91
      Tensor Y1;
      Y1.Resize(std::vector<int64_t>({M * NN}));
92
      Y1_data = Y1.mutable_data<T>();
93 94 95 96

      auto parallel_memcpy_x = [&](int64_t begin, int64_t end) {
        for (int64_t i = begin; i < end; i++) {
          memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
97
        }
98 99 100
      };
      lite::x86::RunParallelFor(0, M, parallel_memcpy_x);

101 102 103 104 105 106 107 108
      blas.GEMM(false,
                false,
                M,
                N,
                K,
                static_cast<T>(1.0),
                X1_data,
                KK,
109
                W,
110 111 112 113
                NN,
                static_cast<T>(0.0),
                Y1_data,
                NN);
114 115 116 117

      if (!B) {
        auto parallel_memcpy_y = [&](int64_t begin, int64_t end) {
          for (int64_t i = begin; i < end; i++) {
118
            memcpy(Y + i * N, Y1_data + i * NN, N * sizeof(T));
119 120 121 122
          }
        };
        lite::x86::RunParallelFor(0, M, parallel_memcpy_y);
        return;
123
      }
124 125

      lite::x86::RunParallelFor(0, M, parallel_compute);
126
    } else {
127 128 129
      blas.MatMul(M, N, K, X, W, Y);
      if (!B) {
        return;
Y
Yan Chunwei 已提交
130
      }
131 132

      lite::x86::RunParallelFor(0, M, parallel_compute);
Y
Yan Chunwei 已提交
133 134
    }
  }
135
};
Y
Yan Chunwei 已提交
136 137 138 139 140 141 142 143

template <typename T>
class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::FcParam;

  void Run() override {
    auto& param = *param_.get_mutable<param_t>();
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    auto* input = param.input;
    auto* w = param.w;
    auto* bias = param.bias;
    auto* output = param.output;
    int in_num_col_dims = param.in_num_col_dims;
    bool with_relu = (param.activation_type == "relu") ? true : false;

    auto w_dims = w->dims();
    bool padding_weights = param.padding_weights;

    std::vector<int64_t> output_dims;
    FCOutputSize(
        input->dims(), w_dims, output_dims, in_num_col_dims, padding_weights);
    output->Resize(output_dims);
    output->set_lod(input->lod());

    auto out_dims = output->dims();
    auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
    auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
    int M = out_dims.production() / w_dims1;

    const T* input_data = input->data<T>();
    const T* w_data = w->data<T>();
    T* output_data = output->mutable_data<T>();
Y
Yan Chunwei 已提交
168

169 170 171 172 173 174 175 176 177 178 179 180
    auto& context = ctx_->As<X86Context>();
    FCFunctor<lite::TargetType::kX86, T> fc;
    fc(context,
       M,
       w_dims1,
       w_dims0,
       input_data,
       w_data,
       output_data,
       bias ? bias->data<T>() : NULL,
       with_relu,
       padding_weights);
Y
Yan Chunwei 已提交
181 182 183 184 185 186 187 188 189
  }

  virtual ~FcCompute() = default;
};

}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle