fc_compute.h 4.9 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14

Y
Yan Chunwei 已提交
15 16
#pragma once

17 18 19 20 21
#include <vector>
#include "lite/backends/x86/jit/helper.h"
#include "lite/backends/x86/jit/kernel_base.h"
#include "lite/backends/x86/jit/kernels.h"
#include "lite/backends/x86/math/blas.h"
22
#include "lite/backends/x86/parallel.h"
Y
Yan Chunwei 已提交
23 24 25 26 27 28 29 30 31 32 33
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/fc_op.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
template <lite::TargetType Target, typename T>
class FCFunctor {
 public:
  void operator()(const lite::X86Context& context,
                  const int M,
                  const int N,
                  const int K,
                  const T* X,
                  const T* W,
                  T* Y,
                  const T* B = nullptr,
                  bool relu = false,
                  bool padding_weights = false) {
    auto blas = lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
    T* Y1_data = nullptr;
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

    auto compute =
        relu
            ? jit::KernelFuncs<jit::VAddReluTuple<T>, fluid::CPUPlace>::Cache()
                  .At(N)
            : jit::KernelFuncs<jit::VAddTuple<T>, fluid::CPUPlace>::Cache().At(
                  N);
    auto parallel_compute = [&](int64_t begin, int64_t end) {
      for (int64_t i = begin; i < end; i++) {
        T* dst = Y + i * N;
        T* src = Y1_data ? Y1_data + i * (N + 4) : dst;
        compute(B, src, dst, N);
      }
    };

    // Because of the overhead of memcpy, we only do padding for GEMM
    //  when weights is already padded in fc_fuse_pass.
    if (padding_weights) {
67 68
      const int NN = N + 4;
      const int KK = K + 4;
69 70 71

      // NOTE: here need to mutable_data for temporary Tensor X1 and Y1,
      //  the overhead is unmeasured.
72
      lite::Tensor X1;
73
      X1.Resize(std::vector<int64_t>{M * KK});
74
      T* X1_data = X1.mutable_data<T>();
75 76

      lite::Tensor Y1;
77
      Y1.Resize(std::vector<int64_t>{M * NN});
78
      Y1_data = Y1.mutable_data<T>();
79 80 81 82

      auto parallel_memcpy_x = [&](int64_t begin, int64_t end) {
        for (int64_t i = begin; i < end; i++) {
          memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
83
        }
84 85 86
      };
      lite::x86::RunParallelFor(0, M, parallel_memcpy_x);

87 88 89 90 91 92 93 94
      blas.GEMM(false,
                false,
                M,
                N,
                K,
                static_cast<T>(1.0),
                X1_data,
                KK,
95
                W,
96 97 98 99
                NN,
                static_cast<T>(0.0),
                Y1_data,
                NN);
100 101 102 103

      if (!B) {
        auto parallel_memcpy_y = [&](int64_t begin, int64_t end) {
          for (int64_t i = begin; i < end; i++) {
104
            memcpy(Y + i * N, Y1_data + i * NN, N * sizeof(T));
105 106 107 108
          }
        };
        lite::x86::RunParallelFor(0, M, parallel_memcpy_y);
        return;
109
      }
110 111

      lite::x86::RunParallelFor(0, M, parallel_compute);
112
    } else {
113 114 115
      blas.MatMul(M, N, K, X, W, Y);
      if (!B) {
        return;
Y
Yan Chunwei 已提交
116
      }
117 118

      lite::x86::RunParallelFor(0, M, parallel_compute);
Y
Yan Chunwei 已提交
119 120
    }
  }
121
};
Y
Yan Chunwei 已提交
122 123 124 125 126 127 128 129

template <typename T>
class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::FcParam;

  void Run() override {
    auto& param = *param_.get_mutable<param_t>();
130 131 132 133 134 135 136
    auto* input = param.input;
    auto* w = param.w;
    auto* bias = param.bias;
    auto* output = param.output;
    bool with_relu = (param.activation_type == "relu") ? true : false;

    bool padding_weights = param.padding_weights;
137
    const auto& w_dims = w->dims();
138 139
    auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
    auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
140 141

    int M = output->dims().production() / w_dims1;
142

H
huzhiqiang 已提交
143 144 145
    const T* input_data = input->template data<T>();
    const T* w_data = w->template data<T>();
    T* output_data = output->template mutable_data<T>();
Y
Yan Chunwei 已提交
146

147 148 149 150 151 152 153 154 155
    auto& context = ctx_->As<X86Context>();
    FCFunctor<lite::TargetType::kX86, T> fc;
    fc(context,
       M,
       w_dims1,
       w_dims0,
       input_data,
       w_data,
       output_data,
H
huzhiqiang 已提交
156
       bias ? bias->template data<T>() : NULL,
157 158
       with_relu,
       padding_weights);
Y
Yan Chunwei 已提交
159 160 161 162 163 164 165 166 167
  }

  virtual ~FcCompute() = default;
};

}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle